├── data ├── .gitkeep └── LJSpeech-1.1 │ ├── energy │ └── .gitkeep │ ├── mel │ └── .gitkeep │ ├── pitch │ └── .gitkeep │ ├── wavs │ └── .gitkeep │ ├── durations │ └── .gitkeep │ ├── stats.json │ └── val.txt ├── fs2 ├── __init__.py ├── VERSION ├── data │ ├── __init__.py │ ├── components │ │ └── __init__.py │ └── text_mel_datamodule.py ├── hifigan │ ├── __init__.py │ ├── env.py │ ├── config.py │ ├── LICENSE │ ├── xutils.py │ ├── denoiser.py │ ├── README.md │ ├── meldataset.py │ └── models.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── postnet.py │ │ ├── flow_matching.py │ │ ├── variance_adaptor.py │ │ └── transformer.py │ ├── fastspeech2.py │ └── baselightningmodule.py ├── onnx │ ├── __init__.py │ ├── export.py │ └── infer.py ├── utils │ ├── monotonic_align │ │ ├── setup.py │ │ ├── __init__.py │ │ └── core.pyx │ ├── __init__.py │ ├── pylogger.py │ ├── logging_utils.py │ ├── instantiators.py │ ├── audio.py │ ├── rich_utils.py │ ├── model.py │ ├── preprocess.py │ └── utils.py ├── text │ ├── symbols.py │ ├── __init__.py │ ├── numbers.py │ └── cleaners.py └── train.py ├── notebooks └── .gitkeep ├── configs ├── local │ └── .gitkeep ├── callbacks │ ├── none.yaml │ ├── default.yaml │ ├── rich_progress_bar.yaml │ ├── model_summary.yaml │ └── model_checkpoint.yaml ├── trainer │ ├── cpu.yaml │ ├── gpu.yaml │ ├── mps.yaml │ ├── ddp_sim.yaml │ ├── ddp.yaml │ └── default.yaml ├── model │ ├── postnet │ │ └── default.yaml │ ├── optimizer │ │ └── adam.yaml │ ├── variance_adaptor │ │ └── default.yaml │ ├── encoder │ │ └── default.yaml │ ├── decoder │ │ └── default.yaml │ └── fastspeech2.yaml ├── __init__.py ├── debug │ ├── fdr.yaml │ ├── overfit.yaml │ ├── limit.yaml │ ├── profiler.yaml │ └── default.yaml ├── logger │ ├── many_loggers.yaml │ ├── csv.yaml │ ├── tensorboard.yaml │ ├── neptune.yaml │ ├── mlflow.yaml │ ├── comet.yaml │ ├── wandb.yaml │ └── aim.yaml ├── extras │ └── default.yaml ├── eval.yaml ├── experiment │ ├── multispeaker.yaml │ ├── ljspeech_min_memory.yaml │ ├── fs2_ryan_det.yaml │ ├── ljspeech.yaml │ └── hifi_dataset_piper_phonemizer.yaml ├── data │ ├── vctk.yaml │ ├── hi-fi_en-US_female.yaml │ ├── ryan.yaml │ └── ljspeech.yaml ├── hydra │ └── default.yaml ├── paths │ └── default.yaml ├── train.yaml └── hparams_search │ └── mnist_optuna.yaml ├── .project-root ├── scripts ├── preprocess_datasets.sh └── schedule.sh ├── .env.example ├── MANIFEST.in ├── .github ├── codecov.yml ├── dependabot.yml ├── PULL_REQUEST_TEMPLATE.md └── release-drafter.yml ├── requirements.txt ├── LICENSE ├── pyproject.toml ├── Makefile ├── setup.py ├── .pre-commit-config.yaml ├── .gitignore └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /fs2/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/energy/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/mel/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/pitch/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/wavs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/data/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fs2/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/durations/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /configs/model/postnet/default.yaml: -------------------------------------------------------------------------------- 1 | n_channels: 512 2 | kernel_size: 5 3 | n_layers: 5 4 | dropout: 0.5 -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /scripts/preprocess_datasets.sh: -------------------------------------------------------------------------------- 1 | python fs2/utils/preprocess.py -i ljspeech 2 | python fs2/utils/preprocess.py -i ryan 3 | python fs2/utils/preprocess.py -i tsg2 -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python fs2/train.py 6 | python fs2/train.py -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/model/variance_adaptor/default.yaml: -------------------------------------------------------------------------------- 1 | duration_prediction_type: det 2 | d_model: ${model.encoder.d_model} 3 | n_bins: 256 4 | hidden_dim: 256 5 | kernel_size: 3 6 | dropout: 0.5 7 | 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: FFTransformer 2 | n_layer: 4 3 | n_head: 2 4 | d_model: 256 5 | d_head: 128 6 | d_inner: 1024 7 | kernel_size: [9, 1] 8 | dropout: 0.1 9 | dropatt: 0.1 10 | dropemb: 0.0 -------------------------------------------------------------------------------- /configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | decoder_type: FFTransformer 2 | n_layer: 4 3 | n_head: 2 4 | d_model: 256 5 | d_head: 128 6 | d_inner: 1024 7 | kernel_size: [9, 9] 8 | dropout: 0.1 9 | dropatt: 0.1 10 | dropemb: 0.0 11 | -------------------------------------------------------------------------------- /fs2/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /fs2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from fs2.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from fs2.utils.logging_utils import log_hyperparameters 3 | from fs2.utils.pylogger import get_pylogger 4 | from fs2.utils.rich_utils import enforce_tags, print_config_tree 5 | from fs2.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "pitch_min": 62.217014, 3 | "pitch_max": 792.962036, 4 | "pitch_mean": 211.046158, 5 | "pitch_std": 53.012085, 6 | "energy_min": 0.023226, 7 | "energy_max": 241.037918, 8 | "energy_mean": 21.821531, 9 | "energy_std": 18.17124, 10 | "mel_mean": -5.517035, 11 | "mel_std": 2.064413 12 | } -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include *.cff 5 | include requirements.txt 6 | include matcha/VERSION 7 | recursive-include matcha *.json 8 | recursive-include matcha *.html 9 | recursive-include matcha *.png 10 | recursive-include matcha *.md 11 | recursive-include matcha *.py 12 | recursive-include matcha *.pyx 13 | recursive-exclude tests * 14 | prune tests* 15 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/model/fastspeech2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - variance_adaptor: default.yaml 5 | - decoder: default.yaml 6 | - postnet: default.yaml 7 | - optimizer: adam.yaml 8 | 9 | _target_: fs2.models.fastspeech2.FastSpeech2 10 | n_vocab: 178 11 | n_spks: ${data.n_spks} 12 | spk_emb_dim: ${model.encoder.d_model} 13 | n_feats: 80 14 | data_statistics: ${data.data_statistics} 15 | add_postnet: true 16 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # measures overall project coverage 4 | project: 5 | default: 6 | threshold: 100% # how much decrease in coverage is needed to not consider success 7 | 8 | # measures PR or single commit coverage 9 | patch: 10 | default: 11 | threshold: 100% # how much decrease in coverage is needed to not consider success 12 | 13 | 14 | # project: off 15 | # patch: off 16 | -------------------------------------------------------------------------------- /configs/data/vctk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 6 | name: vctk 7 | train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt 8 | valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt 9 | batch_size: 32 10 | add_blank: True 11 | n_spks: 109 12 | data_statistics: # Computed for vctk dataset 13 | mel_mean: -6.630575 14 | mel_std: 2.482914 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /configs/experiment/fs2_ryan_det.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ryan.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ryan"] 13 | 14 | run_name: fs2_ryan_det 15 | 16 | 17 | trainer: 18 | max_epochs: 3001 19 | max_steps: 200000 20 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: LJSpeech 15 | 16 | 17 | trainer: 18 | max_steps: 2000000 19 | check_val_every_n_epoch: 5 20 | -------------------------------------------------------------------------------- /fs2/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /configs/experiment/hifi_dataset_piper_phonemizer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: hi-fi_en-US_female.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] 13 | 14 | run_name: hi-fi_en-US_female_piper_phonemizer 15 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /configs/data/hi-fi_en-US_female.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | # Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ 6 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 7 | name: hi-fi_en-US_female 8 | train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt 9 | valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt 10 | batch_size: 32 11 | cleaners: [english_cleaners_piper] 12 | data_statistics: # Computed for this dataset 13 | mel_mean: -6.38385 14 | mel_std: 2.541796 15 | -------------------------------------------------------------------------------- /configs/data/ryan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | name: ryan 6 | train_filelist_path: data/filelists/ryan_train.csv 7 | valid_filelist_path: data/filelists/ryan_val.csv 8 | data_statistics: 9 | pitch_min: 58.112518 10 | pitch_max: 795.156067 11 | pitch_mean: 148.475662 12 | pitch_std: 42.28849 13 | energy_min: 0.01115 14 | energy_max: 141.908127 15 | energy_mean: 34.910458 16 | energy_std: 24.780809 17 | mel_mean: -4.715792 18 | mel_std: 2.124477 19 | processed_folder_path: data/processed_data/ryan # It should matche the name -------------------------------------------------------------------------------- /fs2/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /fs2/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from fs2.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | target-branch: "dev" 11 | schedule: 12 | interval: "daily" 13 | ignore: 14 | - dependency-name: "pytorch-lightning" 15 | update-types: ["version-update:semver-patch"] 16 | - dependency-name: "torchmetrics" 17 | update-types: ["version-update:semver-patch"] 18 | -------------------------------------------------------------------------------- /fs2/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /configs/data/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | _target_: fs2.data.text_mel_datamodule.TextMelDataModule 2 | name: ljspeech 3 | train_filelist_path: data/LJSpeech-1.1/train.txt 4 | valid_filelist_path: data/LJSpeech-1.1/val.txt 5 | batch_size: 256 6 | num_workers: 20 7 | pin_memory: True 8 | cleaners: [english_cleaners2] 9 | add_blank: True 10 | n_spks: 1 11 | n_fft: 1024 12 | n_feats: 80 13 | sample_rate: 22050 14 | hop_length: 256 15 | win_length: 1024 16 | f_min: 0 17 | f_max: 8000 18 | data_statistics: # Computed for ljspeech dataset 19 | pitch_min: 67.836174 20 | pitch_max: 792.962036 21 | pitch_mean: 211.046158 22 | pitch_std: 53.012085 23 | energy_min: 0.023226 24 | energy_max: 241.037918 25 | energy_mean: 21.821531 26 | energy_std: 18.17124 27 | mel_mean: -5.517035 28 | mel_std: 2.064413 29 | seed: ${seed} 30 | generate_properties: false -------------------------------------------------------------------------------- /fs2/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | 4 | categories: 5 | - title: "🚀 Features" 6 | labels: 7 | - "feature" 8 | - "enhancement" 9 | - title: "🐛 Bug Fixes" 10 | labels: 11 | - "fix" 12 | - "bugfix" 13 | - "bug" 14 | - title: "🧹 Maintenance" 15 | labels: 16 | - "maintenance" 17 | - "dependencies" 18 | - "refactoring" 19 | - "cosmetic" 20 | - "chore" 21 | - title: "📝️ Documentation" 22 | labels: 23 | - "documentation" 24 | - "docs" 25 | 26 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 27 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 28 | 29 | version-resolver: 30 | major: 31 | labels: 32 | - "major" 33 | minor: 34 | labels: 35 | - "minor" 36 | patch: 37 | labels: 38 | - "patch" 39 | default: patch 40 | 41 | template: | 42 | ## Changes 43 | 44 | $CHANGES 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | phonemizer # phonemization of text 26 | tensorboard 27 | librosa 28 | Cython 29 | numpy==1.26.4 30 | einops 31 | inflect 32 | Unidecode 33 | scipy 34 | torchaudio 35 | matplotlib 36 | pandas 37 | notebook 38 | ipywidgets 39 | gradio==3.43.2 40 | gdown 41 | wget 42 | seaborn 43 | pyworld -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shivam Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fs2/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ['py310'] 7 | exclude = ''' 8 | 9 | ( 10 | /( 11 | \.eggs # exclude a few common directories in the 12 | | \.git # root of the project 13 | | \.hg 14 | | \.mypy_cache 15 | | \.tox 16 | | \.venv 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | | foo.py # also separately exclude a file named foo.py in 23 | # the root of the project 24 | ) 25 | ''' 26 | 27 | [tool.pytest.ini_options] 28 | addopts = [ 29 | "--color=yes", 30 | "--durations=0", 31 | "--strict-markers", 32 | "--doctest-modules", 33 | ] 34 | filterwarnings = [ 35 | "ignore::DeprecationWarning", 36 | "ignore::UserWarning", 37 | ] 38 | log_cli = "True" 39 | markers = [ 40 | "slow: slow tests", 41 | ] 42 | minversion = "6.0" 43 | testpaths = "tests/" 44 | 45 | [tool.coverage.report] 46 | exclude_lines = [ 47 | "pragma: nocover", 48 | "raise NotImplementedError", 49 | "raise NotImplementedError()", 50 | "if __name__ == .__main__.:", 51 | ] 52 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | create-package: ## Create wheel and tar gz 17 | rm -rf dist/ 18 | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 19 | python setup.py sdist 20 | python -m twine upload dist/* --verbose --skip-existing 21 | 22 | format: ## Run pre-commit hooks 23 | pre-commit run -a 24 | 25 | sync: ## Merge changes from main branch to your current branch 26 | git pull 27 | git pull origin main 28 | 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | test-full: ## Run all tests 33 | pytest 34 | 35 | train-ljspeech: ## Train the model 36 | python matcha/train.py experiment=ljspeech 37 | 38 | train-ljspeech-min: ## Train the model with minimum memory 39 | python matcha/train.py experiment=ljspeech_min_memory 40 | 41 | start_app: ## Start the app 42 | python matcha/app.py 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | from setuptools import find_packages, setup 6 | 7 | with open("README.md", encoding="utf-8") as readme_file: 8 | README = readme_file.read() 9 | 10 | cwd = os.path.dirname(os.path.abspath(__file__)) 11 | with open(os.path.join(cwd, "fs2", "VERSION")) as fin: 12 | version = fin.read().strip() 13 | 14 | setup( 15 | name="fs2", 16 | version=version, 17 | description="I got pissed at all the other implementations for not working properly, so I made my own.", 18 | long_description=README, 19 | long_description_content_type="text/markdown", 20 | author="Shivam Mehta", 21 | author_email="shivam.mehta25@gmail.com", 22 | url="https://shivammehta25.github.io/Matcha-TTS", 23 | install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], 24 | include_dirs=[numpy.get_include()], 25 | include_package_data=True, 26 | packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), 27 | # use this to customize global commands available in the terminal after installing the package 28 | entry_points={ 29 | "console_scripts": [ 30 | "betterfs2=fs2.cli:cli", 31 | # "betterfs2-tts-app=fs2.app:main", 32 | ] 33 | }, 34 | python_requires=">=3.9.0", 35 | ) 36 | -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /fs2/models/components/postnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Postnet(nn.Module): 7 | def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout): 8 | super(Postnet, self).__init__() 9 | self.convolutions = nn.ModuleList() 10 | assert kernel_size % 2 == 1 11 | for i in range(n_layers): 12 | cur_layers = ( 13 | [ 14 | nn.Conv1d( 15 | in_dim if i == 0 else n_channels, 16 | n_channels if i < n_layers - 1 else in_dim, 17 | kernel_size=kernel_size, 18 | padding=((kernel_size - 1) // 2), 19 | ), 20 | nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim), 21 | ] 22 | + ([nn.Tanh()] if i < n_layers - 1 else []) 23 | + [nn.Dropout(dropout)] 24 | ) 25 | nn.init.xavier_uniform_( 26 | cur_layers[0].weight, 27 | torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"), 28 | ) 29 | self.convolutions.append(nn.Sequential(*cur_layers)) 30 | 31 | def forward(self, x): 32 | x = x.transpose(1, 2) # B x T x C -> B x C x T 33 | for conv in self.convolutions: 34 | x = conv(x) 35 | return x.transpose(1, 2) -------------------------------------------------------------------------------- /fs2/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /fs2/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.5.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | # - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | 19 | # python code formatting 20 | - repo: https://github.com/psf/black 21 | rev: 23.12.1 22 | hooks: 23 | - id: black 24 | args: [--line-length, "120"] 25 | 26 | # python import sorting 27 | - repo: https://github.com/PyCQA/isort 28 | rev: 5.13.2 29 | hooks: 30 | - id: isort 31 | args: ["--profile", "black", "--filter-files"] 32 | 33 | # python upgrading syntax to newer version 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.15.0 36 | hooks: 37 | - id: pyupgrade 38 | args: [--py38-plus] 39 | 40 | # python check (PEP8), programming errors and code complexity 41 | - repo: https://github.com/PyCQA/flake8 42 | rev: 7.0.0 43 | hooks: 44 | - id: flake8 45 | args: 46 | [ 47 | "--max-line-length", "120", 48 | "--extend-ignore", 49 | "E203,E402,E501,F401,F841,RST2,RST301", 50 | "--exclude", 51 | "logs/*,data/*,matcha/hifigan/*", 52 | ] 53 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 54 | 55 | # pylint 56 | - repo: https://github.com/pycqa/pylint 57 | rev: v3.0.3 58 | hooks: 59 | - id: pylint 60 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: fastspeech2 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 1234 52 | -------------------------------------------------------------------------------- /fs2/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from fs2.text import cleaners 3 | from fs2.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /fs2/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from fs2.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /fs2/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from fs2.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /fs2/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /fs2/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | magnitudes = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], magnitudes) 80 | spec = spectral_normalize_torch(spec) 81 | energy = torch.norm(magnitudes, dim=1) 82 | 83 | return spec, energy 84 | -------------------------------------------------------------------------------- /fs2/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | # Cython complied files 157 | matcha/utils/monotonic_align/core.c 158 | 159 | # Ignoring hifigan checkpoint 160 | generator_v1 161 | g_02500000 162 | gradio_cached_examples/ 163 | synth_output/ 164 | -------------------------------------------------------------------------------- /fs2/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from fs2.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /fs2/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | from unidecode import unidecode 19 | 20 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 21 | critical_logger = logging.getLogger("phonemizer") 22 | critical_logger.setLevel(logging.CRITICAL) 23 | 24 | # Intializing the phonemizer globally significantly reduces the speed 25 | # now the phonemizer is not initialising at every call 26 | # Might be less flexible, but it is much-much faster 27 | global_phonemizer = phonemizer.backend.EspeakBackend( 28 | language="en-us", 29 | preserve_punctuation=True, 30 | with_stress=True, 31 | language_switch="remove-flags", 32 | logger=critical_logger, 33 | ) 34 | 35 | 36 | # Regular expression matching whitespace: 37 | _whitespace_re = re.compile(r"\s+") 38 | 39 | # List of (regular expression, replacement) pairs for abbreviations: 40 | _abbreviations = [ 41 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 42 | for x in [ 43 | ("mrs", "misess"), 44 | ("mr", "mister"), 45 | ("dr", "doctor"), 46 | ("st", "saint"), 47 | ("co", "company"), 48 | ("jr", "junior"), 49 | ("maj", "major"), 50 | ("gen", "general"), 51 | ("drs", "doctors"), 52 | ("rev", "reverend"), 53 | ("lt", "lieutenant"), 54 | ("hon", "honorable"), 55 | ("sgt", "sergeant"), 56 | ("capt", "captain"), 57 | ("esq", "esquire"), 58 | ("ltd", "limited"), 59 | ("col", "colonel"), 60 | ("ft", "fort"), 61 | ] 62 | ] 63 | 64 | 65 | def expand_abbreviations(text): 66 | for regex, replacement in _abbreviations: 67 | text = re.sub(regex, replacement, text) 68 | return text 69 | 70 | 71 | def lowercase(text): 72 | return text.lower() 73 | 74 | 75 | def collapse_whitespace(text): 76 | return re.sub(_whitespace_re, " ", text) 77 | 78 | 79 | def convert_to_ascii(text): 80 | return unidecode(text) 81 | 82 | 83 | def basic_cleaners(text): 84 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 85 | text = lowercase(text) 86 | text = collapse_whitespace(text) 87 | return text 88 | 89 | 90 | def transliteration_cleaners(text): 91 | """Pipeline for non-English text that transliterates to ASCII.""" 92 | text = convert_to_ascii(text) 93 | text = lowercase(text) 94 | text = collapse_whitespace(text) 95 | return text 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 104 | phonemes = collapse_whitespace(phonemes) 105 | return phonemes 106 | 107 | 108 | # def english_cleaners_piper(text): 109 | # """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 110 | # text = convert_to_ascii(text) 111 | # text = lowercase(text) 112 | # text = expand_abbreviations(text) 113 | # phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) 114 | # phonemes = collapse_whitespace(phonemes) 115 | # return phonemes 116 | -------------------------------------------------------------------------------- /fs2/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def sequence_mask(length, max_length=None): 9 | if max_length is None: 10 | max_length = length.max() 11 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 12 | return x.unsqueeze(0) < length.unsqueeze(1) 13 | 14 | def convert_pad_shape(pad_shape): 15 | inverted_shape = pad_shape[::-1] 16 | pad_shape = [item for sublist in inverted_shape for item in sublist] 17 | return pad_shape 18 | 19 | 20 | def generate_path(duration, mask): 21 | device = duration.device 22 | 23 | b, t_x, t_y = mask.shape 24 | cum_duration = torch.cumsum(duration, 1) 25 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 26 | 27 | cum_duration_flat = cum_duration.view(b * t_x) 28 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 29 | path = path.view(b, t_x, t_y) 30 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 31 | path = path * mask 32 | return path 33 | 34 | def invert_log_norm(data, mu, std): 35 | # log -> normalise inverse := denormalise -> exp 36 | data = denormalize(data, mu, std) 37 | return torch.exp(data) - 1.0 38 | 39 | def normalize(data, mu, std): 40 | if not isinstance(mu, (float, int)): 41 | if isinstance(mu, list): 42 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 43 | elif isinstance(mu, torch.Tensor): 44 | mu = mu.to(data.device) 45 | elif isinstance(mu, np.ndarray): 46 | mu = torch.from_numpy(mu).to(data.device) 47 | mu = mu.unsqueeze(-1) 48 | 49 | if not isinstance(std, (float, int)): 50 | if isinstance(std, list): 51 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 52 | elif isinstance(std, torch.Tensor): 53 | std = std.to(data.device) 54 | elif isinstance(std, np.ndarray): 55 | std = torch.from_numpy(std).to(data.device) 56 | std = std.unsqueeze(-1) 57 | 58 | return (data - mu) / std 59 | 60 | 61 | def denormalize(data, mu, std): 62 | if not isinstance(mu, float): 63 | if isinstance(mu, list): 64 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 65 | elif isinstance(mu, torch.Tensor): 66 | mu = mu.to(data.device) 67 | elif isinstance(mu, np.ndarray): 68 | mu = torch.from_numpy(mu).to(data.device) 69 | mu = mu.unsqueeze(-1) 70 | 71 | if not isinstance(std, float): 72 | if isinstance(std, list): 73 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 74 | elif isinstance(std, torch.Tensor): 75 | std = std.to(data.device) 76 | elif isinstance(std, np.ndarray): 77 | std = torch.from_numpy(std).to(data.device) 78 | std = std.unsqueeze(-1) 79 | 80 | return data * std + mu 81 | 82 | def expand_lengths(enc_out, durations , pace: float = 1.0): 83 | """If target=None, then predicted durations are applied""" 84 | dtype = enc_out.dtype 85 | reps = durations.float() / pace 86 | reps = (reps + 0.5).long() 87 | dec_lens = reps.sum(dim=1) 88 | 89 | max_len = dec_lens.max() 90 | reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), 91 | dim=1)[:, None, :] 92 | reps_cumsum = reps_cumsum.to(dtype) 93 | 94 | range_ = torch.arange(max_len, device=enc_out.device)[None, :, None] 95 | mult = ((reps_cumsum[:, :, :-1] <= range_) & 96 | (reps_cumsum[:, :, 1:] > range_)) 97 | mult = mult.to(dtype) 98 | enc_rep = torch.matmul(mult, enc_out) 99 | 100 | return enc_rep, dec_lens 101 | 102 | 103 | def expand_lengths_slow(x, durations): 104 | # x: B x T x C, durations: B x T 105 | out_lens = durations.sum(dim=1) 106 | max_len = out_lens.max() 107 | bsz, seq_len, dim = x.size() 108 | out = x.new_zeros((bsz, max_len, dim)) 109 | 110 | for b in range(bsz): 111 | indices = [] 112 | for t in range(seq_len): 113 | indices.extend([t] * durations[b, t].item()) 114 | indices = torch.tensor(indices, dtype=torch.long).to(x.device) 115 | out_len = out_lens[b].item() 116 | out[b, :out_len] = x[b].index_select(0, indices) 117 | 118 | return out, out_lens 119 | -------------------------------------------------------------------------------- /fs2/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from fs2 import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /fs2/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from fs2.models.components.postnet import Decoder 7 | from fs2.utils.pylogger import get_pylogger 8 | 9 | log = get_pylogger(__name__) 10 | 11 | 12 | class BASECFM(torch.nn.Module, ABC): 13 | def __init__( 14 | self, 15 | n_feats, 16 | cfm_params, 17 | n_spks=1, 18 | spk_emb_dim=128, 19 | ): 20 | super().__init__() 21 | self.n_feats = n_feats 22 | self.n_spks = n_spks 23 | self.spk_emb_dim = spk_emb_dim 24 | self.solver = cfm_params.solver 25 | if hasattr(cfm_params, "sigma_min"): 26 | self.sigma_min = cfm_params.sigma_min 27 | else: 28 | self.sigma_min = 1e-4 29 | 30 | self.estimator = None 31 | 32 | @torch.inference_mode() 33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 34 | """Forward diffusion 35 | 36 | Args: 37 | mu (torch.Tensor): output of encoder 38 | shape: (batch_size, n_feats, mel_timesteps) 39 | mask (torch.Tensor): output_mask 40 | shape: (batch_size, 1, mel_timesteps) 41 | n_timesteps (int): number of diffusion steps 42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 43 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 44 | shape: (batch_size, spk_emb_dim) 45 | cond: Not used but kept for future purposes 46 | 47 | Returns: 48 | sample: generated mel-spectrogram 49 | shape: (batch_size, n_feats, mel_timesteps) 50 | """ 51 | z = torch.randn_like(mu) * temperature 52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 54 | 55 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 56 | """ 57 | Fixed euler solver for ODEs. 58 | Args: 59 | x (torch.Tensor): random noise 60 | t_span (torch.Tensor): n_timesteps interpolated 61 | shape: (n_timesteps + 1,) 62 | mu (torch.Tensor): output of encoder 63 | shape: (batch_size, n_feats, mel_timesteps) 64 | mask (torch.Tensor): output_mask 65 | shape: (batch_size, 1, mel_timesteps) 66 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 67 | shape: (batch_size, spk_emb_dim) 68 | cond: Not used but kept for future purposes 69 | """ 70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 71 | 72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 73 | # Or in future might add like a return_all_steps flag 74 | sol = [] 75 | 76 | for step in range(1, len(t_span)): 77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 78 | 79 | x = x + dt * dphi_dt 80 | t = t + dt 81 | sol.append(x) 82 | if step < len(t_span) - 1: 83 | dt = t_span[step + 1] - t 84 | 85 | return sol[-1] 86 | 87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 88 | """Computes diffusion loss 89 | 90 | Args: 91 | x1 (torch.Tensor): Target 92 | shape: (batch_size, n_feats, mel_timesteps) 93 | mask (torch.Tensor): target mask 94 | shape: (batch_size, 1, mel_timesteps) 95 | mu (torch.Tensor): output of encoder 96 | shape: (batch_size, n_feats, mel_timesteps) 97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 98 | shape: (batch_size, spk_emb_dim) 99 | 100 | Returns: 101 | loss: conditional flow matching loss 102 | y: conditional flow 103 | shape: (batch_size, n_feats, mel_timesteps) 104 | """ 105 | b, _, t = mu.shape 106 | 107 | # random timestep 108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 109 | # sample noise p(x_0) 110 | z = torch.randn_like(x1) 111 | 112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 113 | u = x1 - (1 - self.sigma_min) * z 114 | 115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 116 | torch.sum(mask) * u.shape[1] 117 | ) 118 | return loss, y 119 | 120 | 121 | class CFM(BASECFM): 122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 123 | super().__init__( 124 | n_feats=in_channels, 125 | cfm_params=cfm_params, 126 | n_spks=n_spks, 127 | spk_emb_dim=spk_emb_dim, 128 | ) 129 | 130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 131 | # Just change the architecture of the estimator here 132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 133 | -------------------------------------------------------------------------------- /fs2/models/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import math 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | 10 | from fs2 import utils 11 | from fs2.models.baselightningmodule import BaseLightningClass 12 | from fs2.models.components.postnet import Postnet 13 | from fs2.models.components.transformer import FFTransformer 14 | from fs2.models.components.variance_adaptor import VarianceAdaptor 15 | from fs2.utils.model import denormalize, invert_log_norm 16 | 17 | log = utils.get_pylogger(__name__) 18 | 19 | 20 | class FastSpeech2(BaseLightningClass): 21 | def __init__( 22 | self, 23 | n_vocab, 24 | n_spks, 25 | spk_emb_dim, 26 | n_feats, 27 | encoder, 28 | decoder, 29 | variance_adaptor, 30 | postnet, 31 | data_statistics, 32 | add_postnet=True, 33 | optimizer=None, 34 | scheduler=None, 35 | ): 36 | super().__init__() 37 | 38 | self.save_hyperparameters(logger=False) 39 | 40 | self.n_vocab = n_vocab 41 | self.n_spks = n_spks 42 | self.spk_emb_dim = spk_emb_dim 43 | self.n_feats = n_feats 44 | self.update_data_statistics(data_statistics) 45 | 46 | self.encoder = FFTransformer( 47 | n_layer=encoder.n_layer, 48 | n_head=encoder.n_head, 49 | d_model=encoder.d_model, 50 | d_head=encoder.d_head, 51 | d_inner=encoder.d_inner, 52 | kernel_size=encoder.kernel_size, 53 | dropout=encoder.dropout, 54 | dropatt=encoder.dropatt, 55 | dropemb=encoder.dropemb, 56 | embed_input=True, 57 | d_embed=encoder.d_model, 58 | n_embed=n_vocab 59 | ) 60 | 61 | 62 | if n_spks > 1: 63 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 64 | 65 | self.variance_adapter = VarianceAdaptor(variance_adaptor, self.pitch_min, self.pitch_max, self.energy_min, self.energy_max) 66 | 67 | self.decoder = FFTransformer( 68 | n_layer=decoder.n_layer, 69 | n_head=decoder.n_head, 70 | d_model=decoder.d_model, 71 | d_head=decoder.d_head, 72 | d_inner=decoder.d_inner, 73 | kernel_size=decoder.kernel_size, 74 | dropout=decoder.dropout, 75 | dropatt=decoder.dropatt, 76 | dropemb=decoder.dropemb, 77 | embed_input=False, 78 | d_embed=decoder.d_model, 79 | ) 80 | 81 | self.out_proj = nn.Linear(decoder.d_model, n_feats) 82 | 83 | if add_postnet: 84 | self.postnet = Postnet( 85 | self.n_feats, 86 | postnet.n_channels, 87 | postnet.kernel_size, 88 | postnet.n_layers, 89 | postnet.dropout, 90 | ) 91 | else: 92 | self.postnet = None 93 | 94 | 95 | def forward(self, x, x_lengths, y, y_lengths, durations, pitches, energies, spks=None): 96 | 97 | x, x_mask = self.encoder(x, x_lengths) 98 | 99 | if self.n_spks > 1: 100 | spk_emb = self.spk_emb(spks) 101 | x = x + spk_emb.unsqueeze(1) 102 | 103 | # teacher forced durations during training 104 | outputs, losses = self.variance_adapter(x, x_mask, durations, pitches, energies) 105 | 106 | decoder_out, y_mask = self.decoder(outputs['x_upscaled'], y_lengths) 107 | 108 | y_hat = self.out_proj(decoder_out) * y_mask 109 | 110 | if self.postnet is not None: 111 | y_hat_post = y_hat + (self.postnet(y_hat) * y_mask) 112 | 113 | 114 | mel_loss = F.l1_loss(y_hat, rearrange(y, "b c t-> b t c"), reduction="mean") 115 | postnet_mel_loss = F.l1_loss(y_hat_post, rearrange(y, "b c t-> b t c"), reduction="mean") if self.postnet is not None else 0.0 116 | 117 | losses.update({"mel_loss": mel_loss, "postnet_mel_loss": postnet_mel_loss}) 118 | return losses 119 | 120 | 121 | 122 | @torch.inference_mode() 123 | def synthesise(self, x, x_lengths, spks=None, length_scale=1.0, p_factor=1.0, e_factor=1.0, d_factor=1.0): 124 | # For RTF computation 125 | t = dt.datetime.now() 126 | 127 | x, x_mask = self.encoder(x, x_lengths) 128 | 129 | if self.n_spks > 1: 130 | spk_emb = self.spk_emb(spks) 131 | x = x + spk_emb.unsqueeze(1) 132 | 133 | # teacher forced durations during training 134 | var_ada_outputs = self.variance_adapter.synthesise(x, x_mask, d_factor=length_scale, p_factor=p_factor, e_factor=e_factor) 135 | 136 | decoder_out, y_mask = self.decoder(var_ada_outputs['x_upscaled'], var_ada_outputs['out_lens']) 137 | 138 | y_hat = self.out_proj(decoder_out) * y_mask 139 | 140 | if self.postnet is not None: 141 | y_hat_post = y_hat + (self.postnet(y_hat) * y_mask) 142 | 143 | t = (dt.datetime.now() - t).total_seconds() 144 | rtf = t * 22050 / (y_hat_post.shape[1] * 256) 145 | 146 | return { 147 | "mel" : denormalize(y_hat_post, self.mel_mean, self.mel_std).transpose(1, 2), 148 | "decoder_output": denormalize(y_hat, self.mel_mean, self.mel_std).transpose(1, 2), 149 | "dur_pred": var_ada_outputs["dur_pred"], 150 | "pitch_pred": denormalize(var_ada_outputs["log_pitch_pred"], self.pitch_mean, self.pitch_std), 151 | "energy_pred": denormalize(var_ada_outputs["log_energy_pred"], self.energy_mean, self.energy_std), 152 | "rtf": rtf, 153 | } -------------------------------------------------------------------------------- /fs2/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /fs2/onnx/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from lightning import LightningModule 8 | 9 | from fs2.cli import VOCODER_URLS, load_model, load_vocoder 10 | 11 | DEFAULT_OPSET = 15 12 | 13 | SEED = 1234 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | torch.cuda.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class MatchaWithVocoder(LightningModule): 23 | def __init__(self, matcha, vocoder): 24 | super().__init__() 25 | self.matcha = matcha 26 | self.vocoder = vocoder 27 | 28 | def forward(self, x, x_lengths, scales, spks=None): 29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) 30 | wavs = self.vocoder(mel).clamp(-1, 1) 31 | lengths = mel_lengths * 256 32 | return wavs.squeeze(1), lengths 33 | 34 | 35 | def get_exportable_module(matcha, vocoder, n_timesteps): 36 | """ 37 | Return an appropriate `LighteningModule` and output-node names 38 | based on whether the vocoder is embedded in the final graph 39 | """ 40 | 41 | def onnx_forward_func(x, x_lengths, scales, spks=None): 42 | """ 43 | Custom forward function for accepting 44 | scaler parameters as tensors 45 | """ 46 | # Extract scaler parameters from tensors 47 | temperature = scales[0] 48 | length_scale = scales[1] 49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) 50 | return output["mel"], output["mel_lengths"] 51 | 52 | # Monkey-patch Matcha's forward function 53 | matcha.forward = onnx_forward_func 54 | 55 | if vocoder is None: 56 | model, output_names = matcha, ["mel", "mel_lengths"] 57 | else: 58 | model = MatchaWithVocoder(matcha, vocoder) 59 | output_names = ["wav", "wav_lengths"] 60 | return model, output_names 61 | 62 | 63 | def get_inputs(is_multi_speaker): 64 | """ 65 | Create dummy inputs for tracing 66 | """ 67 | dummy_input_length = 50 68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 69 | x_lengths = torch.LongTensor([dummy_input_length]) 70 | 71 | # Scales 72 | temperature = 0.667 73 | length_scale = 1.0 74 | scales = torch.Tensor([temperature, length_scale]) 75 | 76 | model_inputs = [x, x_lengths, scales] 77 | input_names = [ 78 | "x", 79 | "x_lengths", 80 | "scales", 81 | ] 82 | 83 | if is_multi_speaker: 84 | spks = torch.LongTensor([1]) 85 | model_inputs.append(spks) 86 | input_names.append("spks") 87 | 88 | return tuple(model_inputs), input_names 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") 93 | 94 | parser.add_argument( 95 | "checkpoint_path", 96 | type=str, 97 | help="Path to the model checkpoint", 98 | ) 99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file") 100 | parser.add_argument( 101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 102 | ) 103 | parser.add_argument( 104 | "--vocoder-name", 105 | type=str, 106 | choices=list(VOCODER_URLS.keys()), 107 | default=None, 108 | help="Name of the vocoder to embed in the ONNX graph", 109 | ) 110 | parser.add_argument( 111 | "--vocoder-checkpoint-path", 112 | type=str, 113 | default=None, 114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 115 | ) 116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 117 | 118 | args = parser.parse_args() 119 | 120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") 121 | print(f"Setting n_timesteps to {args.n_timesteps}") 122 | 123 | checkpoint_path = Path(args.checkpoint_path) 124 | matcha = load_model(checkpoint_path.stem, checkpoint_path, "cpu") 125 | 126 | if args.vocoder_name or args.vocoder_checkpoint_path: 127 | assert ( 128 | args.vocoder_name and args.vocoder_checkpoint_path 129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 131 | else: 132 | vocoder = None 133 | 134 | is_multi_speaker = matcha.n_spks > 1 135 | 136 | dummy_input, input_names = get_inputs(is_multi_speaker) 137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) 138 | 139 | # Set dynamic shape for inputs/outputs 140 | dynamic_axes = { 141 | "x": {0: "batch_size", 1: "time"}, 142 | "x_lengths": {0: "batch_size"}, 143 | } 144 | 145 | if vocoder is None: 146 | dynamic_axes.update( 147 | { 148 | "mel": {0: "batch_size", 2: "time"}, 149 | "mel_lengths": {0: "batch_size"}, 150 | } 151 | ) 152 | else: 153 | print("Embedding the vocoder in the ONNX graph") 154 | dynamic_axes.update( 155 | { 156 | "wav": {0: "batch_size", 1: "time"}, 157 | "wav_lengths": {0: "batch_size"}, 158 | } 159 | ) 160 | 161 | if is_multi_speaker: 162 | dynamic_axes["spks"] = {0: "batch_size"} 163 | 164 | # Create the output directory (if not exists) 165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 166 | 167 | model.to_onnx( 168 | args.output, 169 | dummy_input, 170 | input_names=input_names, 171 | output_names=output_names, 172 | dynamic_axes=dynamic_axes, 173 | opset_version=args.opset, 174 | export_params=True, 175 | do_constant_folding=True, 176 | ) 177 | print(f"[🍵] ONNX model exported to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # BetterFastSpeech 2 4 | 5 | 6 | [![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/) 7 | [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 8 | [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) 9 | [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) 10 | [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) 11 | [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 12 | 13 | 14 |
15 | 16 | It is the ordinary FastSpeech 2 architecture with some modifications. I just wanted to make the code base better and more readable. And finally have an open source implementation of [FastSpeech 2](https://arxiv.org/abs/2006.04558) that doesn't sounds bad and is easier to hack and work with. 17 | 18 | If you like this you will love [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) 19 | 20 | Changes from the original architecture: 21 | - Instead of using MFA, I obtained alignment from a pretrained Matcha-TTS model. 22 | - To save myself from the pain of setting up and training MFA 23 | - Used IPA phonemes with blanks in between phones. 24 | - No LR decay 25 | - Duration prediction in log domain 26 | - Everyone seems to be using the postnet from Tacotron 2; I've used it as well. 27 | 28 | 29 | [Link to LJ Speech checkpoint](https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/betterfs2_ljspeech.ckpt) 30 | **Running the code locally with cli will autodownload the checkpoint as well.** 31 | 32 | ## Installation 33 | 34 | 1. Create an environment (suggested but optional) 35 | 36 | ``` 37 | conda create -n betterfs2 python=3.10 -y 38 | conda activate betterfs2 39 | ``` 40 | 41 | 2. Install from source 42 | 43 | ```bash 44 | git clone https://github.com/shivammehta25/BetterFastSpeech2.git 45 | cd BetterFastSpeech2 46 | pip install -e . 47 | ``` 48 | 49 | 3. Run CLI / gradio app / jupyter notebook 50 | 51 | ```bash 52 | # This will download the required models 53 | betterfs2 --text "" 54 | ``` 55 | 56 | or open `synthesis.ipynb` on jupyter notebook 57 | 58 | ## Train with your own dataset 59 | Let's assume we are training with LJ Speech 60 | 61 | 1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup). 62 | 63 | 64 | 2. [Train a Matcha-TTS model to extract durations or if you have a pretrained model, you can use that as well.](https://github.com/shivammehta25/Matcha-TTS/wiki/Improve-GPU-utilisation-by-extracting-phoneme-alignments) 65 | 66 | Your data directory should look like: 67 | ```bash 68 | data/ 69 | └── LJSpeech-1.1 70 | ├── durations/ # Here 71 | ├── metadata.csv 72 | ├── README 73 | ├── test.txt 74 | ├── train.txt 75 | ├── val.txt 76 | └── wavs/ 77 | ``` 78 | 79 | 3. Clone and enter the BetterFastSpeech2 repository 80 | 81 | ```bash 82 | git clone https://github.com/shivammehta25/BetterFastSpeech2.git 83 | cd BetterFastSpeech2 84 | ``` 85 | 86 | 4. Install the package from source 87 | 88 | ```bash 89 | pip install -e . 90 | ``` 91 | 92 | 5. Go to `configs/data/ljspeech.yaml` and change 93 | 94 | ```yaml 95 | train_filelist_path: data/LJSpeech-1.1/train.txt 96 | valid_filelist_path: data/LJSpeech-1.1/val.txt 97 | ``` 98 | 99 | 5. Generate normalisation statistics with the yaml file of dataset configuration 100 | 101 | ```bash 102 | python fs2/utils/preprocess.py -i ljspeech 103 | # Output: 104 | #{'pitch_min': 67.836174, 'pitch_max': 578.637146, 'pitch_mean': 207.001846, 'pitch_std': 52.747742, 'energy_min': 0.084354, 'energy_max': 190.849121, 'energy_mean': 21.330254, 'energy_std': 17.663319, 'mel_mean': -5.554245, 'mel_std': 2.059021} 105 | ``` 106 | 107 | Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key. 108 | 109 | ```bash 110 | data_statistics: # Computed for ljspeech dataset 111 | pitch_min: 67.836174 112 | pitch_max: 792.962036 113 | pitch_mean: 211.046158 114 | pitch_std: 53.012085 115 | energy_min: 0.023226 116 | energy_max: 241.037918 117 | energy_mean: 21.821531 118 | energy_std: 18.17124 119 | mel_mean: -5.517035 120 | mel_std: 2.064413 121 | ``` 122 | 123 | to the paths of your train and validation filelists. 124 | 125 | 6. Run the training script 126 | 127 | ```bash 128 | python fs2/train.py experiment=ljspeech 129 | ``` 130 | 131 | - for multi-gpu training, run 132 | 133 | ```bash 134 | python fs2/train.py experiment=ljspeech trainer.devices=[0,1] 135 | ``` 136 | 137 | 7. Synthesise from the custom trained model 138 | 139 | ```bash 140 | betterfs2 --text "" --checkpoint_path 141 | ``` 142 | 143 | 144 | ## Citation information 145 | 146 | If you use our code or otherwise find this work useful, please cite our paper: 147 | 148 | ```text 149 | @inproceedings{mehta2024matcha, 150 | title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching}, 151 | author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, 152 | booktitle={Proc. ICASSP}, 153 | year={2024} 154 | } 155 | ``` 156 | 157 | ## Acknowledgements 158 | 159 | Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. 160 | 161 | Other source code we would like to acknowledge: 162 | 163 | - [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS): Base TTS from which we get alignments. 164 | - [FastPitch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch): For transformer implementation 165 | - [FastSpeech 2](https://github.com/ming024/FastSpeech2): For variance predictor implementations 166 | -------------------------------------------------------------------------------- /fs2/models/components/variance_adaptor.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import pack, rearrange 8 | 9 | import fs2.utils as utils 10 | from fs2.utils.model import expand_lengths 11 | 12 | log = utils.get_pylogger(__name__) 13 | 14 | 15 | class VariancePredictor(nn.Module): 16 | def __init__(self, args): 17 | super().__init__() 18 | self.conv1 = nn.Sequential( 19 | nn.Conv1d( 20 | args.d_model, 21 | args.hidden_dim, 22 | kernel_size=args.kernel_size, 23 | padding=(args.kernel_size - 1) // 2, 24 | ), 25 | nn.ReLU(), 26 | ) 27 | self.ln1 = nn.LayerNorm(args.hidden_dim) 28 | self.dropout_module = nn.Dropout( 29 | p=args.dropout 30 | ) 31 | self.conv2 = nn.Sequential( 32 | nn.Conv1d( 33 | args.hidden_dim, 34 | args.hidden_dim, 35 | kernel_size=args.kernel_size, 36 | padding=(args.kernel_size - 1) // 2, 37 | ), 38 | nn.ReLU(), 39 | ) 40 | self.ln2 = nn.LayerNorm(args.hidden_dim) 41 | self.proj = nn.Linear(args.hidden_dim, 1) 42 | 43 | def forward(self, x, mask): 44 | # Input: B x T x C; Output: B x T 45 | x = self.conv1((x * mask).transpose(1, 2)).transpose(1, 2) 46 | x = self.dropout_module(self.ln1(x)) 47 | x = self.conv2((x * mask).transpose(1, 2)).transpose(1, 2) 48 | x = self.dropout_module(self.ln2(x)) 49 | return (self.proj(x) * mask).squeeze(dim=2) 50 | 51 | class LayerNorm(nn.Module): 52 | def __init__(self, channels, eps=1e-4): 53 | super().__init__() 54 | self.channels = channels 55 | self.eps = eps 56 | 57 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 58 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 59 | 60 | def forward(self, x): 61 | n_dims = len(x.shape) 62 | mean = torch.mean(x, 1, keepdim=True) 63 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) 64 | 65 | x = (x - mean) * torch.rsqrt(variance + self.eps) 66 | 67 | shape = [1, -1] + [1] * (n_dims - 2) 68 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 69 | return x 70 | 71 | 72 | class VarianceAdaptor(nn.Module): 73 | def __init__(self, args, pitch_min, pitch_max, energy_min, energy_max): 74 | super().__init__() 75 | self.args = args 76 | self.duration_predictor = VariancePredictor(args) 77 | self.pitch_predictor = VariancePredictor(args) 78 | self.energy_predictor = VariancePredictor(args) 79 | 80 | n_bins, steps = self.args.n_bins, self.args.n_bins - 1 81 | self.pitch_bins = nn.Parameter(torch.linspace(pitch_min.item(), pitch_max.item(), steps), requires_grad=False) 82 | self.embed_pitch = nn.Embedding(n_bins, args.d_model) 83 | nn.init.normal_(self.embed_pitch.weight, mean=0, std=args.d_model**-0.5) 84 | self.energy_bins = nn.Parameter(torch.linspace(energy_min.item(), energy_max.item(), steps), requires_grad=False) 85 | self.embed_energy = nn.Embedding(n_bins, args.d_model) 86 | nn.init.normal_(self.embed_energy.weight, mean=0, std=args.d_model**-0.5) 87 | 88 | def get_pitch_emb(self, x, x_mask, tgt=None, factor=1.0): 89 | out = self.pitch_predictor(x, x_mask) 90 | if tgt is None: 91 | out = out * factor 92 | emb = self.embed_pitch(torch.bucketize(out, self.pitch_bins)) 93 | else: 94 | emb = self.embed_pitch(torch.bucketize(tgt, self.pitch_bins)) 95 | return out, emb * x_mask 96 | 97 | def get_energy_emb(self, x, x_mask, tgt=None, factor=1.0): 98 | out = self.energy_predictor(x, x_mask) 99 | if tgt is None: 100 | out = out * factor 101 | emb = self.embed_energy(torch.bucketize(out, self.energy_bins)) 102 | else: 103 | emb = self.embed_energy(torch.bucketize(tgt, self.energy_bins)) 104 | return out, emb * x_mask 105 | 106 | 107 | def forward( 108 | self, 109 | x, 110 | x_mask, 111 | durations, 112 | pitches, 113 | energies, 114 | ): 115 | # x: B x T x C 116 | # Get log durations 117 | logw = torch.log(durations + 1e-8) * x_mask.squeeze(2) 118 | 119 | logw_hat = self.duration_predictor(x, x_mask) 120 | dur_loss = F.mse_loss(logw_hat, logw, reduction="sum") / torch.sum(x_mask) 121 | 122 | 123 | log_pitch_out, pitch_emb = self.get_pitch_emb(x, x_mask, pitches) 124 | x = x + pitch_emb 125 | log_energy_out, energy_emb = self.get_energy_emb(x, x_mask, energies) 126 | x = x + energy_emb 127 | 128 | x, out_lens = expand_lengths(x, durations) 129 | 130 | pitch_loss = F.mse_loss(log_pitch_out, pitches) 131 | energy_loss = F.mse_loss(log_energy_out, energies) 132 | 133 | outputs = { 134 | 'x_upscaled': x, 135 | 'out_lens': out_lens, 136 | } 137 | losses = { 138 | 'dur_loss': dur_loss, 139 | 'pitch_loss': pitch_loss, 140 | 'energy_loss': energy_loss, 141 | } 142 | 143 | return outputs, losses 144 | 145 | @torch.inference_mode() 146 | def synthesise( 147 | self, 148 | x, 149 | x_mask, 150 | d_factor=1.0, 151 | p_factor=1.0, 152 | e_factor=1.0, 153 | ): 154 | # x: B x T x C 155 | # Get log durations 156 | 157 | logw_hat = self.duration_predictor(x, x_mask) 158 | 159 | w = torch.exp(logw_hat) * x_mask.squeeze(2) 160 | w_ceil = torch.ceil(w) * d_factor 161 | dur_out = torch.clamp(w_ceil.long(), min=0) 162 | 163 | log_pitch_out, pitch_emb = self.get_pitch_emb(x, x_mask, factor=p_factor) 164 | log_pitch_out, _ = expand_lengths(log_pitch_out.unsqueeze(2), dur_out) 165 | x = x + pitch_emb 166 | 167 | log_energy_out, energy_emb = self.get_energy_emb(x, x_mask, factor=e_factor) 168 | log_energy_out, _ = expand_lengths(log_energy_out.unsqueeze(2), dur_out) 169 | x = x + energy_emb 170 | 171 | x, out_lens = expand_lengths(x, dur_out) 172 | 173 | return { 174 | 'x_upscaled': x, 175 | 'out_lens': out_lens, 176 | 'dur_pred': dur_out, 177 | 'log_pitch_pred': log_pitch_out, 178 | 'log_energy_pred': log_energy_out, 179 | } -------------------------------------------------------------------------------- /fs2/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | from time import perf_counter 6 | 7 | import numpy as np 8 | import onnxruntime as ort 9 | import soundfile as sf 10 | import torch 11 | 12 | from fs2.cli import plot_spectrogram_to_numpy, process_text 13 | 14 | 15 | def validate_args(args): 16 | assert ( 17 | args.text or args.file 18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." 19 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 21 | return args 22 | 23 | 24 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 25 | if external_vocoder is None: 26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 27 | t0 = perf_counter() 28 | wavs, wav_lengths = model.run(None, inputs) 29 | infer_secs = perf_counter() - t0 30 | mel_infer_secs = vocoder_infer_secs = None 31 | else: 32 | print("[🍵] Generating mel using Matcha") 33 | mel_t0 = perf_counter() 34 | mels, mel_lengths = model.run(None, inputs) 35 | mel_infer_secs = perf_counter() - mel_t0 36 | print("Generating waveform from mel using external vocoder") 37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 38 | vocoder_t0 = perf_counter() 39 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 40 | vocoder_infer_secs = perf_counter() - vocoder_t0 41 | wavs = wavs.squeeze(1) 42 | wav_lengths = mel_lengths * 256 43 | infer_secs = mel_infer_secs + vocoder_infer_secs 44 | 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(parents=True, exist_ok=True) 47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 49 | audio = wav[:wav_length] 50 | print(f"Writing audio to {output_filename}") 51 | sf.write(output_filename, audio, 22050, "PCM_24") 52 | 53 | wav_secs = wav_lengths.sum() / 22050 54 | print(f"Inference seconds: {infer_secs}") 55 | print(f"Generated wav seconds: {wav_secs}") 56 | rtf = infer_secs / wav_secs 57 | if mel_infer_secs is not None: 58 | mel_rtf = mel_infer_secs / wav_secs 59 | print(f"Matcha RTF: {mel_rtf}") 60 | if vocoder_infer_secs is not None: 61 | vocoder_rtf = vocoder_infer_secs / wav_secs 62 | print(f"Vocoder RTF: {vocoder_rtf}") 63 | print(f"Overall RTF: {rtf}") 64 | 65 | 66 | def write_mels(model, inputs, output_dir): 67 | t0 = perf_counter() 68 | mels, mel_lengths = model.run(None, inputs) 69 | infer_secs = perf_counter() - t0 70 | 71 | output_dir = Path(output_dir) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | for i, mel in enumerate(mels): 74 | output_stem = output_dir.joinpath(f"output_{i + 1}") 75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 76 | np.save(output_stem.with_suffix(".numpy"), mel) 77 | 78 | wav_secs = (mel_lengths * 256).sum() / 22050 79 | print(f"Inference seconds: {infer_secs}") 80 | print(f"Generated wav seconds: {wav_secs}") 81 | rtf = infer_secs / wav_secs 82 | print(f"RTF: {rtf}") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser( 87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" 88 | ) 89 | parser.add_argument( 90 | "model", 91 | type=str, 92 | help="ONNX model to use", 93 | ) 94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID") 98 | parser.add_argument( 99 | "--temperature", 100 | type=float, 101 | default=0.667, 102 | help="Variance of the x0 noise (default: 0.667)", 103 | ) 104 | parser.add_argument( 105 | "--speaking-rate", 106 | type=float, 107 | default=1.0, 108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 109 | ) 110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 111 | parser.add_argument( 112 | "--output-dir", 113 | type=str, 114 | default=os.getcwd(), 115 | help="Output folder to save results (default: current dir)", 116 | ) 117 | 118 | args = parser.parse_args() 119 | args = validate_args(args) 120 | 121 | if args.gpu: 122 | providers = ["GPUExecutionProvider"] 123 | else: 124 | providers = ["CPUExecutionProvider"] 125 | model = ort.InferenceSession(args.model, providers=providers) 126 | 127 | model_inputs = model.get_inputs() 128 | model_outputs = list(model.get_outputs()) 129 | 130 | if args.text: 131 | text_lines = args.text.splitlines() 132 | else: 133 | with open(args.file, encoding="utf-8") as file: 134 | text_lines = file.read().splitlines() 135 | 136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 137 | x = [line["x"].squeeze() for line in processed_lines] 138 | # Pad 139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 140 | x = x.detach().cpu().numpy() 141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 142 | inputs = { 143 | "x": x, 144 | "x_lengths": x_lengths, 145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 146 | } 147 | is_multi_speaker = len(model_inputs) == 4 148 | if is_multi_speaker: 149 | if args.spk is None: 150 | args.spk = 0 151 | warn = "[!] Speaker ID not provided! Using speaker ID 0" 152 | warnings.warn(warn, UserWarning) 153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) 154 | 155 | has_vocoder_embedded = model_outputs[0].name == "wav" 156 | if has_vocoder_embedded: 157 | write_wavs(model, inputs, args.output_dir) 158 | elif args.vocoder: 159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 161 | else: 162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 163 | warnings.warn(warn, UserWarning) 164 | write_mels(model, inputs, args.output_dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /fs2/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /fs2/models/components/transformer.py: -------------------------------------------------------------------------------- 1 | # Taken from FastPitch 2 | # https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from fs2.utils.model import sequence_mask 8 | 9 | 10 | class PositionalEmbedding(nn.Module): 11 | def __init__(self, demb): 12 | super(PositionalEmbedding, self).__init__() 13 | self.demb = demb 14 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, pos_seq, bsz=None): 18 | sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), 19 | torch.unsqueeze(self.inv_freq, 0)) 20 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) 21 | if bsz is not None: 22 | return pos_emb[None, :, :].expand(bsz, -1, -1) 23 | else: 24 | return pos_emb[None, :, :] 25 | 26 | 27 | class PositionwiseConvFF(nn.Module): 28 | def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False): 29 | super(PositionwiseConvFF, self).__init__() 30 | 31 | self.d_model = d_model 32 | self.d_inner = d_inner 33 | self.dropout = dropout 34 | 35 | self.CoreNet = nn.Sequential( 36 | nn.Conv1d(d_model, d_inner, kernel_size[0], padding=(kernel_size[0] - 1) // 2 ), 37 | nn.ReLU(), 38 | # nn.Dropout(dropout), # worse convergence 39 | nn.Conv1d(d_inner, d_model, kernel_size[1], padding=(kernel_size[1] - 1) // 2), 40 | nn.Dropout(dropout), 41 | ) 42 | self.layer_norm = nn.LayerNorm(d_model) 43 | self.pre_lnorm = pre_lnorm 44 | 45 | def forward(self, inp): 46 | return self._forward(inp) 47 | 48 | def _forward(self, inp): 49 | if self.pre_lnorm: 50 | # layer normalization + positionwise feed-forward 51 | core_out = inp.transpose(1, 2) 52 | core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype)) 53 | core_out = core_out.transpose(1, 2) 54 | 55 | # residual connection 56 | output = core_out + inp 57 | else: 58 | # positionwise feed-forward 59 | core_out = inp.transpose(1, 2) 60 | core_out = self.CoreNet(core_out) 61 | core_out = core_out.transpose(1, 2) 62 | 63 | # residual connection + layer normalization 64 | output = self.layer_norm(inp + core_out).to(inp.dtype) 65 | 66 | return output 67 | 68 | 69 | class MultiHeadAttn(nn.Module): 70 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, 71 | pre_lnorm=False): 72 | super(MultiHeadAttn, self).__init__() 73 | 74 | self.n_head = n_head 75 | self.d_model = d_model 76 | self.d_head = d_head 77 | self.scale = 1 / (d_head ** 0.5) 78 | self.pre_lnorm = pre_lnorm 79 | 80 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head) 81 | self.drop = nn.Dropout(dropout) 82 | self.dropatt = nn.Dropout(dropatt) 83 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 84 | self.layer_norm = nn.LayerNorm(d_model) 85 | 86 | def forward(self, inp, attn_mask=None): 87 | return self._forward(inp, attn_mask) 88 | 89 | def _forward(self, inp, attn_mask=None): 90 | residual = inp 91 | 92 | if self.pre_lnorm: 93 | # layer normalization 94 | inp = self.layer_norm(inp) 95 | 96 | n_head, d_head = self.n_head, self.d_head 97 | 98 | head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) 99 | head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head) 100 | head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head) 101 | head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head) 102 | 103 | q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) 104 | k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) 105 | v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) 106 | 107 | attn_score = torch.bmm(q, k.transpose(1, 2)) 108 | attn_score.mul_(self.scale) 109 | 110 | if attn_mask is not None: 111 | attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype) 112 | attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1) 113 | attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf')) 114 | 115 | attn_prob = F.softmax(attn_score, dim=2) 116 | attn_prob = self.dropatt(attn_prob) 117 | attn_vec = torch.bmm(attn_prob, v) 118 | 119 | attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head) 120 | attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view( 121 | inp.size(0), inp.size(1), n_head * d_head) 122 | 123 | # linear projection 124 | attn_out = self.o_net(attn_vec) 125 | attn_out = self.drop(attn_out) 126 | 127 | if self.pre_lnorm: 128 | # residual connection 129 | output = residual + attn_out 130 | else: 131 | # residual connection + layer normalization 132 | output = self.layer_norm(residual + attn_out) 133 | 134 | output = output.to(attn_out.dtype) 135 | 136 | return output 137 | 138 | 139 | class TransformerLayer(nn.Module): 140 | def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout, 141 | **kwargs): 142 | super(TransformerLayer, self).__init__() 143 | 144 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) 145 | self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout, 146 | pre_lnorm=kwargs.get('pre_lnorm')) 147 | 148 | def forward(self, dec_inp, mask=None): 149 | output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2)) 150 | output *= mask 151 | output = self.pos_ff(output) 152 | output *= mask 153 | return output 154 | 155 | 156 | class FFTransformer(nn.Module): 157 | def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size, 158 | dropout, dropatt, dropemb=0.0, embed_input=True, 159 | n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False): 160 | super(FFTransformer, self).__init__() 161 | self.d_model = d_model 162 | self.n_head = n_head 163 | self.d_head = d_head 164 | self.padding_idx = padding_idx 165 | 166 | if embed_input: 167 | self.word_emb = nn.Embedding(n_embed, d_embed or d_model) 168 | torch.nn.init.normal_(self.word_emb.weight, 0.0, self.d_model**-0.5) 169 | else: 170 | self.word_emb = None 171 | 172 | self.pos_emb = PositionalEmbedding(self.d_model) 173 | self.drop = nn.Dropout(dropemb) 174 | self.layers = nn.ModuleList() 175 | 176 | for _ in range(n_layer): 177 | self.layers.append( 178 | TransformerLayer( 179 | n_head, d_model, d_head, d_inner, kernel_size, dropout, 180 | dropatt=dropatt, pre_lnorm=pre_lnorm) 181 | ) 182 | 183 | def forward(self, dec_inp, seq_lens=None, conditioning=0): 184 | if self.word_emb is None: 185 | inp = dec_inp 186 | mask = sequence_mask(seq_lens).unsqueeze(2) 187 | else: 188 | inp = self.word_emb(dec_inp) 189 | # [bsz x L x 1] 190 | mask = sequence_mask(seq_lens).unsqueeze(2) 191 | 192 | pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype) 193 | pos_emb = self.pos_emb(pos_seq) * mask 194 | 195 | out = self.drop(inp + pos_emb + conditioning) 196 | 197 | for layer in self.layers: 198 | out = layer(out, mask=mask) 199 | 200 | # out = self.drop(out) 201 | return out, mask -------------------------------------------------------------------------------- /fs2/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import lightning 14 | import numpy as np 15 | import rootutils 16 | import torch 17 | from hydra import compose, initialize 18 | from omegaconf import DictConfig, open_dict 19 | from torch import nn 20 | from tqdm.auto import tqdm 21 | 22 | from fs2.data.text_mel_datamodule import TextMelDataModule 23 | from fs2.utils.logging_utils import pylogger 24 | from fs2.utils.utils import to_numpy 25 | 26 | log = pylogger.get_pylogger(__name__) 27 | 28 | 29 | @torch.inference_mode() 30 | def generate_preprocessing_files(dataset: torch.utils.data.Dataset, output_folder: Path, cfg: DictConfig, save_stats=True): 31 | """Generate durations from the model for each datapoint and save it in a folder 32 | 33 | Args: 34 | data_loader (torch.utils.data.DataLoader): Dataloader 35 | model (nn.Module): MatchaTTS model 36 | device (torch.device): GPU or CPU 37 | """ 38 | x_lengths = 0 39 | 40 | # Pitch stats 41 | pitch_min = float("inf") 42 | pitch_max = -float("inf") 43 | pitch_sum = 0 44 | pitch_sq_sum = 0 45 | 46 | # Energy stats 47 | energy_min = float("inf") 48 | energy_max = -float("inf") 49 | energy_sum = 0 50 | energy_sq_sum = 0 51 | 52 | # Mel stats 53 | mel_sum = 0 54 | mel_sq_sum = 0 55 | total_mel_len = 0 56 | 57 | processed_folder_name = output_folder 58 | assert (processed_folder_name/ "durations").exists(), "Durations folder not found, it must be generated beforehand for this script to work" 59 | pitch_folder, energy_folder, mel_folder = init_folders(processed_folder_name) 60 | 61 | # Benefit of doing it over batch is the added speed due to multiprocessing 62 | for batch in tqdm(dataset, desc="🍵 Preprocessing durations 🍵"): 63 | # Get pre generated durations with Matcha-TTS 64 | for i in range(batch['x'].shape[0]): 65 | filname = Path(batch['filepaths'][i]).stem 66 | inp_len = batch['x_lengths'][i] 67 | mel_len = batch['y_lengths'][i] 68 | pitch = batch['pitches'][i][:inp_len] 69 | pitch_min = min(pitch_min, torch.min(pitch).item()) 70 | pitch_max = max(pitch_max, torch.max(pitch).item()) 71 | 72 | np.save(pitch_folder / f"{filname}.npy", to_numpy(pitch)) 73 | energy = batch['energies'][i][:inp_len] 74 | energy_min = min(energy_min, torch.min(energy).item()) 75 | energy_max = max(energy_max, torch.max(energy).item()) 76 | np.save(energy_folder / f"{filname}.npy", to_numpy(energy)) 77 | mel_spec = batch['y'][i][:, :mel_len] 78 | np.save(mel_folder / f"{filname}.npy", to_numpy(mel_spec)) 79 | 80 | # normalisation statistics 81 | pitch_sum += torch.sum(pitch) 82 | pitch_sq_sum += torch.sum(torch.pow(pitch, 2)) 83 | energy_sum += torch.sum(energy) 84 | energy_sq_sum += torch.sum(torch.pow(energy, 2)) 85 | x_lengths += inp_len 86 | 87 | mel_sum += torch.sum(mel_spec) 88 | mel_sq_sum += torch.sum(mel_spec ** 2) 89 | total_mel_len += mel_len 90 | 91 | # Save normalisation statistics 92 | pitch_mean = pitch_sum / x_lengths 93 | pitch_std = torch.sqrt((pitch_sq_sum / x_lengths) - torch.pow(pitch_mean, 2)) 94 | 95 | energy_mean = energy_sum / x_lengths 96 | energy_std = torch.sqrt((energy_sq_sum / x_lengths) - torch.pow(energy_mean,2)) 97 | 98 | mel_mean = mel_sum / (total_mel_len * cfg['n_feats']) 99 | mel_std = torch.sqrt((mel_sq_sum / (total_mel_len * cfg['n_feats'])) - torch.pow(mel_mean, 2)) 100 | 101 | 102 | stats = { 103 | "pitch_min": round(pitch_min, 6), 104 | "pitch_max": round(pitch_max, 6), 105 | "pitch_mean": round(pitch_mean.item(), 6), 106 | "pitch_std": round(pitch_std.item(), 6), 107 | "energy_min": round(energy_min, 6), 108 | "energy_max": round(energy_max, 6), 109 | "energy_mean": round(energy_mean.item(), 6), 110 | "energy_std": round(energy_std.item(), 6), 111 | "mel_mean": round(mel_mean.item(), 6), 112 | "mel_std": round(mel_std.item(), 6), 113 | } 114 | 115 | print(stats) 116 | if save_stats: 117 | with open(processed_folder_name / "stats.json", "w") as f: 118 | json.dump(stats,f, indent=4) 119 | else: 120 | print("Stats not saved!") 121 | 122 | print("[+] Done! features saved to: ", processed_folder_name) 123 | 124 | def init_folders(processed_folder_name): 125 | pitch_folder = processed_folder_name / "pitch" 126 | energy_folder = processed_folder_name / "energy" 127 | mel_folder = processed_folder_name / "mel" 128 | pitch_folder.mkdir(parents=True, exist_ok=True) 129 | energy_folder.mkdir(parents=True, exist_ok=True) 130 | mel_folder.mkdir(parents=True, exist_ok=True) 131 | return pitch_folder,energy_folder, mel_folder 132 | 133 | 134 | 135 | def main(): 136 | parser = argparse.ArgumentParser() 137 | 138 | parser.add_argument( 139 | "-i", 140 | "--input-config", 141 | type=str, 142 | default="vctk.yaml", 143 | help="The name of the yaml config file under configs/data", 144 | ) 145 | 146 | parser.add_argument( 147 | "-b", 148 | "--batch-size", 149 | type=int, 150 | default="32", 151 | help="Can have increased batch size for faster computation", 152 | ) 153 | 154 | parser.add_argument( 155 | "-f", 156 | "--force", 157 | action="store_true", 158 | default=False, 159 | required=False, 160 | help="force overwrite the file", 161 | ) 162 | 163 | parser.add_argument( 164 | "-o", 165 | "--output-folder", 166 | type=str, 167 | default=None, 168 | help="Output folder to save the data statistics", 169 | ) 170 | args = parser.parse_args() 171 | 172 | with initialize(version_base="1.3", config_path="../../configs/data"): 173 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 174 | 175 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 176 | 177 | with open_dict(cfg): 178 | del cfg["hydra"] 179 | del cfg["_target_"] 180 | cfg["seed"] = 1234 181 | cfg["batch_size"] = args.batch_size 182 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 183 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 184 | cfg['generate_properties'] = True 185 | # Remove this after testing let the multiprocessing do its job 186 | # cfg['num_workers'] = 0 187 | 188 | if args.output_folder is not None: 189 | output_folder = Path(args.output_folder) 190 | else: 191 | output_folder = Path(cfg["train_filelist_path"]).parent 192 | 193 | # if os.path.exists(output_folder) and not args.force: 194 | # print("Folder already exists. Use -f to force overwrite") 195 | # sys.exit(1) 196 | 197 | output_folder.mkdir(parents=True, exist_ok=True) 198 | 199 | print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") 200 | 201 | 202 | text_mel_datamodule = TextMelDataModule(**cfg) 203 | text_mel_datamodule.setup() 204 | try: 205 | print("Computing stats for training set if exists...") 206 | train_dataloader = text_mel_datamodule.train_dataloader() 207 | generate_preprocessing_files(train_dataloader, output_folder, cfg, save_stats=True) 208 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 209 | print("No training set found") 210 | 211 | try: 212 | print("Computing stats for validation set if exists...") 213 | val_dataloader = text_mel_datamodule.val_dataloader() 214 | generate_preprocessing_files(val_dataloader, output_folder, cfg) 215 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 216 | print("No validation set found") 217 | 218 | try: 219 | print("Computing stats for test set if exists...") 220 | test_dataloader = text_mel_datamodule.test_dataloader() 221 | generate_preprocessing_files(test_dataloader, output_folder, cfg) 222 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 223 | print("No test set found") 224 | 225 | print(f"[+] Done! features saved to: {output_folder}") 226 | 227 | 228 | if __name__ == "__main__": 229 | # Helps with generating durations for the dataset to train other architectures 230 | # that cannot learn to align due to limited size of dataset 231 | # Example usage: 232 | # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model 233 | # This will create a folder in data/processed_data/durations/ljspeech with the durations 234 | main() 235 | -------------------------------------------------------------------------------- /fs2/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Tuple 7 | 8 | import gdown 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import wget 13 | from omegaconf import DictConfig 14 | 15 | from fs2.utils import pylogger, rich_utils 16 | 17 | log = pylogger.get_pylogger(__name__) 18 | 19 | 20 | def extras(cfg: DictConfig) -> None: 21 | """Applies optional utilities before the task is started. 22 | 23 | Utilities: 24 | - Ignoring python warnings 25 | - Setting tags from command line 26 | - Rich config printing 27 | 28 | :param cfg: A DictConfig object containing the config tree. 29 | """ 30 | # return if no `extras` config 31 | if not cfg.get("extras"): 32 | log.warning("Extras config not found! ") 33 | return 34 | 35 | # disable python warnings 36 | if cfg.extras.get("ignore_warnings"): 37 | log.info("Disabling python warnings! ") 38 | warnings.filterwarnings("ignore") 39 | 40 | # prompt user to input tags from command line if none are provided in the config 41 | if cfg.extras.get("enforce_tags"): 42 | log.info("Enforcing tags! ") 43 | rich_utils.enforce_tags(cfg, save_to_file=True) 44 | 45 | # pretty print config tree using Rich library 46 | if cfg.extras.get("print_config"): 47 | log.info("Printing config tree with Rich! ") 48 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 49 | 50 | 51 | def task_wrapper(task_func: Callable) -> Callable: 52 | """Optional decorator that controls the failure behavior when executing the task function. 53 | 54 | This wrapper can be used to: 55 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 56 | - save the exception to a `.log` file 57 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 58 | - etc. (adjust depending on your needs) 59 | 60 | Example: 61 | ``` 62 | @utils.task_wrapper 63 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 64 | ... 65 | return metric_dict, object_dict 66 | ``` 67 | 68 | :param task_func: The task function to be wrapped. 69 | 70 | :return: The wrapped task function. 71 | """ 72 | 73 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 | # execute the task 75 | try: 76 | metric_dict, object_dict = task_func(cfg=cfg) 77 | 78 | # things to do if exception occurs 79 | except Exception as ex: 80 | # save exception to `.log` file 81 | log.exception("") 82 | 83 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 | # so when using hparam search plugins like Optuna, you might want to disable 85 | # raising the below exception to avoid multirun failure 86 | raise ex 87 | 88 | # things to always do after either success or exception 89 | finally: 90 | # display output dir path in terminal 91 | log.info(f"Output dir: {cfg.paths.output_dir}") 92 | 93 | # always close wandb run (even if exception occurs so multirun won't fail) 94 | if find_spec("wandb"): # check if wandb is installed 95 | import wandb 96 | 97 | if wandb.run: 98 | log.info("Closing wandb!") 99 | wandb.finish() 100 | 101 | return metric_dict, object_dict 102 | 103 | return wrap 104 | 105 | 106 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: 107 | """Safely retrieves value of the metric logged in LightningModule. 108 | 109 | :param metric_dict: A dict containing metric values. 110 | :param metric_name: The name of the metric to retrieve. 111 | :return: The value of the metric. 112 | """ 113 | if not metric_name: 114 | log.info("Metric name is None! Skipping metric value retrieval...") 115 | return None 116 | 117 | if metric_name not in metric_dict: 118 | raise ValueError( 119 | f"Metric value not found! \n" 120 | "Make sure metric name logged in LightningModule is correct!\n" 121 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 122 | ) 123 | 124 | metric_value = metric_dict[metric_name].item() 125 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 126 | 127 | return metric_value 128 | 129 | 130 | def intersperse(lst, item): 131 | # Adds blank symbol 132 | result = [item] * (len(lst) * 2 + 1) 133 | result[1::2] = lst 134 | return result 135 | 136 | 137 | def save_figure_to_numpy(fig): 138 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 139 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 140 | return data 141 | 142 | 143 | def plot_tensor(tensor): 144 | plt.style.use("default") 145 | fig, ax = plt.subplots(figsize=(12, 3)) 146 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 147 | plt.colorbar(im, ax=ax) 148 | plt.tight_layout() 149 | fig.canvas.draw() 150 | data = save_figure_to_numpy(fig) 151 | plt.close() 152 | return data 153 | 154 | def plot_line(tensor, min_value, max_value, y_log=True): 155 | plt.style.use("default") 156 | fig, ax = plt.subplots(figsize=(12, 3)) 157 | ax.plot(tensor) 158 | ax.set_ylim(min_value, max_value) 159 | ax.set_yscale("log" if y_log else "linear") 160 | plt.tight_layout() 161 | fig.canvas.draw() 162 | data = save_figure_to_numpy(fig) 163 | plt.close() 164 | return data 165 | 166 | def save_plot(tensor, savepath): 167 | plt.style.use("default") 168 | fig, ax = plt.subplots(figsize=(12, 3)) 169 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 170 | plt.colorbar(im, ax=ax) 171 | plt.tight_layout() 172 | fig.canvas.draw() 173 | plt.savefig(savepath) 174 | plt.close() 175 | 176 | def to_torch(x, dtype): 177 | return torch.tensor(x, dtype=dtype) if not isinstance(x, torch.Tensor) else x 178 | 179 | 180 | def to_numpy(tensor): 181 | if isinstance(tensor, np.ndarray): 182 | return tensor 183 | elif isinstance(tensor, torch.Tensor): 184 | return tensor.detach().cpu().numpy() 185 | elif isinstance(tensor, list): 186 | return np.array(tensor) 187 | else: 188 | raise TypeError("Unsupported type for conversion to numpy array") 189 | 190 | 191 | def get_user_data_dir(appname="matcha_tts"): 192 | """ 193 | Args: 194 | appname (str): Name of application 195 | 196 | Returns: 197 | Path: path to user data directory 198 | """ 199 | 200 | MATCHA_HOME = os.environ.get("MATCHA_HOME") 201 | if MATCHA_HOME is not None: 202 | ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) 203 | elif sys.platform == "win32": 204 | import winreg # pylint: disable=import-outside-toplevel 205 | 206 | key = winreg.OpenKey( 207 | winreg.HKEY_CURRENT_USER, 208 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", 209 | ) 210 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 211 | ans = Path(dir_).resolve(strict=False) 212 | elif sys.platform == "darwin": 213 | ans = Path("~/Library/Application Support/").expanduser() 214 | else: 215 | ans = Path.home().joinpath(".local/share") 216 | 217 | final_path = ans.joinpath(appname) 218 | final_path.mkdir(parents=True, exist_ok=True) 219 | return final_path 220 | 221 | 222 | def assert_model_downloaded(checkpoint_path, url, use_wget=True): 223 | if Path(checkpoint_path).exists(): 224 | log.debug(f"[+] Model already present at {checkpoint_path}!") 225 | print(f"[+] Model already present at {checkpoint_path}!") 226 | return 227 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it") 228 | print(f"[-] Model not found at {checkpoint_path}! Will download it") 229 | checkpoint_path = str(checkpoint_path) 230 | if not use_wget: 231 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) 232 | else: 233 | wget.download(url=url, out=checkpoint_path) 234 | 235 | 236 | 237 | def trim_or_pad_to_target_length( 238 | data_1d_or_2d: np.ndarray, target_length: int 239 | ) -> np.ndarray: 240 | assert len(data_1d_or_2d.shape) in {1, 2} 241 | delta = data_1d_or_2d.shape[0] - target_length 242 | if delta >= 0: # trim if being longer 243 | data_1d_or_2d = data_1d_or_2d[: target_length] 244 | else: # pad if being shorter 245 | if len(data_1d_or_2d.shape) == 1: 246 | data_1d_or_2d = np.concatenate( 247 | [data_1d_or_2d, np.zeros(-delta)], axis=0 248 | ) 249 | else: 250 | data_1d_or_2d = np.concatenate( 251 | [data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))], 252 | axis=0 253 | ) 254 | return data_1d_or_2d -------------------------------------------------------------------------------- /fs2/models/baselightningmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a base lightning module that can be used to train a model. 3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. 4 | """ 5 | import inspect 6 | from abc import ABC 7 | from typing import Any, Dict 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | from lightning import LightningModule 12 | from lightning.pytorch.utilities import grad_norm 13 | 14 | from fs2 import utils 15 | from fs2.utils.model import (denormalize, expand_lengths, invert_log_norm, 16 | normalize) 17 | from fs2.utils.utils import plot_line, plot_tensor, save_figure_to_numpy 18 | 19 | log = utils.get_pylogger(__name__) 20 | 21 | 22 | class BaseLightningClass(LightningModule, ABC): 23 | def update_data_statistics(self, data_statistics): 24 | if data_statistics is None: 25 | raise ValueError(f"data_statistics are not computed. \ 26 | Please run python fs2/utils/preprocess.py -i \ 27 | to get statistics and update them in data_statistics field.") 28 | 29 | self.register_buffer("pitch_mean", torch.tensor(data_statistics["pitch_mean"])) 30 | self.register_buffer("pitch_std", torch.tensor(data_statistics["pitch_std"])) 31 | 32 | self.register_buffer("energy_mean", torch.tensor(data_statistics["energy_mean"])) 33 | self.register_buffer("energy_std", torch.tensor(data_statistics["energy_std"])) 34 | 35 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) 36 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) 37 | 38 | pitch_min = normalize(torch.tensor(data_statistics["pitch_min"]), self.pitch_mean, self.pitch_std) 39 | pitch_max = normalize(torch.tensor(data_statistics["pitch_max"]), self.pitch_mean, self.pitch_std) 40 | energy_min = normalize(torch.tensor(data_statistics["energy_min"]), self.energy_mean, self.energy_std) 41 | energy_max = normalize(torch.tensor(data_statistics["energy_max"]), self.energy_mean, self.energy_std) 42 | 43 | self.register_buffer("pitch_min", pitch_min) 44 | self.register_buffer("pitch_max", pitch_max) 45 | self.register_buffer("energy_min", energy_min) 46 | self.register_buffer("energy_max", energy_max) 47 | 48 | def configure_optimizers(self) -> Any: 49 | optimizer = self.hparams.optimizer(params=self.parameters()) 50 | if self.hparams.scheduler not in (None, {}): 51 | scheduler_args = {} 52 | # Manage last epoch for exponential schedulers 53 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: 54 | if hasattr(self, "ckpt_loaded_epoch"): 55 | current_epoch = self.ckpt_loaded_epoch - 1 56 | else: 57 | current_epoch = -1 58 | 59 | scheduler_args.update({"optimizer": optimizer}) 60 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args) 61 | scheduler.last_epoch = current_epoch 62 | return { 63 | "optimizer": optimizer, 64 | "lr_scheduler": { 65 | "scheduler": scheduler, 66 | "interval": self.hparams.scheduler.lightning_args.interval, 67 | "frequency": self.hparams.scheduler.lightning_args.frequency, 68 | "name": "learning_rate", 69 | }, 70 | } 71 | 72 | return {"optimizer": optimizer} 73 | 74 | def get_losses(self, batch): 75 | x, x_lengths = batch["x"], batch["x_lengths"] 76 | y, y_lengths = batch["y"], batch["y_lengths"] 77 | spks = batch["spks"] 78 | durations, pitches, energies = batch["durations"], batch["pitches"], batch["energies"] 79 | 80 | 81 | losses = self( 82 | x=x, 83 | x_lengths=x_lengths, 84 | y=y, 85 | y_lengths=y_lengths, 86 | durations=durations, 87 | pitches=pitches, 88 | energies=energies, 89 | spks=spks, 90 | ) 91 | return losses 92 | 93 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 94 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init 95 | 96 | def training_step(self, batch: Any, batch_idx: int): 97 | loss_dict = self.get_losses(batch) 98 | self.log( 99 | "step", 100 | float(self.global_step), 101 | on_step=True, 102 | prog_bar=True, 103 | logger=True, 104 | sync_dist=True, 105 | ) 106 | for loss in loss_dict: 107 | self.log( 108 | f"sub_loss/train_{loss}", 109 | loss_dict[loss], 110 | on_step=True, 111 | on_epoch=True, 112 | logger=True, 113 | sync_dist=True, 114 | ) 115 | total_loss = sum(loss_dict.values()) 116 | self.log( 117 | "loss/train", 118 | total_loss, 119 | on_step=True, 120 | on_epoch=True, 121 | logger=True, 122 | prog_bar=True, 123 | sync_dist=True, 124 | ) 125 | 126 | return {"loss": total_loss, "log": loss_dict} 127 | 128 | def validation_step(self, batch: Any, batch_idx: int): 129 | loss_dict = self.get_losses(batch) 130 | for loss in loss_dict: 131 | self.log( 132 | f"sub_loss/val_{loss}", 133 | loss_dict[loss], 134 | on_step=True, 135 | on_epoch=True, 136 | logger=True, 137 | sync_dist=True, 138 | ) 139 | total_loss = sum(loss_dict.values()) 140 | self.log( 141 | "loss/val", 142 | total_loss, 143 | on_step=True, 144 | on_epoch=True, 145 | logger=True, 146 | prog_bar=True, 147 | sync_dist=True, 148 | ) 149 | 150 | return {"loss": total_loss, "log": loss_dict} 151 | 152 | def on_validation_end(self) -> None: 153 | if self.trainer.is_global_zero: 154 | one_batch = next(iter(self.trainer.val_dataloaders)) 155 | if self.current_epoch == 0: 156 | log.debug("Plotting original samples") 157 | for i in range(2): 158 | y = denormalize(one_batch["y"][i].unsqueeze(0).to(self.device), self.mel_mean, self.mel_std)[:, :, :one_batch["y_lengths"][i]] 159 | durations = one_batch["durations"][i].to(self.device)[:one_batch["x_lengths"][i]] 160 | 161 | original_pitch = one_batch["pitches"][i].unsqueeze(0).to(self.device)[:, :one_batch["x_lengths"][i]] 162 | original_pitch, _ = expand_lengths(original_pitch.unsqueeze(2), durations.unsqueeze(0)) 163 | original_pitch = denormalize(original_pitch, self.pitch_mean, self.pitch_std) 164 | 165 | original_energy = one_batch["energies"][i].unsqueeze(0).to(self.device)[:, :one_batch["x_lengths"][i]] 166 | original_energy, _ = expand_lengths(original_energy.unsqueeze(2), durations.unsqueeze(0)) 167 | original_energy = denormalize(original_energy, self.energy_mean, self.energy_std) 168 | 169 | self.logger.experiment.add_image( 170 | f"original/mel_{i}", 171 | self.plot_mel([(y.squeeze().cpu().numpy(), original_pitch.cpu().squeeze(), original_energy.cpu().squeeze())], [f"Data_{i}"]), 172 | self.current_epoch, 173 | dataformats="HWC", 174 | ) 175 | 176 | log.debug("Synthesising...") 177 | for i in range(2): 178 | x = one_batch["x"][i].unsqueeze(0).to(self.device) 179 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) 180 | spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None 181 | output = self.synthesise(x[:, :x_lengths], x_lengths, spks=spks) 182 | decoder_output, y_pred = output["decoder_output"], output["mel"] 183 | pitch_pred, energy_pred = output["pitch_pred"], output["energy_pred"] 184 | 185 | self.logger.experiment.add_image( 186 | f"dec_output/{i}", 187 | plot_tensor(decoder_output.squeeze().cpu()), 188 | self.current_epoch, 189 | dataformats="HWC", 190 | ) 191 | 192 | self.logger.experiment.add_image( 193 | f"generated/mel_{i}", 194 | self.plot_mel([(y_pred.squeeze().cpu().numpy(), pitch_pred.cpu().squeeze(), energy_pred.cpu().squeeze())], [f"Generated_{i}"]), 195 | self.current_epoch, 196 | dataformats="HWC", 197 | ) 198 | 199 | 200 | def on_before_optimizer_step(self, optimizer): 201 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) 202 | 203 | 204 | 205 | def plot_mel(self, data, titles, show=False): 206 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 207 | if titles is None: 208 | titles = [None for i in range(len(data))] 209 | 210 | pitch_max = denormalize(self.pitch_max, self.pitch_mean, self.pitch_std).cpu().item() 211 | energy_min = denormalize(self.energy_min, self.energy_mean, self.energy_std).cpu().item() 212 | energy_max = denormalize(self.energy_max, self.energy_mean, self.energy_std).cpu().item() 213 | 214 | def add_axis(fig, old_ax): 215 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 216 | ax.set_facecolor("None") 217 | return ax 218 | 219 | for i in range(len(data)): 220 | mel, pitch, energy = data[i] 221 | axes[i][0].imshow(mel, origin="lower") 222 | axes[i][0].set_aspect(2.5, adjustable="box") 223 | axes[i][0].set_ylim(0, mel.shape[0]) 224 | axes[i][0].set_title(titles[i], fontsize="medium") 225 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 226 | axes[i][0].set_anchor("W") 227 | 228 | ax1 = add_axis(fig, axes[i][0]) 229 | ax1.plot(pitch, color="tomato") 230 | ax1.set_xlim(0, mel.shape[1]) 231 | ax1.set_ylim(0, pitch_max) 232 | ax1.set_ylabel("F0", color="tomato") 233 | ax1.tick_params( 234 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 235 | ) 236 | 237 | ax2 = add_axis(fig, axes[i][0]) 238 | ax2.plot(energy, color="darkviolet") 239 | ax2.set_xlim(0, mel.shape[1]) 240 | ax2.set_ylim(energy_min, energy_max) 241 | ax2.set_ylabel("Energy", color="darkviolet") 242 | ax2.yaxis.set_label_position("right") 243 | ax2.tick_params( 244 | labelsize="x-small", 245 | colors="darkviolet", 246 | bottom=False, 247 | labelbottom=False, 248 | left=False, 249 | labelleft=False, 250 | right=True, 251 | labelright=True, 252 | ) 253 | fig.canvas.draw() 254 | if show: 255 | plt.show() 256 | return 257 | 258 | data = save_figure_to_numpy(fig) 259 | plt.close() 260 | return data -------------------------------------------------------------------------------- /fs2/hifigan/models.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 7 | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 8 | 9 | from .xutils import get_padding, init_weights 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class ResBlock1(torch.nn.Module): 15 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 16 | super().__init__() 17 | self.h = h 18 | self.convs1 = nn.ModuleList( 19 | [ 20 | weight_norm( 21 | Conv1d( 22 | channels, 23 | channels, 24 | kernel_size, 25 | 1, 26 | dilation=dilation[0], 27 | padding=get_padding(kernel_size, dilation[0]), 28 | ) 29 | ), 30 | weight_norm( 31 | Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | 1, 36 | dilation=dilation[1], 37 | padding=get_padding(kernel_size, dilation[1]), 38 | ) 39 | ), 40 | weight_norm( 41 | Conv1d( 42 | channels, 43 | channels, 44 | kernel_size, 45 | 1, 46 | dilation=dilation[2], 47 | padding=get_padding(kernel_size, dilation[2]), 48 | ) 49 | ), 50 | ] 51 | ) 52 | self.convs1.apply(init_weights) 53 | 54 | self.convs2 = nn.ModuleList( 55 | [ 56 | weight_norm( 57 | Conv1d( 58 | channels, 59 | channels, 60 | kernel_size, 61 | 1, 62 | dilation=1, 63 | padding=get_padding(kernel_size, 1), 64 | ) 65 | ), 66 | weight_norm( 67 | Conv1d( 68 | channels, 69 | channels, 70 | kernel_size, 71 | 1, 72 | dilation=1, 73 | padding=get_padding(kernel_size, 1), 74 | ) 75 | ), 76 | weight_norm( 77 | Conv1d( 78 | channels, 79 | channels, 80 | kernel_size, 81 | 1, 82 | dilation=1, 83 | padding=get_padding(kernel_size, 1), 84 | ) 85 | ), 86 | ] 87 | ) 88 | self.convs2.apply(init_weights) 89 | 90 | def forward(self, x): 91 | for c1, c2 in zip(self.convs1, self.convs2): 92 | xt = F.leaky_relu(x, LRELU_SLOPE) 93 | xt = c1(xt) 94 | xt = F.leaky_relu(xt, LRELU_SLOPE) 95 | xt = c2(xt) 96 | x = xt + x 97 | return x 98 | 99 | def remove_weight_norm(self): 100 | for l in self.convs1: 101 | remove_weight_norm(l) 102 | for l in self.convs2: 103 | remove_weight_norm(l) 104 | 105 | 106 | class ResBlock2(torch.nn.Module): 107 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 108 | super().__init__() 109 | self.h = h 110 | self.convs = nn.ModuleList( 111 | [ 112 | weight_norm( 113 | Conv1d( 114 | channels, 115 | channels, 116 | kernel_size, 117 | 1, 118 | dilation=dilation[0], 119 | padding=get_padding(kernel_size, dilation[0]), 120 | ) 121 | ), 122 | weight_norm( 123 | Conv1d( 124 | channels, 125 | channels, 126 | kernel_size, 127 | 1, 128 | dilation=dilation[1], 129 | padding=get_padding(kernel_size, dilation[1]), 130 | ) 131 | ), 132 | ] 133 | ) 134 | self.convs.apply(init_weights) 135 | 136 | def forward(self, x): 137 | for c in self.convs: 138 | xt = F.leaky_relu(x, LRELU_SLOPE) 139 | xt = c(xt) 140 | x = xt + x 141 | return x 142 | 143 | def remove_weight_norm(self): 144 | for l in self.convs: 145 | remove_weight_norm(l) 146 | 147 | 148 | class Generator(torch.nn.Module): 149 | def __init__(self, h): 150 | super().__init__() 151 | self.h = h 152 | self.num_kernels = len(h.resblock_kernel_sizes) 153 | self.num_upsamples = len(h.upsample_rates) 154 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 155 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 156 | 157 | self.ups = nn.ModuleList() 158 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 159 | self.ups.append( 160 | weight_norm( 161 | ConvTranspose1d( 162 | h.upsample_initial_channel // (2**i), 163 | h.upsample_initial_channel // (2 ** (i + 1)), 164 | k, 165 | u, 166 | padding=(k - u) // 2, 167 | ) 168 | ) 169 | ) 170 | 171 | self.resblocks = nn.ModuleList() 172 | for i in range(len(self.ups)): 173 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 174 | for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 175 | self.resblocks.append(resblock(h, ch, k, d)) 176 | 177 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 178 | self.ups.apply(init_weights) 179 | self.conv_post.apply(init_weights) 180 | 181 | def forward(self, x): 182 | x = self.conv_pre(x) 183 | for i in range(self.num_upsamples): 184 | x = F.leaky_relu(x, LRELU_SLOPE) 185 | x = self.ups[i](x) 186 | xs = None 187 | for j in range(self.num_kernels): 188 | if xs is None: 189 | xs = self.resblocks[i * self.num_kernels + j](x) 190 | else: 191 | xs += self.resblocks[i * self.num_kernels + j](x) 192 | x = xs / self.num_kernels 193 | x = F.leaky_relu(x) 194 | x = self.conv_post(x) 195 | x = torch.tanh(x) 196 | 197 | return x 198 | 199 | def remove_weight_norm(self): 200 | for l in self.ups: 201 | remove_weight_norm(l) 202 | for l in self.resblocks: 203 | l.remove_weight_norm() 204 | remove_weight_norm(self.conv_pre) 205 | remove_weight_norm(self.conv_post) 206 | 207 | 208 | class DiscriminatorP(torch.nn.Module): 209 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 210 | super().__init__() 211 | self.period = period 212 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 213 | self.convs = nn.ModuleList( 214 | [ 215 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 216 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 217 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 218 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 219 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 220 | ] 221 | ) 222 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 223 | 224 | def forward(self, x): 225 | fmap = [] 226 | 227 | # 1d to 2d 228 | b, c, t = x.shape 229 | if t % self.period != 0: # pad first 230 | n_pad = self.period - (t % self.period) 231 | x = F.pad(x, (0, n_pad), "reflect") 232 | t = t + n_pad 233 | x = x.view(b, c, t // self.period, self.period) 234 | 235 | for l in self.convs: 236 | x = l(x) 237 | x = F.leaky_relu(x, LRELU_SLOPE) 238 | fmap.append(x) 239 | x = self.conv_post(x) 240 | fmap.append(x) 241 | x = torch.flatten(x, 1, -1) 242 | 243 | return x, fmap 244 | 245 | 246 | class MultiPeriodDiscriminator(torch.nn.Module): 247 | def __init__(self): 248 | super().__init__() 249 | self.discriminators = nn.ModuleList( 250 | [ 251 | DiscriminatorP(2), 252 | DiscriminatorP(3), 253 | DiscriminatorP(5), 254 | DiscriminatorP(7), 255 | DiscriminatorP(11), 256 | ] 257 | ) 258 | 259 | def forward(self, y, y_hat): 260 | y_d_rs = [] 261 | y_d_gs = [] 262 | fmap_rs = [] 263 | fmap_gs = [] 264 | for _, d in enumerate(self.discriminators): 265 | y_d_r, fmap_r = d(y) 266 | y_d_g, fmap_g = d(y_hat) 267 | y_d_rs.append(y_d_r) 268 | fmap_rs.append(fmap_r) 269 | y_d_gs.append(y_d_g) 270 | fmap_gs.append(fmap_g) 271 | 272 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 273 | 274 | 275 | class DiscriminatorS(torch.nn.Module): 276 | def __init__(self, use_spectral_norm=False): 277 | super().__init__() 278 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 279 | self.convs = nn.ModuleList( 280 | [ 281 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 282 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 283 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 284 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 285 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 286 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 287 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 288 | ] 289 | ) 290 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 291 | 292 | def forward(self, x): 293 | fmap = [] 294 | for l in self.convs: 295 | x = l(x) 296 | x = F.leaky_relu(x, LRELU_SLOPE) 297 | fmap.append(x) 298 | x = self.conv_post(x) 299 | fmap.append(x) 300 | x = torch.flatten(x, 1, -1) 301 | 302 | return x, fmap 303 | 304 | 305 | class MultiScaleDiscriminator(torch.nn.Module): 306 | def __init__(self): 307 | super().__init__() 308 | self.discriminators = nn.ModuleList( 309 | [ 310 | DiscriminatorS(use_spectral_norm=True), 311 | DiscriminatorS(), 312 | DiscriminatorS(), 313 | ] 314 | ) 315 | self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) 316 | 317 | def forward(self, y, y_hat): 318 | y_d_rs = [] 319 | y_d_gs = [] 320 | fmap_rs = [] 321 | fmap_gs = [] 322 | for i, d in enumerate(self.discriminators): 323 | if i != 0: 324 | y = self.meanpools[i - 1](y) 325 | y_hat = self.meanpools[i - 1](y_hat) 326 | y_d_r, fmap_r = d(y) 327 | y_d_g, fmap_g = d(y_hat) 328 | y_d_rs.append(y_d_r) 329 | fmap_rs.append(fmap_r) 330 | y_d_gs.append(y_d_g) 331 | fmap_gs.append(fmap_g) 332 | 333 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 334 | 335 | 336 | def feature_loss(fmap_r, fmap_g): 337 | loss = 0 338 | for dr, dg in zip(fmap_r, fmap_g): 339 | for rl, gl in zip(dr, dg): 340 | loss += torch.mean(torch.abs(rl - gl)) 341 | 342 | return loss * 2 343 | 344 | 345 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 346 | loss = 0 347 | r_losses = [] 348 | g_losses = [] 349 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 350 | r_loss = torch.mean((1 - dr) ** 2) 351 | g_loss = torch.mean(dg**2) 352 | loss += r_loss + g_loss 353 | r_losses.append(r_loss.item()) 354 | g_losses.append(g_loss.item()) 355 | 356 | return loss, r_losses, g_losses 357 | 358 | 359 | def generator_loss(disc_outputs): 360 | loss = 0 361 | gen_losses = [] 362 | for dg in disc_outputs: 363 | l = torch.mean((1 - dg) ** 2) 364 | gen_losses.append(l) 365 | loss += l 366 | 367 | return loss, gen_losses 368 | -------------------------------------------------------------------------------- /data/LJSpeech-1.1/val.txt: -------------------------------------------------------------------------------- 1 | data/LJSpeech-1.1/wavs/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. 2 | data/LJSpeech-1.1/wavs/LJ028-0275.wav|At last, in the twentieth month, 3 | data/LJSpeech-1.1/wavs/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. 4 | data/LJSpeech-1.1/wavs/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace, 5 | data/LJSpeech-1.1/wavs/LJ009-0076.wav|We come to the sermon. 6 | data/LJSpeech-1.1/wavs/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade. 7 | data/LJSpeech-1.1/wavs/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. 8 | data/LJSpeech-1.1/wavs/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. 9 | data/LJSpeech-1.1/wavs/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. 10 | data/LJSpeech-1.1/wavs/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. 11 | data/LJSpeech-1.1/wavs/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary. 12 | data/LJSpeech-1.1/wavs/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. 13 | data/LJSpeech-1.1/wavs/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. 14 | data/LJSpeech-1.1/wavs/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, 15 | data/LJSpeech-1.1/wavs/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies 16 | data/LJSpeech-1.1/wavs/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true, 17 | data/LJSpeech-1.1/wavs/LJ012-0161.wav|he was reported to have fallen away to a shadow. 18 | data/LJSpeech-1.1/wavs/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here. 19 | data/LJSpeech-1.1/wavs/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands. 20 | data/LJSpeech-1.1/wavs/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that 21 | data/LJSpeech-1.1/wavs/LJ008-0294.wav|nearly indefinitely deferred. 22 | data/LJSpeech-1.1/wavs/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand. 23 | data/LJSpeech-1.1/wavs/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles. 24 | data/LJSpeech-1.1/wavs/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files 25 | data/LJSpeech-1.1/wavs/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution. 26 | data/LJSpeech-1.1/wavs/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London, 27 | data/LJSpeech-1.1/wavs/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines. 28 | data/LJSpeech-1.1/wavs/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old. 29 | data/LJSpeech-1.1/wavs/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound, 30 | data/LJSpeech-1.1/wavs/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands. 31 | data/LJSpeech-1.1/wavs/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area 32 | data/LJSpeech-1.1/wavs/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup. 33 | data/LJSpeech-1.1/wavs/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view. 34 | data/LJSpeech-1.1/wavs/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window. 35 | data/LJSpeech-1.1/wavs/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President. 36 | data/LJSpeech-1.1/wavs/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm, 37 | data/LJSpeech-1.1/wavs/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one. 38 | data/LJSpeech-1.1/wavs/LJ014-0030.wav|These were damnatory facts which well supported the prosecution. 39 | data/LJSpeech-1.1/wavs/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald: 40 | data/LJSpeech-1.1/wavs/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston. 41 | data/LJSpeech-1.1/wavs/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood; 42 | data/LJSpeech-1.1/wavs/LJ040-0027.wav|He was never satisfied with anything. 43 | data/LJSpeech-1.1/wavs/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters. 44 | data/LJSpeech-1.1/wavs/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four. 45 | data/LJSpeech-1.1/wavs/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner. 46 | data/LJSpeech-1.1/wavs/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator. 47 | data/LJSpeech-1.1/wavs/LJ033-0047.wav|I noticed when I went out that the light was on, end quote, 48 | data/LJSpeech-1.1/wavs/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on. 49 | data/LJSpeech-1.1/wavs/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job. 50 | data/LJSpeech-1.1/wavs/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five. 51 | data/LJSpeech-1.1/wavs/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal. 52 | data/LJSpeech-1.1/wavs/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery. 53 | data/LJSpeech-1.1/wavs/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash. 54 | data/LJSpeech-1.1/wavs/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated, 55 | data/LJSpeech-1.1/wavs/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen, 56 | data/LJSpeech-1.1/wavs/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized 57 | data/LJSpeech-1.1/wavs/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him, 58 | data/LJSpeech-1.1/wavs/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough. 59 | data/LJSpeech-1.1/wavs/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example." 60 | data/LJSpeech-1.1/wavs/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome? 61 | data/LJSpeech-1.1/wavs/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to 62 | data/LJSpeech-1.1/wavs/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits 63 | data/LJSpeech-1.1/wavs/LJ024-0083.wav|This plan of mine is no attack on the Court; 64 | data/LJSpeech-1.1/wavs/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days, 65 | data/LJSpeech-1.1/wavs/LJ038-0199.wav|eleven. If I am alive and taken prisoner, 66 | data/LJSpeech-1.1/wavs/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present, 67 | data/LJSpeech-1.1/wavs/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words, 68 | data/LJSpeech-1.1/wavs/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely. 69 | data/LJSpeech-1.1/wavs/LJ012-0250.wav|On the seventh July, eighteen thirty-seven, 70 | data/LJSpeech-1.1/wavs/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect: 71 | data/LJSpeech-1.1/wavs/LJ047-0148.wav|On October twenty-five, 72 | data/LJSpeech-1.1/wavs/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally. 73 | data/LJSpeech-1.1/wavs/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there. 74 | data/LJSpeech-1.1/wavs/LJ026-0068.wav|Energy enters the plant, to a small extent, 75 | data/LJSpeech-1.1/wavs/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle. 76 | data/LJSpeech-1.1/wavs/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning. 77 | data/LJSpeech-1.1/wavs/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art. 78 | data/LJSpeech-1.1/wavs/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders. 79 | data/LJSpeech-1.1/wavs/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly. 80 | data/LJSpeech-1.1/wavs/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology. 81 | data/LJSpeech-1.1/wavs/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce. 82 | data/LJSpeech-1.1/wavs/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed. 83 | data/LJSpeech-1.1/wavs/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects. 84 | data/LJSpeech-1.1/wavs/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work. 85 | data/LJSpeech-1.1/wavs/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came. 86 | data/LJSpeech-1.1/wavs/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount 87 | data/LJSpeech-1.1/wavs/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy. 88 | data/LJSpeech-1.1/wavs/LJ012-0235.wav|While they were in a state of insensibility the murder was committed. 89 | data/LJSpeech-1.1/wavs/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties, 90 | data/LJSpeech-1.1/wavs/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others. 91 | data/LJSpeech-1.1/wavs/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County, 92 | data/LJSpeech-1.1/wavs/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved 93 | data/LJSpeech-1.1/wavs/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote. 94 | data/LJSpeech-1.1/wavs/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others. 95 | data/LJSpeech-1.1/wavs/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense. 96 | data/LJSpeech-1.1/wavs/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too. 97 | data/LJSpeech-1.1/wavs/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail. 98 | data/LJSpeech-1.1/wavs/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon 99 | data/LJSpeech-1.1/wavs/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive. 100 | data/LJSpeech-1.1/wavs/LJ016-0138.wav|at a distance from the prison. -------------------------------------------------------------------------------- /fs2/data/text_mel_datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Any, Dict, Optional 4 | 5 | import numpy as np 6 | import pyworld as pw 7 | import torch 8 | import torchaudio as ta 9 | from lightning import LightningDataModule 10 | from scipy.interpolate import interp1d 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | from fs2.text import text_to_sequence 14 | from fs2.utils.audio import mel_spectrogram 15 | from fs2.utils.model import normalize 16 | from fs2.utils.utils import intersperse, to_torch, trim_or_pad_to_target_length 17 | 18 | 19 | def parse_filelist(filelist_path, split_char="|"): 20 | with open(filelist_path, encoding="utf-8") as f: 21 | filepaths_and_text = [line.strip().split(split_char) for line in f] 22 | return filepaths_and_text 23 | 24 | 25 | class TextMelDataModule(LightningDataModule): 26 | def __init__( # pylint: disable=unused-argument 27 | self, 28 | name, 29 | train_filelist_path, 30 | valid_filelist_path, 31 | batch_size, 32 | num_workers, 33 | pin_memory, 34 | cleaners, 35 | add_blank, 36 | n_spks, 37 | n_fft, 38 | n_feats, 39 | sample_rate, 40 | hop_length, 41 | win_length, 42 | f_min, 43 | f_max, 44 | data_statistics, 45 | seed, 46 | generate_properties, 47 | ): 48 | super().__init__() 49 | 50 | # this line allows to access init params with 'self.hparams' attribute 51 | # also ensures init params will be stored in ckpt 52 | self.save_hyperparameters(logger=False) 53 | 54 | def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument 55 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 56 | 57 | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be 58 | careful not to execute things like random split twice! 59 | """ 60 | # load and split datasets only if not loaded already 61 | 62 | self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 63 | self.hparams.train_filelist_path, 64 | self.hparams.n_spks, 65 | self.hparams.cleaners, 66 | self.hparams.add_blank, 67 | self.hparams.n_fft, 68 | self.hparams.n_feats, 69 | self.hparams.sample_rate, 70 | self.hparams.hop_length, 71 | self.hparams.win_length, 72 | self.hparams.f_min, 73 | self.hparams.f_max, 74 | self.hparams.data_statistics, 75 | self.hparams.seed, 76 | self.hparams.generate_properties, 77 | ) 78 | self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 79 | self.hparams.valid_filelist_path, 80 | self.hparams.n_spks, 81 | self.hparams.cleaners, 82 | self.hparams.add_blank, 83 | self.hparams.n_fft, 84 | self.hparams.n_feats, 85 | self.hparams.sample_rate, 86 | self.hparams.hop_length, 87 | self.hparams.win_length, 88 | self.hparams.f_min, 89 | self.hparams.f_max, 90 | self.hparams.data_statistics, 91 | self.hparams.seed, 92 | self.hparams.generate_properties, 93 | ) 94 | 95 | def train_dataloader(self): 96 | return DataLoader( 97 | dataset=self.trainset, 98 | batch_size=self.hparams.batch_size, 99 | num_workers=self.hparams.num_workers, 100 | pin_memory=self.hparams.pin_memory, 101 | shuffle=True, 102 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 103 | ) 104 | 105 | def val_dataloader(self): 106 | return DataLoader( 107 | dataset=self.validset, 108 | batch_size=self.hparams.batch_size, 109 | num_workers=self.hparams.num_workers, 110 | pin_memory=self.hparams.pin_memory, 111 | shuffle=False, 112 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 113 | ) 114 | 115 | def teardown(self, stage: Optional[str] = None): 116 | """Clean up after fit or test.""" 117 | pass # pylint: disable=unnecessary-pass 118 | 119 | def state_dict(self): # pylint: disable=no-self-use 120 | """Extra things to save to checkpoint.""" 121 | return {} 122 | 123 | def load_state_dict(self, state_dict: Dict[str, Any]): 124 | """Things to do when loading checkpoint.""" 125 | pass # pylint: disable=unnecessary-pass 126 | 127 | 128 | class TextMelDataset(torch.utils.data.Dataset): 129 | def __init__( 130 | self, 131 | filelist_path, 132 | n_spks, 133 | cleaners, 134 | add_blank=True, 135 | n_fft=1024, 136 | n_mels=80, 137 | sample_rate=22050, 138 | hop_length=256, 139 | win_length=1024, 140 | f_min=0.0, 141 | f_max=8000, 142 | data_statistics=None, 143 | seed=None, 144 | generate_properties=True, 145 | ): 146 | self.filepaths_and_text = parse_filelist(filelist_path) 147 | self.n_spks = n_spks 148 | self.cleaners = cleaners 149 | self.add_blank = add_blank 150 | self.n_fft = n_fft 151 | self.n_mels = n_mels 152 | self.sample_rate = sample_rate 153 | self.hop_length = hop_length 154 | self.win_length = win_length 155 | self.f_min = f_min 156 | self.f_max = f_max 157 | self.data_statistics = data_statistics 158 | self.generate_properties = generate_properties 159 | self.processed_folder_path = Path(filelist_path).parent 160 | 161 | random.seed(seed) 162 | random.shuffle(self.filepaths_and_text) 163 | 164 | def load_durations(self, filepath, text): 165 | durs = np.load(Path(self.processed_folder_path) / 'durations' / Path(Path(filepath).stem).with_suffix(".npy")).astype(int) 166 | assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" 167 | return durs 168 | 169 | 170 | def get_pitch(self, filepath, phoneme_durations): 171 | _waveform, _sr = ta.load(filepath) 172 | _waveform = _waveform.squeeze(0).double().numpy() 173 | assert _sr == self.sample_rate, f"Sample rate mismatch => Found: {_sr} != {self.sample_rate} = Expected" 174 | 175 | pitch, t = pw.dio( 176 | _waveform, self.sample_rate, frame_period=self.hop_length / self.sample_rate * 1000 177 | ) 178 | pitch = pw.stonemask(_waveform, pitch, t, self.sample_rate) 179 | # A cool function taken from fairseq 180 | # https://github.com/facebookresearch/fairseq/blob/3f0f20f2d12403629224347664b3e75c13b2c8e0/examples/speech_synthesis/data_utils.py#L99 181 | pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations)) 182 | 183 | # Interpolate to cover the unvoiced segments as well 184 | nonzero_ids = np.where(pitch != 0)[0] 185 | 186 | interp_fn = interp1d( 187 | nonzero_ids, 188 | pitch[nonzero_ids], 189 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 190 | bounds_error=False, 191 | ) 192 | pitch = interp_fn(np.arange(0, len(pitch))) 193 | 194 | # Compute phoneme-wise average 195 | d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) 196 | pitch = np.array( 197 | [ 198 | np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]]) 199 | for i in range(1, len(d_cumsum)) 200 | ] 201 | ) 202 | assert len(pitch) == len(phoneme_durations) 203 | return pitch 204 | 205 | def mean_phoneme_energy(self, energy, phoneme_durations): 206 | energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations)) 207 | d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) 208 | energy = np.array( 209 | [ 210 | np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]]) 211 | for i in range(1, len(d_cumsum)) 212 | ] 213 | ) 214 | assert len(energy) == len(phoneme_durations) 215 | 216 | # if log_scale: 217 | # # In fairseq they do it 218 | # energy = np.log(energy + 1) 219 | 220 | return energy 221 | 222 | def get_datapoint(self, filepath_and_text): 223 | if self.n_spks > 1: 224 | filepath, spk, text = ( 225 | filepath_and_text[0], 226 | int(filepath_and_text[1]), 227 | filepath_and_text[2], 228 | ) 229 | else: 230 | filepath, text = filepath_and_text[0], filepath_and_text[1] 231 | spk = None 232 | 233 | if self.generate_properties: 234 | text = self.get_text(text, add_blank=self.add_blank) 235 | phoneme_durations = self.load_durations(filepath, text) 236 | mel, energy = self.get_mel(filepath) 237 | energy = self.mean_phoneme_energy(energy.squeeze().cpu().numpy(), phoneme_durations) 238 | pitch = self.get_pitch(filepath, phoneme_durations) 239 | # Do not normalise them in this case as this is supposed to be called by 240 | # python fs2/utils/preprocess.py -i ljspeech 241 | else: 242 | text = self.get_text(text, add_blank=self.add_blank) 243 | phoneme_durations = self.load_durations(filepath, text) 244 | assert len(phoneme_durations) == len(text) 245 | pitch = np.load(Path(self.processed_folder_path) / 'pitch' / Path(Path(filepath).stem).with_suffix(".npy")) 246 | pitch = normalize(pitch, self.data_statistics['pitch_mean'], self.data_statistics['pitch_std']) 247 | assert len(pitch) == len(text) 248 | mel = np.load(Path(self.processed_folder_path) / 'mel' / Path(Path(filepath).stem).with_suffix(".npy")) 249 | mel = normalize(mel, self.data_statistics['mel_mean'], self.data_statistics['mel_std']) 250 | assert mel.shape[-1] == sum(phoneme_durations) 251 | energy = np.load(Path(self.processed_folder_path) / 'energy' / Path(Path(filepath).stem).with_suffix(".npy")) 252 | energy = normalize(energy, self.data_statistics['energy_mean'], self.data_statistics['energy_std']) 253 | assert len(energy) == len(text) 254 | 255 | return {"x": text, "y": mel, "spk": spk, 'filepath': filepath, 256 | 'energy': energy, 'pitch': pitch, 'duration': phoneme_durations} 257 | 258 | def get_mel(self, filepath): 259 | audio, sr = ta.load(filepath) 260 | assert sr == self.sample_rate 261 | mel, energy = mel_spectrogram( 262 | audio, 263 | self.n_fft, 264 | self.n_mels, 265 | self.sample_rate, 266 | self.hop_length, 267 | self.win_length, 268 | self.f_min, 269 | self.f_max, 270 | center=False, 271 | ) 272 | return mel, energy 273 | 274 | def get_text(self, text, add_blank=True): 275 | text_norm = text_to_sequence(text, self.cleaners) 276 | if self.add_blank: 277 | text_norm = intersperse(text_norm, 0) 278 | text_norm = torch.IntTensor(text_norm) 279 | return text_norm 280 | 281 | def __getitem__(self, index): 282 | datapoint = self.get_datapoint(self.filepaths_and_text[index]) 283 | return datapoint 284 | 285 | def __len__(self): 286 | return len(self.filepaths_and_text) 287 | 288 | 289 | class TextMelBatchCollate: 290 | def __init__(self, n_spks): 291 | self.n_spks = n_spks 292 | 293 | def __call__(self, batch): 294 | B = len(batch) 295 | y_max_length = max([item["y"].shape[-1] for item in batch]) 296 | x_max_length = max([item["x"].shape[-1] for item in batch]) 297 | n_feats = batch[0]["y"].shape[-2] 298 | 299 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) 300 | x = torch.zeros((B, x_max_length), dtype=torch.long) 301 | pitches = torch.zeros((B, x_max_length), dtype=torch.float) 302 | energies = torch.zeros((B, x_max_length), dtype=torch.float) 303 | durations = torch.zeros((B, x_max_length), dtype=torch.long) 304 | 305 | 306 | y_lengths, x_lengths = [], [] 307 | spks = [] 308 | 309 | filepaths = [] 310 | for i, item in enumerate(batch): 311 | y_, x_ = item["y"], item["x"] 312 | y_lengths.append(y_.shape[-1]) 313 | x_lengths.append(x_.shape[-1]) 314 | y[i, :, : y_.shape[-1]] = to_torch(y_, torch.float32) 315 | x[i, : x_.shape[-1]] = to_torch(x_, torch.long) 316 | spks.append(item["spk"]) 317 | 318 | pitches[i, : item["pitch"].shape[-1]] = to_torch(item["pitch"], torch.float) 319 | energies[i, : item["energy"].shape[-1]] = to_torch(item["energy"], torch.float) 320 | durations[i, : item["duration"].shape[-1]] = to_torch(item["duration"], torch.float) 321 | filepaths.append(item['filepath']) 322 | 323 | y_lengths = torch.tensor(y_lengths, dtype=torch.long) 324 | x_lengths = torch.tensor(x_lengths, dtype=torch.long) 325 | spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None 326 | 327 | return { 328 | "x": x, 329 | "x_lengths": x_lengths, 330 | "y": y, 331 | "y_lengths": y_lengths, 332 | "spks": spks, 333 | 'pitches': pitches, 334 | 'energies': energies, 335 | 'durations': durations, 336 | 'filepaths': filepaths 337 | } --------------------------------------------------------------------------------