├── .gitignore ├── .project-root ├── 20k_tboard.png ├── LICENSE ├── LJSpeech_Sample_100_epochs.wav ├── README.md ├── architecture.jpg ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ ├── ljspeech.yaml │ └── vctk.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── ljspeech.yaml │ ├── ljspeech_min_memory.yaml │ └── multispeaker.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── aim.yaml │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── cfm │ │ └── default.yaml │ ├── decoder │ │ └── default.yaml │ ├── encoder │ │ └── default.yaml │ ├── optimizer │ │ └── adam.yaml │ ├── pflow.yaml │ └── scheduler │ │ └── default.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── encodec_poc.wav ├── init_tensorboard_11_13_23.png ├── notebooks ├── .gitkeep ├── dry_run.ipynb ├── dry_run_py.py ├── model_protoyping.ipynb └── synthesis.ipynb ├── pflow ├── __init__.py ├── cli.py ├── data │ ├── __init__.py │ ├── components │ │ └── __init__.py │ └── text_mel_datamodule.py ├── hifigan │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── denoiser.py │ ├── env.py │ ├── meldataset.py │ ├── models.py │ └── xutils.py ├── models │ ├── __init__.py │ ├── baselightningmodule.py │ ├── components │ │ ├── __init__.py │ │ ├── aligner.py │ │ ├── attentions.py │ │ ├── commons.py │ │ ├── decoder.py │ │ ├── flow_matching.py │ │ ├── speech_prompt_encoder.py │ │ ├── speech_prompt_encoder_v0.py │ │ ├── test.py │ │ ├── text_encoder.py │ │ ├── transformer.py │ │ ├── vits_modules.py │ │ ├── vits_posterior.py │ │ ├── vits_wn_decoder.py │ │ └── wn_pflow_decoder.py │ └── pflow_tts.py ├── onnx │ ├── __init__.py │ ├── export.py │ └── infer.py ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── numbers.py │ └── symbols.py ├── train.py └── utils │ ├── __init__.py │ ├── audio.py │ ├── generate_data_statistics.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── model.py │ ├── monotonic_align │ ├── __init__.py │ └── core.pyx │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py ├── requirements.txt ├── samples ├── download_54.wav ├── download_55.wav ├── download_56.wav ├── download_57.wav ├── download_58.wav ├── download_64.wav └── download_65.wav ├── setup.py └── val_out_tboard.png /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | __pycache__ 3 | .ipynb_checkpoints 4 | .*.swp 5 | 6 | build 7 | *.c 8 | pflow/utils/monotonic_align/*.pyd 9 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /20k_tboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/20k_tboard.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 p0p 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LJSpeech_Sample_100_epochs.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/LJSpeech_Sample_100_epochs.wav -------------------------------------------------------------------------------- /architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/architecture.jpg -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow.data.text_mel_datamodule.TextMelDataModule 2 | name: ljspeech 3 | train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt 4 | valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt 5 | batch_size: 32 6 | num_workers: 20 7 | pin_memory: True 8 | cleaners: [english_cleaners2] 9 | add_blank: True 10 | n_spks: 1 11 | n_fft: 1024 12 | n_feats: 80 13 | sample_rate: 22050 14 | hop_length: 256 15 | win_length: 1024 16 | f_min: 0 17 | f_max: 8000 18 | data_statistics: # Computed for ljspeech dataset 19 | mel_mean: -5.536622 20 | mel_std: 2.116101 21 | seed: ${seed} 22 | min_sample_size: 4 23 | -------------------------------------------------------------------------------- /configs/data/vctk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | _target_: pflow.data.text_mel_datamodule.TextMelDataModule 6 | name: vctk 7 | train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt 8 | valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt 9 | batch_size: 32 10 | add_blank: True 11 | n_spks: 109 12 | data_statistics: # Computed for vctk dataset 13 | mel_mean: -6.630575 14 | mel_std: 2.482914 15 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/cfm/default.yaml: -------------------------------------------------------------------------------- 1 | name: CFM 2 | solver: euler 3 | sigma_min: 1e-4 4 | -------------------------------------------------------------------------------- /configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | channels: [256, 256] 2 | dropout: 0.05 3 | attention_head_dim: 64 4 | n_blocks: 1 5 | num_mid_blocks: 2 6 | num_heads: 2 7 | act_fn: snakebeta 8 | -------------------------------------------------------------------------------- /configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: RoPE Encoder 2 | encoder_params: 3 | n_feats: ${model.n_feats} 4 | n_channels: 192 5 | filter_channels: 768 6 | filter_channels_dp: 256 7 | n_heads: 2 8 | n_layers: 6 9 | kernel_size: 3 10 | p_dropout: 0.1 11 | spk_emb_dim: 64 12 | n_spks: 1 13 | prenet: true 14 | 15 | duration_predictor_params: 16 | filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} 17 | kernel_size: 3 18 | p_dropout: ${model.encoder.encoder_params.p_dropout} 19 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /configs/model/pflow.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - decoder: default.yaml 5 | - cfm: default.yaml 6 | - optimizer: adam.yaml 7 | 8 | _target_: pflow.models.pflow_tts.pflowTTS 9 | n_vocab: 178 10 | n_spks: ${data.n_spks} 11 | spk_emb_dim: 64 12 | n_feats: 80 13 | data_statistics: ${data.data_statistics} 14 | out_size: null # Must be divisible by 4 15 | prompt_size: 264 #number of mel frames (3s at 22050 SR; 3*22050//256) 16 | dur_p_use_log: False -------------------------------------------------------------------------------- /configs/model/scheduler/default.yaml: -------------------------------------------------------------------------------- 1 | scheduler: 2 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 3 | _partial_: true 4 | mode: min 5 | factor: 0.1 6 | patience: 10 -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: pflow 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # transfer learning checkpoint path to initialize model weights 51 | transfer_ckpt_path: null 52 | 53 | # seed for random number generators in pytorch, numpy and python.random 54 | seed: 1234 55 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /encodec_poc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/encodec_poc.wav -------------------------------------------------------------------------------- /init_tensorboard_11_13_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/init_tensorboard_11_13_23.png -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/notebooks/.gitkeep -------------------------------------------------------------------------------- /notebooks/dry_run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('..')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 5, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "tensor(4.0884, grad_fn=) tensor(1.5378, grad_fn=) tensor(6.8176, grad_fn=)\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "from pflow.models.pflow_tts import pflowTTS\n", 28 | "import torch\n", 29 | "from dataclasses import dataclass\n", 30 | "\n", 31 | "@dataclass\n", 32 | "class DurationPredictorParams:\n", 33 | " filter_channels_dp: int\n", 34 | " kernel_size: int\n", 35 | " p_dropout: float\n", 36 | "\n", 37 | "@dataclass\n", 38 | "class EncoderParams:\n", 39 | " n_feats: int\n", 40 | " n_channels: int\n", 41 | " filter_channels: int\n", 42 | " filter_channels_dp: int\n", 43 | " n_heads: int\n", 44 | " n_layers: int\n", 45 | " kernel_size: int\n", 46 | " p_dropout: float\n", 47 | " spk_emb_dim: int\n", 48 | " n_spks: int\n", 49 | " prenet: bool\n", 50 | "\n", 51 | "@dataclass\n", 52 | "class CFMParams:\n", 53 | " name: str\n", 54 | " solver: str\n", 55 | " sigma_min: float\n", 56 | "\n", 57 | "# Example usage\n", 58 | "duration_predictor_params = DurationPredictorParams(\n", 59 | " filter_channels_dp=256,\n", 60 | " kernel_size=3,\n", 61 | " p_dropout=0.1\n", 62 | ")\n", 63 | "\n", 64 | "encoder_params = EncoderParams(\n", 65 | " n_feats=80,\n", 66 | " n_channels=192,\n", 67 | " filter_channels=768,\n", 68 | " filter_channels_dp=256,\n", 69 | " n_heads=2,\n", 70 | " n_layers=6,\n", 71 | " kernel_size=3,\n", 72 | " p_dropout=0.1,\n", 73 | " spk_emb_dim=64,\n", 74 | " n_spks=1,\n", 75 | " prenet=True\n", 76 | ")\n", 77 | "\n", 78 | "cfm_params = CFMParams(\n", 79 | " name='CFM',\n", 80 | " solver='euler',\n", 81 | " sigma_min=1e-4\n", 82 | ")\n", 83 | "\n", 84 | "@dataclass\n", 85 | "class EncoderOverallParams:\n", 86 | " encoder_type: str\n", 87 | " encoder_params: EncoderParams\n", 88 | " duration_predictor_params: DurationPredictorParams\n", 89 | "\n", 90 | "encoder_overall_params = EncoderOverallParams(\n", 91 | " encoder_type='RoPE Encoder',\n", 92 | " encoder_params=encoder_params,\n", 93 | " duration_predictor_params=duration_predictor_params\n", 94 | ")\n", 95 | "\n", 96 | "model = pflowTTS(\n", 97 | " n_vocab=100,\n", 98 | " n_feats=80,\n", 99 | " encoder=encoder_overall_params,\n", 100 | " decoder=None,\n", 101 | " cfm=cfm_params,\n", 102 | " data_statistics=None,\n", 103 | ")\n", 104 | "\n", 105 | "x = torch.randint(0, 100, (4, 20))\n", 106 | "x_lengths = torch.randint(10, 20, (4,))\n", 107 | "y = torch.randn(4, 80, 500)\n", 108 | "y_lengths = torch.randint(300, 500, (4,))\n", 109 | "\n", 110 | "dur_loss, prior_loss, diff_loss, attn = model(x, x_lengths, y, y_lengths)\n", 111 | "\n", 112 | "print(dur_loss, prior_loss, diff_loss)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "{'encoder_outputs': tensor([[[ 0.5504, 0.1342, 0.1342, ..., 1.2872, 1.2872, 0.5121],\n", 124 | " [ 0.2854, -0.0067, -0.0067, ..., -0.2164, -0.2164, 0.1162],\n", 125 | " [ 1.0909, 0.0971, 0.0971, ..., -0.4140, -0.4140, 0.0093],\n", 126 | " ...,\n", 127 | " [-0.6167, 0.0214, 0.0214, ..., 0.1322, 0.1322, 0.0024],\n", 128 | " [ 0.7357, 0.7161, 0.7161, ..., 0.0576, 0.0576, 0.1908],\n", 129 | " [-0.3782, -0.0351, -0.0351, ..., 0.5459, 0.5459, -0.3888]]]),\n", 130 | " 'decoder_outputs': tensor([[[ 0.2233, 0.6986, -0.4587, ..., 1.7759, -1.5674, -0.4869],\n", 131 | " [ 0.3813, 0.3476, 0.1070, ..., -1.4641, -0.0952, 1.0354],\n", 132 | " [ 1.4565, 1.2124, -0.3740, ..., -0.8082, 0.4223, -1.3775],\n", 133 | " ...,\n", 134 | " [ 1.7223, -1.4008, 0.5498, ..., 0.7512, 0.2925, -0.6928],\n", 135 | " [ 0.8185, 0.6916, 0.1859, ..., -0.4052, -0.8805, 0.8896],\n", 136 | " [ 0.9668, -0.3577, -0.0522, ..., -0.3391, 1.8045, 1.4378]]]),\n", 137 | " 'attn': tensor([[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 138 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 139 | " [0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 140 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 141 | " [0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 142 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 143 | " [0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 144 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 145 | " [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 146 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 147 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.,\n", 148 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 149 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,\n", 150 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 151 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,\n", 152 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 153 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 154 | " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 155 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 156 | " 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 157 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 158 | " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 159 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 160 | " 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 161 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 162 | " 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 163 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 164 | " 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],\n", 165 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 166 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", 167 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 168 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", 169 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 170 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", 171 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 172 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.],\n", 173 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 174 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", 175 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 176 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]),\n", 177 | " 'mel': tensor([[[ 0.2233, 0.6986, -0.4587, ..., 1.7759, -1.5674, -0.4869],\n", 178 | " [ 0.3813, 0.3476, 0.1070, ..., -1.4641, -0.0952, 1.0354],\n", 179 | " [ 1.4565, 1.2124, -0.3740, ..., -0.8082, 0.4223, -1.3775],\n", 180 | " ...,\n", 181 | " [ 1.7223, -1.4008, 0.5498, ..., 0.7512, 0.2925, -0.6928],\n", 182 | " [ 0.8185, 0.6916, 0.1859, ..., -0.4052, -0.8805, 0.8896],\n", 183 | " [ 0.9668, -0.3577, -0.0522, ..., -0.3391, 1.8045, 1.4378]]]),\n", 184 | " 'mel_lengths': tensor([32]),\n", 185 | " 'rtf': 1.1681655029296876}" 186 | ] 187 | }, 188 | "execution_count": 6, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "x = torch.randint(0, 100, (1, 20))\n", 195 | "x_lengths = torch.randint(10, 20, (1,))\n", 196 | "y_slice = torch.randn(1, 80, 264)\n", 197 | "\n", 198 | "model.synthesise(x, x_lengths, y_slice, n_timesteps=10)" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "pytorch", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.10.9" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 2 223 | } 224 | -------------------------------------------------------------------------------- /notebooks/dry_run_py.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | from pflow.models.pflow_tts import pflowTTS 5 | import torch 6 | from dataclasses import dataclass 7 | 8 | @dataclass 9 | class DurationPredictorParams: 10 | filter_channels_dp: int 11 | kernel_size: int 12 | p_dropout: float 13 | 14 | @dataclass 15 | class EncoderParams: 16 | n_feats: int 17 | n_channels: int 18 | filter_channels: int 19 | filter_channels_dp: int 20 | n_heads: int 21 | n_layers: int 22 | kernel_size: int 23 | p_dropout: float 24 | spk_emb_dim: int 25 | n_spks: int 26 | prenet: bool 27 | 28 | @dataclass 29 | class CFMParams: 30 | name: str 31 | solver: str 32 | sigma_min: float 33 | 34 | # Example usage 35 | duration_predictor_params = DurationPredictorParams( 36 | filter_channels_dp=256, 37 | kernel_size=3, 38 | p_dropout=0.1 39 | ) 40 | 41 | encoder_params = EncoderParams( 42 | n_feats=80, 43 | n_channels=192, 44 | filter_channels=768, 45 | filter_channels_dp=256, 46 | n_heads=2, 47 | n_layers=6, 48 | kernel_size=3, 49 | p_dropout=0.1, 50 | spk_emb_dim=64, 51 | n_spks=1, 52 | prenet=True 53 | ) 54 | 55 | cfm_params = CFMParams( 56 | name='CFM', 57 | solver='euler', 58 | sigma_min=1e-4 59 | ) 60 | 61 | @dataclass 62 | class EncoderOverallParams: 63 | encoder_type: str 64 | encoder_params: EncoderParams 65 | duration_predictor_params: DurationPredictorParams 66 | 67 | encoder_overall_params = EncoderOverallParams( 68 | encoder_type='RoPE Encoder', 69 | encoder_params=encoder_params, 70 | duration_predictor_params=duration_predictor_params 71 | ) 72 | 73 | @dataclass 74 | class DecoderParams: 75 | channels: tuple 76 | dropout: float 77 | attention_head_dim: int 78 | n_blocks: int 79 | num_mid_blocks: int 80 | num_heads: int 81 | act_fn: str 82 | 83 | decoder_params = DecoderParams( 84 | channels=(256, 256), 85 | dropout=0.05, 86 | attention_head_dim=64, 87 | n_blocks=1, 88 | num_mid_blocks=2, 89 | num_heads=2, 90 | act_fn='snakebeta', 91 | ) 92 | 93 | model = pflowTTS( 94 | n_vocab=100, 95 | n_feats=80, 96 | encoder=encoder_overall_params, 97 | decoder=decoder_params.__dict__, 98 | cfm=cfm_params, 99 | data_statistics=None, 100 | ) 101 | 102 | x = torch.randint(0, 100, (4, 20)) 103 | x_lengths = torch.randint(10, 20, (4,)) 104 | y = torch.randn(4, 80, 500) 105 | y_lengths = torch.randint(300, 500, (4,)) 106 | 107 | dur_loss, prior_loss, diff_loss, attn = model(x, x_lengths, y, y_lengths) 108 | 109 | print(dur_loss, prior_loss, diff_loss) 110 | -------------------------------------------------------------------------------- /notebooks/synthesis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%env CUDA_VISIBLE_DEVICES=\"0\"\n", 11 | "import sys\n", 12 | "sys.path.append('..')" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "8d5876c0-b47e-4c80-9e9c-62550f81b64e", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import datetime as dt\n", 23 | "from pathlib import Path\n", 24 | "\n", 25 | "import IPython.display as ipd\n", 26 | "import numpy as np\n", 27 | "import soundfile as sf\n", 28 | "import torch\n", 29 | "from tqdm.auto import tqdm\n", 30 | "\n", 31 | "from pflow.models.pflow_tts import pflowTTS\n", 32 | "from pflow.text import sequence_to_text, text_to_sequence\n", 33 | "from pflow.utils.model import denormalize\n", 34 | "from pflow.utils.utils import get_user_data_dir, intersperse\n", 35 | "\n", 36 | "from pflow.hifigan.config import v1\n", 37 | "from pflow.hifigan.denoiser import Denoiser\n", 38 | "from pflow.hifigan.env import AttrDict\n", 39 | "from pflow.hifigan.models import Generator as HiFiGAN" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "b1a30306-588c-4f22-8d9b-e2676880b0e5", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "%load_ext autoreload\n", 50 | "%autoreload 2\n", 51 | "%matplotlib inline\n", 52 | "# This allows for real time code changes being reflected in the notebook, no need to restart the kernel" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "a312856b-01a9-4d75-a4c8-4666dffa0692", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "23b3bbf9", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "device = \"cpu\"" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "3d3f3db9", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "get_user_data_dir()" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "id": "88f3b3c3-d014-443b-84eb-e143cdec3e21", 88 | "metadata": {}, 89 | "source": [ 90 | "## Filepaths" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "7640a4c1-44ce-447c-a8ff-45012fb7bddd", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "PFLOW_CHECKPOINT = \"\" #fill in the path to the pflow checkpoint\n", 101 | "HIFIGAN_CHECKPOINT = get_user_data_dir() / \"hifigan_T2_v1\"\n", 102 | "OUTPUT_FOLDER = \"synth_output\"" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "6477a3a9-71f2-4d2f-bb86-bdf3e31c2461", 108 | "metadata": {}, 109 | "source": [ 110 | "## Load TTS model" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "26a16230-04ba-4825-a844-2fb5ab945e24", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "def load_model(checkpoint_path):\n", 121 | " model = pflowTTS.load_from_checkpoint(checkpoint_path, map_location=device)\n", 122 | " model.eval()\n", 123 | " return model\n", 124 | "count_params = lambda x: f\"{sum(p.numel() for p in x.parameters()):,}\"\n", 125 | "\n", 126 | "\n", 127 | "model = load_model(PFLOW_CHECKPOINT)\n", 128 | "print(f\"Model loaded! Parameter count: {count_params(model)}\")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "def load_vocoder(checkpoint_path):\n", 138 | " h = AttrDict(v1)\n", 139 | " hifigan = HiFiGAN(h).to(device)\n", 140 | " hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])\n", 141 | " _ = hifigan.eval()\n", 142 | " hifigan.remove_weight_norm()\n", 143 | " return hifigan\n", 144 | "\n", 145 | "vocoder = load_vocoder(HIFIGAN_CHECKPOINT)\n", 146 | "denoiser = Denoiser(vocoder, mode='zeros')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "4cbc2ba0-09ff-40e2-9e60-6b77b534f9fb", 152 | "metadata": {}, 153 | "source": [ 154 | "### Helper functions to synthesise" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "880a1879-24fd-4757-849c-850339120796", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "@torch.inference_mode()\n", 165 | "def process_text(text: str):\n", 166 | " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n", 167 | " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", 168 | " x_phones = sequence_to_text(x.squeeze(0).tolist())\n", 169 | " return {\n", 170 | " 'x_orig': text,\n", 171 | " 'x': x,\n", 172 | " 'x_lengths': x_lengths,\n", 173 | " 'x_phones': x_phones\n", 174 | " }\n", 175 | "\n", 176 | "\n", 177 | "@torch.inference_mode()\n", 178 | "def synthesise(text, prompt):\n", 179 | " text_processed = process_text(text)\n", 180 | " start_t = dt.datetime.now()\n", 181 | " output = model.synthesise(\n", 182 | " text_processed['x'], \n", 183 | " text_processed['x_lengths'],\n", 184 | " n_timesteps=n_timesteps,\n", 185 | " temperature=temperature,\n", 186 | " length_scale=length_scale,\n", 187 | " prompt=prompt\n", 188 | " )\n", 189 | " # merge everything to one dict \n", 190 | " output.update({'start_t': start_t, **text_processed})\n", 191 | " return output\n", 192 | "\n", 193 | "@torch.inference_mode()\n", 194 | "def to_waveform(mel, vocoder):\n", 195 | " audio = vocoder(mel).clamp(-1, 1)\n", 196 | " audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()\n", 197 | " return audio.cpu().squeeze()\n", 198 | " \n", 199 | "# def save_to_folder(filename: str, output: dict, folder: str):\n", 200 | "# folder = Path(folder)\n", 201 | "# folder.mkdir(exist_ok=True, parents=True)\n", 202 | "# np.save(folder / f'{filename}', output['mel'].cpu().numpy())\n", 203 | "# sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "id": "78f857e3-2ef7-4c86-b776-596c4d3cf875", 209 | "metadata": {}, 210 | "source": [ 211 | "## Setup text to synthesise" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "2e0a9acd-0845-4192-ba09-b9683e28a3ac", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "texts = [\n", 222 | "\n", 223 | "]" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "a9da9e2d-99b9-4c6f-8a08-c828e2cba121", 229 | "metadata": {}, 230 | "source": [ 231 | "### Hyperparameters" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "f0d216e5-4895-4da8-9d24-9e61021d2556", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "## Number of ODE Solver steps\n", 242 | "n_timesteps = 10\n", 243 | "\n", 244 | "## Changes to the speaking rate\n", 245 | "length_scale=1.0\n", 246 | "\n", 247 | "## Sampling temperature\n", 248 | "temperature = 0.667" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "import torchaudio\n", 258 | "import glob\n", 259 | "wav_files = glob.glob(\"/*.wav\") ## fill in the path to the LJSpeech-1.1 dataset\n", 260 | "wav, sr = torchaudio.load(wav_files[0])\n", 261 | "from pflow.data.text_mel_datamodule import mel_spectrogram\n", 262 | "mel = mel_spectrogram(\n", 263 | " wav,\n", 264 | " 1024,\n", 265 | " 80,\n", 266 | " 22050,\n", 267 | " 256,\n", 268 | " 1024,\n", 269 | " 0,\n", 270 | " 8000,\n", 271 | " center=False,\n", 272 | " )" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "prompt = mel #load a mel spectrogram from a file and paste it here; check dimensions [batch, channels, time]" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "id": "b93aac89-c7f8-4975-8510-4e763c9689f4", 287 | "metadata": {}, 288 | "source": [ 289 | "## Synthesis" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "5a227963-aa12-43b9-a706-1168b6fc0ba5", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "outputs, rtfs = [], []\n", 300 | "rtfs_w = []\n", 301 | "for i, text in enumerate(tqdm(texts)):\n", 302 | " prompt = prompt[:,:,:264]\n", 303 | " from pflow.utils.model import normalize\n", 304 | " prompt = normalize(prompt, model.mel_mean, model.mel_std)\n", 305 | " output = synthesise(text, prompt) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))\n", 306 | " output['waveform'] = to_waveform(output['mel'], vocoder)\n", 307 | "\n", 308 | " # Compute Real Time Factor (RTF) with HiFi-GAN\n", 309 | " t = (dt.datetime.now() - output['start_t']).total_seconds()\n", 310 | " rtf_w = t * 22050 / (output['waveform'].shape[-1])\n", 311 | "\n", 312 | " ## Pretty print\n", 313 | " print(f\"{'*' * 53}\")\n", 314 | " print(f\"Input text - {i}\")\n", 315 | " print(f\"{'-' * 53}\")\n", 316 | " print(output['x_orig'])\n", 317 | " print(f\"{'*' * 53}\")\n", 318 | " print(f\"Phonetised text - {i}\")\n", 319 | " print(f\"{'-' * 53}\")\n", 320 | " print(output['x_phones'])\n", 321 | " print(f\"{'*' * 53}\")\n", 322 | " print(f\"RTF:\\t\\t{output['rtf']:.6f}\")\n", 323 | " print(f\"RTF Waveform:\\t{rtf_w:.6f}\")\n", 324 | " rtfs.append(output['rtf'])\n", 325 | " rtfs_w.append(rtf_w)\n", 326 | "\n", 327 | " # Display the synthesised waveform\n", 328 | " ipd.display(ipd.Audio(output['waveform'], rate=22050))\n", 329 | "\n", 330 | " ## Save the generated waveform\n", 331 | "# save_to_folder(i, output, OUTPUT_FOLDER)\n", 332 | "\n", 333 | "print(f\"Number of ODE steps: {n_timesteps}\")\n", 334 | "print(f\"Mean RTF:\\t\\t\\t\\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}\")\n", 335 | "print(f\"Mean RTF Waveform (incl. vocoder):\\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}\")" 336 | ] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3 (ipykernel)", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.9.13" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 5 360 | } 361 | -------------------------------------------------------------------------------- /pflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/__init__.py -------------------------------------------------------------------------------- /pflow/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/data/__init__.py -------------------------------------------------------------------------------- /pflow/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/data/components/__init__.py -------------------------------------------------------------------------------- /pflow/data/text_mel_datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | import torchaudio as ta 6 | from lightning import LightningDataModule 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from pflow.text import text_to_sequence 10 | from pflow.utils.audio import mel_spectrogram 11 | from pflow.utils.model import fix_len_compatibility, normalize 12 | from pflow.utils.utils import intersperse 13 | 14 | 15 | def parse_filelist(filelist_path, split_char="|"): 16 | with open(filelist_path, encoding="utf-8") as f: 17 | filepaths_and_text = [line.strip().split(split_char) for line in f] 18 | return filepaths_and_text 19 | 20 | 21 | class TextMelDataModule(LightningDataModule): 22 | def __init__( # pylint: disable=unused-argument 23 | self, 24 | name, 25 | train_filelist_path, 26 | valid_filelist_path, 27 | batch_size, 28 | num_workers, 29 | pin_memory, 30 | cleaners, 31 | add_blank, 32 | n_spks, 33 | n_fft, 34 | n_feats, 35 | sample_rate, 36 | hop_length, 37 | win_length, 38 | f_min, 39 | f_max, 40 | data_statistics, 41 | seed, 42 | min_sample_size, 43 | ): 44 | super().__init__() 45 | 46 | # this line allows to access init params with 'self.hparams' attribute 47 | # also ensures init params will be stored in ckpt 48 | self.save_hyperparameters(logger=False) 49 | 50 | def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument 51 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 52 | 53 | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be 54 | careful not to execute things like random split twice! 55 | """ 56 | # load and split datasets only if not loaded already 57 | 58 | self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 59 | self.hparams.train_filelist_path, 60 | self.hparams.n_spks, 61 | self.hparams.cleaners, 62 | self.hparams.add_blank, 63 | self.hparams.n_fft, 64 | self.hparams.n_feats, 65 | self.hparams.sample_rate, 66 | self.hparams.hop_length, 67 | self.hparams.win_length, 68 | self.hparams.f_min, 69 | self.hparams.f_max, 70 | self.hparams.data_statistics, 71 | self.hparams.seed, 72 | self.hparams.min_sample_size, 73 | ) 74 | self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 75 | self.hparams.valid_filelist_path, 76 | self.hparams.n_spks, 77 | self.hparams.cleaners, 78 | self.hparams.add_blank, 79 | self.hparams.n_fft, 80 | self.hparams.n_feats, 81 | self.hparams.sample_rate, 82 | self.hparams.hop_length, 83 | self.hparams.win_length, 84 | self.hparams.f_min, 85 | self.hparams.f_max, 86 | self.hparams.data_statistics, 87 | self.hparams.seed, 88 | self.hparams.min_sample_size, 89 | ) 90 | 91 | def train_dataloader(self): 92 | return DataLoader( 93 | dataset=self.trainset, 94 | batch_size=self.hparams.batch_size, 95 | num_workers=self.hparams.num_workers, 96 | pin_memory=self.hparams.pin_memory, 97 | shuffle=True, 98 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 99 | ) 100 | 101 | def val_dataloader(self): 102 | return DataLoader( 103 | dataset=self.validset, 104 | batch_size=self.hparams.batch_size, 105 | num_workers=self.hparams.num_workers, 106 | pin_memory=self.hparams.pin_memory, 107 | shuffle=False, 108 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 109 | ) 110 | 111 | def teardown(self, stage: Optional[str] = None): 112 | """Clean up after fit or test.""" 113 | pass # pylint: disable=unnecessary-pass 114 | 115 | def state_dict(self): # pylint: disable=no-self-use 116 | """Extra things to save to checkpoint.""" 117 | return {} 118 | 119 | def load_state_dict(self, state_dict: Dict[str, Any]): 120 | """Things to do when loading checkpoint.""" 121 | pass # pylint: disable=unnecessary-pass 122 | 123 | 124 | class TextMelDataset(torch.utils.data.Dataset): 125 | def __init__( 126 | self, 127 | filelist_path, 128 | n_spks, 129 | cleaners, 130 | add_blank=True, 131 | n_fft=1024, 132 | n_mels=80, 133 | sample_rate=22050, 134 | hop_length=256, 135 | win_length=1024, 136 | f_min=0.0, 137 | f_max=8000, 138 | data_parameters=None, 139 | seed=None, 140 | min_sample_size=4, 141 | ): 142 | self.filepaths_and_text = parse_filelist(filelist_path) 143 | self.n_spks = n_spks 144 | self.cleaners = cleaners 145 | self.add_blank = add_blank 146 | self.n_fft = n_fft 147 | self.n_mels = n_mels 148 | self.sample_rate = sample_rate 149 | self.hop_length = hop_length 150 | self.win_length = win_length 151 | self.f_min = f_min 152 | self.f_max = f_max 153 | self.min_sample_size = min_sample_size 154 | if data_parameters is not None: 155 | self.data_parameters = data_parameters 156 | else: 157 | self.data_parameters = {"mel_mean": 0, "mel_std": 1} 158 | random.seed(seed) 159 | random.shuffle(self.filepaths_and_text) 160 | 161 | def get_datapoint(self, filepath_and_text): 162 | if self.n_spks > 1: 163 | filepath, spk, text = ( 164 | filepath_and_text[0], 165 | int(filepath_and_text[1]), 166 | filepath_and_text[2], 167 | ) 168 | else: 169 | filepath, text = filepath_and_text[0], filepath_and_text[1] 170 | spk = None 171 | 172 | text = self.get_text(text, add_blank=self.add_blank) 173 | mel, audio = self.get_mel(filepath) 174 | # TODO: make dictionary to get different spec for same speaker 175 | # right now naively repeating target mel for testing purposes 176 | return {"x": text, "y": mel, "spk": spk, "wav":audio} 177 | 178 | def get_mel(self, filepath): 179 | audio, sr = ta.load(filepath) 180 | assert sr == self.sample_rate 181 | mel = mel_spectrogram( 182 | audio, 183 | self.n_fft, 184 | self.n_mels, 185 | self.sample_rate, 186 | self.hop_length, 187 | self.win_length, 188 | self.f_min, 189 | self.f_max, 190 | center=False, 191 | ).squeeze() 192 | mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) 193 | return mel, audio 194 | 195 | def get_text(self, text, add_blank=True): 196 | text_norm = text_to_sequence(text, self.cleaners) 197 | if self.add_blank: 198 | text_norm = intersperse(text_norm, 0) 199 | text_norm = torch.IntTensor(text_norm) 200 | return text_norm 201 | 202 | def __getitem__(self, index): 203 | datapoint = self.get_datapoint(self.filepaths_and_text[index]) 204 | if datapoint["wav"].shape[1] <= self.min_sample_size * self.sample_rate: 205 | ''' 206 | skip datapoint if too short (<4s , prompt is 3s) 207 | TODO To not waste data, we can concatenate wavs less than 3s and use them 208 | TODO as a hyperparameter; multispeaker dataset can use another wav of same speaker 209 | ''' 210 | return self.__getitem__(random.randint(0, len(self.filepaths_and_text)-1)) 211 | return datapoint 212 | 213 | def __len__(self): 214 | return len(self.filepaths_and_text) 215 | 216 | 217 | class TextMelBatchCollate: 218 | def __init__(self, n_spks): 219 | self.n_spks = n_spks 220 | 221 | def __call__(self, batch): 222 | B = len(batch) 223 | y_max_length = max([item["y"].shape[-1] for item in batch]) 224 | y_max_length = fix_len_compatibility(y_max_length) 225 | wav_max_length = y_max_length * 256 226 | x_max_length = max([item["x"].shape[-1] for item in batch]) 227 | n_feats = batch[0]["y"].shape[-2] 228 | 229 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) 230 | x = torch.zeros((B, x_max_length), dtype=torch.long) 231 | wav = torch.zeros((B, 1, wav_max_length), dtype=torch.float32) 232 | y_lengths, x_lengths = [], [] 233 | wav_lengths = [] 234 | spks = [] 235 | for i, item in enumerate(batch): 236 | y_, x_ = item["y"], item["x"] 237 | wav_ = item["wav"][:,:wav_max_length] if item["wav"].shape[-1] > wav_max_length else item["wav"] 238 | y_lengths.append(y_.shape[-1]) 239 | x_lengths.append(x_.shape[-1]) 240 | wav_lengths.append(wav_.shape[-1]) 241 | y[i, :, : y_.shape[-1]] = y_ 242 | x[i, : x_.shape[-1]] = x_ 243 | wav[i, :, : wav_.shape[-1]] = wav_ 244 | spks.append(item["spk"]) 245 | 246 | y_lengths = torch.tensor(y_lengths, dtype=torch.long) 247 | x_lengths = torch.tensor(x_lengths, dtype=torch.long) 248 | wav_lengths = torch.tensor(wav_lengths, dtype=torch.long) 249 | spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None 250 | 251 | return { 252 | "x": x, 253 | "x_lengths": x_lengths, 254 | "y": y, 255 | "y_lengths": y_lengths, 256 | "spks": spks, 257 | "wav":wav, 258 | "wav_lengths":wav_lengths, 259 | "prompt_spec": y, 260 | "prompt_lengths": y_lengths, 261 | } -------------------------------------------------------------------------------- /pflow/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pflow/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /pflow/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/hifigan/__init__.py -------------------------------------------------------------------------------- /pflow/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /pflow/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /pflow/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /pflow/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /pflow/hifigan/models.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 7 | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 8 | 9 | from .xutils import get_padding, init_weights 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class ResBlock1(torch.nn.Module): 15 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 16 | super().__init__() 17 | self.h = h 18 | self.convs1 = nn.ModuleList( 19 | [ 20 | weight_norm( 21 | Conv1d( 22 | channels, 23 | channels, 24 | kernel_size, 25 | 1, 26 | dilation=dilation[0], 27 | padding=get_padding(kernel_size, dilation[0]), 28 | ) 29 | ), 30 | weight_norm( 31 | Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | 1, 36 | dilation=dilation[1], 37 | padding=get_padding(kernel_size, dilation[1]), 38 | ) 39 | ), 40 | weight_norm( 41 | Conv1d( 42 | channels, 43 | channels, 44 | kernel_size, 45 | 1, 46 | dilation=dilation[2], 47 | padding=get_padding(kernel_size, dilation[2]), 48 | ) 49 | ), 50 | ] 51 | ) 52 | self.convs1.apply(init_weights) 53 | 54 | self.convs2 = nn.ModuleList( 55 | [ 56 | weight_norm( 57 | Conv1d( 58 | channels, 59 | channels, 60 | kernel_size, 61 | 1, 62 | dilation=1, 63 | padding=get_padding(kernel_size, 1), 64 | ) 65 | ), 66 | weight_norm( 67 | Conv1d( 68 | channels, 69 | channels, 70 | kernel_size, 71 | 1, 72 | dilation=1, 73 | padding=get_padding(kernel_size, 1), 74 | ) 75 | ), 76 | weight_norm( 77 | Conv1d( 78 | channels, 79 | channels, 80 | kernel_size, 81 | 1, 82 | dilation=1, 83 | padding=get_padding(kernel_size, 1), 84 | ) 85 | ), 86 | ] 87 | ) 88 | self.convs2.apply(init_weights) 89 | 90 | def forward(self, x): 91 | for c1, c2 in zip(self.convs1, self.convs2): 92 | xt = F.leaky_relu(x, LRELU_SLOPE) 93 | xt = c1(xt) 94 | xt = F.leaky_relu(xt, LRELU_SLOPE) 95 | xt = c2(xt) 96 | x = xt + x 97 | return x 98 | 99 | def remove_weight_norm(self): 100 | for l in self.convs1: 101 | remove_weight_norm(l) 102 | for l in self.convs2: 103 | remove_weight_norm(l) 104 | 105 | 106 | class ResBlock2(torch.nn.Module): 107 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 108 | super().__init__() 109 | self.h = h 110 | self.convs = nn.ModuleList( 111 | [ 112 | weight_norm( 113 | Conv1d( 114 | channels, 115 | channels, 116 | kernel_size, 117 | 1, 118 | dilation=dilation[0], 119 | padding=get_padding(kernel_size, dilation[0]), 120 | ) 121 | ), 122 | weight_norm( 123 | Conv1d( 124 | channels, 125 | channels, 126 | kernel_size, 127 | 1, 128 | dilation=dilation[1], 129 | padding=get_padding(kernel_size, dilation[1]), 130 | ) 131 | ), 132 | ] 133 | ) 134 | self.convs.apply(init_weights) 135 | 136 | def forward(self, x): 137 | for c in self.convs: 138 | xt = F.leaky_relu(x, LRELU_SLOPE) 139 | xt = c(xt) 140 | x = xt + x 141 | return x 142 | 143 | def remove_weight_norm(self): 144 | for l in self.convs: 145 | remove_weight_norm(l) 146 | 147 | 148 | class Generator(torch.nn.Module): 149 | def __init__(self, h): 150 | super().__init__() 151 | self.h = h 152 | self.num_kernels = len(h.resblock_kernel_sizes) 153 | self.num_upsamples = len(h.upsample_rates) 154 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 155 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 156 | 157 | self.ups = nn.ModuleList() 158 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 159 | self.ups.append( 160 | weight_norm( 161 | ConvTranspose1d( 162 | h.upsample_initial_channel // (2**i), 163 | h.upsample_initial_channel // (2 ** (i + 1)), 164 | k, 165 | u, 166 | padding=(k - u) // 2, 167 | ) 168 | ) 169 | ) 170 | 171 | self.resblocks = nn.ModuleList() 172 | for i in range(len(self.ups)): 173 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 174 | for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 175 | self.resblocks.append(resblock(h, ch, k, d)) 176 | 177 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 178 | self.ups.apply(init_weights) 179 | self.conv_post.apply(init_weights) 180 | 181 | def forward(self, x): 182 | x = self.conv_pre(x) 183 | for i in range(self.num_upsamples): 184 | x = F.leaky_relu(x, LRELU_SLOPE) 185 | x = self.ups[i](x) 186 | xs = None 187 | for j in range(self.num_kernels): 188 | if xs is None: 189 | xs = self.resblocks[i * self.num_kernels + j](x) 190 | else: 191 | xs += self.resblocks[i * self.num_kernels + j](x) 192 | x = xs / self.num_kernels 193 | x = F.leaky_relu(x) 194 | x = self.conv_post(x) 195 | x = torch.tanh(x) 196 | 197 | return x 198 | 199 | def remove_weight_norm(self): 200 | print("Removing weight norm...") 201 | for l in self.ups: 202 | remove_weight_norm(l) 203 | for l in self.resblocks: 204 | l.remove_weight_norm() 205 | remove_weight_norm(self.conv_pre) 206 | remove_weight_norm(self.conv_post) 207 | 208 | 209 | class DiscriminatorP(torch.nn.Module): 210 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 211 | super().__init__() 212 | self.period = period 213 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 214 | self.convs = nn.ModuleList( 215 | [ 216 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 217 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 218 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 219 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 220 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 221 | ] 222 | ) 223 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 224 | 225 | def forward(self, x): 226 | fmap = [] 227 | 228 | # 1d to 2d 229 | b, c, t = x.shape 230 | if t % self.period != 0: # pad first 231 | n_pad = self.period - (t % self.period) 232 | x = F.pad(x, (0, n_pad), "reflect") 233 | t = t + n_pad 234 | x = x.view(b, c, t // self.period, self.period) 235 | 236 | for l in self.convs: 237 | x = l(x) 238 | x = F.leaky_relu(x, LRELU_SLOPE) 239 | fmap.append(x) 240 | x = self.conv_post(x) 241 | fmap.append(x) 242 | x = torch.flatten(x, 1, -1) 243 | 244 | return x, fmap 245 | 246 | 247 | class MultiPeriodDiscriminator(torch.nn.Module): 248 | def __init__(self): 249 | super().__init__() 250 | self.discriminators = nn.ModuleList( 251 | [ 252 | DiscriminatorP(2), 253 | DiscriminatorP(3), 254 | DiscriminatorP(5), 255 | DiscriminatorP(7), 256 | DiscriminatorP(11), 257 | ] 258 | ) 259 | 260 | def forward(self, y, y_hat): 261 | y_d_rs = [] 262 | y_d_gs = [] 263 | fmap_rs = [] 264 | fmap_gs = [] 265 | for _, d in enumerate(self.discriminators): 266 | y_d_r, fmap_r = d(y) 267 | y_d_g, fmap_g = d(y_hat) 268 | y_d_rs.append(y_d_r) 269 | fmap_rs.append(fmap_r) 270 | y_d_gs.append(y_d_g) 271 | fmap_gs.append(fmap_g) 272 | 273 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 274 | 275 | 276 | class DiscriminatorS(torch.nn.Module): 277 | def __init__(self, use_spectral_norm=False): 278 | super().__init__() 279 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 280 | self.convs = nn.ModuleList( 281 | [ 282 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 283 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 284 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 285 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 286 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 287 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 288 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 289 | ] 290 | ) 291 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 292 | 293 | def forward(self, x): 294 | fmap = [] 295 | for l in self.convs: 296 | x = l(x) 297 | x = F.leaky_relu(x, LRELU_SLOPE) 298 | fmap.append(x) 299 | x = self.conv_post(x) 300 | fmap.append(x) 301 | x = torch.flatten(x, 1, -1) 302 | 303 | return x, fmap 304 | 305 | 306 | class MultiScaleDiscriminator(torch.nn.Module): 307 | def __init__(self): 308 | super().__init__() 309 | self.discriminators = nn.ModuleList( 310 | [ 311 | DiscriminatorS(use_spectral_norm=True), 312 | DiscriminatorS(), 313 | DiscriminatorS(), 314 | ] 315 | ) 316 | self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) 317 | 318 | def forward(self, y, y_hat): 319 | y_d_rs = [] 320 | y_d_gs = [] 321 | fmap_rs = [] 322 | fmap_gs = [] 323 | for i, d in enumerate(self.discriminators): 324 | if i != 0: 325 | y = self.meanpools[i - 1](y) 326 | y_hat = self.meanpools[i - 1](y_hat) 327 | y_d_r, fmap_r = d(y) 328 | y_d_g, fmap_g = d(y_hat) 329 | y_d_rs.append(y_d_r) 330 | fmap_rs.append(fmap_r) 331 | y_d_gs.append(y_d_g) 332 | fmap_gs.append(fmap_g) 333 | 334 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 335 | 336 | 337 | def feature_loss(fmap_r, fmap_g): 338 | loss = 0 339 | for dr, dg in zip(fmap_r, fmap_g): 340 | for rl, gl in zip(dr, dg): 341 | loss += torch.mean(torch.abs(rl - gl)) 342 | 343 | return loss * 2 344 | 345 | 346 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 347 | loss = 0 348 | r_losses = [] 349 | g_losses = [] 350 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 351 | r_loss = torch.mean((1 - dr) ** 2) 352 | g_loss = torch.mean(dg**2) 353 | loss += r_loss + g_loss 354 | r_losses.append(r_loss.item()) 355 | g_losses.append(g_loss.item()) 356 | 357 | return loss, r_losses, g_losses 358 | 359 | 360 | def generator_loss(disc_outputs): 361 | loss = 0 362 | gen_losses = [] 363 | for dg in disc_outputs: 364 | l = torch.mean((1 - dg) ** 2) 365 | gen_losses.append(l) 366 | loss += l 367 | 368 | return loss, gen_losses 369 | -------------------------------------------------------------------------------- /pflow/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /pflow/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/models/__init__.py -------------------------------------------------------------------------------- /pflow/models/baselightningmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a base lightning module that can be used to train a model. 3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. 4 | """ 5 | import inspect 6 | from abc import ABC 7 | from typing import Any, Dict 8 | 9 | import torch 10 | from lightning import LightningModule 11 | from lightning.pytorch.utilities import grad_norm 12 | 13 | from pflow import utils 14 | from pflow.utils.utils import plot_tensor 15 | from pflow.models.components import commons 16 | 17 | log = utils.get_pylogger(__name__) 18 | 19 | 20 | class BaseLightningClass(LightningModule, ABC): 21 | def update_data_statistics(self, data_statistics): 22 | if data_statistics is None: 23 | data_statistics = { 24 | "mel_mean": 0.0, 25 | "mel_std": 1.0, 26 | } 27 | 28 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) 29 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) 30 | 31 | def configure_optimizers(self) -> Any: 32 | optimizer = self.hparams.optimizer(params=self.parameters()) 33 | if self.hparams.scheduler not in (None, {}): 34 | scheduler_args = {} 35 | # Manage last epoch for exponential schedulers 36 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: 37 | if hasattr(self, "ckpt_loaded_epoch"): 38 | current_epoch = self.ckpt_loaded_epoch - 1 39 | else: 40 | current_epoch = -1 41 | 42 | scheduler_args.update({"optimizer": optimizer}) 43 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args) 44 | print(self.ckpt_loaded_epoch - 1) 45 | if hasattr(self, "ckpt_loaded_epoch"): 46 | scheduler.last_epoch = self.ckpt_loaded_epoch - 1 47 | else: 48 | scheduler.last_epoch = -1 49 | return { 50 | "optimizer": optimizer, 51 | "lr_scheduler": { 52 | "scheduler": scheduler, 53 | # "interval": self.hparams.scheduler.lightning_args.interval, 54 | # "frequency": self.hparams.scheduler.lightning_args.frequency, 55 | # "name": "learning_rate", 56 | "monitor": "val_loss", 57 | }, 58 | } 59 | 60 | return {"optimizer": optimizer} 61 | 62 | def get_losses(self, batch): 63 | x, x_lengths = batch["x"], batch["x_lengths"] 64 | y, y_lengths = batch["y"], batch["y_lengths"] 65 | # prompt_spec = batch["prompt_spec"] 66 | # prompt_lengths = batch["prompt_lengths"] 67 | # prompt_slice, ids_slice = commons.rand_slice_segments( 68 | # prompt_spec, 69 | # prompt_lengths, 70 | # self.prompt_size 71 | # ) 72 | prompt_slice = None 73 | dur_loss, prior_loss, diff_loss, attn = self( 74 | x=x, 75 | x_lengths=x_lengths, 76 | y=y, 77 | y_lengths=y_lengths, 78 | prompt=prompt_slice, 79 | ) 80 | return ({ 81 | "dur_loss": dur_loss, 82 | "prior_loss": prior_loss, 83 | "diff_loss": diff_loss, 84 | }, 85 | { 86 | "attn": attn 87 | } 88 | ) 89 | 90 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 91 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init 92 | 93 | def training_step(self, batch: Any, batch_idx: int): 94 | loss_dict, attn_dict = self.get_losses(batch) 95 | 96 | self.log( 97 | "step", 98 | float(self.global_step), 99 | on_step=True, 100 | on_epoch=True, 101 | logger=True, 102 | sync_dist=True, 103 | ) 104 | 105 | self.log( 106 | "sub_loss/train_dur_loss", 107 | loss_dict["dur_loss"], 108 | on_step=True, 109 | on_epoch=True, 110 | logger=True, 111 | sync_dist=True, 112 | ) 113 | self.log( 114 | "sub_loss/train_prior_loss", 115 | loss_dict["prior_loss"], 116 | on_step=True, 117 | on_epoch=True, 118 | logger=True, 119 | sync_dist=True, 120 | ) 121 | self.log( 122 | "sub_loss/train_diff_loss", 123 | loss_dict["diff_loss"], 124 | on_step=True, 125 | on_epoch=True, 126 | logger=True, 127 | sync_dist=True, 128 | ) 129 | 130 | total_loss = sum(loss_dict.values()) 131 | self.log( 132 | "loss/train", 133 | total_loss, 134 | on_step=True, 135 | on_epoch=True, 136 | logger=True, 137 | prog_bar=True, 138 | sync_dist=True, 139 | ) 140 | attn = attn_dict["attn"][0] 141 | self.logger.experiment.add_image( 142 | f"train/alignment", 143 | plot_tensor(attn.cpu()), 144 | self.current_epoch, 145 | dataformats="HWC", 146 | ) 147 | return {"loss": total_loss, "log": loss_dict} 148 | 149 | def validation_step(self, batch: Any, batch_idx: int): 150 | loss_dict, attn_dict = self.get_losses(batch) 151 | self.log( 152 | "sub_loss/val_dur_loss", 153 | loss_dict["dur_loss"], 154 | on_step=True, 155 | on_epoch=True, 156 | logger=True, 157 | sync_dist=True, 158 | ) 159 | self.log( 160 | "sub_loss/val_prior_loss", 161 | loss_dict["prior_loss"], 162 | on_step=True, 163 | on_epoch=True, 164 | logger=True, 165 | sync_dist=True, 166 | ) 167 | self.log( 168 | "sub_loss/val_diff_loss", 169 | loss_dict["diff_loss"], 170 | on_step=True, 171 | on_epoch=True, 172 | logger=True, 173 | sync_dist=True, 174 | ) 175 | 176 | total_loss = sum(loss_dict.values()) 177 | self.log( 178 | "loss/val", 179 | total_loss, 180 | on_step=True, 181 | on_epoch=True, 182 | logger=True, 183 | prog_bar=True, 184 | sync_dist=True, 185 | ) 186 | 187 | attn = attn_dict["attn"][0] 188 | self.logger.experiment.add_image( 189 | f"val/alignment", 190 | plot_tensor(attn.cpu()), 191 | self.current_epoch, 192 | dataformats="HWC", 193 | ) 194 | return total_loss 195 | 196 | def on_validation_end(self) -> None: 197 | if self.trainer.is_global_zero: 198 | one_batch = next(iter(self.trainer.val_dataloaders)) 199 | 200 | if self.current_epoch == 0: 201 | log.debug("Plotting original samples") 202 | for i in range(2): 203 | y = one_batch["y"][i].unsqueeze(0).to(self.device) 204 | self.logger.experiment.add_image( 205 | f"original/{i}", 206 | plot_tensor(y.squeeze().cpu()), 207 | self.current_epoch, 208 | dataformats="HWC", 209 | ) 210 | 211 | log.debug("Synthesising...") 212 | for i in range(2): 213 | x = one_batch["x"][i].unsqueeze(0).to(self.device) 214 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) 215 | y = one_batch["y"][i].unsqueeze(0).to(self.device) 216 | y_lengths = one_batch["y_lengths"][i].unsqueeze(0).to(self.device) 217 | # prompt = one_batch["prompt_spec"][i].unsqueeze(0).to(self.device) 218 | # prompt_lengths = one_batch["prompt_lengths"][i].unsqueeze(0).to(self.device) 219 | prompt = y 220 | prompt_lengths = y_lengths 221 | prompt_slice, ids_slice = commons.rand_slice_segments( 222 | prompt, prompt_lengths, self.prompt_size 223 | ) 224 | output = self.synthesise(x[:, :x_lengths], x_lengths, prompt=prompt_slice, n_timesteps=10, guidance_scale=0.0) 225 | y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] 226 | attn = output["attn"] 227 | self.logger.experiment.add_image( 228 | f"generated_enc/{i}", 229 | plot_tensor(y_enc.squeeze().cpu()), 230 | self.current_epoch, 231 | dataformats="HWC", 232 | ) 233 | self.logger.experiment.add_image( 234 | f"generated_dec/{i}", 235 | plot_tensor(y_dec.squeeze().cpu()), 236 | self.current_epoch, 237 | dataformats="HWC", 238 | ) 239 | self.logger.experiment.add_image( 240 | f"alignment/{i}", 241 | plot_tensor(attn.squeeze().cpu()), 242 | self.current_epoch, 243 | dataformats="HWC", 244 | ) 245 | 246 | def on_before_optimizer_step(self, optimizer): 247 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) 248 | -------------------------------------------------------------------------------- /pflow/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/models/components/__init__.py -------------------------------------------------------------------------------- /pflow/models/components/aligner.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.nn import Module 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, repeat 10 | 11 | from beartype import beartype 12 | from beartype.typing import Optional 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | class AlignerNet(Module): 18 | """alignment model https://arxiv.org/pdf/2108.10447.pdf """ 19 | def __init__( 20 | self, 21 | dim_in=80, 22 | dim_hidden=512, 23 | attn_channels=80, 24 | temperature=0.0005, 25 | ): 26 | super().__init__() 27 | self.temperature = temperature 28 | 29 | self.key_layers = nn.ModuleList([ 30 | nn.Conv1d( 31 | dim_hidden, 32 | dim_hidden * 2, 33 | kernel_size=3, 34 | padding=1, 35 | bias=True, 36 | ), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True) 39 | ]) 40 | 41 | self.query_layers = nn.ModuleList([ 42 | nn.Conv1d( 43 | dim_in, 44 | dim_in * 2, 45 | kernel_size=3, 46 | padding=1, 47 | bias=True, 48 | ), 49 | nn.ReLU(inplace=True), 50 | nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True), 51 | nn.ReLU(inplace=True), 52 | nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True) 53 | ]) 54 | 55 | @beartype 56 | def forward( 57 | self, 58 | queries: Tensor, 59 | keys: Tensor, 60 | mask: Optional[Tensor] = None 61 | ): 62 | key_out = keys 63 | for layer in self.key_layers: 64 | key_out = layer(key_out) 65 | 66 | query_out = queries 67 | for layer in self.query_layers: 68 | query_out = layer(query_out) 69 | 70 | key_out = rearrange(key_out, 'b c t -> b t c') 71 | query_out = rearrange(query_out, 'b c t -> b t c') 72 | 73 | attn_logp = torch.cdist(query_out, key_out) 74 | attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...') 75 | 76 | if exists(mask): 77 | mask = rearrange(mask.bool(), '... c -> ... 1 c') 78 | attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max) 79 | 80 | attn = attn_logp.softmax(dim = -1) 81 | return attn, attn_logp 82 | 83 | def pad_tensor(input, pad, value=0): 84 | pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple 85 | assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions' 86 | return F.pad(input, pad, mode='constant', value=value) 87 | 88 | def maximum_path(value, mask, const=None): 89 | device = value.device 90 | dtype = value.dtype 91 | if not exists(const): 92 | const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint 93 | value = value * mask 94 | 95 | b, t_x, t_y = value.shape 96 | direction = torch.zeros(value.shape, dtype=torch.int64, device=device) 97 | v = torch.zeros((b, t_x), dtype=torch.float32, device=device) 98 | x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1) 99 | 100 | for j in range(t_y): 101 | v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1] 102 | v1 = v 103 | max_mask = v1 >= v0 104 | v_max = torch.where(max_mask, v1, v0) 105 | direction[:, :, j] = max_mask 106 | 107 | index_mask = x_range <= j 108 | v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const) 109 | 110 | direction = torch.where(mask.bool(), direction, 1) 111 | 112 | path = torch.zeros(value.shape, dtype=torch.float32, device=device) 113 | index = mask[:, :, 0].sum(1).long() - 1 114 | index_range = torch.arange(b, device=device) 115 | 116 | for j in reversed(range(t_y)): 117 | path[index_range, index, j] = 1 118 | index = index + direction[index_range, index, j] - 1 119 | 120 | path = path * mask.float() 121 | path = path.to(dtype=dtype) 122 | return path 123 | 124 | class ForwardSumLoss(Module): 125 | def __init__( 126 | self, 127 | blank_logprob = -1 128 | ): 129 | super().__init__() 130 | self.blank_logprob = blank_logprob 131 | 132 | self.ctc_loss = torch.nn.CTCLoss( 133 | blank = 0, # check this value 134 | zero_infinity = True 135 | ) 136 | 137 | def forward(self, attn_logprob, key_lens, query_lens): 138 | device, blank_logprob = attn_logprob.device, self.blank_logprob 139 | max_key_len = attn_logprob.size(-1) 140 | 141 | # Reorder input to [query_len, batch_size, key_len] 142 | attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') 143 | 144 | # Add blank label 145 | attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob) 146 | 147 | # Convert to log probabilities 148 | # Note: Mask out probs beyond key_len 149 | mask_value = -torch.finfo(attn_logprob.dtype).max 150 | attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) 151 | 152 | attn_logprob = attn_logprob.log_softmax(dim = -1) 153 | 154 | # Target sequences 155 | target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long) 156 | target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel()) 157 | 158 | # Evaluate CTC loss 159 | cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens) 160 | 161 | return cost 162 | 163 | class BinLoss(Module): 164 | def forward(self, attn_hard, attn_logprob, key_lens): 165 | batch, device = attn_logprob.shape[0], attn_logprob.device 166 | max_key_len = attn_logprob.size(-1) 167 | 168 | # Reorder input to [query_len, batch_size, key_len] 169 | attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') 170 | attn_hard = rearrange(attn_hard, 'b t c -> c b t') 171 | 172 | mask_value = -torch.finfo(attn_logprob.dtype).max 173 | 174 | attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) 175 | attn_logprob = attn_logprob.log_softmax(dim = -1) 176 | 177 | return (attn_hard * attn_logprob).sum() / batch 178 | 179 | class Aligner(Module): 180 | def __init__( 181 | self, 182 | dim_in, 183 | dim_hidden, 184 | attn_channels=80, 185 | temperature=0.0005 186 | ): 187 | super().__init__() 188 | self.dim_in = dim_in 189 | self.dim_hidden = dim_hidden 190 | self.attn_channels = attn_channels 191 | self.temperature = temperature 192 | self.aligner = AlignerNet( 193 | dim_in = self.dim_in, 194 | dim_hidden = self.dim_hidden, 195 | attn_channels = self.attn_channels, 196 | temperature = self.temperature 197 | ) 198 | 199 | def forward( 200 | self, 201 | x, 202 | x_mask, 203 | y, 204 | y_mask 205 | ): 206 | alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask) 207 | 208 | x_mask = rearrange(x_mask, '... i -> ... i 1') 209 | y_mask = rearrange(y_mask, '... j -> ... 1 j') 210 | attn_mask = x_mask * y_mask 211 | attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j') 212 | 213 | alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c') 214 | alignment_mask = maximum_path(alignment_soft, attn_mask) 215 | 216 | alignment_hard = torch.sum(alignment_mask, -1).int() 217 | return alignment_hard, alignment_soft, alignment_logprob, alignment_mask 218 | 219 | if __name__ == '__main__': 220 | batch_size = 10 221 | seq_len_y = 200 # length of sequence y 222 | seq_len_x = 35 223 | feature_dim = 80 # feature dimension 224 | 225 | x = torch.randn(batch_size, 512, seq_len_x) 226 | x = x.transpose(1,2) #dim-1 is the channels for conv 227 | y = torch.randn(batch_size, seq_len_y, feature_dim) 228 | y = y.transpose(1,2) #dim-1 is the channels for conv 229 | 230 | # Create masks 231 | x_mask = torch.ones(batch_size, 1, seq_len_x) 232 | y_mask = torch.ones(batch_size, 1, seq_len_y) 233 | 234 | align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80) 235 | alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask) -------------------------------------------------------------------------------- /pflow/models/components/commons.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import math 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def init_weights(m, mean=0.0, std=0.01): 8 | classname = m.__class__.__name__ 9 | if classname.find("Conv") != -1: 10 | m.weight.data.normal_(mean, std) 11 | 12 | 13 | def get_padding(kernel_size, dilation=1): 14 | return int((kernel_size * dilation - dilation) / 2) 15 | 16 | 17 | def convert_pad_shape(pad_shape): 18 | l = pad_shape[::-1] 19 | pad_shape = [item for sublist in l for item in sublist] 20 | return pad_shape 21 | 22 | 23 | def intersperse(lst, item): 24 | result = [item] * (len(lst) * 2 + 1) 25 | result[1::2] = lst 26 | return result 27 | 28 | 29 | def kl_divergence(m_p, logs_p, m_q, logs_q): 30 | """KL(P||Q)""" 31 | kl = (logs_q - logs_p) - 0.5 32 | kl += ( 33 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 34 | ) 35 | return kl 36 | 37 | 38 | def rand_gumbel(shape): 39 | """Sample from the Gumbel distribution, protect from overflows.""" 40 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 41 | return -torch.log(-torch.log(uniform_samples)) 42 | 43 | 44 | def rand_gumbel_like(x): 45 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 46 | return g 47 | 48 | 49 | def slice_segments(x, ids_str, segment_size=4): 50 | ret = torch.zeros_like(x[:, :, :segment_size]) 51 | for i in range(x.size(0)): 52 | idx_str = ids_str[i] 53 | idx_end = idx_str + segment_size 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | return ret 56 | 57 | 58 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 59 | b, d, t = x.size() 60 | if x_lengths is None: 61 | x_lengths = t 62 | ids_str_max = x_lengths - segment_size + 1 63 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 64 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( 65 | dtype=torch.long 66 | ) 67 | ret = slice_segments(x, ids_str, segment_size) 68 | return ret, ids_str 69 | 70 | 71 | def rand_slice_segments_for_cat(x, x_lengths=None, segment_size=4): 72 | b, d, t = x.size() 73 | if x_lengths is None: 74 | x_lengths = t 75 | ids_str_max = x_lengths - segment_size + 1 76 | ids_str = torch.rand([b // 2]).to(device=x.device) 77 | ids_str = (torch.cat([ids_str, ids_str], dim=0) * ids_str_max).to(dtype=torch.long) 78 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( 79 | dtype=torch.long 80 | ) 81 | ret = slice_segments(x, ids_str, segment_size) 82 | return ret, ids_str 83 | 84 | 85 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 86 | position = torch.arange(length, dtype=torch.float) 87 | num_timescales = channels // 2 88 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 89 | num_timescales - 1 90 | ) 91 | inv_timescales = min_timescale * torch.exp( 92 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 93 | ) 94 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 95 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 96 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 97 | signal = signal.view(1, channels, length) 98 | return signal 99 | 100 | 101 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 102 | b, channels, length = x.size() 103 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 104 | return x + signal.to(dtype=x.dtype, device=x.device) 105 | 106 | 107 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 108 | b, channels, length = x.size() 109 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 110 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 111 | 112 | 113 | def subsequent_mask(length): 114 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 115 | return mask 116 | 117 | 118 | @torch.jit.script 119 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 120 | n_channels_int = n_channels[0] 121 | in_act = input_a + input_b 122 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 123 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 124 | acts = t_act * s_act 125 | return acts 126 | 127 | 128 | def convert_pad_shape(pad_shape): 129 | l = pad_shape[::-1] 130 | pad_shape = [item for sublist in l for item in sublist] 131 | return pad_shape 132 | 133 | 134 | def shift_1d(x): 135 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 136 | return x 137 | 138 | 139 | def sequence_mask(length, max_length=None): 140 | if max_length is None: 141 | max_length = length.max() 142 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 143 | return x.unsqueeze(0) < length.unsqueeze(1) 144 | 145 | 146 | def generate_path(duration, mask): 147 | """ 148 | duration: [b, 1, t_x] 149 | mask: [b, 1, t_y, t_x] 150 | """ 151 | device = duration.device 152 | 153 | b, _, t_y, t_x = mask.shape 154 | cum_duration = torch.cumsum(duration, -1) 155 | 156 | cum_duration_flat = cum_duration.view(b * t_x) 157 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 158 | path = path.view(b, t_x, t_y) 159 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 160 | path = path.unsqueeze(1).transpose(2, 3) * mask 161 | return path 162 | 163 | 164 | def clip_grad_value_(parameters, clip_value, norm_type=2): 165 | if isinstance(parameters, torch.Tensor): 166 | parameters = [parameters] 167 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 168 | norm_type = float(norm_type) 169 | if clip_value is not None: 170 | clip_value = float(clip_value) 171 | 172 | total_norm = 0 173 | for p in parameters: 174 | param_norm = p.grad.data.norm(norm_type) 175 | total_norm += param_norm.item() ** norm_type 176 | if clip_value is not None: 177 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 178 | total_norm = total_norm ** (1.0 / norm_type) 179 | return total_norm 180 | -------------------------------------------------------------------------------- /pflow/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from pflow.models.components.decoder import Decoder 7 | from pflow.models.components.wn_pflow_decoder import DiffSingerNet 8 | from pflow.models.components.vits_wn_decoder import VitsWNDecoder 9 | 10 | from pflow.utils.pylogger import get_pylogger 11 | 12 | log = get_pylogger(__name__) 13 | 14 | 15 | class BASECFM(torch.nn.Module, ABC): 16 | def __init__( 17 | self, 18 | n_feats, 19 | cfm_params, 20 | n_spks=1, 21 | spk_emb_dim=128, 22 | ): 23 | super().__init__() 24 | self.n_feats = n_feats 25 | self.n_spks = n_spks 26 | self.spk_emb_dim = spk_emb_dim 27 | self.solver = cfm_params.solver 28 | if hasattr(cfm_params, "sigma_min"): 29 | self.sigma_min = cfm_params.sigma_min 30 | else: 31 | self.sigma_min = 1e-4 32 | 33 | self.estimator = None 34 | 35 | @torch.inference_mode() 36 | def forward(self, mu, mask, n_timesteps, temperature=1.0, cond=None, training=False, guidance_scale=0.0): 37 | """Forward diffusion 38 | 39 | Args: 40 | mu (torch.Tensor): output of encoder 41 | shape: (batch_size, n_feats, mel_timesteps) 42 | mask (torch.Tensor): output_mask 43 | shape: (batch_size, 1, mel_timesteps) 44 | n_timesteps (int): number of diffusion steps 45 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 46 | cond: Not used but kept for future purposes 47 | 48 | Returns: 49 | sample: generated mel-spectrogram 50 | shape: (batch_size, n_feats, mel_timesteps) 51 | """ 52 | z = torch.randn_like(mu) * temperature 53 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 54 | if self.solver == "euler": 55 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, cond=cond, training=training, guidance_scale=guidance_scale) 56 | elif self.solver == "heun": 57 | return self.solve_heun(z, t_span=t_span, mu=mu, mask=mask, cond=cond, training=training, guidance_scale=guidance_scale) 58 | elif self.solver == "midpoint": 59 | return self.solve_midpoint(z, t_span=t_span, mu=mu, mask=mask, cond=cond, training=training, guidance_scale=guidance_scale) 60 | 61 | def solve_euler(self, x, t_span, mu, mask, cond, training=False, guidance_scale=0.0): 62 | """ 63 | Fixed euler solver for ODEs. 64 | Args: 65 | x (torch.Tensor): random noise 66 | t_span (torch.Tensor): n_timesteps interpolated 67 | shape: (n_timesteps + 1,) 68 | mu (torch.Tensor): output of encoder 69 | shape: (batch_size, n_feats, mel_timesteps) 70 | mask (torch.Tensor): output_mask 71 | shape: (batch_size, 1, mel_timesteps) 72 | cond: Not used but kept for future purposes 73 | """ 74 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 75 | 76 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 77 | # Or in future might add like a return_all_steps flag 78 | sol = [] 79 | steps = 1 80 | while steps <= len(t_span) - 1: 81 | dphi_dt = self.func_dphi_dt(x, mask, mu, t, cond, training=training, guidance_scale=guidance_scale) 82 | x = x + dt * dphi_dt 83 | t = t + dt 84 | sol.append(x) 85 | if steps < len(t_span) - 1: 86 | dt = t_span[steps + 1] - t 87 | steps += 1 88 | 89 | return sol[-1] 90 | 91 | def solve_heun(self, x, t_span, mu, mask, cond, training=False, guidance_scale=0.0): 92 | """ 93 | Fixed heun solver for ODEs. 94 | Args: 95 | x (torch.Tensor): random noise 96 | t_span (torch.Tensor): n_timesteps interpolated 97 | shape: (n_timesteps + 1,) 98 | mu (torch.Tensor): output of encoder 99 | shape: (batch_size, n_feats, mel_timesteps) 100 | mask (torch.Tensor): output_mask 101 | shape: (batch_size, 1, mel_timesteps) 102 | cond: Not used but kept for future purposes 103 | """ 104 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 105 | 106 | #-! : reserved space for debugger 107 | sol = [] 108 | steps = 1 109 | 110 | while steps <= len(t_span) - 1: 111 | dphi_dt = self.func_dphi_dt(x, mask, mu, t, cond, training=training, guidance_scale=guidance_scale) 112 | dphi_dt_2 = self.func_dphi_dt(x + dt * dphi_dt, mask, mu, t+dt, cond, training=training, guidance_scale=guidance_scale) 113 | 114 | #- Euler's -> Y'n+1' = Y'n' + h * F(X'n', Y'n') 115 | # x = x + dt * dphi_dt 116 | 117 | #- Heun's -> Y'n+1' = Y'n' + h * 0.5( F(X'n', Y'n') + F(X'n' + h, Y'n' + h * F(X'n', Y'n') ) ) 118 | x = x + dt * 0.5 * (dphi_dt + dphi_dt_2) 119 | t = t + dt 120 | 121 | sol.append(x) 122 | if steps < len(t_span) - 1: 123 | dt = t_span[steps + 1] - t 124 | steps += 1 125 | 126 | return sol[-1] 127 | 128 | def solve_midpoint(self, x, t_span, mu, mask, cond, training=False, guidance_scale=0.0): 129 | """ 130 | Fixed midpoint solver for ODEs. 131 | Args: 132 | x (torch.Tensor): random noise 133 | t_span (torch.Tensor): n_timesteps interpolated 134 | shape: (n_timesteps + 1,) 135 | mu (torch.Tensor): output of encoder 136 | shape: (batch_size, n_feats, mel_timesteps) 137 | mask (torch.Tensor): output_mask 138 | shape: (batch_size, 1, mel_timesteps) 139 | cond: Not used but kept for future purposes 140 | """ 141 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 142 | 143 | # -! : reserved space for debugger 144 | sol = [] 145 | steps = 1 146 | 147 | while steps <= len(t_span) - 1: 148 | dphi_dt = self.func_dphi_dt(x, mask, mu, t, cond, training=training, guidance_scale=guidance_scale) 149 | dphi_dt_2 = self.func_dphi_dt(x + dt * 0.5 * dphi_dt, mask, mu, t + dt * 0.5, cond, training=training, guidance_scale=guidance_scale) 150 | 151 | # - Euler's -> Y'n+1' = Y'n' + h * F(X'n', Y'n') 152 | # x = x + dt * dphi_dt 153 | 154 | #- midpoint -> Y'n+1' = Y'n' + h * F(X'n' + 0.5 * h, Y'n' + 0.5 * h * F(X'n', Y'n') ) 155 | x = x + dt * dphi_dt_2 156 | t = t + dt 157 | 158 | sol.append(x) 159 | if steps < len(t_span) - 1: 160 | dt = t_span[steps + 1] - t 161 | steps += 1 162 | 163 | return sol[-1] 164 | 165 | def func_dphi_dt(self, x, mask, mu, t, cond, training=False, guidance_scale=0.0): 166 | dphi_dt = self.estimator(x, mask, mu, t, cond, training=training) 167 | 168 | if guidance_scale > 0.0: 169 | mu_avg = mu.mean(2, keepdims=True).expand_as(mu) 170 | dphi_avg = self.estimator(x, mask, mu_avg, t, cond, training=training) 171 | dphi_dt = dphi_dt + guidance_scale * (dphi_dt - dphi_avg) 172 | 173 | return dphi_dt 174 | 175 | def compute_loss(self, x1, mask, mu, cond=None, training=True, loss_mask=None): 176 | """Computes diffusion loss 177 | 178 | Args: 179 | x1 (torch.Tensor): Target 180 | shape: (batch_size, n_feats, mel_timesteps) 181 | mask (torch.Tensor): target mask 182 | shape: (batch_size, 1, mel_timesteps) 183 | mu (torch.Tensor): output of encoder 184 | shape: (batch_size, n_feats, mel_timesteps) 185 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 186 | shape: (batch_size, spk_emb_dim) 187 | 188 | Returns: 189 | loss: conditional flow matching loss 190 | y: conditional flow 191 | shape: (batch_size, n_feats, mel_timesteps) 192 | """ 193 | b, _, t = mu.shape 194 | 195 | # random timestep 196 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 197 | # sample noise p(x_0) 198 | z = torch.randn_like(x1) 199 | 200 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 201 | u = x1 - (1 - self.sigma_min) * z 202 | # y = u * t + z 203 | estimator_out = self.estimator(y, mask, mu, t.squeeze(), training=training) 204 | 205 | if loss_mask is not None: 206 | mask = loss_mask 207 | loss = F.mse_loss(estimator_out*mask, u*mask, reduction="sum") / ( 208 | torch.sum(mask) * u.shape[1] 209 | ) 210 | return loss, y 211 | 212 | 213 | class CFM(BASECFM): 214 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params): 215 | super().__init__( 216 | n_feats=in_channels, 217 | cfm_params=cfm_params, 218 | ) 219 | 220 | # Just change the architecture of the estimator here 221 | self.estimator = Decoder(in_channels=in_channels*2, out_channels=out_channel, **decoder_params) 222 | # self.estimator = DiffSingerNet(in_dims=in_channels, encoder_hidden=out_channel) 223 | # self.estimator = VitsWNDecoder( 224 | # in_channels=in_channels, 225 | # out_channels=out_channel, 226 | # hidden_channels=out_channel, 227 | # kernel_size=3, 228 | # dilation_rate=1, 229 | # n_layers=18, 230 | # gin_channels=out_channel*2 231 | # ) 232 | 233 | -------------------------------------------------------------------------------- /pflow/models/components/test.py: -------------------------------------------------------------------------------- 1 | from pflow.hifigan.meldataset import mel_spectrogram 2 | import torch 3 | 4 | audio = torch.randn(2,1, 1000) 5 | mels = mel_spectrogram(audio, 1024, 80, 22050, 256, 1024, 0, 8000, center=False) 6 | print(mels.shape) -------------------------------------------------------------------------------- /pflow/models/components/vits_modules.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from pflow.models.components import commons 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | class LayerNorm(nn.Module): 11 | def __init__(self, channels, eps=1e-5): 12 | super().__init__() 13 | self.channels = channels 14 | self.eps = eps 15 | 16 | self.gamma = nn.Parameter(torch.ones(channels)) 17 | self.beta = nn.Parameter(torch.zeros(channels)) 18 | 19 | def forward(self, x): 20 | x = x.transpose(1, -1) 21 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 22 | return x.transpose(1, -1) 23 | 24 | 25 | class ConvReluNorm(nn.Module): 26 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | self.hidden_channels = hidden_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = kernel_size 32 | self.n_layers = n_layers 33 | self.p_dropout = p_dropout 34 | assert n_layers > 1, "Number of layers should be larger than 0." 35 | 36 | self.conv_layers = nn.ModuleList() 37 | self.norm_layers = nn.ModuleList() 38 | self.conv_layers.append( 39 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2) 40 | ) 41 | self.norm_layers.append(LayerNorm(hidden_channels)) 42 | self.relu_drop = nn.Sequential( 43 | nn.ReLU(), 44 | nn.Dropout(p_dropout)) 45 | for _ in range(n_layers-1): 46 | self.conv_layers.append(nn.Conv1d( 47 | hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2) 48 | ) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 51 | self.proj.weight.data.zero_() 52 | self.proj.bias.data.zero_() 53 | 54 | def forward(self, x, x_mask): 55 | x_org = x 56 | for i in range(self.n_layers): 57 | x = self.conv_layers[i](x * x_mask) 58 | x = self.norm_layers[i](x) 59 | x = self.relu_drop(x) 60 | x = x_org + self.proj(x) 61 | return x * x_mask 62 | 63 | 64 | class DDSConv(nn.Module): 65 | """Dialted and Depth-Separable Convolution""" 66 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 67 | super().__init__() 68 | self.channels = channels 69 | self.kernel_size = kernel_size 70 | self.n_layers = n_layers 71 | self.p_dropout = p_dropout 72 | 73 | self.drop = nn.Dropout(p_dropout) 74 | self.convs_sep = nn.ModuleList() 75 | self.convs_1x1 = nn.ModuleList() 76 | self.norms_1 = nn.ModuleList() 77 | self.norms_2 = nn.ModuleList() 78 | for i in range(n_layers): 79 | dilation = kernel_size ** i 80 | padding = (kernel_size * dilation - dilation) // 2 81 | self.convs_sep.append( 82 | nn.Conv1d( 83 | channels, 84 | channels, 85 | kernel_size, 86 | groups=channels, 87 | dilation=dilation, 88 | padding=padding 89 | ) 90 | ) 91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 92 | self.norms_1.append(LayerNorm(channels)) 93 | self.norms_2.append(LayerNorm(channels)) 94 | 95 | def forward(self, x, x_mask, g=None): 96 | if g is not None: 97 | x = x + g 98 | for i in range(self.n_layers): 99 | y = self.convs_sep[i](x * x_mask) 100 | y = self.norms_1[i](y) 101 | y = F.gelu(y) 102 | y = self.convs_1x1[i](y) 103 | y = self.norms_2[i](y) 104 | y = F.gelu(y) 105 | y = self.drop(y) 106 | x = x + y 107 | return x * x_mask 108 | 109 | 110 | class WN(torch.nn.Module): 111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 112 | super(WN, self).__init__() 113 | assert(kernel_size % 2 == 1) 114 | self.hidden_channels = hidden_channels 115 | self.kernel_size = kernel_size, 116 | self.dilation_rate = dilation_rate 117 | self.n_layers = n_layers 118 | self.gin_channels = gin_channels 119 | self.p_dropout = p_dropout 120 | 121 | self.in_layers = torch.nn.ModuleList() 122 | self.res_skip_layers = torch.nn.ModuleList() 123 | self.drop = nn.Dropout(p_dropout) 124 | 125 | if gin_channels != 0: 126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 128 | 129 | for i in range(n_layers): 130 | dilation = dilation_rate ** i 131 | padding = int((kernel_size * dilation - dilation) / 2) 132 | in_layer = torch.nn.Conv1d( 133 | hidden_channels, 134 | 2*hidden_channels, 135 | kernel_size, 136 | dilation=dilation, 137 | padding=padding 138 | ) 139 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 140 | self.in_layers.append(in_layer) 141 | 142 | # last one is not necessary 143 | if i < n_layers - 1: 144 | res_skip_channels = 2 * hidden_channels 145 | else: 146 | res_skip_channels = hidden_channels 147 | 148 | res_skip_layer = torch.nn.Conv1d( 149 | hidden_channels, res_skip_channels, 1 150 | ) 151 | res_skip_layer = torch.nn.utils.weight_norm( 152 | res_skip_layer, name='weight' 153 | ) 154 | self.res_skip_layers.append(res_skip_layer) 155 | 156 | def forward(self, x, x_mask, g=None, **kwargs): 157 | output = torch.zeros_like(x) 158 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 159 | if g is not None: 160 | g = g.unsqueeze(-1) 161 | g = self.cond_layer(g) 162 | 163 | for i in range(self.n_layers): 164 | x_in = self.in_layers[i](x) 165 | if g is not None: 166 | cond_offset = i * 2 * self.hidden_channels 167 | g_l = g[:, cond_offset:cond_offset+2*self.hidden_channels, :] 168 | else: 169 | g_l = torch.zeros_like(x_in) 170 | 171 | acts = commons.fused_add_tanh_sigmoid_multiply( 172 | x_in, 173 | g_l, 174 | n_channels_tensor 175 | ) 176 | acts = self.drop(acts) 177 | 178 | res_skip_acts = self.res_skip_layers[i](acts) 179 | if i < self.n_layers - 1: 180 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 181 | x = (x + res_acts) * x_mask 182 | output = output + res_skip_acts[:, self.hidden_channels:, :] 183 | else: 184 | output = output + res_skip_acts 185 | return output * x_mask 186 | 187 | def remove_weight_norm(self): 188 | if self.gin_channels != 0: 189 | torch.nn.utils.remove_weight_norm(self.cond_layer) 190 | for l in self.in_layers: 191 | torch.nn.utils.remove_weight_norm(l) 192 | for l in self.res_skip_layers: 193 | torch.nn.utils.remove_weight_norm(l) 194 | 195 | -------------------------------------------------------------------------------- /pflow/models/components/vits_posterior.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | import pflow.models.components.vits_modules as modules 5 | import pflow.models.components.commons as commons 6 | 7 | class PosteriorEncoder(nn.Module): 8 | 9 | def __init__(self, 10 | in_channels, 11 | out_channels, 12 | hidden_channels, 13 | kernel_size, 14 | dilation_rate, 15 | n_layers, 16 | gin_channels=0): 17 | super().__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.hidden_channels = hidden_channels 21 | self.kernel_size = kernel_size 22 | self.dilation_rate = dilation_rate 23 | self.n_layers = n_layers 24 | self.gin_channels = gin_channels 25 | 26 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 27 | self.enc = modules.WN(hidden_channels, 28 | kernel_size, 29 | dilation_rate, 30 | n_layers, 31 | gin_channels=gin_channels) 32 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 33 | 34 | def forward(self, x, x_lengths, g=None): 35 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 36 | 1).to(x.dtype) 37 | x = self.pre(x) * x_mask 38 | x = self.enc(x, x_mask, g=g) 39 | stats = self.proj(x) * x_mask 40 | # m, logs = torch.split(stats, self.out_channels, dim=1) 41 | # z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 42 | # z = m * x_mask 43 | return stats, x_mask 44 | -------------------------------------------------------------------------------- /pflow/models/components/vits_wn_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | import pflow.models.components.vits_modules as modules 7 | import pflow.models.components.commons as commons 8 | 9 | class Mish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.tanh(F.softplus(x)) 12 | 13 | 14 | class SinusoidalPosEmb(nn.Module): 15 | def __init__(self, dim): 16 | super(SinusoidalPosEmb, self).__init__() 17 | self.dim = dim 18 | 19 | def forward(self, x, scale=1000): 20 | if x.ndim < 1: 21 | x = x.unsqueeze(0) 22 | device = x.device 23 | half_dim = self.dim // 2 24 | emb = math.log(10000) / (half_dim - 1) 25 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 26 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 28 | return emb 29 | 30 | class VitsWNDecoder(nn.Module): 31 | 32 | def __init__(self, 33 | in_channels, 34 | out_channels, 35 | hidden_channels, 36 | kernel_size, 37 | dilation_rate, 38 | n_layers, 39 | gin_channels=0, 40 | pe_scale=1000 41 | ): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | self.out_channels = out_channels 45 | self.hidden_channels = hidden_channels 46 | self.kernel_size = kernel_size 47 | self.dilation_rate = dilation_rate 48 | self.n_layers = n_layers 49 | self.gin_channels = gin_channels 50 | self.pe_scale = pe_scale 51 | self.time_pos_emb = SinusoidalPosEmb(hidden_channels * 2) 52 | dim = hidden_channels * 2 53 | self.mlp = nn.Sequential( 54 | nn.Linear(dim, dim * 4), 55 | Mish(), 56 | nn.Linear(dim * 4, dim) 57 | ) 58 | 59 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 60 | self.enc = modules.WN(hidden_channels * 2, 61 | kernel_size, 62 | dilation_rate, 63 | n_layers, 64 | gin_channels=gin_channels) 65 | self.proj = nn.Conv1d(hidden_channels * 2, out_channels, 1) 66 | 67 | def forward(self, x, x_mask, mu, t, *args, **kwargs): 68 | # x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 69 | # 1).to(x.dtype) 70 | t = self.time_pos_emb(t, scale=self.pe_scale) 71 | t = self.mlp(t) 72 | 73 | x = self.pre(x) * x_mask 74 | mu = self.pre(mu) 75 | x = torch.cat((x, mu), dim=1) 76 | x = self.enc(x, x_mask, g=t) 77 | stats = self.proj(x) * x_mask 78 | 79 | return stats 80 | -------------------------------------------------------------------------------- /pflow/models/components/wn_pflow_decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/cantabile-kwok/VoiceFlow-TTS/blob/main/model/diffsinger.py#L51 3 | This is the original implementation of the DiffSinger model. 4 | It is a slightly modified WV which can be used for initial tests. 5 | Will update this into original p-flow implementation later. 6 | ''' 7 | import math 8 | 9 | import torch.nn as nn 10 | import torch 11 | from torch.nn import Conv1d, Linear 12 | import math 13 | import torch.nn.functional as F 14 | 15 | 16 | class Mish(nn.Module): 17 | def forward(self, x): 18 | return x * torch.tanh(F.softplus(x)) 19 | 20 | 21 | class SinusoidalPosEmb(nn.Module): 22 | def __init__(self, dim): 23 | super(SinusoidalPosEmb, self).__init__() 24 | self.dim = dim 25 | 26 | def forward(self, x, scale=1000): 27 | if x.ndim < 1: 28 | x = x.unsqueeze(0) 29 | device = x.device 30 | half_dim = self.dim // 2 31 | emb = math.log(10000) / (half_dim - 1) 32 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 33 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 34 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 35 | return emb 36 | 37 | 38 | class ResidualBlock(nn.Module): 39 | def __init__(self, encoder_hidden, residual_channels, dilation): 40 | super().__init__() 41 | self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) 42 | self.diffusion_projection = Linear(residual_channels, residual_channels) 43 | self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) 44 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 45 | 46 | def forward(self, x, conditioner, diffusion_step): 47 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 48 | conditioner = self.conditioner_projection(conditioner) 49 | y = x + diffusion_step 50 | 51 | y = self.dilated_conv(y) + conditioner 52 | 53 | gate, filter = torch.chunk(y, 2, dim=1) 54 | y = torch.sigmoid(gate) * torch.tanh(filter) 55 | 56 | y = self.output_projection(y) 57 | residual, skip = torch.chunk(y, 2, dim=1) 58 | return (x + residual) / math.sqrt(2.0), skip 59 | 60 | class DiffSingerNet(nn.Module): 61 | def __init__( 62 | self, 63 | in_dims=80, 64 | residual_channels=256, 65 | encoder_hidden=80, 66 | dilation_cycle_length=1, 67 | residual_layers=20, 68 | pe_scale=1000 69 | ): 70 | super().__init__() 71 | 72 | self.pe_scale = pe_scale 73 | 74 | self.input_projection = Conv1d(in_dims, residual_channels, 1) 75 | self.time_pos_emb = SinusoidalPosEmb(residual_channels) 76 | dim = residual_channels 77 | self.mlp = nn.Sequential( 78 | nn.Linear(dim, dim * 4), 79 | Mish(), 80 | nn.Linear(dim * 4, dim) 81 | ) 82 | self.residual_layers = nn.ModuleList([ 83 | ResidualBlock(encoder_hidden, residual_channels, 2 ** (i % dilation_cycle_length)) 84 | for i in range(residual_layers) 85 | ]) 86 | self.skip_projection = Conv1d(residual_channels, residual_channels, 1) 87 | self.output_projection = Conv1d(residual_channels, in_dims, 1) 88 | nn.init.zeros_(self.output_projection.weight) 89 | 90 | def forward(self, spec, spec_mask, mu, t, *args, **kwargs): 91 | """ 92 | :param spec: [B, M, T] 93 | :param t: [B, ] 94 | :param mu: [B, M, T] 95 | :return: 96 | """ 97 | # x = spec[:, 0] 98 | x = spec 99 | x = self.input_projection(x) # x [B, residual_channel, T] 100 | 101 | x = F.relu(x) 102 | 103 | t = self.time_pos_emb(t, scale=self.pe_scale) 104 | t = self.mlp(t) 105 | 106 | cond = mu 107 | 108 | skip = [] 109 | for layer_id, layer in enumerate(self.residual_layers): 110 | x, skip_connection = layer(x, cond, t) 111 | skip.append(skip_connection) 112 | 113 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 114 | x = self.skip_projection(x) 115 | x = F.relu(x) 116 | x = self.output_projection(x) # [B, M, T] 117 | return x * spec_mask -------------------------------------------------------------------------------- /pflow/models/pflow_tts.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import math 3 | import random 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from pflow import utils 9 | from pflow.models.baselightningmodule import BaseLightningClass 10 | from pflow.models.components.flow_matching import CFM 11 | from pflow.models.components.speech_prompt_encoder import TextEncoder 12 | from pflow.utils.model import ( 13 | denormalize, 14 | duration_loss, 15 | fix_len_compatibility, 16 | generate_path, 17 | sequence_mask, 18 | ) 19 | from pflow.models.components import commons 20 | from pflow.models.components.aligner import Aligner, ForwardSumLoss, BinLoss 21 | 22 | log = utils.get_pylogger(__name__) 23 | 24 | class pflowTTS(BaseLightningClass): # 25 | def __init__( 26 | self, 27 | n_vocab, 28 | n_feats, 29 | encoder, 30 | decoder, 31 | cfm, 32 | data_statistics, 33 | prompt_size=264, 34 | dur_p_use_log=False, 35 | optimizer=None, 36 | scheduler=None, 37 | **kwargs, 38 | ): 39 | super().__init__() 40 | 41 | self.save_hyperparameters(logger=False) 42 | 43 | self.n_vocab = n_vocab 44 | self.n_feats = n_feats 45 | self.prompt_size = prompt_size 46 | self.dur_p_use_log = dur_p_use_log 47 | speech_in_channels = n_feats 48 | 49 | self.encoder = TextEncoder( 50 | encoder.encoder_type, 51 | encoder.encoder_params, 52 | encoder.duration_predictor_params, 53 | n_vocab, 54 | speech_in_channels, 55 | ) 56 | 57 | # self.aligner = Aligner( 58 | # dim_in=encoder.encoder_params.n_feats, 59 | # dim_hidden=encoder.encoder_params.n_feats, 60 | # attn_channels=encoder.encoder_params.n_feats, 61 | # ) 62 | 63 | # self.aligner_loss = ForwardSumLoss() 64 | # self.bin_loss = BinLoss() 65 | # self.aligner_bin_loss_weight = 0.0 66 | 67 | self.decoder = CFM( 68 | in_channels=encoder.encoder_params.n_feats, 69 | out_channel=encoder.encoder_params.n_feats, 70 | cfm_params=cfm, 71 | decoder_params=decoder, 72 | ) 73 | 74 | self.proj_prompt = torch.nn.Conv1d(encoder.encoder_params.n_channels, self.n_feats, 1) 75 | 76 | self.update_data_statistics(data_statistics) 77 | 78 | @torch.inference_mode() 79 | def synthesise(self, x, x_lengths, prompt, n_timesteps, temperature=1.0, length_scale=1.0, guidance_scale=0.0): 80 | 81 | # For RTF computation 82 | t = dt.datetime.now() 83 | assert prompt is not None, "Prompt must be provided for synthesis" 84 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 85 | mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt) 86 | w = torch.exp(logw) * x_mask 87 | w_ceil = torch.ceil(w) * length_scale 88 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 89 | y_max_length = y_lengths.max() 90 | y_max_length_ = fix_len_compatibility(y_max_length) 91 | 92 | # Using obtained durations `w` construct alignment map `attn` 93 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 94 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 95 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 96 | 97 | # Align encoded text and get mu_y 98 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 99 | mu_y = mu_y.transpose(1, 2) 100 | encoder_outputs = mu_y[:, :, :y_max_length] 101 | 102 | # Generate sample tracing the probability flow 103 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, guidance_scale=guidance_scale) 104 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 105 | 106 | t = (dt.datetime.now() - t).total_seconds() 107 | rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) 108 | 109 | return { 110 | "encoder_outputs": encoder_outputs, 111 | "decoder_outputs": decoder_outputs, 112 | "attn": attn[:, :, :y_max_length], 113 | "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), 114 | "mel_lengths": y_lengths, 115 | "rtf": rtf, 116 | } 117 | 118 | def forward(self, x, x_lengths, y, y_lengths, prompt=None, cond=None, **kwargs): 119 | if prompt is None: 120 | prompt_slice, ids_slice = commons.rand_slice_segments( 121 | y, y_lengths, self.prompt_size 122 | ) 123 | else: 124 | prompt_slice = prompt 125 | mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt_slice) 126 | 127 | y_max_length = y.shape[-1] 128 | 129 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 130 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 131 | 132 | with torch.no_grad(): 133 | # negative cross-entropy 134 | s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] 135 | # s_p_sq_r = torch.exp(-2 * logx) 136 | neg_cent1 = torch.sum( 137 | -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True 138 | ) 139 | # neg_cent1 = torch.sum( 140 | # -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True 141 | # ) # [b, 1, t_s] 142 | neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r) 143 | neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r)) 144 | neg_cent4 = torch.sum( 145 | -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True 146 | ) 147 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 148 | 149 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 150 | from pflow.utils.monotonic_align import maximum_path 151 | attn = ( 152 | maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 153 | ) 154 | 155 | logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask 156 | dur_loss = duration_loss(logw, logw_, x_lengths, use_log=self.dur_p_use_log) 157 | 158 | # aln_hard, aln_soft, aln_log, aln_mask = self.aligner( 159 | # mu_x.transpose(1,2), x_mask, y, y_mask 160 | # ) 161 | # attn = aln_mask.transpose(1,2).unsqueeze(1) 162 | # align_loss = self.aligner_loss(aln_log, x_lengths, y_lengths) 163 | # if self.aligner_bin_loss_weight > 0.: 164 | # align_bin_loss = self.bin_loss(aln_mask, aln_log, x_lengths) * self.aligner_bin_loss_weight 165 | # align_loss = align_loss + align_bin_loss 166 | # dur_loss = F.l1_loss(logw, attn.sum(2)) 167 | # dur_loss = dur_loss + align_loss 168 | 169 | # Align encoded text with mel-spectrogram and get mu_y segment 170 | attn = attn.squeeze(1).transpose(1,2) 171 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 172 | mu_y = mu_y.transpose(1, 2) 173 | 174 | y_loss_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 175 | if prompt is None: 176 | for i in range(y.size(0)): 177 | y_loss_mask[i,:,ids_slice[i]:ids_slice[i] + self.prompt_size] = False 178 | # Compute loss of the decoder 179 | diff_loss, _ = self.decoder.compute_loss(x1=y.detach(), mask=y_mask, mu=mu_y, cond=cond, loss_mask=y_loss_mask) 180 | 181 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_loss_mask) 182 | prior_loss = prior_loss / (torch.sum(y_loss_mask) * self.n_feats) 183 | 184 | return dur_loss, prior_loss, diff_loss, attn -------------------------------------------------------------------------------- /pflow/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/pflow/onnx/__init__.py -------------------------------------------------------------------------------- /pflow/onnx/export.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | 4 | import argparse 5 | import random 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from lightning import LightningModule 11 | 12 | from pflow.cli import VOCODER_URLS, load_pflow, load_vocoder 13 | from pflow.utils.model import normalize 14 | 15 | DEFAULT_OPSET = 15 16 | 17 | SEED = 1234 18 | random.seed(SEED) 19 | np.random.seed(SEED) 20 | torch.manual_seed(SEED) 21 | torch.cuda.manual_seed(SEED) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | 26 | class PflowWithVocoder(LightningModule): 27 | def __init__(self, pflow, vocoder): 28 | super().__init__() 29 | self.pflow = pflow 30 | self.vocoder = vocoder 31 | 32 | def forward(self, x, x_lengths, prompt, scales): 33 | mel, mel_lengths = self.pflow(x, x_lengths, prompt, scales) 34 | wavs = self.vocoder(mel).clamp(-1, 1) 35 | lengths = mel_lengths * 256 36 | return wavs.squeeze(1), lengths 37 | 38 | 39 | def get_exportable_module(pflow, vocoder, n_timesteps): 40 | """ 41 | Return an appropriate `LighteningModule` and output-node names 42 | based on whether the vocoder is embedded in the final graph 43 | """ 44 | 45 | def onnx_forward_func(x, x_lengths, prompt, scales, guidance_scale=0.7): 46 | """ 47 | Custom forward function for accepting 48 | scaler parameters as tensors 49 | """ 50 | # Extract scaler parameters from tensors 51 | temperature = scales[0] 52 | length_scale = scales[1] 53 | prompt = normalize(prompt, pflow.mel_mean, pflow.mel_std) 54 | output = pflow.synthesise( 55 | x, 56 | x_lengths, 57 | prompt, 58 | n_timesteps, 59 | temperature, 60 | length_scale, 61 | guidance_scale=guidance_scale 62 | ) 63 | return output["mel"], output["mel_lengths"] 64 | 65 | # Monkey-patch pflow's forward function 66 | pflow.forward = onnx_forward_func 67 | 68 | if vocoder is None: 69 | model, output_names = pflow, ["mel", "mel_lengths"] 70 | else: 71 | model = PflowWithVocoder(pflow, vocoder) 72 | output_names = ["wav", "wav_lengths"] 73 | return model, output_names 74 | 75 | 76 | def get_inputs(): 77 | """ 78 | Create dummy inputs for tracing 79 | """ 80 | dummy_input_length = 50 81 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 82 | x_lengths = torch.LongTensor([dummy_input_length]) 83 | 84 | prompt = torch.randn(1, 80, 264) #264 is default prompt size 85 | 86 | # Scales 87 | temperature = 0.667 88 | length_scale = 1.0 89 | scales = torch.Tensor([temperature, length_scale]) 90 | 91 | model_inputs = [x, x_lengths, prompt, scales] 92 | input_names = [ 93 | "x", 94 | "x_lengths", 95 | "prompt", 96 | "scales", 97 | ] 98 | 99 | return tuple(model_inputs), input_names 100 | 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser(description="Export pflow-TTS to ONNX") 104 | 105 | parser.add_argument( 106 | "--checkpoint_path", 107 | type=str, 108 | help="Path to the model checkpoint", 109 | ) 110 | parser.add_argument("--output", type=str, help="Path to output `.onnx` file") 111 | 112 | parser.add_argument( 113 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 114 | ) 115 | parser.add_argument( 116 | "--vocoder-name", 117 | type=str, 118 | choices=list(VOCODER_URLS.keys()), 119 | default=None, 120 | help="Name of the vocoder to embed in the ONNX graph", 121 | ) 122 | parser.add_argument( 123 | "--vocoder-checkpoint-path", 124 | type=str, 125 | default=None, 126 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 127 | ) 128 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 129 | 130 | args = parser.parse_args() 131 | 132 | print(f"Loading pflow checkpoint from {args.checkpoint_path}") 133 | print(f"Setting n_timesteps to {args.n_timesteps}") 134 | 135 | checkpoint_path = Path(args.checkpoint_path) 136 | pflow = load_pflow(checkpoint_path.stem, checkpoint_path, "cpu") 137 | 138 | if args.vocoder_name or args.vocoder_checkpoint_path: 139 | assert ( 140 | args.vocoder_name and args.vocoder_checkpoint_path 141 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 142 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 143 | else: 144 | vocoder = None 145 | 146 | dummy_input, input_names = get_inputs() 147 | model, output_names = get_exportable_module(pflow, vocoder, args.n_timesteps) 148 | 149 | # Set dynamic shape for inputs/outputs 150 | dynamic_axes = { 151 | "x": {0: "batch_size", 1: "time"}, 152 | "x_lengths": {0: "batch_size"}, 153 | "prompt": {0: "batch_size", 2: "time"}, 154 | } 155 | 156 | if vocoder is None: 157 | dynamic_axes.update( 158 | { 159 | "mel": {0: "batch_size", 2: "time"}, 160 | "mel_lengths": {0: "batch_size"}, 161 | } 162 | ) 163 | else: 164 | print("Embedding the vocoder in the ONNX graph") 165 | dynamic_axes.update( 166 | { 167 | "wav": {0: "batch_size", 1: "time"}, 168 | "wav_lengths": {0: "batch_size"}, 169 | } 170 | ) 171 | 172 | # Create the output directory (if not exists) 173 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 174 | 175 | model.to_onnx( 176 | args.output, 177 | dummy_input, 178 | input_names=input_names, 179 | output_names=output_names, 180 | dynamic_axes=dynamic_axes, 181 | opset_version=args.opset, 182 | export_params=True, 183 | do_constant_folding=True, 184 | ) 185 | print(f"ONNX model exported to {args.output}") 186 | 187 | 188 | if __name__ == "__main__": 189 | main() -------------------------------------------------------------------------------- /pflow/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | 4 | import argparse 5 | import os 6 | import warnings 7 | from pathlib import Path 8 | from time import perf_counter 9 | 10 | import numpy as np 11 | import onnxruntime as ort 12 | import soundfile as sf 13 | import torch 14 | import torchaudio 15 | 16 | from pflow.cli import process_text, plot_spectrogram_to_numpy 17 | from pflow.data.text_mel_datamodule import mel_spectrogram 18 | 19 | def validate_args(args): 20 | assert ( 21 | args.text or args.file 22 | ), "Either text or file must be provided pflowTTS need some text to generate the waveforms." 23 | assert args.prompt, "Prompt wav must be provided" 24 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 25 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 26 | return args 27 | 28 | 29 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 30 | if external_vocoder is None: 31 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 32 | t0 = perf_counter() 33 | wavs, wav_lengths = model.run(None, inputs) 34 | infer_secs = perf_counter() - t0 35 | mel_infer_secs = vocoder_infer_secs = None 36 | else: 37 | print("Generating mel using pflowTTS") 38 | mel_t0 = perf_counter() 39 | mels, mel_lengths = model.run(None, inputs) 40 | mel_infer_secs = perf_counter() - mel_t0 41 | print("Generating waveform from mel using external vocoder") 42 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 43 | vocoder_t0 = perf_counter() 44 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 45 | vocoder_infer_secs = perf_counter() - vocoder_t0 46 | wavs = wavs.squeeze(1) 47 | wav_lengths = mel_lengths * 256 48 | infer_secs = mel_infer_secs + vocoder_infer_secs 49 | 50 | output_dir = Path(output_dir) 51 | output_dir.mkdir(parents=True, exist_ok=True) 52 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 53 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 54 | audio = wav[:wav_length] 55 | print(f"Writing audio to {output_filename}") 56 | sf.write(output_filename, audio, 22050, "PCM_24") 57 | 58 | wav_secs = wav_lengths.sum() / 22050 59 | print(f"Inference seconds: {infer_secs}") 60 | print(f"Generated wav seconds: {wav_secs}") 61 | rtf = infer_secs / wav_secs 62 | if mel_infer_secs is not None: 63 | mel_rtf = mel_infer_secs / wav_secs 64 | print(f"pflowTTS RTF: {mel_rtf}") 65 | if vocoder_infer_secs is not None: 66 | vocoder_rtf = vocoder_infer_secs / wav_secs 67 | print(f"Vocoder RTF: {vocoder_rtf}") 68 | print(f"Overall RTF: {rtf}") 69 | 70 | 71 | def write_mels(model, inputs, output_dir): 72 | t0 = perf_counter() 73 | mels, mel_lengths = model.run(None, inputs) 74 | infer_secs = perf_counter() - t0 75 | 76 | output_dir = Path(output_dir) 77 | output_dir.mkdir(parents=True, exist_ok=True) 78 | for i, mel in enumerate(mels): 79 | output_stem = output_dir.joinpath(f"output_{i + 1}") 80 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 81 | np.save(output_stem.with_suffix(".numpy"), mel) 82 | 83 | wav_secs = (mel_lengths * 256).sum() / 22050 84 | print(f"Inference seconds: {infer_secs}") 85 | print(f"Generated wav seconds: {wav_secs}") 86 | rtf = infer_secs / wav_secs 87 | print(f"RTF: {rtf}") 88 | 89 | 90 | def main(): 91 | parser = argparse.ArgumentParser( 92 | description="pflowTTS inference script" 93 | ) 94 | parser.add_argument( 95 | "--model", 96 | type=str, 97 | help="ONNX model to use", 98 | ) 99 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 100 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 101 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 102 | 103 | parser.add_argument("--prompt", type=str, default=None, help="Prompt wav file to use") 104 | 105 | parser.add_argument( 106 | "--temperature", 107 | type=float, 108 | default=0.667, 109 | help="Variance of the x0 noise (default: 0.667)", 110 | ) 111 | parser.add_argument( 112 | "--speaking-rate", 113 | type=float, 114 | default=1.0, 115 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 116 | ) 117 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 118 | parser.add_argument( 119 | "--output-dir", 120 | type=str, 121 | default=os.getcwd(), 122 | help="Output folder to save results (default: current dir)", 123 | ) 124 | 125 | args = parser.parse_args() 126 | args = validate_args(args) 127 | 128 | if args.gpu: 129 | providers = ["GPUExecutionProvider"] 130 | else: 131 | providers = ["CPUExecutionProvider"] 132 | model = ort.InferenceSession(args.model, providers=providers) 133 | 134 | model_inputs = model.get_inputs() 135 | model_outputs = list(model.get_outputs()) 136 | 137 | if args.text: 138 | text_lines = args.text.splitlines() 139 | else: 140 | with open(args.file, encoding="utf-8") as file: 141 | text_lines = file.read().splitlines() 142 | 143 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 144 | x = [line["x"].squeeze() for line in processed_lines] 145 | # Pad 146 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 147 | x = x.detach().cpu().numpy() 148 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 149 | # prompt wav 150 | wav_file = args.prompt 151 | prompt_wav, sr = torchaudio.load(wav_file) 152 | prompt = mel_spectrogram( 153 | prompt_wav, 154 | 1024, 155 | 80, 156 | 22050, 157 | 256, 158 | 1024, 159 | 0, 160 | 8000, 161 | center=False, 162 | ) 163 | 164 | inputs = { 165 | "x": x, 166 | "x_lengths": x_lengths, 167 | "prompt": np.array(prompt, dtype=np.float32), 168 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 169 | } 170 | 171 | has_vocoder_embedded = model_outputs[0].name == "wav" 172 | if has_vocoder_embedded: 173 | write_wavs(model, inputs, args.output_dir) 174 | elif args.vocoder: 175 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 176 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 177 | else: 178 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 179 | warnings.warn(warn, UserWarning) 180 | write_mels(model, inputs, args.output_dir) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() -------------------------------------------------------------------------------- /pflow/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from pflow.text import cleaners 3 | from pflow.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /pflow/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | from unidecode import unidecode 19 | 20 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 21 | critical_logger = logging.getLogger("phonemizer") 22 | critical_logger.setLevel(logging.CRITICAL) 23 | 24 | # Intializing the phonemizer globally significantly reduces the speed 25 | # now the phonemizer is not initialising at every call 26 | # Might be less flexible, but it is much-much faster 27 | global_phonemizer = phonemizer.backend.EspeakBackend( 28 | language="en-us", 29 | preserve_punctuation=True, 30 | with_stress=True, 31 | language_switch="remove-flags", 32 | logger=critical_logger, 33 | ) 34 | 35 | 36 | # Regular expression matching whitespace: 37 | _whitespace_re = re.compile(r"\s+") 38 | 39 | # List of (regular expression, replacement) pairs for abbreviations: 40 | _abbreviations = [ 41 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 42 | for x in [ 43 | ("mrs", "misess"), 44 | ("mr", "mister"), 45 | ("dr", "doctor"), 46 | ("st", "saint"), 47 | ("co", "company"), 48 | ("jr", "junior"), 49 | ("maj", "major"), 50 | ("gen", "general"), 51 | ("drs", "doctors"), 52 | ("rev", "reverend"), 53 | ("lt", "lieutenant"), 54 | ("hon", "honorable"), 55 | ("sgt", "sergeant"), 56 | ("capt", "captain"), 57 | ("esq", "esquire"), 58 | ("ltd", "limited"), 59 | ("col", "colonel"), 60 | ("ft", "fort"), 61 | ] 62 | ] 63 | 64 | 65 | def expand_abbreviations(text): 66 | for regex, replacement in _abbreviations: 67 | text = re.sub(regex, replacement, text) 68 | return text 69 | 70 | 71 | def lowercase(text): 72 | return text.lower() 73 | 74 | 75 | def collapse_whitespace(text): 76 | return re.sub(_whitespace_re, " ", text) 77 | 78 | 79 | def convert_to_ascii(text): 80 | return unidecode(text) 81 | 82 | 83 | def basic_cleaners(text): 84 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 85 | text = lowercase(text) 86 | text = collapse_whitespace(text) 87 | return text 88 | 89 | 90 | def transliteration_cleaners(text): 91 | """Pipeline for non-English text that transliterates to ASCII.""" 92 | text = convert_to_ascii(text) 93 | text = lowercase(text) 94 | text = collapse_whitespace(text) 95 | return text 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 104 | phonemes = collapse_whitespace(phonemes) 105 | return phonemes 106 | -------------------------------------------------------------------------------- /pflow/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /pflow/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /pflow/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import hydra 6 | import lightning as L 7 | import rootutils 8 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 9 | from lightning.pytorch.loggers import Logger 10 | from omegaconf import DictConfig 11 | import torch 12 | 13 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 14 | # ------------------------------------------------------------------------------------ # 15 | # the setup_root above is equivalent to: 16 | # - adding project root dir to PYTHONPATH 17 | # (so you don't need to force user to install project as a package) 18 | # (necessary before importing any local modules e.g. `from src import utils`) 19 | # - setting up PROJECT_ROOT environment variable 20 | # (which is used as a base for paths in "configs/paths/default.yaml") 21 | # (this way all filepaths are the same no matter where you run the code) 22 | # - loading environment variables from ".env" in root dir 23 | # 24 | # you can remove it if you: 25 | # 1. either install project as a package or move entry files to project root dir 26 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 27 | # 28 | # more info: https://github.com/ashleve/rootutils 29 | # ------------------------------------------------------------------------------------ # 30 | 31 | 32 | from pflow import utils 33 | 34 | log = utils.get_pylogger(__name__) 35 | 36 | 37 | @utils.task_wrapper 38 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 39 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 40 | training. 41 | 42 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 43 | failure. Useful for multiruns, saving info about the crash, etc. 44 | 45 | :param cfg: A DictConfig configuration composed by Hydra. 46 | :return: A tuple with metrics and dict with all instantiated objects. 47 | """ 48 | # set seed for random number generators in pytorch, numpy and python.random 49 | if cfg.get("seed"): 50 | L.seed_everything(cfg.seed, workers=True) 51 | 52 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 53 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 54 | 55 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 56 | model: LightningModule = hydra.utils.instantiate(cfg.model) 57 | 58 | log.info("Instantiating callbacks...") 59 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 60 | 61 | log.info("Instantiating loggers...") 62 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 63 | 64 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 65 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 66 | 67 | log.info("Instantiating transfer learning...") 68 | if cfg.get("transfer_ckpt_path") is not None: 69 | assert cfg.get("ckpt_path") == None, "Cannot resume from checkpoint and transfer learn at the same time!" 70 | model_state_dict = model.state_dict() 71 | state_dict = torch.load(cfg.get("transfer_ckpt_path"), map_location="cpu")['state_dict'] 72 | for k, v in model_state_dict.items(): 73 | if k in state_dict and state_dict[k].size() == v.size(): 74 | state_dict[k] = v 75 | model.load_state_dict(state_dict, strict=False) 76 | log.info(f"Loaded model from {cfg.get('transfer_ckpt_path')}") 77 | 78 | object_dict = { 79 | "cfg": cfg, 80 | "datamodule": datamodule, 81 | "model": model, 82 | "callbacks": callbacks, 83 | "logger": logger, 84 | "trainer": trainer, 85 | } 86 | 87 | if logger: 88 | log.info("Logging hyperparameters!") 89 | utils.log_hyperparameters(object_dict) 90 | 91 | if cfg.get("train"): 92 | log.info("Starting training!") 93 | if cfg.get("ckpt_path") is not None: 94 | log.info(f"Resuming from checkpoint: {cfg.get('ckpt_path')}") 95 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 96 | else: 97 | trainer.fit(model=model, datamodule=datamodule) 98 | 99 | train_metrics = trainer.callback_metrics 100 | 101 | if cfg.get("test"): 102 | log.info("Starting testing!") 103 | ckpt_path = trainer.checkpoint_callback.best_model_path 104 | if ckpt_path == "": 105 | log.warning("Best ckpt not found! Using current weights for testing...") 106 | ckpt_path = None 107 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 108 | log.info(f"Best ckpt path: {ckpt_path}") 109 | 110 | test_metrics = trainer.callback_metrics 111 | 112 | # merge train and test metrics 113 | metric_dict = {**train_metrics, **test_metrics} 114 | 115 | return metric_dict, object_dict 116 | 117 | 118 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 119 | def main(cfg: DictConfig) -> Optional[float]: 120 | """Main entry point for training. 121 | 122 | :param cfg: DictConfig configuration composed by Hydra. 123 | :return: Optional[float] with optimized metric value. 124 | """ 125 | # apply extra utilities 126 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 127 | utils.extras(cfg) 128 | 129 | # train the model 130 | metric_dict, _ = train(cfg) 131 | 132 | # safely retrieve metric value for hydra-based hyperparameter optimization 133 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 134 | 135 | # return optimized metric 136 | return metric_value 137 | 138 | 139 | if __name__ == "__main__": 140 | main() # pylint: disable=no-value-for-parameter 141 | -------------------------------------------------------------------------------- /pflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pflow.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from pflow.utils.logging_utils import log_hyperparameters 3 | from pflow.utils.pylogger import get_pylogger 4 | from pflow.utils.rich_utils import enforce_tags, print_config_tree 5 | from pflow.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /pflow/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /pflow/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import os 8 | 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) 11 | 12 | import argparse 13 | import json 14 | import sys 15 | from pathlib import Path 16 | 17 | import rootutils 18 | import torch 19 | from hydra import compose, initialize 20 | from omegaconf import open_dict 21 | from tqdm.auto import tqdm 22 | 23 | from pflow.data.text_mel_datamodule import TextMelDataModule 24 | from pflow.utils.logging_utils import pylogger 25 | 26 | log = pylogger.get_pylogger(__name__) 27 | 28 | 29 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 30 | """Generate data mean and standard deviation helpful in data normalisation 31 | 32 | Args: 33 | data_loader (torch.utils.data.Dataloader): _description_ 34 | out_channels (int): mel spectrogram channels 35 | """ 36 | total_mel_sum = 0 37 | total_mel_sq_sum = 0 38 | total_mel_len = 0 39 | 40 | for batch in tqdm(data_loader, leave=False): 41 | mels = batch["y"] 42 | mel_lengths = batch["y_lengths"] 43 | 44 | total_mel_len += torch.sum(mel_lengths) 45 | total_mel_sum += torch.sum(mels) 46 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 47 | 48 | data_mean = total_mel_sum / (total_mel_len * out_channels) 49 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 50 | 51 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser() 56 | 57 | parser.add_argument( 58 | "-i", 59 | "--input-config", 60 | type=str, 61 | default="vctk.yaml", 62 | help="The name of the yaml config file under configs/data", 63 | ) 64 | 65 | parser.add_argument( 66 | "-b", 67 | "--batch-size", 68 | type=int, 69 | default="256", 70 | help="Can have increased batch size for faster computation", 71 | ) 72 | 73 | parser.add_argument( 74 | "-f", 75 | "--force", 76 | action="store_true", 77 | default=False, 78 | required=False, 79 | help="force overwrite the file", 80 | ) 81 | args = parser.parse_args() 82 | output_file = Path(args.input_config).with_suffix(".json") 83 | 84 | if os.path.exists(output_file) and not args.force: 85 | print("File already exists. Use -f to force overwrite") 86 | sys.exit(1) 87 | 88 | with initialize(version_base="1.3", config_path="../../configs/data"): 89 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 90 | 91 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 92 | 93 | with open_dict(cfg): 94 | del cfg["hydra"] 95 | del cfg["_target_"] 96 | cfg["data_statistics"] = None 97 | cfg["seed"] = 1234 98 | cfg["batch_size"] = args.batch_size 99 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 100 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 101 | 102 | text_mel_datamodule = TextMelDataModule(**cfg) 103 | text_mel_datamodule.setup() 104 | data_loader = text_mel_datamodule.train_dataloader() 105 | log.info("Dataloader loaded! Now computing stats...") 106 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 107 | print(params) 108 | json.dump( 109 | params, 110 | open(output_file, "w"), 111 | ) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /pflow/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from pflow.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /pflow/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from pflow.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /pflow/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths, use_log=False): 45 | if use_log: 46 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 47 | else: 48 | loss = torch.sum((torch.exp(logw) - torch.exp(logw_)) ** 2) / torch.sum(lengths) 49 | return loss 50 | 51 | 52 | def normalize(data, mu, std): 53 | if not isinstance(mu, (float, int)): 54 | if isinstance(mu, list): 55 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 56 | elif isinstance(mu, torch.Tensor): 57 | mu = mu.to(data.device) 58 | elif isinstance(mu, np.ndarray): 59 | mu = torch.from_numpy(mu).to(data.device) 60 | mu = mu.unsqueeze(-1) 61 | 62 | if not isinstance(std, (float, int)): 63 | if isinstance(std, list): 64 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 65 | elif isinstance(std, torch.Tensor): 66 | std = std.to(data.device) 67 | elif isinstance(std, np.ndarray): 68 | std = torch.from_numpy(std).to(data.device) 69 | std = std.unsqueeze(-1) 70 | 71 | return (data - mu) / std 72 | 73 | 74 | def denormalize(data, mu, std): 75 | if not isinstance(mu, float): 76 | if isinstance(mu, list): 77 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 78 | elif isinstance(mu, torch.Tensor): 79 | mu = mu.to(data.device) 80 | elif isinstance(mu, np.ndarray): 81 | mu = torch.from_numpy(mu).to(data.device) 82 | mu = mu.unsqueeze(-1) 83 | 84 | if not isinstance(std, float): 85 | if isinstance(std, list): 86 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 87 | elif isinstance(std, torch.Tensor): 88 | std = std.to(data.device) 89 | elif isinstance(std, np.ndarray): 90 | std = torch.from_numpy(std).to(data.device) 91 | std = std.unsqueeze(-1) 92 | 93 | return data * std + mu 94 | -------------------------------------------------------------------------------- /pflow/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pflow.utils.monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /pflow/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /pflow/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /pflow/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from pflow.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /pflow/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Tuple 7 | 8 | import gdown 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import wget 13 | from omegaconf import DictConfig 14 | 15 | from pflow.utils import pylogger, rich_utils 16 | 17 | log = pylogger.get_pylogger(__name__) 18 | 19 | 20 | def extras(cfg: DictConfig) -> None: 21 | """Applies optional utilities before the task is started. 22 | 23 | Utilities: 24 | - Ignoring python warnings 25 | - Setting tags from command line 26 | - Rich config printing 27 | 28 | :param cfg: A DictConfig object containing the config tree. 29 | """ 30 | # return if no `extras` config 31 | if not cfg.get("extras"): 32 | log.warning("Extras config not found! ") 33 | return 34 | 35 | # disable python warnings 36 | if cfg.extras.get("ignore_warnings"): 37 | log.info("Disabling python warnings! ") 38 | warnings.filterwarnings("ignore") 39 | 40 | # prompt user to input tags from command line if none are provided in the config 41 | if cfg.extras.get("enforce_tags"): 42 | log.info("Enforcing tags! ") 43 | rich_utils.enforce_tags(cfg, save_to_file=True) 44 | 45 | # pretty print config tree using Rich library 46 | if cfg.extras.get("print_config"): 47 | log.info("Printing config tree with Rich! ") 48 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 49 | 50 | 51 | def task_wrapper(task_func: Callable) -> Callable: 52 | """Optional decorator that controls the failure behavior when executing the task function. 53 | 54 | This wrapper can be used to: 55 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 56 | - save the exception to a `.log` file 57 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 58 | - etc. (adjust depending on your needs) 59 | 60 | Example: 61 | ``` 62 | @utils.task_wrapper 63 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 64 | ... 65 | return metric_dict, object_dict 66 | ``` 67 | 68 | :param task_func: The task function to be wrapped. 69 | 70 | :return: The wrapped task function. 71 | """ 72 | 73 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 | # execute the task 75 | try: 76 | metric_dict, object_dict = task_func(cfg=cfg) 77 | 78 | # things to do if exception occurs 79 | except Exception as ex: 80 | # save exception to `.log` file 81 | log.exception("") 82 | 83 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 | # so when using hparam search plugins like Optuna, you might want to disable 85 | # raising the below exception to avoid multirun failure 86 | raise ex 87 | 88 | # things to always do after either success or exception 89 | finally: 90 | # display output dir path in terminal 91 | log.info(f"Output dir: {cfg.paths.output_dir}") 92 | 93 | # always close wandb run (even if exception occurs so multirun won't fail) 94 | if find_spec("wandb"): # check if wandb is installed 95 | import wandb 96 | 97 | if wandb.run: 98 | log.info("Closing wandb!") 99 | wandb.finish() 100 | 101 | return metric_dict, object_dict 102 | 103 | return wrap 104 | 105 | 106 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: 107 | """Safely retrieves value of the metric logged in LightningModule. 108 | 109 | :param metric_dict: A dict containing metric values. 110 | :param metric_name: The name of the metric to retrieve. 111 | :return: The value of the metric. 112 | """ 113 | if not metric_name: 114 | log.info("Metric name is None! Skipping metric value retrieval...") 115 | return None 116 | 117 | if metric_name not in metric_dict: 118 | raise Exception( 119 | f"Metric value not found! \n" 120 | "Make sure metric name logged in LightningModule is correct!\n" 121 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 122 | ) 123 | 124 | metric_value = metric_dict[metric_name].item() 125 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 126 | 127 | return metric_value 128 | 129 | 130 | def intersperse(lst, item): 131 | # Adds blank symbol 132 | result = [item] * (len(lst) * 2 + 1) 133 | result[1::2] = lst 134 | return result 135 | 136 | 137 | def save_figure_to_numpy(fig): 138 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 139 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 140 | return data 141 | 142 | 143 | def plot_tensor(tensor): 144 | plt.style.use("default") 145 | fig, ax = plt.subplots(figsize=(12, 3)) 146 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 147 | plt.colorbar(im, ax=ax) 148 | plt.tight_layout() 149 | fig.canvas.draw() 150 | data = save_figure_to_numpy(fig) 151 | plt.close() 152 | return data 153 | 154 | 155 | def save_plot(tensor, savepath): 156 | plt.style.use("default") 157 | fig, ax = plt.subplots(figsize=(12, 3)) 158 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 159 | plt.colorbar(im, ax=ax) 160 | plt.tight_layout() 161 | fig.canvas.draw() 162 | plt.savefig(savepath) 163 | plt.close() 164 | 165 | 166 | def to_numpy(tensor): 167 | if isinstance(tensor, np.ndarray): 168 | return tensor 169 | elif isinstance(tensor, torch.Tensor): 170 | return tensor.detach().cpu().numpy() 171 | elif isinstance(tensor, list): 172 | return np.array(tensor) 173 | else: 174 | raise TypeError("Unsupported type for conversion to numpy array") 175 | 176 | 177 | def get_user_data_dir(appname="pflow_tts"): 178 | """ 179 | Args: 180 | appname (str): Name of application 181 | 182 | Returns: 183 | Path: path to user data directory 184 | """ 185 | 186 | PFLOW_HOME = os.environ.get("PFLOW_HOME") 187 | if PFLOW_HOME is not None: 188 | ans = Path(PFLOW_HOME).expanduser().resolve(strict=False) 189 | elif sys.platform == "win32": 190 | import winreg # pylint: disable=import-outside-toplevel 191 | 192 | key = winreg.OpenKey( 193 | winreg.HKEY_CURRENT_USER, 194 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", 195 | ) 196 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 197 | ans = Path(dir_).resolve(strict=False) 198 | elif sys.platform == "darwin": 199 | ans = Path("~/Library/Application Support/").expanduser() 200 | else: 201 | ans = Path.home().joinpath(".local/share") 202 | 203 | final_path = ans.joinpath(appname) 204 | final_path.mkdir(parents=True, exist_ok=True) 205 | return final_path 206 | 207 | 208 | def assert_model_downloaded(checkpoint_path, url, use_wget=False): 209 | if Path(checkpoint_path).exists(): 210 | log.debug(f"[+] Model already present at {checkpoint_path}!") 211 | return 212 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it") 213 | checkpoint_path = str(checkpoint_path) 214 | if not use_wget: 215 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) 216 | else: 217 | wget.download(url=url, out=checkpoint_path) 218 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | phonemizer # phonemization of text 26 | tensorboard 27 | librosa 28 | Cython 29 | numpy 30 | einops 31 | inflect 32 | Unidecode 33 | scipy 34 | torchaudio 35 | matplotlib 36 | pandas 37 | conformer==0.3.2 38 | diffusers==0.21.3 39 | notebook 40 | ipywidgets 41 | gradio 42 | gdown 43 | wget 44 | seaborn 45 | beartype -------------------------------------------------------------------------------- /samples/download_54.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_54.wav -------------------------------------------------------------------------------- /samples/download_55.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_55.wav -------------------------------------------------------------------------------- /samples/download_56.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_56.wav -------------------------------------------------------------------------------- /samples/download_57.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_57.wav -------------------------------------------------------------------------------- /samples/download_58.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_58.wav -------------------------------------------------------------------------------- /samples/download_64.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_64.wav -------------------------------------------------------------------------------- /samples/download_65.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/samples/download_65.wav -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from Cython.Build import cythonize 2 | import numpy 3 | from setuptools import Extension, setup 4 | 5 | # exts = [ 6 | # Extension( 7 | # name="pflow.utils.monotonic_align.core", 8 | # sources=["pflow/utils/monotonic_align/core.pyx"], 9 | # ) 10 | # ] 11 | # setup(name='monotonic_align', 12 | # ext_modules=cythonize(exts, language_level=3), 13 | # include_dirs=[numpy.get_include()]) 14 | 15 | setup( 16 | name="monotonic_align", 17 | ext_modules=cythonize("pflow/utils/monotonic_align/core.pyx"), 18 | include_dirs=[numpy.get_include()], 19 | ) 20 | -------------------------------------------------------------------------------- /val_out_tboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/pflowtts_pytorch/629b10bb525f0681327e36cfc91c176a5e5af17d/val_out_tboard.png --------------------------------------------------------------------------------