├── matcha ├── __init__.py ├── data │ ├── __init__.py │ ├── components │ │ └── __init__.py │ └── text_mel_datamodule.py ├── onnx │ ├── __init__.py │ ├── export.py │ └── infer.py ├── VERSION ├── hifigan │ ├── __init__.py │ ├── env.py │ ├── config.py │ ├── LICENSE │ ├── xutils.py │ ├── denoiser.py │ ├── README.md │ ├── meldataset.py │ └── models.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── test.py │ │ ├── vits_posterior.py │ │ ├── flow_matching.py │ │ ├── commons.py │ │ ├── vits_modules.py │ │ └── transformer.py │ ├── baselightningmodule.py │ └── matcha_tts.py ├── utils │ ├── monotonic_align │ │ ├── setup.py │ │ ├── __init__.py │ │ └── core.pyx │ ├── monotonic_align_vits │ │ ├── setup.py │ │ ├── __init__.py │ │ └── core.pyx │ ├── __init__.py │ ├── pylogger.py │ ├── logging_utils.py │ ├── instantiators.py │ ├── audio.py │ ├── model.py │ ├── rich_utils.py │ ├── generate_data_statistics.py │ └── utils.py ├── text │ ├── symbols.py │ ├── __init__.py │ ├── numbers.py │ └── cleaners.py └── train.py ├── notebooks └── .gitkeep ├── configs ├── local │ └── .gitkeep ├── callbacks │ ├── none.yaml │ ├── default.yaml │ ├── rich_progress_bar.yaml │ ├── model_summary.yaml │ └── model_checkpoint.yaml ├── model │ ├── cfm │ │ └── default.yaml │ ├── optimizer │ │ └── adam.yaml │ ├── decoder │ │ └── default.yaml │ ├── matcha.yaml │ └── encoder │ │ └── default.yaml ├── trainer │ ├── cpu.yaml │ ├── gpu.yaml │ ├── mps.yaml │ ├── ddp_sim.yaml │ ├── ddp.yaml │ └── default.yaml ├── __init__.py ├── debug │ ├── fdr.yaml │ ├── overfit.yaml │ ├── limit.yaml │ ├── profiler.yaml │ └── default.yaml ├── logger │ ├── many_loggers.yaml │ ├── csv.yaml │ ├── tensorboard.yaml │ ├── neptune.yaml │ ├── mlflow.yaml │ ├── comet.yaml │ ├── wandb.yaml │ └── aim.yaml ├── extras │ └── default.yaml ├── experiment │ ├── ljspeech.yaml │ ├── multispeaker.yaml │ └── ljspeech_min_memory.yaml ├── eval.yaml ├── data │ ├── vctk.yaml │ └── ljspeech.yaml ├── hydra │ └── default.yaml ├── paths │ └── default.yaml ├── train.yaml └── hparams_search │ └── mnist_optuna.yaml ├── data ├── image.png ├── .project-root ├── scripts └── schedule.sh ├── .env.example ├── MANIFEST.in ├── .github ├── codecov.yml ├── dependabot.yml ├── PULL_REQUEST_TEMPLATE.md └── release-drafter.yml ├── requirements.txt ├── LICENSE ├── pyproject.toml ├── Makefile ├── setup.py ├── .pre-commit-config.yaml ├── .gitignore └── README.md /matcha/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.4 2 | -------------------------------------------------------------------------------- /matcha/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/data/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matcha/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /home/smehta/Projects/Speech-Backbones/Grad-TTS/data -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/Matcha-TTS-2/HEAD/image.png -------------------------------------------------------------------------------- /configs/model/cfm/default.yaml: -------------------------------------------------------------------------------- 1 | name: CFM 2 | solver: euler 3 | sigma_min: 1e-4 4 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | channels: [256, 256] 2 | dropout: 0.05 3 | attention_head_dim: 64 4 | n_blocks: 1 5 | num_mid_blocks: 2 6 | num_heads: 2 7 | act_fn: snakebeta 8 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /matcha/models/components/test.py: -------------------------------------------------------------------------------- 1 | from matcha.hifigan.meldataset import mel_spectrogram 2 | import torch 3 | 4 | audio = torch.randn(2,1, 1000) 5 | mels = mel_spectrogram(audio, 1024, 80, 22050, 256, 1024, 0, 8000, center=False) 6 | print(mels.shape) -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align_vits/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name='monotonic_align', 7 | ext_modules=cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /matcha/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from matcha.utils.logging_utils import log_hyperparameters 3 | from matcha.utils.pylogger import get_pylogger 4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree 5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /configs/model/matcha.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - decoder: default.yaml 5 | - cfm: default.yaml 6 | - optimizer: adam.yaml 7 | 8 | _target_: matcha.models.matcha_tts.MatchaTTS 9 | n_vocab: 178 10 | n_spks: ${data.n_spks} 11 | spk_emb_dim: 64 12 | n_feats: 80 13 | data_statistics: ${data.data_statistics} 14 | out_size: null # Must be divisible by 4 15 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include *.cff 5 | include requirements.txt 6 | include matcha/VERSION 7 | recursive-include matcha *.json 8 | recursive-include matcha *.html 9 | recursive-include matcha *.png 10 | recursive-include matcha *.md 11 | recursive-include matcha *.py 12 | recursive-include matcha *.pyx 13 | recursive-exclude tests * 14 | prune tests* 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # measures overall project coverage 4 | project: 5 | default: 6 | threshold: 100% # how much decrease in coverage is needed to not consider success 7 | 8 | # measures PR or single commit coverage 9 | patch: 10 | default: 11 | threshold: 100% # how much decrease in coverage is needed to not consider success 12 | 13 | 14 | # project: off 15 | # patch: off 16 | -------------------------------------------------------------------------------- /configs/data/vctk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 6 | name: vctk 7 | train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt 8 | valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt 9 | batch_size: 32 10 | add_blank: True 11 | n_spks: 109 12 | data_statistics: # Computed for vctk dataset 13 | mel_mean: -6.630575 14 | mel_std: 2.482914 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /matcha/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: RoPE Encoder 2 | encoder_params: 3 | n_feats: ${model.n_feats} 4 | n_channels: 192 5 | filter_channels: 768 6 | filter_channels_dp: 256 7 | n_heads: 2 8 | n_layers: 6 9 | kernel_size: 3 10 | p_dropout: 0.1 11 | spk_emb_dim: 64 12 | n_spks: 1 13 | prenet: true 14 | 15 | duration_predictor_params: 16 | filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} 17 | kernel_size: 3 18 | p_dropout: ${model.encoder.encoder_params.p_dropout} 19 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /matcha/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/data/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 2 | name: ljspeech 3 | train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt 4 | valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt 5 | batch_size: 32 6 | num_workers: 20 7 | pin_memory: True 8 | cleaners: [english_cleaners2] 9 | add_blank: True 10 | n_spks: 1 11 | n_fft: 1024 12 | n_feats: 80 13 | sample_rate: 22050 14 | hop_length: 256 15 | win_length: 1024 16 | f_min: 0 17 | f_max: 8000 18 | data_statistics: # Computed for ljspeech dataset 19 | mel_mean: -5.536622 20 | mel_std: 2.116101 21 | seed: ${seed} 22 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align_vits/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """ Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matcha.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | target-branch: "dev" 11 | schedule: 12 | interval: "daily" 13 | ignore: 14 | - dependency-name: "pytorch-lightning" 15 | update-types: ["version-update:semver-patch"] 16 | - dependency-name: "torchmetrics" 17 | update-types: ["version-update:semver-patch"] 18 | -------------------------------------------------------------------------------- /matcha/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /matcha/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | 4 | categories: 5 | - title: "🚀 Features" 6 | labels: 7 | - "feature" 8 | - "enhancement" 9 | - title: "🐛 Bug Fixes" 10 | labels: 11 | - "fix" 12 | - "bugfix" 13 | - "bug" 14 | - title: "🧹 Maintenance" 15 | labels: 16 | - "maintenance" 17 | - "dependencies" 18 | - "refactoring" 19 | - "cosmetic" 20 | - "chore" 21 | - title: "📝️ Documentation" 22 | labels: 23 | - "documentation" 24 | - "docs" 25 | 26 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 27 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 28 | 29 | version-resolver: 30 | major: 31 | labels: 32 | - "major" 33 | minor: 34 | labels: 35 | - "minor" 36 | patch: 37 | labels: 38 | - "patch" 39 | default: patch 40 | 41 | template: | 42 | ## Changes 43 | 44 | $CHANGES 45 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | phonemizer # phonemization of text 26 | tensorboard 27 | librosa 28 | Cython 29 | numpy 30 | einops 31 | inflect 32 | Unidecode 33 | scipy 34 | torchaudio 35 | matplotlib 36 | pandas 37 | conformer==0.3.2 38 | diffusers==0.21.3 39 | notebook 40 | ipywidgets 41 | gradio 42 | gdown 43 | wget 44 | seaborn 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 p0p 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /matcha/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ['py310'] 7 | exclude = ''' 8 | 9 | ( 10 | /( 11 | \.eggs # exclude a few common directories in the 12 | | \.git # root of the project 13 | | \.hg 14 | | \.mypy_cache 15 | | \.tox 16 | | \.venv 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | | foo.py # also separately exclude a file named foo.py in 23 | # the root of the project 24 | ) 25 | ''' 26 | 27 | [tool.pytest.ini_options] 28 | addopts = [ 29 | "--color=yes", 30 | "--durations=0", 31 | "--strict-markers", 32 | "--doctest-modules", 33 | ] 34 | filterwarnings = [ 35 | "ignore::DeprecationWarning", 36 | "ignore::UserWarning", 37 | ] 38 | log_cli = "True" 39 | markers = [ 40 | "slow: slow tests", 41 | ] 42 | minversion = "6.0" 43 | testpaths = "tests/" 44 | 45 | [tool.coverage.report] 46 | exclude_lines = [ 47 | "pragma: nocover", 48 | "raise NotImplementedError", 49 | "raise NotImplementedError()", 50 | "if __name__ == .__main__.:", 51 | ] 52 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | create-package: ## Create wheel and tar gz 17 | rm -rf dist/ 18 | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 19 | python setup.py sdist 20 | python -m twine upload dist/* --verbose --skip-existing 21 | 22 | format: ## Run pre-commit hooks 23 | pre-commit run -a 24 | 25 | sync: ## Merge changes from main branch to your current branch 26 | git pull 27 | git pull origin main 28 | 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | test-full: ## Run all tests 33 | pytest 34 | 35 | train-ljspeech: ## Train the model 36 | python matcha/train.py experiment=ljspeech 37 | 38 | train-ljspeech-min: ## Train the model with minimum memory 39 | python matcha/train.py experiment=ljspeech_min_memory 40 | 41 | start_app: ## Start the app 42 | python matcha/app.py 43 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align_vits/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /matcha/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | from Cython.Build import cythonize 6 | from setuptools import Extension, find_packages, setup 7 | 8 | exts = [ 9 | Extension( 10 | name="matcha.utils.monotonic_align.core", 11 | sources=["matcha/utils/monotonic_align/core.pyx"], 12 | ) 13 | ] 14 | 15 | with open("README.md", encoding="utf-8") as readme_file: 16 | README = readme_file.read() 17 | 18 | cwd = os.path.dirname(os.path.abspath(__file__)) 19 | with open(os.path.join(cwd, "matcha", "VERSION")) as fin: 20 | version = fin.read().strip() 21 | 22 | setup( 23 | name="matcha-tts", 24 | version=version, 25 | description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", 26 | long_description=README, 27 | long_description_content_type="text/markdown", 28 | author="Shivam Mehta", 29 | author_email="shivam.mehta25@gmail.com", 30 | url="https://shivammehta25.github.io/Matcha-TTS", 31 | install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], 32 | include_dirs=[numpy.get_include()], 33 | include_package_data=True, 34 | packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), 35 | # use this to customize global commands available in the terminal after installing the package 36 | entry_points={ 37 | "console_scripts": [ 38 | "matcha-data-stats=matcha.utils.generate_data_statistics:main", 39 | "matcha-tts=matcha.cli:cli", 40 | "matcha-tts-app=matcha.app:main", 41 | ] 42 | }, 43 | ext_modules=cythonize(exts, language_level=3), 44 | python_requires=">=3.9.0", 45 | ) 46 | -------------------------------------------------------------------------------- /matcha/models/components/vits_posterior.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | import matcha.models.components.vits_modules as modules 5 | import matcha.models.components.commons as commons 6 | 7 | class PosteriorEncoder(nn.Module): 8 | 9 | def __init__(self, 10 | in_channels, 11 | out_channels, 12 | hidden_channels, 13 | kernel_size, 14 | dilation_rate, 15 | n_layers, 16 | gin_channels=0): 17 | super().__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.hidden_channels = hidden_channels 21 | self.kernel_size = kernel_size 22 | self.dilation_rate = dilation_rate 23 | self.n_layers = n_layers 24 | self.gin_channels = gin_channels 25 | 26 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 27 | self.enc = modules.WN(hidden_channels, 28 | kernel_size, 29 | dilation_rate, 30 | n_layers, 31 | gin_channels=gin_channels) 32 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 33 | 34 | def forward(self, x, x_lengths, g=None): 35 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 36 | 1).to(x.dtype) 37 | x = self.pre(x) * x_mask 38 | x = self.enc(x, x_mask, g=g) 39 | stats = self.proj(x) * x_mask 40 | # m, logs = torch.split(stats, self.out_channels, dim=1) 41 | # z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 42 | # z = m * x_mask 43 | return stats, x_mask 44 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: matcha 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 1234 52 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | # - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | 19 | # python code formatting 20 | - repo: https://github.com/psf/black 21 | rev: 23.9.1 22 | hooks: 23 | - id: black 24 | args: [--line-length, "120"] 25 | 26 | # python import sorting 27 | - repo: https://github.com/PyCQA/isort 28 | rev: 5.12.0 29 | hooks: 30 | - id: isort 31 | args: ["--profile", "black", "--filter-files"] 32 | 33 | # python upgrading syntax to newer version 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.14.0 36 | hooks: 37 | - id: pyupgrade 38 | args: [--py38-plus] 39 | 40 | # python check (PEP8), programming errors and code complexity 41 | - repo: https://github.com/PyCQA/flake8 42 | rev: 6.1.0 43 | hooks: 44 | - id: flake8 45 | args: 46 | [ 47 | "--max-line-length", "120", 48 | "--extend-ignore", 49 | "E203,E402,E501,F401,F841,RST2,RST301", 50 | "--exclude", 51 | "logs/*,data/*,matcha/hifigan/*", 52 | ] 53 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 54 | 55 | # pylint 56 | - repo: https://github.com/pycqa/pylint 57 | rev: v3.0.0 58 | hooks: 59 | - id: pylint 60 | -------------------------------------------------------------------------------- /matcha/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from matcha.text import cleaners 3 | from matcha.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /matcha/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from matcha.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /matcha/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from matcha.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /matcha/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /matcha/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /matcha/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | # Cython complied files 157 | matcha/utils/monotonic_align/core.c 158 | 159 | # Ignoring hifigan checkpoint 160 | generator_v1 161 | g_02500000 162 | gradio_cached_examples/ 163 | synth_output/ 164 | -------------------------------------------------------------------------------- /matcha/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths): 45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 46 | return loss 47 | 48 | 49 | def normalize(data, mu, std): 50 | if not isinstance(mu, (float, int)): 51 | if isinstance(mu, list): 52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 53 | elif isinstance(mu, torch.Tensor): 54 | mu = mu.to(data.device) 55 | elif isinstance(mu, np.ndarray): 56 | mu = torch.from_numpy(mu).to(data.device) 57 | mu = mu.unsqueeze(-1) 58 | 59 | if not isinstance(std, (float, int)): 60 | if isinstance(std, list): 61 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 62 | elif isinstance(std, torch.Tensor): 63 | std = std.to(data.device) 64 | elif isinstance(std, np.ndarray): 65 | std = torch.from_numpy(std).to(data.device) 66 | std = std.unsqueeze(-1) 67 | 68 | return (data - mu) / std 69 | 70 | 71 | def denormalize(data, mu, std): 72 | if not isinstance(mu, float): 73 | if isinstance(mu, list): 74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 75 | elif isinstance(mu, torch.Tensor): 76 | mu = mu.to(data.device) 77 | elif isinstance(mu, np.ndarray): 78 | mu = torch.from_numpy(mu).to(data.device) 79 | mu = mu.unsqueeze(-1) 80 | 81 | if not isinstance(std, float): 82 | if isinstance(std, list): 83 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 84 | elif isinstance(std, torch.Tensor): 85 | std = std.to(data.device) 86 | elif isinstance(std, np.ndarray): 87 | std = torch.from_numpy(std).to(data.device) 88 | std = std.unsqueeze(-1) 89 | 90 | return data * std + mu 91 | -------------------------------------------------------------------------------- /matcha/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | from unidecode import unidecode 19 | 20 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 21 | critical_logger = logging.getLogger("phonemizer") 22 | critical_logger.setLevel(logging.CRITICAL) 23 | 24 | # Intializing the phonemizer globally significantly reduces the speed 25 | # now the phonemizer is not initialising at every call 26 | # Might be less flexible, but it is much-much faster 27 | global_phonemizer = phonemizer.backend.EspeakBackend( 28 | language="en-us", 29 | preserve_punctuation=True, 30 | with_stress=True, 31 | language_switch="remove-flags", 32 | logger=critical_logger, 33 | ) 34 | 35 | 36 | # Regular expression matching whitespace: 37 | _whitespace_re = re.compile(r"\s+") 38 | 39 | # List of (regular expression, replacement) pairs for abbreviations: 40 | _abbreviations = [ 41 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 42 | for x in [ 43 | ("mrs", "misess"), 44 | ("mr", "mister"), 45 | ("dr", "doctor"), 46 | ("st", "saint"), 47 | ("co", "company"), 48 | ("jr", "junior"), 49 | ("maj", "major"), 50 | ("gen", "general"), 51 | ("drs", "doctors"), 52 | ("rev", "reverend"), 53 | ("lt", "lieutenant"), 54 | ("hon", "honorable"), 55 | ("sgt", "sergeant"), 56 | ("capt", "captain"), 57 | ("esq", "esquire"), 58 | ("ltd", "limited"), 59 | ("col", "colonel"), 60 | ("ft", "fort"), 61 | ] 62 | ] 63 | 64 | 65 | def expand_abbreviations(text): 66 | for regex, replacement in _abbreviations: 67 | text = re.sub(regex, replacement, text) 68 | return text 69 | 70 | 71 | def lowercase(text): 72 | return text.lower() 73 | 74 | 75 | def collapse_whitespace(text): 76 | return re.sub(_whitespace_re, " ", text) 77 | 78 | 79 | def convert_to_ascii(text): 80 | return unidecode(text) 81 | 82 | 83 | def basic_cleaners(text): 84 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 85 | text = lowercase(text) 86 | text = collapse_whitespace(text) 87 | return text 88 | 89 | 90 | def transliteration_cleaners(text): 91 | """Pipeline for non-English text that transliterates to ASCII.""" 92 | text = convert_to_ascii(text) 93 | text = lowercase(text) 94 | text = collapse_whitespace(text) 95 | return text 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 104 | phonemes = collapse_whitespace(phonemes) 105 | return phonemes 106 | -------------------------------------------------------------------------------- /matcha/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from matcha.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /matcha/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import rootutils 14 | import torch 15 | from hydra import compose, initialize 16 | from omegaconf import open_dict 17 | from tqdm.auto import tqdm 18 | 19 | from matcha.data.text_mel_datamodule import TextMelDataModule 20 | from matcha.utils.logging_utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 26 | """Generate data mean and standard deviation helpful in data normalisation 27 | 28 | Args: 29 | data_loader (torch.utils.data.Dataloader): _description_ 30 | out_channels (int): mel spectrogram channels 31 | """ 32 | total_mel_sum = 0 33 | total_mel_sq_sum = 0 34 | total_mel_len = 0 35 | 36 | for batch in tqdm(data_loader, leave=False): 37 | mels = batch["y"] 38 | mel_lengths = batch["y_lengths"] 39 | 40 | total_mel_len += torch.sum(mel_lengths) 41 | total_mel_sum += torch.sum(mels) 42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 43 | 44 | data_mean = total_mel_sum / (total_mel_len * out_channels) 45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 46 | 47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument( 54 | "-i", 55 | "--input-config", 56 | type=str, 57 | default="vctk.yaml", 58 | help="The name of the yaml config file under configs/data", 59 | ) 60 | 61 | parser.add_argument( 62 | "-b", 63 | "--batch-size", 64 | type=int, 65 | default="256", 66 | help="Can have increased batch size for faster computation", 67 | ) 68 | 69 | parser.add_argument( 70 | "-f", 71 | "--force", 72 | action="store_true", 73 | default=False, 74 | required=False, 75 | help="force overwrite the file", 76 | ) 77 | args = parser.parse_args() 78 | output_file = Path(args.input_config).with_suffix(".json") 79 | 80 | if os.path.exists(output_file) and not args.force: 81 | print("File already exists. Use -f to force overwrite") 82 | sys.exit(1) 83 | 84 | with initialize(version_base="1.3", config_path="../../configs/data"): 85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 86 | 87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 88 | 89 | with open_dict(cfg): 90 | del cfg["hydra"] 91 | del cfg["_target_"] 92 | cfg["data_statistics"] = None 93 | cfg["seed"] = 1234 94 | cfg["batch_size"] = args.batch_size 95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 97 | 98 | text_mel_datamodule = TextMelDataModule(**cfg) 99 | text_mel_datamodule.setup() 100 | data_loader = text_mel_datamodule.train_dataloader() 101 | log.info("Dataloader loaded! Now computing stats...") 102 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 103 | print(params) 104 | json.dump( 105 | params, 106 | open(output_file, "w"), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /matcha/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from matcha import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /matcha/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from matcha.models.components.decoder import Decoder 7 | from matcha.utils.pylogger import get_pylogger 8 | from matcha.hifigan.meldataset import mel_spectrogram 9 | 10 | log = get_pylogger(__name__) 11 | 12 | 13 | class BASECFM(torch.nn.Module, ABC): 14 | def __init__( 15 | self, 16 | n_feats, 17 | cfm_params, 18 | n_spks=1, 19 | spk_emb_dim=128, 20 | ): 21 | super().__init__() 22 | self.n_feats = n_feats 23 | self.n_spks = n_spks 24 | self.spk_emb_dim = spk_emb_dim 25 | self.solver = cfm_params.solver 26 | if hasattr(cfm_params, "sigma_min"): 27 | self.sigma_min = cfm_params.sigma_min 28 | else: 29 | self.sigma_min = 1e-4 30 | 31 | self.estimator = None 32 | 33 | @torch.inference_mode() 34 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, training=False): 35 | """Forward diffusion 36 | 37 | Args: 38 | mu (torch.Tensor): output of encoder 39 | shape: (batch_size, n_feats, mel_timesteps) 40 | mask (torch.Tensor): output_mask 41 | shape: (batch_size, 1, mel_timesteps) 42 | n_timesteps (int): number of diffusion steps 43 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 44 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 45 | shape: (batch_size, spk_emb_dim) 46 | cond: Not used but kept for future purposes 47 | 48 | Returns: 49 | sample: generated mel-spectrogram 50 | shape: (batch_size, n_feats, mel_timesteps) 51 | """ 52 | z = torch.randn_like(mu) * temperature 53 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 54 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, training=training) 55 | 56 | def solve_euler(self, x, t_span, mu, mask, spks, cond, training=False): 57 | """ 58 | Fixed euler solver for ODEs. 59 | Args: 60 | x (torch.Tensor): random noise 61 | t_span (torch.Tensor): n_timesteps interpolated 62 | shape: (n_timesteps + 1,) 63 | mu (torch.Tensor): output of encoder 64 | shape: (batch_size, n_feats, mel_timesteps) 65 | mask (torch.Tensor): output_mask 66 | shape: (batch_size, 1, mel_timesteps) 67 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 68 | shape: (batch_size, spk_emb_dim) 69 | cond: Not used but kept for future purposes 70 | """ 71 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 72 | 73 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 74 | # Or in future might add like a return_all_steps flag 75 | sol = [] 76 | 77 | steps = 1 78 | while steps <= len(t_span) - 1: 79 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond, training=training) 80 | 81 | x = x + dt * dphi_dt 82 | t = t + dt 83 | sol.append(x) 84 | if steps < len(t_span) - 1: 85 | dt = t_span[steps + 1] - t 86 | steps += 1 87 | 88 | return sol[-1] 89 | 90 | def compute_loss(self, x1, mask, mu, spks=None, cond=None, training=True): 91 | """Computes diffusion loss 92 | 93 | Args: 94 | x1 (torch.Tensor): Target 95 | shape: (batch_size, n_feats, mel_timesteps) 96 | mask (torch.Tensor): target mask 97 | shape: (batch_size, 1, mel_timesteps) 98 | mu (torch.Tensor): output of encoder 99 | shape: (batch_size, n_feats, mel_timesteps) 100 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 101 | shape: (batch_size, spk_emb_dim) 102 | 103 | Returns: 104 | loss: conditional flow matching loss 105 | y: conditional flow 106 | shape: (batch_size, n_feats, mel_timesteps) 107 | """ 108 | b, _, t = mu.shape 109 | 110 | # random timestep 111 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 112 | # sample noise p(x_0) 113 | z = torch.randn_like(x1) 114 | 115 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 116 | u = x1 - (1 - self.sigma_min) * z 117 | # y = u * t + z 118 | 119 | estimator_out = self.estimator(y, mask, mu, t.squeeze(), spks, training=training) 120 | 121 | loss = F.mse_loss(estimator_out, u, reduction="sum") / ( 122 | torch.sum(mask) * u.shape[1] 123 | ) 124 | return loss, y 125 | 126 | 127 | class CFM(BASECFM): 128 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 129 | super().__init__( 130 | n_feats=in_channels, 131 | cfm_params=cfm_params, 132 | n_spks=n_spks, 133 | spk_emb_dim=spk_emb_dim, 134 | ) 135 | 136 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 137 | # Just change the architecture of the estimator here 138 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 139 | 140 | -------------------------------------------------------------------------------- /matcha/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /matcha/onnx/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from lightning import LightningModule 8 | 9 | from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder 10 | 11 | DEFAULT_OPSET = 15 12 | 13 | SEED = 1234 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | torch.cuda.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class MatchaWithVocoder(LightningModule): 23 | def __init__(self, matcha, vocoder): 24 | super().__init__() 25 | self.matcha = matcha 26 | self.vocoder = vocoder 27 | 28 | def forward(self, x, x_lengths, scales, spks=None): 29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) 30 | wavs = self.vocoder(mel).clamp(-1, 1) 31 | lengths = mel_lengths * 256 32 | return wavs.squeeze(1), lengths 33 | 34 | 35 | def get_exportable_module(matcha, vocoder, n_timesteps): 36 | """ 37 | Return an appropriate `LighteningModule` and output-node names 38 | based on whether the vocoder is embedded in the final graph 39 | """ 40 | 41 | def onnx_forward_func(x, x_lengths, scales, spks=None): 42 | """ 43 | Custom forward function for accepting 44 | scaler parameters as tensors 45 | """ 46 | # Extract scaler parameters from tensors 47 | temperature = scales[0] 48 | length_scale = scales[1] 49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) 50 | return output["mel"], output["mel_lengths"] 51 | 52 | # Monkey-patch Matcha's forward function 53 | matcha.forward = onnx_forward_func 54 | 55 | if vocoder is None: 56 | model, output_names = matcha, ["mel", "mel_lengths"] 57 | else: 58 | model = MatchaWithVocoder(matcha, vocoder) 59 | output_names = ["wav", "wav_lengths"] 60 | return model, output_names 61 | 62 | 63 | def get_inputs(is_multi_speaker): 64 | """ 65 | Create dummy inputs for tracing 66 | """ 67 | dummy_input_length = 50 68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 69 | x_lengths = torch.LongTensor([dummy_input_length]) 70 | 71 | # Scales 72 | temperature = 0.667 73 | length_scale = 1.0 74 | scales = torch.Tensor([temperature, length_scale]) 75 | 76 | model_inputs = [x, x_lengths, scales] 77 | input_names = [ 78 | "x", 79 | "x_lengths", 80 | "scales", 81 | ] 82 | 83 | if is_multi_speaker: 84 | spks = torch.LongTensor([1]) 85 | model_inputs.append(spks) 86 | input_names.append("spks") 87 | 88 | return tuple(model_inputs), input_names 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") 93 | 94 | parser.add_argument( 95 | "checkpoint_path", 96 | type=str, 97 | help="Path to the model checkpoint", 98 | ) 99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file") 100 | parser.add_argument( 101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 102 | ) 103 | parser.add_argument( 104 | "--vocoder-name", 105 | type=str, 106 | choices=list(VOCODER_URLS.keys()), 107 | default=None, 108 | help="Name of the vocoder to embed in the ONNX graph", 109 | ) 110 | parser.add_argument( 111 | "--vocoder-checkpoint-path", 112 | type=str, 113 | default=None, 114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 115 | ) 116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 117 | 118 | args = parser.parse_args() 119 | 120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") 121 | print(f"Setting n_timesteps to {args.n_timesteps}") 122 | 123 | checkpoint_path = Path(args.checkpoint_path) 124 | matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") 125 | 126 | if args.vocoder_name or args.vocoder_checkpoint_path: 127 | assert ( 128 | args.vocoder_name and args.vocoder_checkpoint_path 129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 131 | else: 132 | vocoder = None 133 | 134 | is_multi_speaker = matcha.n_spks > 1 135 | 136 | dummy_input, input_names = get_inputs(is_multi_speaker) 137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) 138 | 139 | # Set dynamic shape for inputs/outputs 140 | dynamic_axes = { 141 | "x": {0: "batch_size", 1: "time"}, 142 | "x_lengths": {0: "batch_size"}, 143 | } 144 | 145 | if vocoder is None: 146 | dynamic_axes.update( 147 | { 148 | "mel": {0: "batch_size", 2: "time"}, 149 | "mel_lengths": {0: "batch_size"}, 150 | } 151 | ) 152 | else: 153 | print("Embedding the vocoder in the ONNX graph") 154 | dynamic_axes.update( 155 | { 156 | "wav": {0: "batch_size", 1: "time"}, 157 | "wav_lengths": {0: "batch_size"}, 158 | } 159 | ) 160 | 161 | if is_multi_speaker: 162 | dynamic_axes["spks"] = {0: "batch_size"} 163 | 164 | # Create the output directory (if not exists) 165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 166 | 167 | model.to_onnx( 168 | args.output, 169 | dummy_input, 170 | input_names=input_names, 171 | output_names=output_names, 172 | dynamic_axes=dynamic_axes, 173 | opset_version=args.opset, 174 | export_params=True, 175 | do_constant_folding=True, 176 | ) 177 | print(f"[🍵] ONNX model exported to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /matcha/models/components/commons.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import math 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def init_weights(m, mean=0.0, std=0.01): 8 | classname = m.__class__.__name__ 9 | if classname.find("Conv") != -1: 10 | m.weight.data.normal_(mean, std) 11 | 12 | 13 | def get_padding(kernel_size, dilation=1): 14 | return int((kernel_size * dilation - dilation) / 2) 15 | 16 | 17 | def convert_pad_shape(pad_shape): 18 | l = pad_shape[::-1] 19 | pad_shape = [item for sublist in l for item in sublist] 20 | return pad_shape 21 | 22 | 23 | def intersperse(lst, item): 24 | result = [item] * (len(lst) * 2 + 1) 25 | result[1::2] = lst 26 | return result 27 | 28 | 29 | def kl_divergence(m_p, logs_p, m_q, logs_q): 30 | """KL(P||Q)""" 31 | kl = (logs_q - logs_p) - 0.5 32 | kl += ( 33 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 34 | ) 35 | return kl 36 | 37 | 38 | def rand_gumbel(shape): 39 | """Sample from the Gumbel distribution, protect from overflows.""" 40 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 41 | return -torch.log(-torch.log(uniform_samples)) 42 | 43 | 44 | def rand_gumbel_like(x): 45 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 46 | return g 47 | 48 | 49 | def slice_segments(x, ids_str, segment_size=4): 50 | ret = torch.zeros_like(x[:, :, :segment_size]) 51 | for i in range(x.size(0)): 52 | idx_str = ids_str[i] 53 | idx_end = idx_str + segment_size 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | return ret 56 | 57 | 58 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 59 | b, d, t = x.size() 60 | if x_lengths is None: 61 | x_lengths = t 62 | ids_str_max = x_lengths - segment_size + 1 63 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 64 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( 65 | dtype=torch.long 66 | ) 67 | ret = slice_segments(x, ids_str, segment_size) 68 | return ret, ids_str 69 | 70 | 71 | def rand_slice_segments_for_cat(x, x_lengths=None, segment_size=4): 72 | b, d, t = x.size() 73 | if x_lengths is None: 74 | x_lengths = t 75 | ids_str_max = x_lengths - segment_size + 1 76 | ids_str = torch.rand([b // 2]).to(device=x.device) 77 | ids_str = (torch.cat([ids_str, ids_str], dim=0) * ids_str_max).to(dtype=torch.long) 78 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( 79 | dtype=torch.long 80 | ) 81 | ret = slice_segments(x, ids_str, segment_size) 82 | return ret, ids_str 83 | 84 | 85 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 86 | position = torch.arange(length, dtype=torch.float) 87 | num_timescales = channels // 2 88 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 89 | num_timescales - 1 90 | ) 91 | inv_timescales = min_timescale * torch.exp( 92 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 93 | ) 94 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 95 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 96 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 97 | signal = signal.view(1, channels, length) 98 | return signal 99 | 100 | 101 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 102 | b, channels, length = x.size() 103 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 104 | return x + signal.to(dtype=x.dtype, device=x.device) 105 | 106 | 107 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 108 | b, channels, length = x.size() 109 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 110 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 111 | 112 | 113 | def subsequent_mask(length): 114 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 115 | return mask 116 | 117 | 118 | @torch.jit.script 119 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 120 | n_channels_int = n_channels[0] 121 | in_act = input_a + input_b 122 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 123 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 124 | acts = t_act * s_act 125 | return acts 126 | 127 | 128 | def convert_pad_shape(pad_shape): 129 | l = pad_shape[::-1] 130 | pad_shape = [item for sublist in l for item in sublist] 131 | return pad_shape 132 | 133 | 134 | def shift_1d(x): 135 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 136 | return x 137 | 138 | 139 | def sequence_mask(length, max_length=None): 140 | if max_length is None: 141 | max_length = length.max() 142 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 143 | return x.unsqueeze(0) < length.unsqueeze(1) 144 | 145 | 146 | def generate_path(duration, mask): 147 | """ 148 | duration: [b, 1, t_x] 149 | mask: [b, 1, t_y, t_x] 150 | """ 151 | device = duration.device 152 | 153 | b, _, t_y, t_x = mask.shape 154 | cum_duration = torch.cumsum(duration, -1) 155 | 156 | cum_duration_flat = cum_duration.view(b * t_x) 157 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 158 | path = path.view(b, t_x, t_y) 159 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 160 | path = path.unsqueeze(1).transpose(2, 3) * mask 161 | return path 162 | 163 | 164 | def clip_grad_value_(parameters, clip_value, norm_type=2): 165 | if isinstance(parameters, torch.Tensor): 166 | parameters = [parameters] 167 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 168 | norm_type = float(norm_type) 169 | if clip_value is not None: 170 | clip_value = float(clip_value) 171 | 172 | total_norm = 0 173 | for p in parameters: 174 | param_norm = p.grad.data.norm(norm_type) 175 | total_norm += param_norm.item() ** norm_type 176 | if clip_value is not None: 177 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 178 | total_norm = total_norm ** (1.0 / norm_type) 179 | return total_norm 180 | -------------------------------------------------------------------------------- /matcha/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | from time import perf_counter 6 | 7 | import numpy as np 8 | import onnxruntime as ort 9 | import soundfile as sf 10 | import torch 11 | 12 | from matcha.cli import plot_spectrogram_to_numpy, process_text 13 | 14 | 15 | def validate_args(args): 16 | assert ( 17 | args.text or args.file 18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." 19 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 21 | return args 22 | 23 | 24 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 25 | if external_vocoder is None: 26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 27 | t0 = perf_counter() 28 | wavs, wav_lengths = model.run(None, inputs) 29 | infer_secs = perf_counter() - t0 30 | mel_infer_secs = vocoder_infer_secs = None 31 | else: 32 | print("[🍵] Generating mel using Matcha") 33 | mel_t0 = perf_counter() 34 | mels, mel_lengths = model.run(None, inputs) 35 | mel_infer_secs = perf_counter() - mel_t0 36 | print("Generating waveform from mel using external vocoder") 37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 38 | vocoder_t0 = perf_counter() 39 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 40 | vocoder_infer_secs = perf_counter() - vocoder_t0 41 | wavs = wavs.squeeze(1) 42 | wav_lengths = mel_lengths * 256 43 | infer_secs = mel_infer_secs + vocoder_infer_secs 44 | 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(parents=True, exist_ok=True) 47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 49 | audio = wav[:wav_length] 50 | print(f"Writing audio to {output_filename}") 51 | sf.write(output_filename, audio, 22050, "PCM_24") 52 | 53 | wav_secs = wav_lengths.sum() / 22050 54 | print(f"Inference seconds: {infer_secs}") 55 | print(f"Generated wav seconds: {wav_secs}") 56 | rtf = infer_secs / wav_secs 57 | if mel_infer_secs is not None: 58 | mel_rtf = mel_infer_secs / wav_secs 59 | print(f"Matcha RTF: {mel_rtf}") 60 | if vocoder_infer_secs is not None: 61 | vocoder_rtf = vocoder_infer_secs / wav_secs 62 | print(f"Vocoder RTF: {vocoder_rtf}") 63 | print(f"Overall RTF: {rtf}") 64 | 65 | 66 | def write_mels(model, inputs, output_dir): 67 | t0 = perf_counter() 68 | mels, mel_lengths = model.run(None, inputs) 69 | infer_secs = perf_counter() - t0 70 | 71 | output_dir = Path(output_dir) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | for i, mel in enumerate(mels): 74 | output_stem = output_dir.joinpath(f"output_{i + 1}") 75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 76 | np.save(output_stem.with_suffix(".numpy"), mel) 77 | 78 | wav_secs = (mel_lengths * 256).sum() / 22050 79 | print(f"Inference seconds: {infer_secs}") 80 | print(f"Generated wav seconds: {wav_secs}") 81 | rtf = infer_secs / wav_secs 82 | print(f"RTF: {rtf}") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser( 87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" 88 | ) 89 | parser.add_argument( 90 | "model", 91 | type=str, 92 | help="ONNX model to use", 93 | ) 94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID") 98 | parser.add_argument( 99 | "--temperature", 100 | type=float, 101 | default=0.667, 102 | help="Variance of the x0 noise (default: 0.667)", 103 | ) 104 | parser.add_argument( 105 | "--speaking-rate", 106 | type=float, 107 | default=1.0, 108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 109 | ) 110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 111 | parser.add_argument( 112 | "--output-dir", 113 | type=str, 114 | default=os.getcwd(), 115 | help="Output folder to save results (default: current dir)", 116 | ) 117 | 118 | args = parser.parse_args() 119 | args = validate_args(args) 120 | 121 | if args.gpu: 122 | providers = ["GPUExecutionProvider"] 123 | else: 124 | providers = ["CPUExecutionProvider"] 125 | model = ort.InferenceSession(args.model, providers=providers) 126 | 127 | model_inputs = model.get_inputs() 128 | model_outputs = list(model.get_outputs()) 129 | 130 | if args.text: 131 | text_lines = args.text.splitlines() 132 | else: 133 | with open(args.file, encoding="utf-8") as file: 134 | text_lines = file.read().splitlines() 135 | 136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 137 | x = [line["x"].squeeze() for line in processed_lines] 138 | # Pad 139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 140 | x = x.detach().cpu().numpy() 141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 142 | inputs = { 143 | "x": x, 144 | "x_lengths": x_lengths, 145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 146 | } 147 | is_multi_speaker = len(model_inputs) == 4 148 | if is_multi_speaker: 149 | if args.spk is None: 150 | args.spk = 0 151 | warn = "[!] Speaker ID not provided! Using speaker ID 0" 152 | warnings.warn(warn, UserWarning) 153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) 154 | 155 | has_vocoder_embedded = model_outputs[0].name == "wav" 156 | if has_vocoder_embedded: 157 | write_wavs(model, inputs, args.output_dir) 158 | elif args.vocoder: 159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 161 | else: 162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 163 | warnings.warn(warn, UserWarning) 164 | write_mels(model, inputs, args.output_dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /matcha/models/components/vits_modules.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from matcha.models.components import commons 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | class LayerNorm(nn.Module): 11 | def __init__(self, channels, eps=1e-5): 12 | super().__init__() 13 | self.channels = channels 14 | self.eps = eps 15 | 16 | self.gamma = nn.Parameter(torch.ones(channels)) 17 | self.beta = nn.Parameter(torch.zeros(channels)) 18 | 19 | def forward(self, x): 20 | x = x.transpose(1, -1) 21 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 22 | return x.transpose(1, -1) 23 | 24 | 25 | class ConvReluNorm(nn.Module): 26 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | self.hidden_channels = hidden_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = kernel_size 32 | self.n_layers = n_layers 33 | self.p_dropout = p_dropout 34 | assert n_layers > 1, "Number of layers should be larger than 0." 35 | 36 | self.conv_layers = nn.ModuleList() 37 | self.norm_layers = nn.ModuleList() 38 | self.conv_layers.append( 39 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2) 40 | ) 41 | self.norm_layers.append(LayerNorm(hidden_channels)) 42 | self.relu_drop = nn.Sequential( 43 | nn.ReLU(), 44 | nn.Dropout(p_dropout)) 45 | for _ in range(n_layers-1): 46 | self.conv_layers.append(nn.Conv1d( 47 | hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2) 48 | ) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 51 | self.proj.weight.data.zero_() 52 | self.proj.bias.data.zero_() 53 | 54 | def forward(self, x, x_mask): 55 | x_org = x 56 | for i in range(self.n_layers): 57 | x = self.conv_layers[i](x * x_mask) 58 | x = self.norm_layers[i](x) 59 | x = self.relu_drop(x) 60 | x = x_org + self.proj(x) 61 | return x * x_mask 62 | 63 | 64 | class DDSConv(nn.Module): 65 | """Dialted and Depth-Separable Convolution""" 66 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 67 | super().__init__() 68 | self.channels = channels 69 | self.kernel_size = kernel_size 70 | self.n_layers = n_layers 71 | self.p_dropout = p_dropout 72 | 73 | self.drop = nn.Dropout(p_dropout) 74 | self.convs_sep = nn.ModuleList() 75 | self.convs_1x1 = nn.ModuleList() 76 | self.norms_1 = nn.ModuleList() 77 | self.norms_2 = nn.ModuleList() 78 | for i in range(n_layers): 79 | dilation = kernel_size ** i 80 | padding = (kernel_size * dilation - dilation) // 2 81 | self.convs_sep.append( 82 | nn.Conv1d( 83 | channels, 84 | channels, 85 | kernel_size, 86 | groups=channels, 87 | dilation=dilation, 88 | padding=padding 89 | ) 90 | ) 91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 92 | self.norms_1.append(LayerNorm(channels)) 93 | self.norms_2.append(LayerNorm(channels)) 94 | 95 | def forward(self, x, x_mask, g=None): 96 | if g is not None: 97 | x = x + g 98 | for i in range(self.n_layers): 99 | y = self.convs_sep[i](x * x_mask) 100 | y = self.norms_1[i](y) 101 | y = F.gelu(y) 102 | y = self.convs_1x1[i](y) 103 | y = self.norms_2[i](y) 104 | y = F.gelu(y) 105 | y = self.drop(y) 106 | x = x + y 107 | return x * x_mask 108 | 109 | 110 | class WN(torch.nn.Module): 111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 112 | super(WN, self).__init__() 113 | assert(kernel_size % 2 == 1) 114 | self.hidden_channels = hidden_channels 115 | self.kernel_size = kernel_size, 116 | self.dilation_rate = dilation_rate 117 | self.n_layers = n_layers 118 | self.gin_channels = gin_channels 119 | self.p_dropout = p_dropout 120 | 121 | self.in_layers = torch.nn.ModuleList() 122 | self.res_skip_layers = torch.nn.ModuleList() 123 | self.drop = nn.Dropout(p_dropout) 124 | 125 | if gin_channels != 0: 126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 128 | 129 | for i in range(n_layers): 130 | dilation = dilation_rate ** i 131 | padding = int((kernel_size * dilation - dilation) / 2) 132 | in_layer = torch.nn.Conv1d( 133 | hidden_channels, 134 | 2*hidden_channels, 135 | kernel_size, 136 | dilation=dilation, 137 | padding=padding 138 | ) 139 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 140 | self.in_layers.append(in_layer) 141 | 142 | # last one is not necessary 143 | if i < n_layers - 1: 144 | res_skip_channels = 2 * hidden_channels 145 | else: 146 | res_skip_channels = hidden_channels 147 | 148 | res_skip_layer = torch.nn.Conv1d( 149 | hidden_channels, res_skip_channels, 1 150 | ) 151 | res_skip_layer = torch.nn.utils.weight_norm( 152 | res_skip_layer, name='weight' 153 | ) 154 | self.res_skip_layers.append(res_skip_layer) 155 | 156 | def forward(self, x, x_mask, g=None, **kwargs): 157 | output = torch.zeros_like(x) 158 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 159 | 160 | if g is not None: 161 | g = self.cond_layer(g) 162 | 163 | for i in range(self.n_layers): 164 | x_in = self.in_layers[i](x) 165 | if g is not None: 166 | cond_offset = i * 2 * self.hidden_channels 167 | g_l = g[:, cond_offset:cond_offset+2*self.hidden_channels, :] 168 | else: 169 | g_l = torch.zeros_like(x_in) 170 | 171 | acts = commons.fused_add_tanh_sigmoid_multiply( 172 | x_in, 173 | g_l, 174 | n_channels_tensor 175 | ) 176 | acts = self.drop(acts) 177 | 178 | res_skip_acts = self.res_skip_layers[i](acts) 179 | if i < self.n_layers - 1: 180 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 181 | x = (x + res_acts) * x_mask 182 | output = output + res_skip_acts[:, self.hidden_channels:, :] 183 | else: 184 | output = output + res_skip_acts 185 | return output * x_mask 186 | 187 | def remove_weight_norm(self): 188 | if self.gin_channels != 0: 189 | torch.nn.utils.remove_weight_norm(self.cond_layer) 190 | for l in self.in_layers: 191 | torch.nn.utils.remove_weight_norm(l) 192 | for l in self.res_skip_layers: 193 | torch.nn.utils.remove_weight_norm(l) 194 | 195 | -------------------------------------------------------------------------------- /matcha/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /matcha/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Tuple 7 | 8 | import gdown 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import wget 13 | from omegaconf import DictConfig 14 | 15 | from matcha.utils import pylogger, rich_utils 16 | 17 | log = pylogger.get_pylogger(__name__) 18 | 19 | 20 | def extras(cfg: DictConfig) -> None: 21 | """Applies optional utilities before the task is started. 22 | 23 | Utilities: 24 | - Ignoring python warnings 25 | - Setting tags from command line 26 | - Rich config printing 27 | 28 | :param cfg: A DictConfig object containing the config tree. 29 | """ 30 | # return if no `extras` config 31 | if not cfg.get("extras"): 32 | log.warning("Extras config not found! ") 33 | return 34 | 35 | # disable python warnings 36 | if cfg.extras.get("ignore_warnings"): 37 | log.info("Disabling python warnings! ") 38 | warnings.filterwarnings("ignore") 39 | 40 | # prompt user to input tags from command line if none are provided in the config 41 | if cfg.extras.get("enforce_tags"): 42 | log.info("Enforcing tags! ") 43 | rich_utils.enforce_tags(cfg, save_to_file=True) 44 | 45 | # pretty print config tree using Rich library 46 | if cfg.extras.get("print_config"): 47 | log.info("Printing config tree with Rich! ") 48 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 49 | 50 | 51 | def task_wrapper(task_func: Callable) -> Callable: 52 | """Optional decorator that controls the failure behavior when executing the task function. 53 | 54 | This wrapper can be used to: 55 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 56 | - save the exception to a `.log` file 57 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 58 | - etc. (adjust depending on your needs) 59 | 60 | Example: 61 | ``` 62 | @utils.task_wrapper 63 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 64 | ... 65 | return metric_dict, object_dict 66 | ``` 67 | 68 | :param task_func: The task function to be wrapped. 69 | 70 | :return: The wrapped task function. 71 | """ 72 | 73 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 | # execute the task 75 | try: 76 | metric_dict, object_dict = task_func(cfg=cfg) 77 | 78 | # things to do if exception occurs 79 | except Exception as ex: 80 | # save exception to `.log` file 81 | log.exception("") 82 | 83 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 | # so when using hparam search plugins like Optuna, you might want to disable 85 | # raising the below exception to avoid multirun failure 86 | raise ex 87 | 88 | # things to always do after either success or exception 89 | finally: 90 | # display output dir path in terminal 91 | log.info(f"Output dir: {cfg.paths.output_dir}") 92 | 93 | # always close wandb run (even if exception occurs so multirun won't fail) 94 | if find_spec("wandb"): # check if wandb is installed 95 | import wandb 96 | 97 | if wandb.run: 98 | log.info("Closing wandb!") 99 | wandb.finish() 100 | 101 | return metric_dict, object_dict 102 | 103 | return wrap 104 | 105 | 106 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: 107 | """Safely retrieves value of the metric logged in LightningModule. 108 | 109 | :param metric_dict: A dict containing metric values. 110 | :param metric_name: The name of the metric to retrieve. 111 | :return: The value of the metric. 112 | """ 113 | if not metric_name: 114 | log.info("Metric name is None! Skipping metric value retrieval...") 115 | return None 116 | 117 | if metric_name not in metric_dict: 118 | raise Exception( 119 | f"Metric value not found! \n" 120 | "Make sure metric name logged in LightningModule is correct!\n" 121 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 122 | ) 123 | 124 | metric_value = metric_dict[metric_name].item() 125 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 126 | 127 | return metric_value 128 | 129 | 130 | def intersperse(lst, item): 131 | # Adds blank symbol 132 | result = [item] * (len(lst) * 2 + 1) 133 | result[1::2] = lst 134 | return result 135 | 136 | 137 | def save_figure_to_numpy(fig): 138 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 139 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 140 | return data 141 | 142 | 143 | def plot_tensor(tensor): 144 | plt.style.use("default") 145 | fig, ax = plt.subplots(figsize=(12, 3)) 146 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 147 | plt.colorbar(im, ax=ax) 148 | plt.tight_layout() 149 | fig.canvas.draw() 150 | data = save_figure_to_numpy(fig) 151 | plt.close() 152 | return data 153 | 154 | 155 | def save_plot(tensor, savepath): 156 | plt.style.use("default") 157 | fig, ax = plt.subplots(figsize=(12, 3)) 158 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 159 | plt.colorbar(im, ax=ax) 160 | plt.tight_layout() 161 | fig.canvas.draw() 162 | plt.savefig(savepath) 163 | plt.close() 164 | 165 | 166 | def to_numpy(tensor): 167 | if isinstance(tensor, np.ndarray): 168 | return tensor 169 | elif isinstance(tensor, torch.Tensor): 170 | return tensor.detach().cpu().numpy() 171 | elif isinstance(tensor, list): 172 | return np.array(tensor) 173 | else: 174 | raise TypeError("Unsupported type for conversion to numpy array") 175 | 176 | 177 | def get_user_data_dir(appname="matcha_tts"): 178 | """ 179 | Args: 180 | appname (str): Name of application 181 | 182 | Returns: 183 | Path: path to user data directory 184 | """ 185 | 186 | MATCHA_HOME = os.environ.get("MATCHA_HOME") 187 | if MATCHA_HOME is not None: 188 | ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) 189 | elif sys.platform == "win32": 190 | import winreg # pylint: disable=import-outside-toplevel 191 | 192 | key = winreg.OpenKey( 193 | winreg.HKEY_CURRENT_USER, 194 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", 195 | ) 196 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 197 | ans = Path(dir_).resolve(strict=False) 198 | elif sys.platform == "darwin": 199 | ans = Path("~/Library/Application Support/").expanduser() 200 | else: 201 | ans = Path.home().joinpath(".local/share") 202 | 203 | final_path = ans.joinpath(appname) 204 | final_path.mkdir(parents=True, exist_ok=True) 205 | return final_path 206 | 207 | 208 | def assert_model_downloaded(checkpoint_path, url, use_wget=False): 209 | if Path(checkpoint_path).exists(): 210 | log.debug(f"[+] Model already present at {checkpoint_path}!") 211 | return 212 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it") 213 | checkpoint_path = str(checkpoint_path) 214 | if not use_wget: 215 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) 216 | else: 217 | wget.download(url=url, out=checkpoint_path) 218 | -------------------------------------------------------------------------------- /matcha/data/text_mel_datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | import torchaudio as ta 6 | from lightning import LightningDataModule 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from matcha.text import text_to_sequence 10 | from matcha.utils.audio import mel_spectrogram 11 | from matcha.utils.model import fix_len_compatibility, normalize 12 | from matcha.utils.utils import intersperse 13 | 14 | 15 | def parse_filelist(filelist_path, split_char="|"): 16 | with open(filelist_path, encoding="utf-8") as f: 17 | filepaths_and_text = [line.strip().split(split_char) for line in f] 18 | return filepaths_and_text 19 | 20 | 21 | class TextMelDataModule(LightningDataModule): 22 | def __init__( # pylint: disable=unused-argument 23 | self, 24 | name, 25 | train_filelist_path, 26 | valid_filelist_path, 27 | batch_size, 28 | num_workers, 29 | pin_memory, 30 | cleaners, 31 | add_blank, 32 | n_spks, 33 | n_fft, 34 | n_feats, 35 | sample_rate, 36 | hop_length, 37 | win_length, 38 | f_min, 39 | f_max, 40 | data_statistics, 41 | seed, 42 | ): 43 | super().__init__() 44 | 45 | # this line allows to access init params with 'self.hparams' attribute 46 | # also ensures init params will be stored in ckpt 47 | self.save_hyperparameters(logger=False) 48 | 49 | def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument 50 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 51 | 52 | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be 53 | careful not to execute things like random split twice! 54 | """ 55 | # load and split datasets only if not loaded already 56 | 57 | self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 58 | self.hparams.train_filelist_path, 59 | self.hparams.n_spks, 60 | self.hparams.cleaners, 61 | self.hparams.add_blank, 62 | self.hparams.n_fft, 63 | self.hparams.n_feats, 64 | self.hparams.sample_rate, 65 | self.hparams.hop_length, 66 | self.hparams.win_length, 67 | self.hparams.f_min, 68 | self.hparams.f_max, 69 | self.hparams.data_statistics, 70 | self.hparams.seed, 71 | ) 72 | self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 73 | self.hparams.valid_filelist_path, 74 | self.hparams.n_spks, 75 | self.hparams.cleaners, 76 | self.hparams.add_blank, 77 | self.hparams.n_fft, 78 | self.hparams.n_feats, 79 | self.hparams.sample_rate, 80 | self.hparams.hop_length, 81 | self.hparams.win_length, 82 | self.hparams.f_min, 83 | self.hparams.f_max, 84 | self.hparams.data_statistics, 85 | self.hparams.seed, 86 | ) 87 | 88 | def train_dataloader(self): 89 | return DataLoader( 90 | dataset=self.trainset, 91 | batch_size=self.hparams.batch_size, 92 | num_workers=self.hparams.num_workers, 93 | pin_memory=self.hparams.pin_memory, 94 | shuffle=True, 95 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 96 | ) 97 | 98 | def val_dataloader(self): 99 | return DataLoader( 100 | dataset=self.validset, 101 | batch_size=self.hparams.batch_size, 102 | num_workers=self.hparams.num_workers, 103 | pin_memory=self.hparams.pin_memory, 104 | shuffle=False, 105 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 106 | ) 107 | 108 | def teardown(self, stage: Optional[str] = None): 109 | """Clean up after fit or test.""" 110 | pass # pylint: disable=unnecessary-pass 111 | 112 | def state_dict(self): # pylint: disable=no-self-use 113 | """Extra things to save to checkpoint.""" 114 | return {} 115 | 116 | def load_state_dict(self, state_dict: Dict[str, Any]): 117 | """Things to do when loading checkpoint.""" 118 | pass # pylint: disable=unnecessary-pass 119 | 120 | 121 | class TextMelDataset(torch.utils.data.Dataset): 122 | def __init__( 123 | self, 124 | filelist_path, 125 | n_spks, 126 | cleaners, 127 | add_blank=True, 128 | n_fft=1024, 129 | n_mels=80, 130 | sample_rate=22050, 131 | hop_length=256, 132 | win_length=1024, 133 | f_min=0.0, 134 | f_max=8000, 135 | data_parameters=None, 136 | seed=None, 137 | ): 138 | self.filepaths_and_text = parse_filelist(filelist_path) 139 | self.n_spks = n_spks 140 | self.cleaners = cleaners 141 | self.add_blank = add_blank 142 | self.n_fft = n_fft 143 | self.n_mels = n_mels 144 | self.sample_rate = sample_rate 145 | self.hop_length = hop_length 146 | self.win_length = win_length 147 | self.f_min = f_min 148 | self.f_max = f_max 149 | if data_parameters is not None: 150 | self.data_parameters = data_parameters 151 | else: 152 | self.data_parameters = {"mel_mean": 0, "mel_std": 1} 153 | random.seed(seed) 154 | random.shuffle(self.filepaths_and_text) 155 | 156 | def get_datapoint(self, filepath_and_text): 157 | if self.n_spks > 1: 158 | filepath, spk, text = ( 159 | filepath_and_text[0], 160 | int(filepath_and_text[1]), 161 | filepath_and_text[2], 162 | ) 163 | else: 164 | filepath, text = filepath_and_text[0], filepath_and_text[1] 165 | spk = None 166 | 167 | text = self.get_text(text, add_blank=self.add_blank) 168 | mel, audio = self.get_mel(filepath) 169 | 170 | return {"x": text, "y": mel, "spk": spk, "wav":audio} 171 | 172 | def get_mel(self, filepath): 173 | audio, sr = ta.load(filepath) 174 | assert sr == self.sample_rate 175 | mel = mel_spectrogram( 176 | audio, 177 | self.n_fft, 178 | self.n_mels, 179 | self.sample_rate, 180 | self.hop_length, 181 | self.win_length, 182 | self.f_min, 183 | self.f_max, 184 | center=False, 185 | ).squeeze() 186 | mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) 187 | return mel, audio 188 | 189 | def get_text(self, text, add_blank=True): 190 | text_norm = text_to_sequence(text, self.cleaners) 191 | if self.add_blank: 192 | text_norm = intersperse(text_norm, 0) 193 | text_norm = torch.IntTensor(text_norm) 194 | return text_norm 195 | 196 | def __getitem__(self, index): 197 | datapoint = self.get_datapoint(self.filepaths_and_text[index]) 198 | return datapoint 199 | 200 | def __len__(self): 201 | return len(self.filepaths_and_text) 202 | 203 | 204 | class TextMelBatchCollate: 205 | def __init__(self, n_spks): 206 | self.n_spks = n_spks 207 | 208 | def __call__(self, batch): 209 | B = len(batch) 210 | y_max_length = max([item["y"].shape[-1] for item in batch]) 211 | y_max_length = fix_len_compatibility(y_max_length) 212 | x_max_length = max([item["x"].shape[-1] for item in batch]) 213 | wav_max_length = y_max_length * 256 #hoplength times mel frame numbers = wav size TODO remove hard code 214 | n_feats = batch[0]["y"].shape[-2] 215 | 216 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) 217 | x = torch.zeros((B, x_max_length), dtype=torch.long) 218 | wav = torch.zeros((B, 1, wav_max_length), dtype=torch.float32) 219 | y_lengths, x_lengths = [], [] 220 | wav_lengths = [] 221 | spks = [] 222 | for i, item in enumerate(batch): 223 | y_, x_ = item["y"], item["x"] 224 | wav_ = item["wav"] 225 | y_lengths.append(y_.shape[-1]) 226 | x_lengths.append(x_.shape[-1]) 227 | wav_lengths.append(wav_.shape[-1]) 228 | y[i, :, : y_.shape[-1]] = y_ 229 | x[i, : x_.shape[-1]] = x_ 230 | wav[i, :, : wav_.shape[-1]] = wav_ 231 | spks.append(item["spk"]) 232 | 233 | y_lengths = torch.tensor(y_lengths, dtype=torch.long) 234 | x_lengths = torch.tensor(x_lengths, dtype=torch.long) 235 | wav_lengths = torch.tensor(wav_lengths, dtype=torch.long) 236 | spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None 237 | 238 | return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "wav":wav, "wav_lengths":wav_lengths} 239 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## TODO (edit readme and add more info about the project; currently it is a copy of the original Matcha-TTS readme) 4 | ## WIP: Matcha-TTS-2 5 | ## 🍵 Matcha-TTS-2: A fast E2E TTS architecture with conditional flow matching (not fast enough for training yet xD) 6 | 7 | [x] added experimental E2E TTS support; doing some small training to verify results. Expect me to update this model completely by end of November,23. 8 | [x] if anybody willing to help me understand cfm quicker, would be great. I have few questions. ty 9 | [x] I am trying 2 things -> (1) cfm decoder intermideiate output gives mel, give the mel to hifigan and compare hifigan's output's mel with real mel and also use prior loss to force textenc to be near the decoder's output (Which again is the mel). (2) a more "learnable" flexible model with freedom at decoder out, no prior restrictions and just final output mel (to make hifigan robust to noise, we add small sigma to decoder output while training), ATM the code is (2) and i think it is better. 10 | ![image](https://github.com/p0p4k/Matcha-TTS-2/assets/8834712/560de995-8fbb-4155-8d1f-c8ed4200ddd6) 11 | output from (2 ..still training) 12 | ![image](https://github.com/p0p4k/Matcha-TTS-2/assets/8834712/1ce7bef4-96f4-4ffd-8041-a9d2cacfd85e) 13 | 14 | ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) 15 | 16 | [![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/) 17 | [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 18 | [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) 19 | [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) 20 | [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) 21 | [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 22 | 23 |

24 | 25 |

26 | 27 |
28 | 29 | > This is the official code implementation of 🍵 Matcha-TTS. 30 | 31 | We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method: 32 | 33 | - Is probabilistic 34 | - Has compact memory footprint 35 | - Sounds highly natural 36 | - Is very fast to synthesise from 37 | 38 | Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our arXiv preprint](https://arxiv.org/abs/2309.03199) for more details. 39 | 40 | [Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface. 41 | 42 | [Try 🍵 Matcha-TTS on HuggingFace 🤗 spaces!](https://huggingface.co/spaces/shivammehta25/Matcha-TTS) 43 | 44 | ## Watch the teaser 45 | 46 | [![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0) 47 | 48 | ## Installation 49 | 50 | 1. Create an environment (suggested but optional) 51 | 52 | ``` 53 | conda create -n matcha-tts python=3.10 -y 54 | conda activate matcha-tts 55 | ``` 56 | 57 | 2. Install Matcha TTS using pip or from source 58 | 59 | ```bash 60 | pip install matcha-tts 61 | ``` 62 | 63 | from source 64 | 65 | ```bash 66 | pip install git+https://github.com/shivammehta25/Matcha-TTS.git 67 | cd Matcha-TTS 68 | pip install -e . 69 | ``` 70 | 71 | 3. Run CLI / gradio app / jupyter notebook 72 | 73 | ```bash 74 | # This will download the required models 75 | matcha-tts --text "" 76 | ``` 77 | 78 | or 79 | 80 | ```bash 81 | matcha-tts-app 82 | ``` 83 | 84 | or open `synthesis.ipynb` on jupyter notebook 85 | 86 | ### CLI Arguments 87 | 88 | - To synthesise from given text, run: 89 | 90 | ```bash 91 | matcha-tts --text "" 92 | ``` 93 | 94 | - To synthesise from a file, run: 95 | 96 | ```bash 97 | matcha-tts --file 98 | ``` 99 | 100 | - To batch synthesise from a file, run: 101 | 102 | ```bash 103 | matcha-tts --file --batched 104 | ``` 105 | 106 | Additional arguments 107 | 108 | - Speaking rate 109 | 110 | ```bash 111 | matcha-tts --text "" --speaking_rate 1.0 112 | ``` 113 | 114 | - Sampling temperature 115 | 116 | ```bash 117 | matcha-tts --text "" --temperature 0.667 118 | ``` 119 | 120 | - Euler ODE solver steps 121 | 122 | ```bash 123 | matcha-tts --text "" --steps 10 124 | ``` 125 | 126 | ## Train with your own dataset 127 | 128 | Let's assume we are training with LJ Speech 129 | 130 | 1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup). 131 | 132 | 2. Clone and enter the Matcha-TTS repository 133 | 134 | ```bash 135 | git clone https://github.com/shivammehta25/Matcha-TTS.git 136 | cd Matcha-TTS 137 | ``` 138 | 139 | 3. Install the package from source 140 | 141 | ```bash 142 | pip install -e . 143 | ``` 144 | 145 | 4. Go to `configs/data/ljspeech.yaml` and change 146 | 147 | ```yaml 148 | train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt 149 | valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt 150 | ``` 151 | 152 | 5. Generate normalisation statistics with the yaml file of dataset configuration 153 | 154 | ```bash 155 | matcha-data-stats -i ljspeech.yaml 156 | # Output: 157 | #{'mel_mean': -5.53662231756592, 'mel_std': 2.1161014277038574} 158 | ``` 159 | 160 | Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key. 161 | 162 | ```bash 163 | data_statistics: # Computed for ljspeech dataset 164 | mel_mean: -5.536622 165 | mel_std: 2.116101 166 | ``` 167 | 168 | to the paths of your train and validation filelists. 169 | 170 | 6. Run the training script 171 | 172 | ```bash 173 | make train-ljspeech 174 | ``` 175 | 176 | or 177 | 178 | ```bash 179 | python matcha/train.py experiment=ljspeech 180 | ``` 181 | 182 | - for a minimum memory run 183 | 184 | ```bash 185 | python matcha/train.py experiment=ljspeech_min_memory 186 | ``` 187 | 188 | - for multi-gpu training, run 189 | 190 | ```bash 191 | python matcha/train.py experiment=ljspeech trainer.devices=[0,1] 192 | ``` 193 | 194 | 7. Synthesise from the custom trained model 195 | 196 | ```bash 197 | matcha-tts --text "" --checkpoint_path 198 | ``` 199 | 200 | ## ONNX support 201 | 202 | > Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support. 203 | 204 | It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. 205 | 206 | ### ONNX export 207 | 208 | To export a checkpoint to ONNX, first install ONNX with 209 | 210 | ```bash 211 | pip install onnx 212 | ``` 213 | 214 | then run the following: 215 | 216 | ```bash 217 | python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 218 | ``` 219 | 220 | Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). 221 | 222 | **Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. 223 | 224 | **Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. 225 | 226 | ### ONNX Inference 227 | 228 | To run inference on the exported model, first install `onnxruntime` using 229 | 230 | ```bash 231 | pip install onnxruntime 232 | pip install onnxruntime-gpu # for GPU inference 233 | ``` 234 | 235 | then use the following: 236 | 237 | ```bash 238 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs 239 | ``` 240 | 241 | You can also control synthesis parameters: 242 | 243 | ```bash 244 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 245 | ``` 246 | 247 | To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: 248 | 249 | ```bash 250 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu 251 | ``` 252 | 253 | If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. 254 | If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. 255 | 256 | If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: 257 | 258 | ```bash 259 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx 260 | ``` 261 | 262 | This will write `.wav` audio files to the output directory. 263 | 264 | ## Citation information 265 | 266 | If you use our code or otherwise find this work useful, please cite our paper: 267 | 268 | ```text 269 | @article{mehta2023matcha, 270 | title={Matcha-TTS: A fast TTS architecture with conditional flow matching}, 271 | author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, 272 | journal={arXiv preprint arXiv:2309.03199}, 273 | year={2023} 274 | } 275 | ``` 276 | 277 | ## Acknowledgements 278 | 279 | Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. 280 | 281 | Other source code I would like to acknowledge: 282 | 283 | - [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement 284 | - [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components 285 | - [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code 286 | - [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development 287 | - [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation 288 | -------------------------------------------------------------------------------- /matcha/models/baselightningmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a base lightning module that can be used to train a model. 3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. 4 | """ 5 | import inspect 6 | from abc import ABC 7 | from typing import Any, Dict 8 | 9 | import torch 10 | from lightning import LightningModule 11 | from lightning.pytorch.utilities import grad_norm 12 | 13 | from matcha import utils 14 | from matcha.utils.utils import plot_tensor 15 | 16 | log = utils.get_pylogger(__name__) 17 | 18 | 19 | class BaseLightningClass(LightningModule, ABC): 20 | def update_data_statistics(self, data_statistics): 21 | if data_statistics is None: 22 | data_statistics = { 23 | "mel_mean": 0.0, 24 | "mel_std": 1.0, 25 | } 26 | 27 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) 28 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) 29 | 30 | def configure_optimizers(self) -> Any: 31 | optimizer = self.hparams.optimizer(params=self.parameters()) 32 | if self.hparams.scheduler not in (None, {}): 33 | scheduler_args = {} 34 | # Manage last epoch for exponential schedulers 35 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: 36 | if hasattr(self, "ckpt_loaded_epoch"): 37 | current_epoch = self.ckpt_loaded_epoch - 1 38 | else: 39 | current_epoch = -1 40 | 41 | scheduler_args.update({"optimizer": optimizer}) 42 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args) 43 | scheduler.last_epoch = current_epoch 44 | return { 45 | "optimizer": optimizer, 46 | "lr_scheduler": { 47 | "scheduler": scheduler, 48 | "interval": self.hparams.scheduler.lightning_args.interval, 49 | "frequency": self.hparams.scheduler.lightning_args.frequency, 50 | "name": "learning_rate", 51 | }, 52 | } 53 | 54 | return {"optimizer": optimizer} 55 | 56 | def get_losses(self, batch): 57 | x, x_lengths = batch["x"], batch["x_lengths"] 58 | y, y_lengths = batch["y"], batch["y_lengths"] 59 | # wav, wav_lengths = batch["wav"], batch["wav_lengths"] 60 | spks = batch["spks"] 61 | # print("y", y.shape) 62 | dur_loss, prior_loss, diff_loss, mel_loss, loss_disc, loss_gen, y_hat_mel, y_slice = self( 63 | x=x, 64 | x_lengths=x_lengths, 65 | y=y, 66 | y_lengths=y_lengths, 67 | spks=spks, 68 | out_size=self.out_size, 69 | # wav=wav, 70 | # wav_lengths=wav_lengths, 71 | ) 72 | return ( 73 | { 74 | "dur_loss": dur_loss, 75 | "prior_loss": prior_loss, 76 | "diff_loss": diff_loss, 77 | "mel_loss": mel_loss, 78 | "loss_disc": loss_disc, 79 | "loss_gen": loss_gen, 80 | }, 81 | { 82 | "y_hat_mel": y_hat_mel, 83 | "y_slice": y_slice, 84 | } 85 | ) 86 | 87 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 88 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init 89 | 90 | def training_step(self, batch: Any, batch_idx: int): 91 | loss_dict, plot_dict = self.get_losses(batch) 92 | loss_dict 93 | self.log( 94 | "step", 95 | float(self.global_step), 96 | on_step=True, 97 | on_epoch=True, 98 | logger=True, 99 | sync_dist=True, 100 | ) 101 | 102 | self.log( 103 | "sub_loss/train_dur_loss", 104 | loss_dict["dur_loss"], 105 | on_step=True, 106 | on_epoch=True, 107 | logger=True, 108 | sync_dist=True, 109 | ) 110 | self.log( 111 | "sub_loss/train_prior_loss", 112 | loss_dict["prior_loss"], 113 | on_step=True, 114 | on_epoch=True, 115 | logger=True, 116 | sync_dist=True, 117 | ) 118 | self.log( 119 | "sub_loss/train_diff_loss", 120 | loss_dict["diff_loss"], 121 | on_step=True, 122 | on_epoch=True, 123 | logger=True, 124 | sync_dist=True, 125 | ) 126 | self.log( 127 | "sub_loss/train_mel_loss", 128 | loss_dict["mel_loss"], 129 | on_step=True, 130 | on_epoch=True, 131 | logger=True, 132 | sync_dist=True, 133 | ) 134 | self.log( 135 | "sub_loss/train_loss_disc", 136 | loss_dict["loss_disc"], 137 | on_step=True, 138 | on_epoch=True, 139 | logger=True, 140 | sync_dist=True, 141 | ) 142 | self.log( 143 | "sub_loss/train_loss_gen", 144 | loss_dict["loss_gen"], 145 | on_step=True, 146 | on_epoch=True, 147 | logger=True, 148 | sync_dist=True, 149 | ) 150 | 151 | total_loss = sum(loss_dict.values()) 152 | self.log( 153 | "loss/train", 154 | total_loss, 155 | on_step=True, 156 | on_epoch=True, 157 | logger=True, 158 | prog_bar=True, 159 | sync_dist=True, 160 | ) 161 | self.logger.experiment.add_image( 162 | f"generated_slice/train_step", 163 | plot_tensor(plot_dict["y_hat_mel"].squeeze().detach().cpu()[0].squeeze()), 164 | self.current_epoch, 165 | dataformats="HWC", 166 | ) 167 | self.logger.experiment.add_image( 168 | f"real_slice/train_step", 169 | plot_tensor(plot_dict["y_slice"].squeeze().detach().cpu()[0].squeeze()), 170 | self.current_epoch, 171 | dataformats="HWC", 172 | ) 173 | return {"loss": total_loss, "log": loss_dict} 174 | 175 | def validation_step(self, batch: Any, batch_idx: int): 176 | loss_dict, plot_dict = self.get_losses(batch) 177 | self.log( 178 | "sub_loss/val_dur_loss", 179 | loss_dict["dur_loss"], 180 | on_step=True, 181 | on_epoch=True, 182 | logger=True, 183 | sync_dist=True, 184 | ) 185 | self.log( 186 | "sub_loss/val_prior_loss", 187 | loss_dict["prior_loss"], 188 | on_step=True, 189 | on_epoch=True, 190 | logger=True, 191 | sync_dist=True, 192 | ) 193 | self.log( 194 | "sub_loss/val_diff_loss", 195 | loss_dict["diff_loss"], 196 | on_step=True, 197 | on_epoch=True, 198 | logger=True, 199 | sync_dist=True, 200 | ) 201 | self.log( 202 | "sub_loss/val_mel_loss", 203 | loss_dict["mel_loss"], 204 | on_step=True, 205 | on_epoch=True, 206 | logger=True, 207 | sync_dist=True, 208 | ) 209 | self.log( 210 | "sub_loss/val_loss_disc", 211 | loss_dict["loss_disc"], 212 | on_step=True, 213 | on_epoch=True, 214 | logger=True, 215 | sync_dist=True, 216 | ) 217 | self.log( 218 | "sub_loss/val_loss_gen", 219 | loss_dict["loss_gen"], 220 | on_step=True, 221 | on_epoch=True, 222 | logger=True, 223 | sync_dist=True, 224 | ) 225 | 226 | total_loss = sum(loss_dict.values()) 227 | self.log( 228 | "loss/val", 229 | total_loss, 230 | on_step=True, 231 | on_epoch=True, 232 | logger=True, 233 | prog_bar=True, 234 | sync_dist=True, 235 | ) 236 | 237 | self.logger.experiment.add_image( 238 | f"generated_slice/val", 239 | plot_tensor(plot_dict["y_hat_mel"].squeeze().cpu()[0]), 240 | self.current_epoch, 241 | dataformats="HWC", 242 | ) 243 | self.logger.experiment.add_image( 244 | f"real_slice/val", 245 | plot_tensor(plot_dict["y_slice"].squeeze().cpu()[0]), 246 | self.current_epoch, 247 | dataformats="HWC", 248 | ) 249 | 250 | return total_loss 251 | 252 | def on_validation_end(self) -> None: 253 | if self.trainer.is_global_zero: 254 | one_batch = next(iter(self.trainer.val_dataloaders)) 255 | 256 | if self.current_epoch == 0: 257 | log.debug("Plotting original samples") 258 | for i in range(2): 259 | y = one_batch["y"][i].unsqueeze(0).to(self.device) 260 | self.logger.experiment.add_image( 261 | f"original/{i}", 262 | plot_tensor(y.squeeze().cpu()), 263 | self.current_epoch, 264 | dataformats="HWC", 265 | ) 266 | 267 | log.debug("Synthesising...") 268 | for i in range(2): 269 | x = one_batch["x"][i].unsqueeze(0).to(self.device) 270 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) 271 | spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None 272 | output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) 273 | y_enc, y_dec, y_mel = output["encoder_outputs"], output["decoder_outputs"], output["mel"] 274 | attn = output["attn"] 275 | self.logger.experiment.add_image( 276 | f"generated_enc/{i}", 277 | plot_tensor(y_enc.squeeze().cpu()), 278 | self.current_epoch, 279 | dataformats="HWC", 280 | ) 281 | self.logger.experiment.add_image( 282 | f"generated_dec/{i}", 283 | plot_tensor(y_dec.squeeze().cpu()), 284 | self.current_epoch, 285 | dataformats="HWC", 286 | ) 287 | self.logger.experiment.add_image( 288 | f"generated_mel/{i}", 289 | plot_tensor(y_mel.squeeze().cpu()), 290 | self.current_epoch, 291 | dataformats="HWC", 292 | ) 293 | self.logger.experiment.add_image( 294 | f"alignment/{i}", 295 | plot_tensor(attn.squeeze().cpu()), 296 | self.current_epoch, 297 | dataformats="HWC", 298 | ) 299 | 300 | def on_before_optimizer_step(self, optimizer): 301 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) 302 | -------------------------------------------------------------------------------- /matcha/models/matcha_tts.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import math 3 | import random 4 | 5 | import torch 6 | 7 | from matcha.models.components.vits_posterior import PosteriorEncoder 8 | 9 | from matcha.utils.monotonic_align import maximum_path 10 | from matcha import utils 11 | from matcha.models.baselightningmodule import BaseLightningClass 12 | from matcha.models.components.flow_matching import CFM 13 | from matcha.models.components.text_encoder import TextEncoder 14 | from matcha.utils.model import ( 15 | denormalize, 16 | duration_loss, 17 | fix_len_compatibility, 18 | generate_path, 19 | sequence_mask, 20 | ) 21 | from matcha.utils.audio import mel_spectrogram 22 | 23 | from torch.nn import functional as F 24 | 25 | from matcha.hifigan.models import Generator as HiFiGAN 26 | from matcha.hifigan.models import MultiPeriodDiscriminator 27 | from matcha.hifigan.models import discriminator_loss, generator_loss, feature_loss 28 | from matcha.hifigan.config import v1 29 | from matcha.hifigan.env import AttrDict 30 | from matcha.hifigan.meldataset import mel_spectrogram 31 | from matcha.models.components import commons 32 | from matcha.utils.model import normalize 33 | log = utils.get_pylogger(__name__) 34 | 35 | 36 | class MatchaTTS(BaseLightningClass): # 🍵 37 | def __init__( 38 | self, 39 | n_vocab, 40 | n_spks, 41 | spk_emb_dim, 42 | n_feats, 43 | encoder, 44 | decoder, 45 | cfm, 46 | data_statistics, 47 | out_size, 48 | optimizer=None, 49 | scheduler=None, 50 | ): 51 | super().__init__() 52 | 53 | self.save_hyperparameters(logger=False) 54 | 55 | self.n_vocab = n_vocab 56 | self.n_spks = n_spks 57 | self.spk_emb_dim = spk_emb_dim 58 | self.n_feats = n_feats 59 | self.out_size = out_size 60 | 61 | # if n_spks > 1: 62 | # self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 63 | 64 | self.h = AttrDict(v1) 65 | self.hifigan = HiFiGAN(self.h) 66 | self.wav2mel = mel_spectrogram 67 | 68 | self.encoder = TextEncoder( 69 | encoder.encoder_type, 70 | encoder.encoder_params, 71 | encoder.duration_predictor_params, 72 | n_vocab, 73 | n_spks, 74 | spk_emb_dim, 75 | ) 76 | self.enc_spec = PosteriorEncoder( 77 | n_feats, 78 | n_feats, 79 | n_feats, 80 | 5, 81 | 1, 82 | 16, 83 | gin_channels=spk_emb_dim, 84 | ) 85 | self.decoder = CFM( 86 | in_channels=2 * encoder.encoder_params.n_feats, 87 | out_channel=encoder.encoder_params.n_feats, 88 | cfm_params=cfm, 89 | decoder_params=decoder, 90 | n_spks=n_spks, 91 | spk_emb_dim=spk_emb_dim, 92 | ) 93 | self.hifigan_disc = MultiPeriodDiscriminator() 94 | self.update_data_statistics(data_statistics) 95 | 96 | @torch.inference_mode() 97 | def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): 98 | """ 99 | Generates mel-spectrogram from text. Returns: 100 | 1. encoder outputs 101 | 2. decoder outputs 102 | 3. generated alignment 103 | 104 | Args: 105 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 106 | shape: (batch_size, max_text_length) 107 | x_lengths (torch.Tensor): lengths of texts in batch. 108 | shape: (batch_size,) 109 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 110 | temperature (float, optional): controls variance of terminal distribution. 111 | spks (bool, optional): speaker ids. 112 | shape: (batch_size,) 113 | length_scale (float, optional): controls speech pace. 114 | Increase value to slow down generated speech and vice versa. 115 | 116 | Returns: 117 | dict: { 118 | "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 119 | # Average mel spectrogram generated by the encoder 120 | "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 121 | # Refined mel spectrogram improved by the CFM 122 | "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), 123 | # Alignment map between text and mel spectrogram 124 | "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 125 | # Denormalized mel spectrogram 126 | "mel_lengths": torch.Tensor, shape: (batch_size,), 127 | # Lengths of mel spectrograms 128 | "rtf": float, 129 | # Real-time factor 130 | """ 131 | # For RTF computation 132 | t = dt.datetime.now() 133 | 134 | # if self.n_spks > 1: 135 | # # Get speaker embedding 136 | # spks = self.spk_emb(spks.long()) 137 | spks = spks.squeeze(2) 138 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 139 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) 140 | 141 | w = torch.exp(logw) * x_mask 142 | w_ceil = torch.ceil(w) * length_scale 143 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 144 | y_max_length = y_lengths.max() 145 | y_max_length_ = fix_len_compatibility(y_max_length) 146 | 147 | # Using obtained durations `w` construct alignment map `attn` 148 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 149 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 150 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 151 | 152 | # Align encoded text and get mu_y 153 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 154 | mu_y = mu_y.transpose(1, 2) 155 | 156 | encoder_outputs = mu_y[:, :, :y_max_length_] 157 | 158 | # Generate sample tracing the probability flow 159 | decoder_outputs = self.decoder(encoder_outputs, y_mask, n_timesteps, temperature, spks) 160 | decoder_outputs = decoder_outputs[:, :, :y_max_length_] 161 | 162 | hifigan_out = self.hifigan(decoder_outputs) 163 | mel = self.wav2mel(hifigan_out.squeeze(1), num_mels=80, sampling_rate=22050, hop_size=256, win_size=1024, n_fft=1024, fmin=0, fmax=8000, center=False) 164 | # normalize mel 165 | 166 | mel = normalize(mel, self.mel_mean, self.mel_std) 167 | 168 | t = (dt.datetime.now() - t).total_seconds() 169 | rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) 170 | 171 | return { 172 | "encoder_outputs": encoder_outputs, 173 | "decoder_outputs": decoder_outputs, 174 | "attn": attn[:, :, :y_max_length], 175 | "hifigan_out": hifigan_out, 176 | "mel": denormalize(mel, self.mel_mean, self.mel_std), 177 | "mel_lengths": y_lengths, 178 | "rtf": rtf, 179 | } 180 | 181 | def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, wav=None, wav_lengths=None): 182 | """ 183 | Computes 3 losses: 184 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 185 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 186 | 3. flow matching loss: loss between mel-spectrogram and decoder outputs. 187 | 188 | Args: 189 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 190 | shape: (batch_size, max_text_length) 191 | x_lengths (torch.Tensor): lengths of texts in batch. 192 | shape: (batch_size,) 193 | y (torch.Tensor): batch of corresponding mel-spectrograms. 194 | shape: (batch_size, n_feats, max_mel_length) 195 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 196 | shape: (batch_size,) 197 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. 198 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. 199 | spks (torch.Tensor, optional): speaker ids. 200 | shape: (batch_size,) 201 | """ 202 | # if self.n_spks > 1: 203 | # # Get speaker embedding 204 | # spks = self.spk_emb(spks) 205 | spks = spks.squeeze(2) 206 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 207 | 208 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) 209 | 210 | y_max_length = y.shape[-1] 211 | z_spec, spec_mask = self.enc_spec(y, y_lengths, g=spks.unsqueeze(1).transpose(1,2)) 212 | 213 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 214 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 215 | # z_spec = y 216 | spec_mask = y_mask 217 | z_spec = z_spec * spec_mask 218 | with torch.no_grad(): 219 | # negative cross-entropy 220 | s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] 221 | neg_cent1 = torch.sum( 222 | -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True 223 | ) 224 | # s_p_sq_r = torch.exp(-2 * log_x) # [b, d, t] 225 | # neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_x, [1], keepdim=True) # [b, 1, t_s] 226 | 227 | neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (z_spec**2), s_p_sq_r) 228 | neg_cent3 = torch.einsum("bdt, bds -> bts", z_spec, (mu_x * s_p_sq_r)) 229 | neg_cent4 = torch.sum( 230 | -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True 231 | ) 232 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 233 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 234 | from matcha.utils.monotonic_align_vits import maximum_path 235 | attn = ( 236 | maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 237 | ) 238 | 239 | logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask 240 | dur_loss = duration_loss(logw, logw_, x_lengths) 241 | attn = attn.squeeze(1).transpose(1,2) 242 | 243 | # Align encoded text with mel-spectrogram and get mu_y segment 244 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 245 | mu_y = mu_y.transpose(1, 2) 246 | 247 | # Compute loss of the decoder 248 | 249 | diff_loss, _ = self.decoder.compute_loss(x1=z_spec.detach(), mask=y_mask, mu=mu_y, spks=spks, cond=cond) 250 | 251 | # prior_loss = torch.sum(0.5 * ((z_spec - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) 252 | # prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 253 | prior_loss = torch.FloatTensor([0.0]).to(z_spec.device) 254 | z_spec = 1e-4 * torch.randn_like(z_spec) + z_spec 255 | SEGMENT_SIZE = 8192//256 256 | z_sliced, ids_slice = commons.rand_slice_segments( 257 | z_spec, y_lengths , segment_size=SEGMENT_SIZE 258 | ) 259 | 260 | 261 | output_sliced_wav = self.hifigan(z_sliced) 262 | # real_wav_slice = commons.slice_segments( 263 | # wav, ids_slice * 256, 4096 264 | # ) 265 | # y_d_hat_r, y_d_hat_g, _, _ = self.hifigan_disc(real_wav_slice, output_sliced_wav.detach()) 266 | # loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 267 | # y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.hifigan_disc(real_wav_slice, output_sliced_wav) 268 | # loss_gen, losses_gen = generator_loss(y_d_hat_g) 269 | # loss_fm = feature_loss(fmap_r, fmap_g) 270 | # loss_gen += loss_fm 271 | y_slice = commons.slice_segments( 272 | y, ids_slice, SEGMENT_SIZE) 273 | y_hat_mel = self.wav2mel( 274 | output_sliced_wav.squeeze(1), 275 | num_mels=80, 276 | sampling_rate=22050, 277 | hop_size=256, 278 | win_size=1024, 279 | n_fft=1024, 280 | fmin=0, 281 | fmax=8000, 282 | center=False, 283 | ) 284 | 285 | # denorm_y = denormalize(y_slice, self.mel_mean, self.mel_std) 286 | y_hat_mel = normalize(y_hat_mel, self.mel_mean, self.mel_std) 287 | mel_loss = F.l1_loss(y_slice, y_hat_mel) 288 | loss_disc, loss_gen = torch.Tensor([0.0]).to(z_spec.device), torch.Tensor([0.0]).to(z_spec.device) 289 | return dur_loss, prior_loss, diff_loss, mel_loss*45, loss_disc, loss_gen, y_hat_mel, y_slice 290 | -------------------------------------------------------------------------------- /matcha/hifigan/models.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 7 | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 8 | 9 | from .xutils import get_padding, init_weights 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class ResBlock1(torch.nn.Module): 15 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 16 | super().__init__() 17 | self.h = h 18 | self.convs1 = nn.ModuleList( 19 | [ 20 | weight_norm( 21 | Conv1d( 22 | channels, 23 | channels, 24 | kernel_size, 25 | 1, 26 | dilation=dilation[0], 27 | padding=get_padding(kernel_size, dilation[0]), 28 | ) 29 | ), 30 | weight_norm( 31 | Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | 1, 36 | dilation=dilation[1], 37 | padding=get_padding(kernel_size, dilation[1]), 38 | ) 39 | ), 40 | weight_norm( 41 | Conv1d( 42 | channels, 43 | channels, 44 | kernel_size, 45 | 1, 46 | dilation=dilation[2], 47 | padding=get_padding(kernel_size, dilation[2]), 48 | ) 49 | ), 50 | ] 51 | ) 52 | self.convs1.apply(init_weights) 53 | 54 | self.convs2 = nn.ModuleList( 55 | [ 56 | weight_norm( 57 | Conv1d( 58 | channels, 59 | channels, 60 | kernel_size, 61 | 1, 62 | dilation=1, 63 | padding=get_padding(kernel_size, 1), 64 | ) 65 | ), 66 | weight_norm( 67 | Conv1d( 68 | channels, 69 | channels, 70 | kernel_size, 71 | 1, 72 | dilation=1, 73 | padding=get_padding(kernel_size, 1), 74 | ) 75 | ), 76 | weight_norm( 77 | Conv1d( 78 | channels, 79 | channels, 80 | kernel_size, 81 | 1, 82 | dilation=1, 83 | padding=get_padding(kernel_size, 1), 84 | ) 85 | ), 86 | ] 87 | ) 88 | self.convs2.apply(init_weights) 89 | 90 | def forward(self, x): 91 | for c1, c2 in zip(self.convs1, self.convs2): 92 | xt = F.leaky_relu(x, LRELU_SLOPE) 93 | xt = c1(xt) 94 | xt = F.leaky_relu(xt, LRELU_SLOPE) 95 | xt = c2(xt) 96 | x = xt + x 97 | return x 98 | 99 | def remove_weight_norm(self): 100 | for l in self.convs1: 101 | remove_weight_norm(l) 102 | for l in self.convs2: 103 | remove_weight_norm(l) 104 | 105 | 106 | class ResBlock2(torch.nn.Module): 107 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 108 | super().__init__() 109 | self.h = h 110 | self.convs = nn.ModuleList( 111 | [ 112 | weight_norm( 113 | Conv1d( 114 | channels, 115 | channels, 116 | kernel_size, 117 | 1, 118 | dilation=dilation[0], 119 | padding=get_padding(kernel_size, dilation[0]), 120 | ) 121 | ), 122 | weight_norm( 123 | Conv1d( 124 | channels, 125 | channels, 126 | kernel_size, 127 | 1, 128 | dilation=dilation[1], 129 | padding=get_padding(kernel_size, dilation[1]), 130 | ) 131 | ), 132 | ] 133 | ) 134 | self.convs.apply(init_weights) 135 | 136 | def forward(self, x): 137 | for c in self.convs: 138 | xt = F.leaky_relu(x, LRELU_SLOPE) 139 | xt = c(xt) 140 | x = xt + x 141 | return x 142 | 143 | def remove_weight_norm(self): 144 | for l in self.convs: 145 | remove_weight_norm(l) 146 | 147 | 148 | class Generator(torch.nn.Module): 149 | def __init__(self, h): 150 | super().__init__() 151 | self.h = h 152 | self.num_kernels = len(h.resblock_kernel_sizes) 153 | self.num_upsamples = len(h.upsample_rates) 154 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 155 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 156 | 157 | self.ups = nn.ModuleList() 158 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 159 | self.ups.append( 160 | weight_norm( 161 | ConvTranspose1d( 162 | h.upsample_initial_channel // (2**i), 163 | h.upsample_initial_channel // (2 ** (i + 1)), 164 | k, 165 | u, 166 | padding=(k - u) // 2, 167 | ) 168 | ) 169 | ) 170 | 171 | self.resblocks = nn.ModuleList() 172 | for i in range(len(self.ups)): 173 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 174 | for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 175 | self.resblocks.append(resblock(h, ch, k, d)) 176 | 177 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 178 | self.ups.apply(init_weights) 179 | self.conv_post.apply(init_weights) 180 | 181 | def forward(self, x): 182 | x = self.conv_pre(x) 183 | for i in range(self.num_upsamples): 184 | x = F.leaky_relu(x, LRELU_SLOPE) 185 | x = self.ups[i](x) 186 | xs = None 187 | for j in range(self.num_kernels): 188 | if xs is None: 189 | xs = self.resblocks[i * self.num_kernels + j](x) 190 | else: 191 | xs += self.resblocks[i * self.num_kernels + j](x) 192 | x = xs / self.num_kernels 193 | x = F.leaky_relu(x) 194 | x = self.conv_post(x) 195 | x = torch.tanh(x) 196 | 197 | return x 198 | 199 | def remove_weight_norm(self): 200 | print("Removing weight norm...") 201 | for l in self.ups: 202 | remove_weight_norm(l) 203 | for l in self.resblocks: 204 | l.remove_weight_norm() 205 | remove_weight_norm(self.conv_pre) 206 | remove_weight_norm(self.conv_post) 207 | 208 | 209 | class DiscriminatorP(torch.nn.Module): 210 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 211 | super().__init__() 212 | self.period = period 213 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 214 | self.convs = nn.ModuleList( 215 | [ 216 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 217 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 218 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 219 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 220 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 221 | ] 222 | ) 223 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 224 | 225 | def forward(self, x): 226 | fmap = [] 227 | 228 | # 1d to 2d 229 | b, c, t = x.shape 230 | if t % self.period != 0: # pad first 231 | n_pad = self.period - (t % self.period) 232 | x = F.pad(x, (0, n_pad), "reflect") 233 | t = t + n_pad 234 | x = x.view(b, c, t // self.period, self.period) 235 | 236 | for l in self.convs: 237 | x = l(x) 238 | x = F.leaky_relu(x, LRELU_SLOPE) 239 | fmap.append(x) 240 | x = self.conv_post(x) 241 | fmap.append(x) 242 | x = torch.flatten(x, 1, -1) 243 | 244 | return x, fmap 245 | 246 | 247 | class MultiPeriodDiscriminator(torch.nn.Module): 248 | def __init__(self): 249 | super().__init__() 250 | self.discriminators = nn.ModuleList( 251 | [ 252 | DiscriminatorP(2), 253 | DiscriminatorP(3), 254 | DiscriminatorP(5), 255 | DiscriminatorP(7), 256 | DiscriminatorP(11), 257 | ] 258 | ) 259 | 260 | def forward(self, y, y_hat): 261 | y_d_rs = [] 262 | y_d_gs = [] 263 | fmap_rs = [] 264 | fmap_gs = [] 265 | for _, d in enumerate(self.discriminators): 266 | y_d_r, fmap_r = d(y) 267 | y_d_g, fmap_g = d(y_hat) 268 | y_d_rs.append(y_d_r) 269 | fmap_rs.append(fmap_r) 270 | y_d_gs.append(y_d_g) 271 | fmap_gs.append(fmap_g) 272 | 273 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 274 | 275 | 276 | class DiscriminatorS(torch.nn.Module): 277 | def __init__(self, use_spectral_norm=False): 278 | super().__init__() 279 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 280 | self.convs = nn.ModuleList( 281 | [ 282 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 283 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 284 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 285 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 286 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 287 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 288 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 289 | ] 290 | ) 291 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 292 | 293 | def forward(self, x): 294 | fmap = [] 295 | for l in self.convs: 296 | x = l(x) 297 | x = F.leaky_relu(x, LRELU_SLOPE) 298 | fmap.append(x) 299 | x = self.conv_post(x) 300 | fmap.append(x) 301 | x = torch.flatten(x, 1, -1) 302 | 303 | return x, fmap 304 | 305 | 306 | class MultiScaleDiscriminator(torch.nn.Module): 307 | def __init__(self): 308 | super().__init__() 309 | self.discriminators = nn.ModuleList( 310 | [ 311 | DiscriminatorS(use_spectral_norm=True), 312 | DiscriminatorS(), 313 | DiscriminatorS(), 314 | ] 315 | ) 316 | self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) 317 | 318 | def forward(self, y, y_hat): 319 | y_d_rs = [] 320 | y_d_gs = [] 321 | fmap_rs = [] 322 | fmap_gs = [] 323 | for i, d in enumerate(self.discriminators): 324 | if i != 0: 325 | y = self.meanpools[i - 1](y) 326 | y_hat = self.meanpools[i - 1](y_hat) 327 | y_d_r, fmap_r = d(y) 328 | y_d_g, fmap_g = d(y_hat) 329 | y_d_rs.append(y_d_r) 330 | fmap_rs.append(fmap_r) 331 | y_d_gs.append(y_d_g) 332 | fmap_gs.append(fmap_g) 333 | 334 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 335 | 336 | 337 | def feature_loss(fmap_r, fmap_g): 338 | loss = 0 339 | for dr, dg in zip(fmap_r, fmap_g): 340 | for rl, gl in zip(dr, dg): 341 | loss += torch.mean(torch.abs(rl - gl)) 342 | 343 | return loss * 2 344 | 345 | 346 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 347 | loss = 0 348 | r_losses = [] 349 | g_losses = [] 350 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 351 | r_loss = torch.mean((1 - dr) ** 2) 352 | g_loss = torch.mean(dg**2) 353 | loss += r_loss + g_loss 354 | r_losses.append(r_loss.item()) 355 | g_losses.append(g_loss.item()) 356 | 357 | return loss, r_losses, g_losses 358 | 359 | 360 | def generator_loss(disc_outputs): 361 | loss = 0 362 | gen_losses = [] 363 | for dg in disc_outputs: 364 | l = torch.mean((1 - dg) ** 2) 365 | gen_losses.append(l) 366 | loss += l 367 | 368 | return loss, gen_losses 369 | -------------------------------------------------------------------------------- /matcha/models/components/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.attention import ( 6 | GEGLU, 7 | GELU, 8 | AdaLayerNorm, 9 | AdaLayerNormZero, 10 | ApproximateGELU, 11 | ) 12 | from diffusers.models.attention_processor import Attention 13 | from diffusers.models.lora import LoRACompatibleLinear 14 | from diffusers.utils.torch_utils import maybe_allow_in_graph 15 | 16 | 17 | class SnakeBeta(nn.Module): 18 | """ 19 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 20 | Shape: 21 | - Input: (B, C, T) 22 | - Output: (B, C, T), same shape as the input 23 | Parameters: 24 | - alpha - trainable parameter that controls frequency 25 | - beta - trainable parameter that controls magnitude 26 | References: 27 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 28 | https://arxiv.org/abs/2006.08195 29 | Examples: 30 | >>> a1 = snakebeta(256) 31 | >>> x = torch.randn(256) 32 | >>> x = a1(x) 33 | """ 34 | 35 | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): 36 | """ 37 | Initialization. 38 | INPUT: 39 | - in_features: shape of the input 40 | - alpha - trainable parameter that controls frequency 41 | - beta - trainable parameter that controls magnitude 42 | alpha is initialized to 1 by default, higher values = higher-frequency. 43 | beta is initialized to 1 by default, higher values = higher-magnitude. 44 | alpha will be trained along with the rest of your model. 45 | """ 46 | super().__init__() 47 | self.in_features = out_features if isinstance(out_features, list) else [out_features] 48 | self.proj = LoRACompatibleLinear(in_features, out_features) 49 | 50 | # initialize alpha 51 | self.alpha_logscale = alpha_logscale 52 | if self.alpha_logscale: # log scale alphas initialized to zeros 53 | self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) 54 | self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) 55 | else: # linear scale alphas initialized to ones 56 | self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) 57 | self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) 58 | 59 | self.alpha.requires_grad = alpha_trainable 60 | self.beta.requires_grad = alpha_trainable 61 | 62 | self.no_div_by_zero = 0.000000001 63 | 64 | def forward(self, x): 65 | """ 66 | Forward pass of the function. 67 | Applies the function to the input elementwise. 68 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 69 | """ 70 | x = self.proj(x) 71 | if self.alpha_logscale: 72 | alpha = torch.exp(self.alpha) 73 | beta = torch.exp(self.beta) 74 | else: 75 | alpha = self.alpha 76 | beta = self.beta 77 | 78 | x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) 79 | 80 | return x 81 | 82 | 83 | class FeedForward(nn.Module): 84 | r""" 85 | A feed-forward layer. 86 | 87 | Parameters: 88 | dim (`int`): The number of channels in the input. 89 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 90 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 91 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 92 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 93 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | dim: int, 99 | dim_out: Optional[int] = None, 100 | mult: int = 4, 101 | dropout: float = 0.0, 102 | activation_fn: str = "geglu", 103 | final_dropout: bool = False, 104 | ): 105 | super().__init__() 106 | inner_dim = int(dim * mult) 107 | dim_out = dim_out if dim_out is not None else dim 108 | 109 | if activation_fn == "gelu": 110 | act_fn = GELU(dim, inner_dim) 111 | if activation_fn == "gelu-approximate": 112 | act_fn = GELU(dim, inner_dim, approximate="tanh") 113 | elif activation_fn == "geglu": 114 | act_fn = GEGLU(dim, inner_dim) 115 | elif activation_fn == "geglu-approximate": 116 | act_fn = ApproximateGELU(dim, inner_dim) 117 | elif activation_fn == "snakebeta": 118 | act_fn = SnakeBeta(dim, inner_dim) 119 | 120 | self.net = nn.ModuleList([]) 121 | # project in 122 | self.net.append(act_fn) 123 | # project dropout 124 | self.net.append(nn.Dropout(dropout)) 125 | # project out 126 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 127 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 128 | if final_dropout: 129 | self.net.append(nn.Dropout(dropout)) 130 | 131 | def forward(self, hidden_states): 132 | for module in self.net: 133 | hidden_states = module(hidden_states) 134 | return hidden_states 135 | 136 | 137 | @maybe_allow_in_graph 138 | class BasicTransformerBlock(nn.Module): 139 | r""" 140 | A basic Transformer block. 141 | 142 | Parameters: 143 | dim (`int`): The number of channels in the input and output. 144 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 145 | attention_head_dim (`int`): The number of channels in each head. 146 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 147 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 148 | only_cross_attention (`bool`, *optional*): 149 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 150 | double_self_attention (`bool`, *optional*): 151 | Whether to use two self-attention layers. In this case no cross attention layers are used. 152 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 153 | num_embeds_ada_norm (: 154 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 155 | attention_bias (: 156 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 157 | """ 158 | 159 | def __init__( 160 | self, 161 | dim: int, 162 | num_attention_heads: int, 163 | attention_head_dim: int, 164 | dropout=0.0, 165 | cross_attention_dim: Optional[int] = None, 166 | activation_fn: str = "geglu", 167 | num_embeds_ada_norm: Optional[int] = None, 168 | attention_bias: bool = False, 169 | only_cross_attention: bool = False, 170 | double_self_attention: bool = False, 171 | upcast_attention: bool = False, 172 | norm_elementwise_affine: bool = True, 173 | norm_type: str = "layer_norm", 174 | final_dropout: bool = False, 175 | ): 176 | super().__init__() 177 | self.only_cross_attention = only_cross_attention 178 | 179 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 180 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 181 | 182 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 183 | raise ValueError( 184 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 185 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 186 | ) 187 | 188 | # Define 3 blocks. Each block has its own normalization layer. 189 | # 1. Self-Attn 190 | if self.use_ada_layer_norm: 191 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 192 | elif self.use_ada_layer_norm_zero: 193 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 194 | else: 195 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 196 | self.attn1 = Attention( 197 | query_dim=dim, 198 | heads=num_attention_heads, 199 | dim_head=attention_head_dim, 200 | dropout=dropout, 201 | bias=attention_bias, 202 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 203 | upcast_attention=upcast_attention, 204 | ) 205 | 206 | # 2. Cross-Attn 207 | if cross_attention_dim is not None or double_self_attention: 208 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 209 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 210 | # the second cross attention block. 211 | self.norm2 = ( 212 | AdaLayerNorm(dim, num_embeds_ada_norm) 213 | if self.use_ada_layer_norm 214 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 215 | ) 216 | self.attn2 = Attention( 217 | query_dim=dim, 218 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 219 | heads=num_attention_heads, 220 | dim_head=attention_head_dim, 221 | dropout=dropout, 222 | bias=attention_bias, 223 | upcast_attention=upcast_attention, 224 | # scale_qk=False, # uncomment this to not to use flash attention 225 | ) # is self-attn if encoder_hidden_states is none 226 | else: 227 | self.norm2 = None 228 | self.attn2 = None 229 | 230 | # 3. Feed-forward 231 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 232 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 233 | 234 | # let chunk size default to None 235 | self._chunk_size = None 236 | self._chunk_dim = 0 237 | 238 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 239 | # Sets chunk feed-forward 240 | self._chunk_size = chunk_size 241 | self._chunk_dim = dim 242 | 243 | def forward( 244 | self, 245 | hidden_states: torch.FloatTensor, 246 | attention_mask: Optional[torch.FloatTensor] = None, 247 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 248 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 249 | timestep: Optional[torch.LongTensor] = None, 250 | cross_attention_kwargs: Dict[str, Any] = None, 251 | class_labels: Optional[torch.LongTensor] = None, 252 | ): 253 | # Notice that normalization is always applied before the real computation in the following blocks. 254 | # 1. Self-Attention 255 | if self.use_ada_layer_norm: 256 | norm_hidden_states = self.norm1(hidden_states, timestep) 257 | elif self.use_ada_layer_norm_zero: 258 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 259 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 260 | ) 261 | else: 262 | norm_hidden_states = self.norm1(hidden_states) 263 | 264 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 265 | 266 | attn_output = self.attn1( 267 | norm_hidden_states, 268 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 269 | attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, 270 | **cross_attention_kwargs, 271 | ) 272 | if self.use_ada_layer_norm_zero: 273 | attn_output = gate_msa.unsqueeze(1) * attn_output 274 | hidden_states = attn_output + hidden_states 275 | 276 | # 2. Cross-Attention 277 | if self.attn2 is not None: 278 | norm_hidden_states = ( 279 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 280 | ) 281 | 282 | attn_output = self.attn2( 283 | norm_hidden_states, 284 | encoder_hidden_states=encoder_hidden_states, 285 | attention_mask=encoder_attention_mask, 286 | **cross_attention_kwargs, 287 | ) 288 | hidden_states = attn_output + hidden_states 289 | 290 | # 3. Feed-forward 291 | norm_hidden_states = self.norm3(hidden_states) 292 | 293 | if self.use_ada_layer_norm_zero: 294 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 295 | 296 | if self._chunk_size is not None: 297 | # "feed_forward_chunk_size" can be used to save memory 298 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 299 | raise ValueError( 300 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 301 | ) 302 | 303 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 304 | ff_output = torch.cat( 305 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], 306 | dim=self._chunk_dim, 307 | ) 308 | else: 309 | ff_output = self.ff(norm_hidden_states) 310 | 311 | if self.use_ada_layer_norm_zero: 312 | ff_output = gate_mlp.unsqueeze(1) * ff_output 313 | 314 | hidden_states = ff_output + hidden_states 315 | 316 | return hidden_states 317 | --------------------------------------------------------------------------------