├── data └── .gitkeep ├── logs └── .gitkeep ├── notebooks ├── .gitkeep ├── test.ipynb └── generate.ipynb ├── scripts ├── .gitkeep ├── dump_latents.py ├── dump_resampled.py └── dump_durations.py ├── tests ├── __init__.py └── helpers │ ├── __init__.py │ ├── run_sh_command.py │ ├── package_available.py │ └── run_if.py ├── configs ├── callbacks │ ├── none.yaml │ ├── lr_monitor.yaml │ ├── rich_progress_bar.yaml │ ├── model_summary.yaml │ ├── default.yaml │ ├── early_stopping.yaml │ └── model_checkpoint.yaml ├── trainer │ ├── cpu.yaml │ ├── gpu.yaml │ ├── mps.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ └── default.yaml ├── __init__.py ├── debug │ ├── fdr.yaml │ ├── profiler.yaml │ ├── overfit.yaml │ ├── limit.yaml │ └── default.yaml ├── logger │ ├── many_loggers.yaml │ ├── csv.yaml │ ├── tensorboard.yaml │ ├── neptune.yaml │ ├── mlflow.yaml │ ├── comet.yaml │ ├── wandb.yaml │ └── aim.yaml ├── extras │ └── default.yaml ├── eval.yaml ├── data │ ├── libritts.yaml │ ├── libritts_clean.yaml │ ├── multilingual.yaml │ ├── libritts_r.yaml │ ├── aihub_libri_korean.yaml │ ├── aihub_libri_japanese.yaml │ └── multilingual_lang_id.yaml ├── hydra │ └── default.yaml ├── paths │ └── default.yaml ├── experiment │ ├── korean_base.yaml │ ├── japanese_base.yaml │ ├── libritts_base.yaml │ ├── default.yaml │ ├── multilingual_base.yaml │ ├── multilingual_lang_id.yaml │ ├── libritts_clean_small.yaml │ └── libritts_small.yaml ├── train.yaml └── model │ ├── pflow_base.yaml │ └── pflow_small.yaml ├── pflow_encodec ├── __init__.py ├── data │ ├── __init__.py │ ├── tokenizer.py │ ├── text_latent_dur_dataset.py │ ├── sampler.py │ └── datamodule.py ├── models │ ├── __init__.py │ ├── lightning_modules │ │ ├── __init__.py │ │ └── pflow.py │ └── pflow.py ├── utils │ ├── helper.py │ ├── __init__.py │ ├── logging_utils.py │ ├── instantiators.py │ ├── export.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── grl.py │ ├── duration_predictor.py │ ├── spk_enc.py │ ├── text_enc.py │ ├── flow_matching.py │ └── transformer.py ├── eval.py └── train.py ├── .project-root ├── samples ├── lj_prompt.wav ├── jsut_prompt.wav ├── jsut_sample.wav ├── kss_prompt.wav ├── kss_sample.wav ├── code_switch_jsut.wav ├── code_switch_kss.wav ├── code_switch_libri.wav ├── libritts_r_prompt.wav └── libritts_r_sample.wav ├── screenshots └── pflow_libri_tb.png ├── infer-requirements.txt ├── .env.example ├── pyproject.toml ├── setup.py ├── Makefile ├── requirements.txt ├── MODEL_CARD.md ├── .gitignore ├── .pre-commit-config.yaml └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pflow_encodec/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pflow_encodec/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pflow_encodec/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pflow_encodec/models/lightning_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pflow_encodec/utils/helper.py: -------------------------------------------------------------------------------- 1 | def exists(x): 2 | return x is not None 3 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /samples/lj_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/lj_prompt.wav -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /samples/jsut_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/jsut_prompt.wav -------------------------------------------------------------------------------- /samples/jsut_sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/jsut_sample.wav -------------------------------------------------------------------------------- /samples/kss_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/kss_prompt.wav -------------------------------------------------------------------------------- /samples/kss_sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/kss_sample.wav -------------------------------------------------------------------------------- /samples/code_switch_jsut.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/code_switch_jsut.wav -------------------------------------------------------------------------------- /samples/code_switch_kss.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/code_switch_kss.wav -------------------------------------------------------------------------------- /samples/code_switch_libri.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/code_switch_libri.wav -------------------------------------------------------------------------------- /samples/libritts_r_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/libritts_r_prompt.wav -------------------------------------------------------------------------------- /samples/libritts_r_sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/samples/libritts_r_sample.wav -------------------------------------------------------------------------------- /screenshots/pflow_libri_tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seastar105/pflow-encodec/HEAD/screenshots/pflow_libri_tb.png -------------------------------------------------------------------------------- /infer-requirements.txt: -------------------------------------------------------------------------------- 1 | deepfilternet 2 | git+https://github.com/seastar105/audiocraft.git#egg=audiocraft 3 | vocos 4 | gradio 5 | -------------------------------------------------------------------------------- /configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | learning_rate_monitor: 2 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 3 | logging_interval: "step" 4 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: 4 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /pflow_encodec/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from pflow_encodec.modules.duration_predictor import DurationPredictor 2 | from pflow_encodec.modules.flow_matching import FlowMatchingTransformer 3 | from pflow_encodec.modules.grl import GradientReversal 4 | from pflow_encodec.modules.spk_enc import SpeakerEncoder 5 | from pflow_encodec.modules.text_enc import TextEncoder 6 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /pflow_encodec/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pflow_encodec.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from pflow_encodec.utils.logging_utils import log_hyperparameters 3 | from pflow_encodec.utils.pylogger import RankedLogger 4 | from pflow_encodec.utils.rich_utils import enforce_tags, print_config_tree 5 | from pflow_encodec.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/data/libritts.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/libritts/train.tsv 4 | val_tsv_path: /home/seastar105/datasets/libritts/dev.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5385722746271299 15 | std: 4.867310381340673 16 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /configs/data/libritts_clean.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/libritts/train-clean.tsv 4 | val_tsv_path: /home/seastar105/datasets/libritts/dev-clean.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5370954750544933 15 | std: 4.868750292663943 16 | -------------------------------------------------------------------------------- /configs/data/multilingual.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/multilingual/train_meta.tsv 4 | val_tsv_path: /home/seastar105/datasets/multilingual/dev_meta.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5408026252250875 15 | std: 4.998098761068811 16 | -------------------------------------------------------------------------------- /configs/data/libritts_r.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/libritts_r/train_duration.tsv 4 | val_tsv_path: /home/seastar105/datasets/libritts_r/dev_duration.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5444963574409485 15 | std: 5.242217063903809 16 | -------------------------------------------------------------------------------- /configs/data/aihub_libri_korean.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/aihub_libri/korean/train/meta.tsv 4 | val_tsv_path: /home/seastar105/datasets/aihub_libri/korean/dev/meta.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5410842929830185 15 | std: 4.892964436708968 16 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /configs/data/aihub_libri_japanese.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/aihub_libri/japanese/train/meta.tsv 4 | val_tsv_path: /home/seastar105/datasets/aihub_libri/japanese/dev/meta.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5383276429997254 15 | std: 4.951283757776964 16 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]) -> None: 12 | """Default method for executing shell commands with `pytest` and `sh` package. 13 | 14 | :param command: A list of shell commands as strings. 15 | """ 16 | msg = None 17 | try: 18 | sh.python(command) 19 | except sh.ErrorReturnCode as e: 20 | msg = e.stderr.decode() 21 | if msg: 22 | pytest.fail(msg=msg) 23 | -------------------------------------------------------------------------------- /configs/data/multilingual_lang_id.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 2 | 3 | train_tsv_path: /home/seastar105/datasets/multilingual/train_meta_lang.tsv 4 | val_tsv_path: /home/seastar105/datasets/multilingual/dev_meta_lang.tsv 5 | add_trailing_silence: True 6 | batch_durations: 50.0 7 | min_duration: 3.5 8 | max_duration: 15.0 9 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 10 | num_workers: 8 11 | return_upsampled: False 12 | max_frame: 1500 # 20s 13 | text2latent_rate: 1.5 # 50Hz:75Hz 14 | mean: -0.5408026252250875 15 | std: 4.998098761068811 16 | use_lang_id: True 17 | languages: 18 | - en 19 | - ja 20 | - ko 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="src", 7 | version="0.0.1", 8 | description="Describe Your Cool Project", 9 | author="", 10 | author_email="", 11 | url="https://github.com/user/project", 12 | install_requires=["lightning", "hydra-core"], 13 | packages=find_packages(), 14 | # use this to customize global commands available in the terminal after installing the package 15 | entry_points={ 16 | "console_scripts": [ 17 | "train_command = src.train:main", 18 | "eval_command = src.eval:main", 19 | ] 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${task_name}.log 20 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_summary 3 | - rich_progress_bar 4 | - lr_monitor 5 | - _self_ 6 | 7 | model_summary: 8 | max_depth: -1 9 | 10 | val_checkpoint: 11 | dirpath: ${paths.output_dir}/checkpoints 12 | filename: ??? 13 | monitor: ??? 14 | mode: ??? 15 | save_last: True 16 | auto_insert_metric_name: False 17 | save_top_k: 3 18 | every_n_train_steps: 5000 19 | 20 | step_checkpoint: 21 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 22 | dirpath: ${paths.output_dir}/checkpoints 23 | filename: "step_{step:06d}" 24 | monitor: "step" 25 | mode: "max" 26 | save_last: True 27 | auto_insert_metric_name: False 28 | save_top_k: 1 29 | every_n_train_steps: 5000 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | -------------------------------------------------------------------------------- /pflow_encodec/modules/grl.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/tadeephuy/GradientReversal/tree/master 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function 5 | 6 | 7 | class GradientReversal(Function): 8 | @staticmethod 9 | def forward(ctx, x, alpha): 10 | ctx.save_for_backward(x, alpha) 11 | return x 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | grad_input = None 16 | _, alpha = ctx.saved_tensors 17 | if ctx.needs_input_grad[0]: 18 | grad_input = -alpha * grad_output 19 | return grad_input, None 20 | 21 | 22 | revgrad = GradientReversal.apply 23 | 24 | 25 | class GradientReversal(nn.Module): 26 | def __init__(self, alpha): 27 | super().__init__() 28 | self.alpha = torch.tensor(alpha, requires_grad=False) 29 | 30 | def forward(self, x): 31 | return revgrad(x, self.alpha) 32 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/experiment/korean_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: aihub_libri_korean.yaml 5 | - override /model: pflow_base.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 100.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | scheduler: 25 | total_steps: ${trainer.max_steps} 26 | pct_start: 0.01 27 | sample_freq: 5000 28 | mean: ${data.mean} 29 | std: ${data.std} 30 | trainer: 31 | max_steps: 500000 32 | max_epochs: 10000 # arbitrary large number 33 | precision: bf16-mixed 34 | accumulate_grad_batches: 1 35 | gradient_clip_val: 0.2 36 | num_nodes: 1 37 | devices: 1 38 | hydra: 39 | run: 40 | dir: ${paths.log_dir}/${task_name}/runs/korean_base_bs100 41 | -------------------------------------------------------------------------------- /configs/experiment/japanese_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: aihub_libri_japanese.yaml 5 | - override /model: pflow_base.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 100.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | scheduler: 25 | total_steps: ${trainer.max_steps} 26 | pct_start: 0.01 27 | sample_freq: 5000 28 | mean: ${data.mean} 29 | std: ${data.std} 30 | trainer: 31 | max_steps: 500000 32 | max_epochs: 10000 # arbitrary large number 33 | precision: bf16-mixed 34 | accumulate_grad_batches: 1 35 | gradient_clip_val: 0.2 36 | num_nodes: 1 37 | devices: 1 38 | hydra: 39 | run: 40 | dir: ${paths.log_dir}/${task_name}/runs/japanese_base_bs100 41 | -------------------------------------------------------------------------------- /configs/experiment/libritts_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: libritts_r.yaml 5 | - override /model: pflow_base.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 100.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | scheduler: 25 | total_steps: ${trainer.max_steps} 26 | pct_start: 0.01 27 | sample_freq: 5000 28 | mean: ${data.mean} 29 | std: ${data.std} 30 | trainer: 31 | max_steps: 500000 32 | max_epochs: 10000 # arbitrary large number 33 | precision: bf16-mixed 34 | accumulate_grad_batches: 4 35 | gradient_clip_val: 1.0 36 | num_nodes: 1 37 | devices: 1 38 | hydra: 39 | run: 40 | dir: ${paths.log_dir}/${task_name}/runs/libritts_base_fixed_bs100x4_clip1.0_skipconnect 41 | -------------------------------------------------------------------------------- /configs/experiment/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: ??? 5 | - override /model: pflow_base.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 100.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | scheduler: 25 | total_steps: ${trainer.max_steps} 26 | pct_start: 0.02 27 | sample_freq: 5000 28 | sample_idx: [] # sample indices used for sampling while train. idx will be used to choose samples from validation dataset. 29 | mean: ${data.mean} 30 | std: ${data.std} 31 | trainer: 32 | max_steps: 500000 33 | max_epochs: 10000 # arbitrary large number 34 | precision: bf16-mixed 35 | accumulate_grad_batches: 4 36 | gradient_clip_val: 0.2 37 | num_nodes: 1 38 | devices: 1 39 | hydra: 40 | run: 41 | dir: ??? 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | # recommend install torch manually 3 | # conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia 4 | # torch>=2.2.0 5 | # torchaudio>=2.2.0 6 | lightning>=2.0.0 7 | torchmetrics>=0.11.4 8 | 9 | # --------- hydra --------- # 10 | hydra-core==1.3.2 11 | hydra-colorlog==1.2.0 12 | hydra-optuna-sweeper==1.2.0 13 | 14 | # --------- loggers --------- # 15 | # wandb 16 | # neptune-client 17 | # mlflow 18 | # comet-ml 19 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 20 | 21 | # --------- others --------- # 22 | rootutils # standardizing the project root setup 23 | pre-commit # hooks for applying linters on commit 24 | rich # beautiful text formatting in terminal 25 | pytest # tests 26 | einops 27 | descript-audiotools 28 | --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.2.0/cu121 29 | fairseq2==0.3.0.dev202402202355+cu121 30 | git+https://github.com/seastar105/seamless_communication.git#egg=seamless_communication 31 | pandas 32 | torchdiffeq 33 | transformers 34 | -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from lightning.fabric.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment. 9 | 10 | :param package_name: The name of the package to be checked. 11 | 12 | :return: `True` if the package is available. `False` otherwise. 13 | """ 14 | try: 15 | return pkg_resources.require(package_name) is not None 16 | except pkg_resources.DistributionNotFound: 17 | return False 18 | 19 | 20 | _TPU_AVAILABLE = TPUAccelerator.is_available() 21 | 22 | _IS_WINDOWS = platform.system() == "Windows" 23 | 24 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 25 | 26 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 27 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 28 | 29 | _WANDB_AVAILABLE = _package_available("wandb") 30 | _NEPTUNE_AVAILABLE = _package_available("neptune") 31 | _COMET_AVAILABLE = _package_available("comet_ml") 32 | _MLFLOW_AVAILABLE = _package_available("mlflow") 33 | -------------------------------------------------------------------------------- /configs/experiment/multilingual_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: multilingual.yaml 5 | - override /model: pflow_base.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 100.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | scheduler: 25 | total_steps: ${trainer.max_steps} 26 | pct_start: 0.02 27 | sample_freq: 5000 28 | mean: ${data.mean} 29 | std: ${data.std} 30 | # net_ckpt_path: /home/seastar105/Work/pflow-encodec/checkpoints/multilingual_base.ckpt 31 | trainer: 32 | max_steps: 500000 33 | max_epochs: 10000 # arbitrary large number 34 | precision: bf16-mixed 35 | accumulate_grad_batches: 1 36 | gradient_clip_val: 0.2 37 | num_nodes: 1 38 | devices: 1 39 | hydra: 40 | run: 41 | dir: ${paths.log_dir}/${task_name}/runs/multilingual_base_bs100x4_test 42 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/experiment/multilingual_lang_id.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: multilingual_lang_id.yaml 5 | - override /model: pflow_small.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 300.0 17 | num_workers: 12 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | net: 25 | num_languages: 3 26 | p_drop_lang: 0.1 27 | scheduler: 28 | total_steps: ${trainer.max_steps} 29 | pct_start: 0.02 30 | sample_freq: 5000 31 | mean: ${data.mean} 32 | std: ${data.std} 33 | languages: ${data.languages} 34 | max_lang_loss: 10.0 35 | trainer: 36 | max_steps: 250000 37 | max_epochs: 10000 # arbitrary large number 38 | precision: bf16-mixed 39 | accumulate_grad_batches: 1 40 | gradient_clip_val: 1.0 41 | num_nodes: 1 42 | devices: 1 43 | hydra: 44 | run: 45 | dir: ${paths.log_dir}/${task_name}/runs/multilingual_lang_id_small_bs300_no_grl 46 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: null # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/experiment/libritts_clean_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: libritts_clean.yaml 5 | - override /model: pflow_small.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 400.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | sample_idx: [0, 500, 1000, 1500, 2000, 2500] 25 | scheduler: 26 | total_steps: ${trainer.max_steps} 27 | pct_start: 0.0125 # 5000 warmup steps 28 | final_div_factor: 0.04 29 | net: 30 | flow_matching_attn_processor: sdpa 31 | sample_freq: 5000 32 | mean: ${data.mean} 33 | std: ${data.std} 34 | 35 | trainer: 36 | max_steps: 1000000 37 | max_epochs: 10000 # arbitrary large number 38 | precision: bf16-mixed 39 | accumulate_grad_batches: 1 40 | gradient_clip_val: 1.0 41 | num_nodes: 1 42 | devices: 1 43 | hydra: 44 | run: 45 | dir: ${paths.log_dir}/${task_name}/runs/true_libritts_clean_small_bs400_retrain 46 | -------------------------------------------------------------------------------- /configs/experiment/libritts_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: libritts.yaml 5 | - override /model: pflow_small.yaml 6 | - override /callbacks: default.yaml 7 | - override /trainer: gpu.yaml 8 | - override /logger: tensorboard.yaml 9 | 10 | task_name: pflow 11 | tags: ["pflow"] 12 | seed: 998244353 13 | test: False 14 | 15 | data: 16 | batch_durations: 400.0 17 | num_workers: 8 18 | callbacks: 19 | val_checkpoint: 20 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 21 | monitor: val/latent_loss 22 | mode: "min" 23 | model: 24 | optimizer: 25 | _target_: torch.optim.AdamW 26 | fused: False 27 | scheduler: 28 | total_steps: ${trainer.max_steps} 29 | pct_start: 0.005 # 5000 warmup steps 30 | final_div_factor: 0.04 31 | net: 32 | flow_matching_attn_processor: sdpa 33 | sample_freq: 5000 34 | mean: ${data.mean} 35 | std: ${data.std} 36 | 37 | trainer: 38 | max_steps: 1000000 39 | max_epochs: 10000 # arbitrary large number 40 | precision: bf16-mixed 41 | accumulate_grad_batches: 1 42 | gradient_clip_val: 1.0 43 | num_nodes: 1 44 | devices: 1 45 | hydra: 46 | run: 47 | dir: ${paths.log_dir}/${task_name}/runs/true_libritts_small_bs400 48 | -------------------------------------------------------------------------------- /pflow_encodec/modules/duration_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from pflow_encodec.modules.transformer import Wav2Vec2PositionEncoderLayer 7 | from pflow_encodec.utils.helper import exists 8 | 9 | 10 | class DurationPredictor(nn.Module): 11 | def __init__(self, dim_input: int, dim: int, depth: int, kernel_size: int, dropout: float): 12 | super().__init__() 13 | self.input_proj = nn.Linear(dim_input, dim) 14 | 15 | self.convs = nn.ModuleList() 16 | self.dropout = nn.Dropout(dropout) 17 | for _ in range(depth): 18 | self.convs.append(Wav2Vec2PositionEncoderLayer(dim, kernel_size, groups=1)) 19 | self.output_proj = nn.Linear(dim, 1) 20 | 21 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 22 | x = self.input_proj(x) 23 | if exists(mask): 24 | mask = mask[..., None] 25 | x = x.masked_fill(~mask, 0.0) 26 | x = x.transpose(1, 2) 27 | for conv in self.convs: 28 | x = conv(x) 29 | x = self.dropout(x) 30 | x = x.transpose(1, 2) 31 | if exists(mask): 32 | x = x.masked_fill(~mask, 0.0) 33 | return self.output_proj(x).squeeze(-1) 34 | -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /scripts/dump_latents.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from tqdm.auto import tqdm 8 | 9 | from pflow_encodec.data.tokenizer import EncodecTokenizer 10 | 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--input_tsv", type=str, required=True) 19 | parser.add_argument("--output_ext", type=str, help="output csv file", default=".latent.npy") 20 | 21 | args = parser.parse_args() 22 | 23 | df = pd.read_csv(args.input_tsv, sep="\t", engine="pyarrow") 24 | tokenizer = EncodecTokenizer() 25 | paths = df["audio_path"].tolist() 26 | cnt = 0 27 | val_sum = 0 28 | square_sum = 0 29 | with torch.inference_mode(): 30 | for path in tqdm(paths): 31 | output_path = Path(path).with_suffix(args.output_ext) 32 | if output_path.exists(): 33 | continue 34 | latent = tokenizer.encode_file(path, return_code=False) 35 | cnt += latent.numel() 36 | val_sum += latent.sum().item() 37 | square_sum += (latent**2).sum().item() 38 | np.save(output_path, latent.cpu().numpy().astype(np.float32)) 39 | print(f"mean: {val_sum / cnt}, std: {np.sqrt(square_sum / cnt - (val_sum / cnt) ** 2)}") 40 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: mnist 8 | - model: mnist 9 | - callbacks: default 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | tags: ["dev"] 37 | 38 | # set False to skip model training 39 | train: True 40 | 41 | # evaluate on test set, using best model weights achieved during training 42 | # lightning chooses best weights based on the metric specified in checkpoint callback 43 | test: True 44 | 45 | # simply provide checkpoint path to resume training 46 | ckpt_path: null 47 | 48 | # seed for random number generators in pytorch, numpy and python.random 49 | seed: null 50 | -------------------------------------------------------------------------------- /pflow_encodec/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning_utilities.core.rank_zero import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from pflow_encodec.utils import pylogger 7 | 8 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /pflow_encodec/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from pflow_encodec.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /pflow_encodec/utils/export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import torch 5 | 6 | 7 | def export_lightning_ckpt(input_path, output_path): 8 | ckpt = torch.load(input_path, map_location="cpu") 9 | state_dict = ckpt["state_dict"] 10 | cfg = ckpt["hyper_parameters"] 11 | model_config = cfg["net"] 12 | data_config = {} 13 | data_config["mean"] = cfg["mean"] 14 | data_config["std"] = cfg["std"] 15 | data_config["text2latent_ratio"] = cfg["text2latent_ratio"] 16 | if "languages" in cfg and cfg["languages"] is not None: 17 | languages = cfg["languages"] 18 | lang2idx = {lang: idx for idx, lang in enumerate(languages)} 19 | data_config["lang2idx"] = lang2idx 20 | state_dict = {k[len("net.") :]: v for k, v in state_dict.items()} 21 | 22 | model = hydra.utils.instantiate(model_config) 23 | model.load_state_dict(state_dict) 24 | 25 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 26 | 27 | torch.save( 28 | { 29 | "state_dict": model.state_dict(), 30 | "data_config": data_config, 31 | "model_config": model_config, 32 | }, 33 | output_path, 34 | ) 35 | 36 | 37 | def merge_ckpts(ckpt_paths, output_path): 38 | assert len(ckpt_paths) > 1, "Please provide more than one checkpoint path" 39 | ckpts = [torch.load(p, map_location="cpu") for p in ckpt_paths] 40 | state_dicts = [ckpt["state_dict"] for ckpt in ckpts] 41 | keys = set(state_dicts[0].keys()) 42 | 43 | # key check 44 | for key in keys: 45 | for state_dict in state_dicts: 46 | assert key in state_dict, f"{key} not found in state_dict" 47 | 48 | # shape check 49 | for key in keys: 50 | tensors = [state_dict[key] for state_dict in state_dicts] 51 | shapes = [t.shape for t in tensors] 52 | assert len(set(shapes)) == 1, f"Shapes of {key} are not the same: {shapes}" 53 | 54 | new_state_dict = {} 55 | for key in keys: 56 | new_state_dict[key] = torch.stack([state_dict[key] for state_dict in state_dicts]).mean(dim=0) 57 | 58 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 59 | data_config = ckpts[0]["data_config"] 60 | model_config = ckpts[0]["model_config"] 61 | torch.save( 62 | { 63 | "state_dict": new_state_dict, 64 | "data_config": data_config, 65 | "model_config": model_config, 66 | }, 67 | output_path, 68 | ) 69 | -------------------------------------------------------------------------------- /scripts/dump_resampled.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import torch 6 | from audiotools import AudioSignal 7 | from tqdm.auto import tqdm 8 | 9 | 10 | class AudioOnlyDataset(torch.utils.data.Dataset): 11 | def __init__(self, input_path_list, output_sr): 12 | super().__init__() 13 | self.input_path_list = input_path_list 14 | self.output_sr = output_sr 15 | 16 | def __len__(self): 17 | return len(self.input_path_list) 18 | 19 | def __getitem__(self, idx): 20 | try: 21 | input_path = self.input_path_list[idx] 22 | signal = AudioSignal(input_path) 23 | if signal.sample_rate != self.output_sr: 24 | signal = signal.resample(self.output_sr) 25 | if signal.num_channels > 1: 26 | signal = signal.to_mono() 27 | signal = signal.ensure_max_of_audio() 28 | duration = signal.duration 29 | except Exception as e: 30 | print(f"Error in {input_path}") 31 | return None, None 32 | 33 | return signal, duration 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--input_path_list", type=str, required=True) 40 | parser.add_argument("--output_path_list", type=str, required=True) 41 | parser.add_argument("--output_sr", type=int, default=24000) 42 | parser.add_argument("--output_tsv_path", type=str, required=True) 43 | 44 | arguments = parser.parse_args() 45 | with open(arguments.input_path_list) as f: 46 | input_path_list = [line.strip() for line in f] 47 | with open(arguments.output_path_list) as f: 48 | output_path_list = [line.strip() for line in f] 49 | 50 | ds = AudioOnlyDataset(input_path_list, arguments.output_sr) 51 | dl = torch.utils.data.DataLoader(ds, batch_size=None, num_workers=32) 52 | durations = [] 53 | for i, (signal, duration) in tqdm(enumerate(dl), total=len(dl)): 54 | output_path = output_path_list[i] 55 | try: 56 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 57 | if not Path(output_path).exists(): 58 | signal.write(output_path) 59 | durations.append(duration) 60 | except Exception as e: 61 | print(f"Error in {output_path}") 62 | 63 | df = pd.DataFrame({"audio_path": output_path_list, "duration": durations}) 64 | df.to_csv(arguments.output_tsv_path, sep="\t", index=False) 65 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # Model Card for P-Flow Encodec TTS (English, Korean, Japanese) 2 | 3 | ## Model Details 4 | 5 | ### Model Description 6 | 7 | P-Flow Encodec is Text-to-Speech model based on paper [P-Flow: A Fast and Data-Efficient Zero-Shot TTS through Speech Prompting](https://openreview.net/pdf?id=zNA7u7wtIN), with some modification. You can check differences [here](https://github.com/seastar105/pflow-encodec?tab=readme-ov-file#difference-from-paper). Model consists of Encodec model from Meta, and Multiband Diffusion decoder, which is also from Meta. 8 | 9 | - **Developed by**: [seastar105](https://github.com/seastar105) 10 | - **Model type**: Text-to-Speech 11 | - **Language**: English, Korean, Japanese 12 | - **License**: MIT for codes, also it's free to use model weights, but you should indicate that model weight is trained with data from AI Hub, (e.g. This research (paper) used datasets from 'The Open AI Dataset Project (AI-Hub, S. Korea)'. All data information can be accessed through 'AI-Hub (www.aihub.or.kr)) 13 | - **Model version**: 1.0 14 | - **Code Repository**: https://github.com/seastar105/pflow-encodec 15 | - **Model Repository**: https://huggingface.co/seastar105/pflow-encodec-ejk 16 | 17 | ## Intended Use 18 | 19 | ### Primary intended use 20 | 21 | This model is trained for zero-shot Multilingual TTS. It can be used for generating speech from text in English, Korean, Japanese. Primary intended use is for research purpose, as a baseline for multilingual, code-switch TTS. 22 | 23 | ### Out of scope use cases 24 | 25 | This model should not be used for generating or editing someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. 26 | 27 | ## Training Details 28 | 29 | - **Training dataset**: LibriTTS-R, Korean and Japanese corpus from [AIHub 131](https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=data&dataSetSn=71524) dataset (Multi-lingual Read Speech corpus for Translation). Only samples with duration less than 15 seconds and over 3.5 seconds are used, 380 hours for english, 637 hours for japanese, 705 hours for korean. 30 | - **Finetuned from**: Multilingual model is initialized using merged pretrained model for each languages mentioned above. Monolingual model for each language is trained with ~250K steps, and then merge their weights with average. Then, it is finetuned with ~250K steps with all languages. 31 | - **Compute Resource**: All model is trained with one RTX 4090 GPU. It takes about 1 day for 100K steps using 4 gradient accumulation steps with batch_durations of 100. 32 | -------------------------------------------------------------------------------- /pflow_encodec/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank 17 | prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: 28 | """Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's 29 | being logged from. If `'rank'` is provided, then the log will only occur on that rank/process. 30 | 31 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 32 | :param msg: The message to log. 33 | :param rank: The rank to log at. 34 | :param args: Additional args to pass to the underlying logging function. 35 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 36 | """ 37 | if self.isEnabledFor(level): 38 | msg, kwargs = self.process(msg, kwargs) 39 | current_rank = getattr(rank_zero_only, "rank", None) 40 | if current_rank is None: 41 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 42 | msg = rank_prefixed_message(msg, current_rank) 43 | if self.rank_zero_only: 44 | if current_rank == 0: 45 | self.logger.log(level, msg, *args, **kwargs) 46 | else: 47 | if rank is None: 48 | self.logger.log(level, msg, *args, **kwargs) 49 | elif current_rank == rank: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | -------------------------------------------------------------------------------- /configs/model/pflow_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.models.lightning_modules.pflow.PFlowLightningModule 2 | 3 | net: 4 | _target_: pflow_encodec.models.pflow.PFlow 5 | feature_dim: 128 6 | text_encoder_vocab_size: 10904 7 | text_encoder_embed_dim: 192 8 | text_encoder_conv_pos_depth: 2 9 | text_encoder_conv_pos_kernel_size: 15 10 | text_encoder_conv_pos_groups: 16 11 | text_encoder_depth: 6 12 | text_encoder_dim: 192 13 | text_encoder_dim_head: 96 14 | text_encoder_heads: 2 15 | text_encoder_ff_mult: 4.0 16 | text_encoder_attn_dropout: 0.1 17 | text_encoder_ff_dropout: 0.0 18 | text_encoder_attn_processor: naive 19 | text_encoder_norm_type: ada_proj 20 | text_encoder_ff_type: conv 21 | text_encoder_ff_kernel_size: 3 22 | text_encoder_ff_groups: 1 23 | text_encoder_scale_type: ada_proj 24 | speaker_encoder_dim_input: 128 25 | speaker_encoder_conv_pos_depth: 2 26 | speaker_encoder_conv_pos_kernel_size: 15 27 | speaker_encoder_conv_pos_groups: 16 28 | speaker_encoder_depth: 2 29 | speaker_encoder_dim: 192 30 | speaker_encoder_dim_head: 96 31 | speaker_encoder_heads: 2 32 | speaker_encoder_ff_mult: 4.0 33 | speaker_encoder_attn_dropout: 0.1 34 | speaker_encoder_ff_dropout: 0.0 35 | speaker_encoder_attn_processor: naive 36 | speaker_encoder_norm_type: layer 37 | speaker_encoder_ff_type: conv 38 | speaker_encoder_ff_kernel_size: 3 39 | speaker_encoder_ff_groups: 1 40 | speaker_encoder_scale_type: none 41 | flow_matching_dim_time: 2048 42 | flow_matching_conv_pos_kernel_size: 31 43 | flow_matching_conv_pos_depth: 2 44 | flow_matching_conv_pos_groups: 16 45 | flow_matching_depth: 6 46 | flow_matching_dim: 512 47 | flow_matching_dim_head: 128 48 | flow_matching_heads: 4 49 | flow_matching_ff_mult: 4.0 50 | flow_matching_attn_dropout: 0.1 51 | flow_matching_ff_dropout: 0.0 52 | flow_matching_attn_processor: naive 53 | flow_matching_norm_type: ada_embed 54 | flow_matching_ff_type: conv 55 | flow_matching_ff_kernel_size: 3 56 | flow_matching_ff_groups: 2 57 | flow_matching_scale_type: ada_embed 58 | duration_predictor_dim: 256 59 | duration_predictor_depth: 2 60 | duration_predictor_kernel_size: 3 61 | duration_predictor_dropout: 0.1 62 | p_uncond: 0.1 63 | interpolate_mode: linear 64 | sigma: 0.01 # from pflow paper 65 | optimizer: 66 | _target_: torch.optim.Adam 67 | lr: 0.0001 68 | scheduler: 69 | _target_: torch.optim.lr_scheduler.OneCycleLR 70 | max_lr: 0.0001 71 | anneal_strategy: linear 72 | total_steps: ??? 73 | pct_start: ??? 74 | sample_freq: 2000 75 | sample_idx: [0, 1000, 2000, 3000, 4000, 5000] 76 | mean: ??? 77 | std: ??? 78 | text2latent_ratio: 1.5 79 | -------------------------------------------------------------------------------- /configs/model/pflow_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: pflow_encodec.models.lightning_modules.pflow.PFlowLightningModule 2 | 3 | net: 4 | _target_: pflow_encodec.models.pflow.PFlow 5 | feature_dim: 128 6 | text_encoder_vocab_size: 10904 7 | text_encoder_embed_dim: 192 8 | text_encoder_conv_pos_depth: 2 9 | text_encoder_conv_pos_kernel_size: 15 10 | text_encoder_conv_pos_groups: 16 11 | text_encoder_depth: 4 12 | text_encoder_dim: 192 13 | text_encoder_dim_head: 96 14 | text_encoder_heads: 2 15 | text_encoder_ff_mult: 4.0 16 | text_encoder_attn_dropout: 0.1 17 | text_encoder_ff_dropout: 0.0 18 | text_encoder_attn_processor: naive 19 | text_encoder_norm_type: ada_proj 20 | text_encoder_ff_type: conv 21 | text_encoder_ff_kernel_size: 3 22 | text_encoder_ff_groups: 1 23 | text_encoder_scale_type: ada_proj 24 | speaker_encoder_dim_input: 128 25 | speaker_encoder_conv_pos_depth: 2 26 | speaker_encoder_conv_pos_kernel_size: 15 27 | speaker_encoder_conv_pos_groups: 16 28 | speaker_encoder_depth: 2 29 | speaker_encoder_dim: 256 30 | speaker_encoder_dim_head: 96 31 | speaker_encoder_heads: 2 32 | speaker_encoder_ff_mult: 4.0 33 | speaker_encoder_attn_dropout: 0.1 34 | speaker_encoder_ff_dropout: 0.0 35 | speaker_encoder_attn_processor: naive 36 | speaker_encoder_norm_type: layer 37 | speaker_encoder_ff_type: conv 38 | speaker_encoder_ff_kernel_size: 3 39 | speaker_encoder_ff_groups: 1 40 | speaker_encoder_scale_type: none 41 | flow_matching_dim_time: 2048 42 | flow_matching_conv_pos_kernel_size: 31 43 | flow_matching_conv_pos_depth: 2 44 | flow_matching_conv_pos_groups: 16 45 | flow_matching_depth: 8 46 | flow_matching_dim: 256 47 | flow_matching_dim_head: 64 48 | flow_matching_heads: 4 49 | flow_matching_ff_mult: 4.0 50 | flow_matching_attn_dropout: 0.1 51 | flow_matching_ff_dropout: 0.0 52 | flow_matching_attn_processor: naive 53 | flow_matching_norm_type: ada_embed 54 | flow_matching_ff_type: conv 55 | flow_matching_ff_kernel_size: 3 56 | flow_matching_ff_groups: 2 57 | flow_matching_scale_type: ada_embed 58 | duration_predictor_dim: 256 59 | duration_predictor_depth: 2 60 | duration_predictor_kernel_size: 3 61 | duration_predictor_dropout: 0.1 62 | p_uncond: 0.1 63 | interpolate_mode: linear 64 | sigma: 0.01 # from pflow paper 65 | optimizer: 66 | _target_: torch.optim.Adam 67 | lr: 0.0001 68 | scheduler: 69 | _target_: torch.optim.lr_scheduler.OneCycleLR 70 | max_lr: 0.0001 71 | anneal_strategy: linear 72 | total_steps: ??? 73 | pct_start: ??? 74 | sample_freq: 2000 75 | sample_idx: [0, 1000, 2000, 3000, 4000, 5000] 76 | mean: ??? 77 | std: ??? 78 | text2latent_ratio: 1.5 79 | -------------------------------------------------------------------------------- /pflow_encodec/modules/spk_enc.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from pflow_encodec.modules.transformer import ( 7 | MultiHeadAttention, 8 | Transformer, 9 | Wav2Vec2StackedPositionEncoder, 10 | ) 11 | 12 | 13 | class SpeakerEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | dim_input: int, 17 | conv_pos_kernel_size: int, 18 | conv_pos_depth: int, 19 | conv_pos_groups: int, 20 | depth: int, 21 | dim: int, 22 | dim_head: int, 23 | heads: int, 24 | ff_mult: float, 25 | attn_dropout: float, 26 | ff_dropout: float, 27 | attn_processor: Literal["naive", "sdpa"] = "naive", 28 | norm_type: Literal["layer", "ada_proj", "ada_embed"] = "layer", 29 | ff_type: Literal["conv", "linear"] = "linear", 30 | ff_kernel_size: Optional[int] = None, 31 | ff_groups: Optional[int] = None, 32 | layer_norm_eps: float = 1e-6, 33 | scale_type: Literal["none", "ada_proj", "ada_embed"] = "none", 34 | pool_query_range: float = 0.02, 35 | ): 36 | super().__init__() 37 | self.proj = nn.Linear(dim_input, dim) 38 | 39 | self.conv_pos = Wav2Vec2StackedPositionEncoder( 40 | depth=conv_pos_depth, 41 | dim=dim, 42 | kernel_size=conv_pos_kernel_size, 43 | groups=conv_pos_groups, 44 | ) 45 | self.encoder = Transformer( 46 | depth=depth, 47 | dim=dim, 48 | dim_head=dim_head, 49 | heads=heads, 50 | ff_mult=ff_mult, 51 | attn_dropout=attn_dropout, 52 | ff_dropout=ff_dropout, 53 | attn_processor=attn_processor, 54 | norm_type=norm_type, 55 | ff_type=ff_type, 56 | ff_kernel_size=ff_kernel_size, 57 | ff_groups=ff_groups, 58 | layer_norm_eps=layer_norm_eps, 59 | scale_type=scale_type, 60 | ) 61 | self.query = nn.Parameter(torch.randn(1, 1, dim)) 62 | nn.init.trunc_normal_(self.query, mean=0, std=pool_query_range) 63 | 64 | self.pool = MultiHeadAttention( 65 | dim=dim, 66 | dim_head=dim, 67 | heads=1, 68 | processor=attn_processor, 69 | ) 70 | 71 | def reset_parameters(self): 72 | self.conv_pos.reset_parameters() 73 | 74 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): 75 | x = self.proj(x) 76 | x = x + self.conv_pos(x, mask) 77 | x = self.encoder(x, mask) 78 | emb = self.pool(self.query, context=x, mask=mask) 79 | return emb 80 | -------------------------------------------------------------------------------- /scripts/dump_durations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from seamless_communication.models.aligner.alignment_extractor import AlignmentExtractor 9 | from tqdm.auto import tqdm 10 | 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--input_tsv", type=str, required=True) 19 | parser.add_argument( 20 | "--output_ext", type=str, help="output extension of character duration", default=".duration.npy" 21 | ) 22 | parser.add_argument("--empty_cache_rate", type=int, default=5000) 23 | args = parser.parse_args() 24 | 25 | extractor = AlignmentExtractor( 26 | aligner_model_name_or_card="nar_t2u_aligner", 27 | unit_extractor_model_name_or_card="xlsr2_1b_v2", 28 | unit_extractor_output_layer=35, 29 | unit_extractor_kmeans_model_uri="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy", 30 | device=torch.device("cuda"), 31 | dtype=torch.float16, 32 | ) 33 | 34 | df = pd.read_csv(args.input_tsv, sep="\t", engine="pyarrow") 35 | paths = df["audio_path"].tolist() 36 | texts = df["text"].tolist() 37 | errors = [] 38 | with torch.inference_mode(): 39 | for idx, (path, text) in tqdm(enumerate(zip(paths, texts)), total=len(paths)): 40 | if args.empty_cache_rate > 0 and idx % args.empty_cache_rate == 0: 41 | torch.cuda.empty_cache() 42 | logging.info("Cleaned CUDA cache") 43 | output_path = Path(path).with_suffix(args.output_ext) 44 | if output_path.exists(): 45 | continue 46 | try: 47 | durations, token_ids, tokens = extractor.extract_alignment( 48 | path, 49 | text, 50 | plot=False, 51 | add_trailing_silence=True, 52 | ) 53 | assert ( 54 | durations.shape[-1] == token_ids.shape[-1] 55 | ), f"Text token and duration shape mismatch: {durations.shape} != {token_ids.shape}, path={path}, text={text}" 56 | np.save(output_path, durations.cpu().numpy().astype(np.int64)) 57 | except Exception as e: 58 | errors.append((path, text, str(e))) 59 | print(f"Error in {path}: {e}") # fallback to cpu? 60 | if errors: 61 | logging.error(f"Errors: {errors}") 62 | with open(Path(args.input_tsv).parent / "errors.txt", "w") as f: 63 | for error in errors: 64 | f.write(f"{error}\n") 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | ### VisualStudioCode 133 | .vscode/* 134 | !.vscode/settings.json 135 | !.vscode/tasks.json 136 | !.vscode/launch.json 137 | !.vscode/extensions.json 138 | *.code-workspace 139 | **/.vscode 140 | 141 | # JetBrains 142 | .idea/ 143 | 144 | # Data & Models 145 | *.h5 146 | *.tar 147 | *.tar.gz 148 | 149 | # Lightning-Hydra-Template 150 | configs/local/default.yaml 151 | /data/ 152 | /logs/ 153 | .env 154 | 155 | # Aim logging 156 | .aim 157 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | 19 | # python code formatting 20 | - repo: https://github.com/psf/black 21 | rev: 23.1.0 22 | hooks: 23 | - id: black 24 | args: [--line-length, "120"] 25 | 26 | # python import sorting 27 | - repo: https://github.com/PyCQA/isort 28 | rev: 5.12.0 29 | hooks: 30 | - id: isort 31 | args: ["--profile", "black", "--filter-files"] 32 | 33 | # python upgrading syntax to newer version 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.3.1 36 | hooks: 37 | - id: pyupgrade 38 | args: [--py38-plus] 39 | 40 | # python docstring formatting 41 | - repo: https://github.com/myint/docformatter 42 | rev: v1.5.1 43 | hooks: 44 | - id: docformatter 45 | args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] 46 | 47 | # python check (PEP8), programming errors and code complexity 48 | - repo: https://github.com/PyCQA/flake8 49 | rev: 6.0.0 50 | hooks: 51 | - id: flake8 52 | args: 53 | [ 54 | "--extend-ignore", 55 | "E203,E402,E501,F401,F841,E731", 56 | "--exclude", 57 | "logs/*,data/*", 58 | ] 59 | 60 | # python security linter 61 | - repo: https://github.com/PyCQA/bandit 62 | rev: "1.7.5" 63 | hooks: 64 | - id: bandit 65 | args: ["-s", "B101"] 66 | 67 | # yaml formatting 68 | - repo: https://github.com/pre-commit/mirrors-prettier 69 | rev: v3.0.0-alpha.6 70 | hooks: 71 | - id: prettier 72 | types: [yaml] 73 | exclude: "environment.yaml" 74 | 75 | # shell scripts linter 76 | - repo: https://github.com/shellcheck-py/shellcheck-py 77 | rev: v0.9.0.2 78 | hooks: 79 | - id: shellcheck 80 | 81 | # md formatting 82 | - repo: https://github.com/executablebooks/mdformat 83 | rev: 0.7.16 84 | hooks: 85 | - id: mdformat 86 | args: ["--number"] 87 | additional_dependencies: 88 | - mdformat-gfm 89 | - mdformat-tables 90 | - mdformat_frontmatter 91 | # - mdformat-toc 92 | # - mdformat-black 93 | 94 | # word spelling linter 95 | - repo: https://github.com/codespell-project/codespell 96 | rev: v2.2.4 97 | hooks: 98 | - id: codespell 99 | args: 100 | - --skip=logs/**,data/**,*.ipynb,*.md 101 | - --ignore-words-list=numer 102 | 103 | # jupyter notebook cell output clearing 104 | - repo: https://github.com/kynan/nbstripout 105 | rev: 0.6.1 106 | hooks: 107 | - id: nbstripout 108 | 109 | # jupyter notebook linting 110 | - repo: https://github.com/nbQA-dev/nbQA 111 | rev: 1.6.3 112 | hooks: 113 | - id: nbqa-black 114 | args: ["--line-length=120"] 115 | - id: nbqa-isort 116 | args: ["--profile=black"] 117 | - id: nbqa-flake8 118 | args: 119 | [ 120 | "--extend-ignore=E203,E402,E501,F401,F841", 121 | "--exclude=logs/*,data/*", 122 | ] 123 | -------------------------------------------------------------------------------- /pflow_encodec/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from pflow_encodec.utils import pylogger 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | queue.append(field) if field in cfg else log.warning( 48 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 49 | ) 50 | 51 | # add all the other fields to queue (not specified in `print_order`) 52 | for field in cfg: 53 | if field not in queue: 54 | queue.append(field) 55 | 56 | # generate config tree from queue 57 | for field in queue: 58 | branch = tree.add(field, style=style, guide_style=style) 59 | 60 | config_group = cfg[field] 61 | if isinstance(config_group, DictConfig): 62 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 63 | else: 64 | branch_content = str(config_group) 65 | 66 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 67 | 68 | # print config tree 69 | rich.print(tree) 70 | 71 | # save config tree to file 72 | if save_to_file: 73 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 74 | rich.print(tree, file=file) 75 | 76 | 77 | @rank_zero_only 78 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 79 | """Prompts user to input tags from command line if no tags are provided in config. 80 | 81 | :param cfg: A DictConfig composed by Hydra. 82 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 83 | """ 84 | if not cfg.get("tags"): 85 | if "id" in HydraConfig().cfg.hydra.job: 86 | raise ValueError("Specify tags before launching a multirun!") 87 | 88 | log.warning("No tags provided in config. Prompting user to input tags...") 89 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 90 | tags = [t.strip() for t in tags.split(",") if t != ""] 91 | 92 | with open_dict(cfg): 93 | cfg.tags = tags 94 | 95 | log.info(f"Tags: {cfg.tags}") 96 | 97 | if save_to_file: 98 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 99 | rich.print(cfg.tags, file=file) 100 | -------------------------------------------------------------------------------- /pflow_encodec/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import hydra 4 | import rootutils 5 | from lightning import LightningDataModule, LightningModule, Trainer 6 | from lightning.pytorch.loggers import Logger 7 | from omegaconf import DictConfig 8 | 9 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 10 | # ------------------------------------------------------------------------------------ # 11 | # the setup_root above is equivalent to: 12 | # - adding project root dir to PYTHONPATH 13 | # (so you don't need to force user to install project as a package) 14 | # (necessary before importing any local modules e.g. `from src import utils`) 15 | # - setting up PROJECT_ROOT environment variable 16 | # (which is used as a base for paths in "configs/paths/default.yaml") 17 | # (this way all filepaths are the same no matter where you run the code) 18 | # - loading environment variables from ".env" in root dir 19 | # 20 | # you can remove it if you: 21 | # 1. either install project as a package or move entry files to project root dir 22 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 23 | # 24 | # more info: https://github.com/ashleve/rootutils 25 | # ------------------------------------------------------------------------------------ # 26 | 27 | from pflow_encodec.utils import ( 28 | RankedLogger, 29 | extras, 30 | instantiate_loggers, 31 | log_hyperparameters, 32 | task_wrapper, 33 | ) 34 | 35 | log = RankedLogger(__name__, rank_zero_only=True) 36 | 37 | 38 | @task_wrapper 39 | def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 40 | """Evaluates given checkpoint on a datamodule testset. 41 | 42 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 43 | failure. Useful for multiruns, saving info about the crash, etc. 44 | 45 | :param cfg: DictConfig configuration composed by Hydra. 46 | :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. 47 | """ 48 | assert cfg.ckpt_path 49 | 50 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 51 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 52 | 53 | log.info(f"Instantiating model <{cfg.model._target_}>") 54 | model: LightningModule = hydra.utils.instantiate(cfg.model) 55 | 56 | log.info("Instantiating loggers...") 57 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 58 | 59 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 60 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 61 | 62 | object_dict = { 63 | "cfg": cfg, 64 | "datamodule": datamodule, 65 | "model": model, 66 | "logger": logger, 67 | "trainer": trainer, 68 | } 69 | 70 | if logger: 71 | log.info("Logging hyperparameters!") 72 | log_hyperparameters(object_dict) 73 | 74 | log.info("Starting testing!") 75 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) 76 | 77 | # for predictions use trainer.predict(...) 78 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) 79 | 80 | metric_dict = trainer.callback_metrics 81 | 82 | return metric_dict, object_dict 83 | 84 | 85 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") 86 | def main(cfg: DictConfig) -> None: 87 | """Main entry point for evaluation. 88 | 89 | :param cfg: DictConfig configuration composed by Hydra. 90 | """ 91 | # apply extra utilities 92 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 93 | extras(cfg) 94 | 95 | evaluate(cfg) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /pflow_encodec/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audiotools import AudioSignal 3 | from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer 4 | from transformers import EncodecModel 5 | 6 | 7 | class EncodecTokenizer: 8 | def __init__( 9 | self, model_name: str = "facebook/encodec_24khz", device: str = "cuda", dtype: torch.dtype = torch.float32 10 | ): 11 | model = EncodecModel.from_pretrained(model_name) 12 | 13 | self.device = torch.device(device) 14 | self.dtype = dtype 15 | self.codec: EncodecModel = model.to(self.device, dtype=self.dtype).eval() 16 | 17 | self.sample_rate = self.codec.config.sampling_rate 18 | 19 | def load_audio(self, path: str) -> torch.Tensor: 20 | """Load audio file and transform it to the correct format for the model. 21 | 22 | Args: 23 | path (str): audio file path 24 | Returns: 25 | audio (torch.Tensor): audio tensor of shape (1, 1, T) 26 | """ 27 | signal = AudioSignal(path) 28 | if signal.sample_rate != self.sample_rate: 29 | signal = signal.resample(self.sample_rate) 30 | if signal.num_channels > 1: 31 | signal = signal.to_mono() 32 | return signal.audio_data.to(device=self.device, dtype=self.dtype) 33 | 34 | def encode_audio(self, audio: torch.Tensor, return_code: bool = False) -> torch.Tensor: 35 | """Encode audio to latent space, return discrete tokens if return_latent is False. 36 | 37 | Args: 38 | audio (torch.Tensor): audio tensor of shape (1, 1, T) 39 | return_latent (bool, optional): return discrete tokens if False, return continuous latent before quantization if True. 40 | 41 | Returns: 42 | torch.Tensor: encoded tokens or latent 43 | """ 44 | latents = self.codec.encoder(audio).transpose(-2, -1) 45 | if return_code: 46 | return self.codec.quantizer.encode(latents.transpose(-2, -1)).transpose(0, 1) 47 | return latents 48 | 49 | def encode_file(self, path: str, return_code: bool = False) -> torch.Tensor: 50 | audio = self.load_audio(path) 51 | return self.encode_audio(audio, return_code) 52 | 53 | def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: 54 | """Decode discrete tokens to audio. 55 | 56 | Args: 57 | codes (torch.Tensor): discrete tokens of shape (1, Q, T) 58 | 59 | Returns: 60 | torch.Tensor: audio tensor of shape (1, 1, T) 61 | """ 62 | return self.codec.decode(codes[None], [None])[0] 63 | 64 | def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: 65 | """Decode continuous latent to audio. 66 | 67 | Args: 68 | latents (torch.Tensor): continuous latent of shape (1, T, D) 69 | 70 | Returns: 71 | torch.Tensor: audio tensor of shape (1, 1, T) 72 | """ 73 | codes = self.codec.quantizer.encode(latents.transpose(-2, -1)).transpose(0, 1) 74 | return self.decode_codes(codes) 75 | 76 | def quantize_latents(self, latents: torch.Tensor) -> torch.Tensor: 77 | """Quantize continuous latent to discrete tokens. 78 | 79 | Args: 80 | latents (torch.Tensor): continuous latent of shape (1, T, D) 81 | 82 | Returns: 83 | torch.Tensor: discrete tokens of shape (1, Q, T) 84 | """ 85 | return self.codec.quantizer.encode(latents.transpose(-2, -1)).transpose(0, 1) 86 | 87 | 88 | class TextTokenizer: 89 | def __init__(self, add_trailing_silence: bool = True) -> None: 90 | text_tokenizer = load_unity_char_tokenizer("nar_t2u_aligner") 91 | self.tokenizer = text_tokenizer.create_raw_encoder() 92 | self.vocab_info = text_tokenizer.vocab_info 93 | 94 | self.bos_idx = self.vocab_info.bos_idx 95 | self.eos_idx = self.vocab_info.eos_idx 96 | self.pad_idx = self.vocab_info.pad_idx 97 | 98 | self.add_trailing_silence = add_trailing_silence 99 | 100 | def encode_text(self, text: str) -> torch.Tensor: 101 | """Encode text to discrete tokens. 102 | 103 | Args: 104 | text (str): input text 105 | 106 | Returns: 107 | torch.Tensor: discrete tokens 108 | """ 109 | tokens = self.tokenizer(text) 110 | if self.add_trailing_silence: 111 | tokens = torch.cat([tokens, tokens[0:1]]) 112 | return tokens 113 | -------------------------------------------------------------------------------- /pflow_encodec/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from typing import Any, Callable, Dict, Optional, Tuple 4 | 5 | from omegaconf import DictConfig 6 | 7 | from pflow_encodec.utils import pylogger, rich_utils 8 | 9 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 10 | 11 | 12 | def extras(cfg: DictConfig) -> None: 13 | """Applies optional utilities before the task is started. 14 | 15 | Utilities: 16 | - Ignoring python warnings 17 | - Setting tags from command line 18 | - Rich config printing 19 | 20 | :param cfg: A DictConfig object containing the config tree. 21 | """ 22 | # return if no `extras` config 23 | if not cfg.get("extras"): 24 | log.warning("Extras config not found! ") 25 | return 26 | 27 | # disable python warnings 28 | if cfg.extras.get("ignore_warnings"): 29 | log.info("Disabling python warnings! ") 30 | warnings.filterwarnings("ignore") 31 | 32 | # prompt user to input tags from command line if none are provided in the config 33 | if cfg.extras.get("enforce_tags"): 34 | log.info("Enforcing tags! ") 35 | rich_utils.enforce_tags(cfg, save_to_file=True) 36 | 37 | # pretty print config tree using Rich library 38 | if cfg.extras.get("print_config"): 39 | log.info("Printing config tree with Rich! ") 40 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 41 | 42 | 43 | def task_wrapper(task_func: Callable) -> Callable: 44 | """Optional decorator that controls the failure behavior when executing the task function. 45 | 46 | This wrapper can be used to: 47 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 48 | - save the exception to a `.log` file 49 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 50 | - etc. (adjust depending on your needs) 51 | 52 | Example: 53 | ``` 54 | @utils.task_wrapper 55 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 56 | ... 57 | return metric_dict, object_dict 58 | ``` 59 | 60 | :param task_func: The task function to be wrapped. 61 | 62 | :return: The wrapped task function. 63 | """ 64 | 65 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 66 | # execute the task 67 | try: 68 | metric_dict, object_dict = task_func(cfg=cfg) 69 | 70 | # things to do if exception occurs 71 | except Exception as ex: 72 | # save exception to `.log` file 73 | log.exception("") 74 | 75 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 76 | # so when using hparam search plugins like Optuna, you might want to disable 77 | # raising the below exception to avoid multirun failure 78 | raise ex 79 | 80 | # things to always do after either success or exception 81 | finally: 82 | # display output dir path in terminal 83 | log.info(f"Output dir: {cfg.paths.output_dir}") 84 | 85 | # always close wandb run (even if exception occurs so multirun won't fail) 86 | if find_spec("wandb"): # check if wandb is installed 87 | import wandb 88 | 89 | if wandb.run: 90 | log.info("Closing wandb!") 91 | wandb.finish() 92 | 93 | return metric_dict, object_dict 94 | 95 | return wrap 96 | 97 | 98 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: 99 | """Safely retrieves value of the metric logged in LightningModule. 100 | 101 | :param metric_dict: A dict containing metric values. 102 | :param metric_name: If provided, the name of the metric to retrieve. 103 | :return: If a metric name was provided, the value of the metric. 104 | """ 105 | if not metric_name: 106 | log.info("Metric name is None! Skipping metric value retrieval...") 107 | return None 108 | 109 | if metric_name not in metric_dict: 110 | raise Exception( 111 | f"Metric value not found! \n" 112 | "Make sure metric name logged in LightningModule is correct!\n" 113 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 114 | ) 115 | 116 | metric_value = metric_dict[metric_name].item() 117 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 118 | 119 | return metric_value 120 | -------------------------------------------------------------------------------- /pflow_encodec/data/text_latent_dur_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from pflow_encodec.data.tokenizer import TextTokenizer 10 | 11 | 12 | class TextLatentDataset(Dataset): 13 | """Dataset for Voicebox Training, returns text tokens, duration, and pre-quantize encodec latent. 14 | 15 | text_tokens: torch.Tensor, shape (1, T_text) 16 | duration: torch.Tensor, shape (1, T_text) 17 | latent: torch.Tensor, shape (1, T_latent, D_latent) 18 | """ 19 | 20 | def __init__( 21 | self, 22 | tsv_path: str, 23 | add_trailing_silence: bool = True, 24 | mean: float = 0.0, 25 | std: float = 1.0, 26 | min_duration: float = 3.0, 27 | max_duration: float = 15.0, 28 | ): 29 | df = pd.read_csv(tsv_path, sep="\t", engine="pyarrow") 30 | df = df[df["duration"] >= min_duration] 31 | df = df[df["duration"] <= max_duration] 32 | 33 | self.paths = df["audio_path"].tolist() 34 | self.texts = df["text"].tolist() 35 | 36 | self.audio_durations = df["duration"].tolist() 37 | 38 | self.tokenizer = TextTokenizer(add_trailing_silence=add_trailing_silence) 39 | 40 | self.mean = mean 41 | self.std = std 42 | 43 | def __len__(self): 44 | return len(self.paths) 45 | 46 | def __getitem__(self, idx): 47 | path = self.paths[idx] 48 | text = self.texts[idx] 49 | text_tokens = self.tokenizer.encode_text(text).squeeze() 50 | 51 | latent_npy_path = Path(path).with_suffix(".latent.npy") 52 | duration_npy_path = Path(path).with_suffix(".duration.npy") 53 | 54 | latent = torch.from_numpy(np.load(latent_npy_path)).squeeze().unsqueeze(0) 55 | duration = torch.from_numpy(np.load(duration_npy_path)).squeeze().unsqueeze(0) 56 | 57 | # text_tokens = torch.repeat_interleave(text_tokens, duration, dim=0) 58 | if text_tokens.ndim == 1: 59 | text_tokens = text_tokens.unsqueeze(0) 60 | 61 | if text_tokens.shape[-1] != duration.shape[-1]: 62 | raise ValueError( 63 | f"Text token and duration shape mismatch: {text_tokens.shape} != {duration.shape}, path={path}, text={text}" 64 | ) 65 | 66 | latent = (latent - self.mean) / self.std 67 | 68 | return text_tokens, duration, latent 69 | 70 | 71 | class TextLatentLangDataset(Dataset): 72 | """Dataset for Voicebox Training, returns text tokens, duration, and pre-quantize encodec latent. 73 | 74 | text_tokens: torch.Tensor, shape (1, T_text) 75 | duration: torch.Tensor, shape (1, T_text) 76 | latent: torch.Tensor, shape (1, T_latent, D_latent) 77 | """ 78 | 79 | def __init__( 80 | self, 81 | tsv_path: str, 82 | add_trailing_silence: bool = True, 83 | mean: float = 0.0, 84 | std: float = 1.0, 85 | min_duration: float = 3.0, 86 | max_duration: float = 15.0, 87 | languages: list[str] = None, 88 | ): 89 | df = pd.read_csv(tsv_path, sep="\t", engine="pyarrow") 90 | df = df[df["duration"] >= min_duration] 91 | df = df[df["duration"] <= max_duration] 92 | 93 | self.paths = df["audio_path"].tolist() 94 | self.texts = df["text"].tolist() 95 | self.languages = df["lang"].tolist() 96 | 97 | self.audio_durations = df["duration"].tolist() 98 | 99 | self.tokenizer = TextTokenizer(add_trailing_silence=add_trailing_silence) 100 | 101 | self.mean = mean 102 | self.std = std 103 | 104 | if languages is not None: 105 | self.lang2idx = {lang: idx for idx, lang in enumerate(languages)} 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | 110 | def __getitem__(self, idx): 111 | path = self.paths[idx] 112 | text = self.texts[idx] 113 | text_tokens = self.tokenizer.encode_text(text).squeeze() 114 | 115 | latent_npy_path = Path(path).with_suffix(".latent.npy") 116 | duration_npy_path = Path(path).with_suffix(".duration.npy") 117 | 118 | latent = torch.from_numpy(np.load(latent_npy_path)).squeeze().unsqueeze(0) 119 | duration = torch.from_numpy(np.load(duration_npy_path)).squeeze().unsqueeze(0) 120 | 121 | # text_tokens = torch.repeat_interleave(text_tokens, duration, dim=0) 122 | if text_tokens.ndim == 1: 123 | text_tokens = text_tokens.unsqueeze(0) 124 | 125 | if text_tokens.shape[-1] != duration.shape[-1]: 126 | raise ValueError( 127 | f"Text token and duration shape mismatch: {text_tokens.shape} != {duration.shape}, path={path}, text={text}" 128 | ) 129 | 130 | latent = (latent - self.mean) / self.std 131 | 132 | language = self.languages[idx] 133 | lang_id = torch.tensor(self.lang2idx[language]) 134 | 135 | return text_tokens, duration, latent, lang_id 136 | -------------------------------------------------------------------------------- /pflow_encodec/modules/text_enc.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from pflow_encodec.modules.transformer import ( 7 | AlibiPositionalBias, 8 | Transformer, 9 | Wav2Vec2StackedPositionEncoder, 10 | ) 11 | 12 | 13 | class TextEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | vocab_size: int, 17 | dim_text: int, 18 | dim_spk: int, 19 | dim_output: int, 20 | conv_pos_kernel_size: int, 21 | conv_pos_depth: int, 22 | conv_pos_groups: int, 23 | depth: int, 24 | dim: int, 25 | dim_head: int, 26 | heads: int, 27 | ff_mult: float, 28 | attn_dropout: float, 29 | ff_dropout: float, 30 | attn_processor: Literal["naive", "sdpa"] = "naive", 31 | norm_type: Literal["layer", "ada_proj", "ada_embed"] = "layer", 32 | ff_type: Literal["conv", "linear"] = "linear", 33 | ff_kernel_size: Optional[int] = None, 34 | ff_groups: Optional[int] = None, 35 | layer_norm_eps: float = 1e-6, 36 | scale_type: Literal["none", "ada_proj", "ada_embed"] = "none", 37 | ): 38 | super().__init__() 39 | 40 | self.text_emb = nn.Embedding(vocab_size, dim_text) 41 | self.input_proj = nn.Linear(dim_text, dim) 42 | 43 | self.conv_pos = Wav2Vec2StackedPositionEncoder( 44 | depth=conv_pos_depth, 45 | dim=dim, 46 | kernel_size=conv_pos_kernel_size, 47 | groups=conv_pos_groups, 48 | ) 49 | 50 | self.norm_type = norm_type 51 | self.scale_type = scale_type 52 | 53 | if norm_type == "ada_embed": 54 | self.adaln_linear = nn.Linear(dim_spk, dim * 4) 55 | if scale_type == "ada_embed": 56 | self.ada_scale_linear = nn.Linear(dim_spk, dim * 2) 57 | 58 | self.transformer = Transformer( 59 | depth=depth, 60 | dim=dim, 61 | dim_head=dim_head, 62 | heads=heads, 63 | ff_mult=ff_mult, 64 | attn_dropout=attn_dropout, 65 | ff_dropout=ff_dropout, 66 | attn_processor=attn_processor, 67 | norm_type=norm_type, 68 | ff_type=ff_type, 69 | ff_kernel_size=ff_kernel_size, 70 | ff_groups=ff_groups, 71 | layer_norm_eps=layer_norm_eps, 72 | scale_type=scale_type, 73 | dim_cond=dim if norm_type == "ada_embed" else dim_spk, 74 | ) 75 | 76 | self.output_proj = nn.Linear(dim, dim_output) 77 | 78 | self.alibi = AlibiPositionalBias(heads) 79 | 80 | def reset_parameters(self): 81 | self.conv_pos.reset_parameters() 82 | 83 | # init adaln 84 | if self.norm_type == "ada_embed": 85 | nn.init.zeros_(self.adaln_linear.weight) 86 | nn.init.zeros_(self.adaln_linear.bias) 87 | 88 | if self.scale_type == "ada_embed": 89 | nn.init.zeros_(self.ada_scale_linear.weight) 90 | nn.init.zeros_(self.ada_scale_linear.bias) 91 | 92 | self.transformer.reset_adaln_parameters() 93 | 94 | # zero init output proj 95 | nn.init.zeros_(self.output_proj.weight) 96 | nn.init.zeros_(self.output_proj.bias) 97 | 98 | def forward( 99 | self, 100 | text_tokens: torch.Tensor, 101 | spk_emb: torch.Tensor, 102 | padding_mask: Optional[torch.Tensor] = None, 103 | lang_emb: Optional[torch.Tensor] = None, 104 | ): 105 | x = self.input_proj(self.text_emb(text_tokens)) 106 | x = x + self.conv_pos(x, padding_mask) 107 | 108 | cond = spk_emb 109 | if lang_emb is not None: 110 | cond = spk_emb + lang_emb 111 | cond_input = dict() 112 | if self.norm_type == "ada_proj": 113 | cond_input["attn_norm_cond"] = cond 114 | cond_input["ff_norm_cond"] = cond 115 | cond_input["final_norm_cond"] = cond 116 | elif self.norm_type == "ada_embed": 117 | attn_norm_scale, attn_norm_bias, ff_norm_scale, ff_norm_bias = self.adaln_linear(cond).chunk(4, dim=-1) 118 | cond_input["attn_norm_cond"] = torch.cat([attn_norm_scale, attn_norm_bias], dim=-1) 119 | cond_input["ff_norm_cond"] = torch.cat([ff_norm_scale, ff_norm_bias], dim=-1) 120 | cond_input["final_norm_cond"] = cond 121 | 122 | if self.scale_type == "ada_proj": 123 | cond_input["attn_scale_cond"] = cond 124 | cond_input["ff_scale_cond"] = cond 125 | elif self.scale_type == "ada_embed": 126 | attn_scale, ff_scale = self.ada_scale_linear(cond).chunk(2, dim=-1) 127 | cond_input["attn_scale_cond"] = attn_scale 128 | cond_input["ff_scale_cond"] = ff_scale 129 | 130 | seq_len = x.size(1) 131 | bias = self.alibi(seq_len) 132 | x = self.transformer(x, mask=padding_mask, cond_input=cond_input, bias=bias) 133 | return self.output_proj(x), x 134 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | from typing import Any, Dict, Optional 8 | 9 | import pytest 10 | import torch 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | from pytest import MarkDecorator 14 | 15 | from tests.helpers.package_available import ( 16 | _COMET_AVAILABLE, 17 | _DEEPSPEED_AVAILABLE, 18 | _FAIRSCALE_AVAILABLE, 19 | _IS_WINDOWS, 20 | _MLFLOW_AVAILABLE, 21 | _NEPTUNE_AVAILABLE, 22 | _SH_AVAILABLE, 23 | _TPU_AVAILABLE, 24 | _WANDB_AVAILABLE, 25 | ) 26 | 27 | 28 | class RunIf: 29 | """RunIf wrapper for conditional skipping of tests. 30 | 31 | Fully compatible with `@pytest.mark`. 32 | 33 | Example: 34 | 35 | ```python 36 | @RunIf(min_torch="1.8") 37 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 38 | def test_wrapper(arg1): 39 | assert arg1 > 0 40 | ``` 41 | """ 42 | 43 | def __new__( 44 | cls, 45 | min_gpus: int = 0, 46 | min_torch: Optional[str] = None, 47 | max_torch: Optional[str] = None, 48 | min_python: Optional[str] = None, 49 | skip_windows: bool = False, 50 | sh: bool = False, 51 | tpu: bool = False, 52 | fairscale: bool = False, 53 | deepspeed: bool = False, 54 | wandb: bool = False, 55 | neptune: bool = False, 56 | comet: bool = False, 57 | mlflow: bool = False, 58 | **kwargs: Dict[Any, Any], 59 | ) -> MarkDecorator: 60 | """Creates a new `@RunIf` `MarkDecorator` decorator. 61 | 62 | :param min_gpus: Min number of GPUs required to run test. 63 | :param min_torch: Minimum pytorch version to run test. 64 | :param max_torch: Maximum pytorch version to run test. 65 | :param min_python: Minimum python version required to run test. 66 | :param skip_windows: Skip test for Windows platform. 67 | :param tpu: If TPU is available. 68 | :param sh: If `sh` module is required to run the test. 69 | :param fairscale: If `fairscale` module is required to run the test. 70 | :param deepspeed: If `deepspeed` module is required to run the test. 71 | :param wandb: If `wandb` module is required to run the test. 72 | :param neptune: If `neptune` module is required to run the test. 73 | :param comet: If `comet` module is required to run the test. 74 | :param mlflow: If `mlflow` module is required to run the test. 75 | :param kwargs: Native `pytest.mark.skipif` keyword arguments. 76 | """ 77 | conditions = [] 78 | reasons = [] 79 | 80 | if min_gpus: 81 | conditions.append(torch.cuda.device_count() < min_gpus) 82 | reasons.append(f"GPUs>={min_gpus}") 83 | 84 | if min_torch: 85 | torch_version = get_distribution("torch").version 86 | conditions.append(Version(torch_version) < Version(min_torch)) 87 | reasons.append(f"torch>={min_torch}") 88 | 89 | if max_torch: 90 | torch_version = get_distribution("torch").version 91 | conditions.append(Version(torch_version) >= Version(max_torch)) 92 | reasons.append(f"torch<{max_torch}") 93 | 94 | if min_python: 95 | py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 96 | conditions.append(Version(py_version) < Version(min_python)) 97 | reasons.append(f"python>={min_python}") 98 | 99 | if skip_windows: 100 | conditions.append(_IS_WINDOWS) 101 | reasons.append("does not run on Windows") 102 | 103 | if tpu: 104 | conditions.append(not _TPU_AVAILABLE) 105 | reasons.append("TPU") 106 | 107 | if sh: 108 | conditions.append(not _SH_AVAILABLE) 109 | reasons.append("sh") 110 | 111 | if fairscale: 112 | conditions.append(not _FAIRSCALE_AVAILABLE) 113 | reasons.append("fairscale") 114 | 115 | if deepspeed: 116 | conditions.append(not _DEEPSPEED_AVAILABLE) 117 | reasons.append("deepspeed") 118 | 119 | if wandb: 120 | conditions.append(not _WANDB_AVAILABLE) 121 | reasons.append("wandb") 122 | 123 | if neptune: 124 | conditions.append(not _NEPTUNE_AVAILABLE) 125 | reasons.append("neptune") 126 | 127 | if comet: 128 | conditions.append(not _COMET_AVAILABLE) 129 | reasons.append("comet") 130 | 131 | if mlflow: 132 | conditions.append(not _MLFLOW_AVAILABLE) 133 | reasons.append("mlflow") 134 | 135 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 136 | return pytest.mark.skipif( 137 | condition=any(conditions), 138 | reason=f"Requires: [{' + '.join(reasons)}]", 139 | **kwargs, 140 | ) 141 | -------------------------------------------------------------------------------- /pflow_encodec/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | import torch 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning.pytorch.loggers import Logger 9 | from omegaconf import DictConfig 10 | 11 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 12 | # ------------------------------------------------------------------------------------ # 13 | # the setup_root above is equivalent to: 14 | # - adding project root dir to PYTHONPATH 15 | # (so you don't need to force user to install project as a package) 16 | # (necessary before importing any local modules e.g. `from src import utils`) 17 | # - setting up PROJECT_ROOT environment variable 18 | # (which is used as a base for paths in "configs/paths/default.yaml") 19 | # (this way all filepaths are the same no matter where you run the code) 20 | # - loading environment variables from ".env" in root dir 21 | # 22 | # you can remove it if you: 23 | # 1. either install project as a package or move entry files to project root dir 24 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 25 | # 26 | # more info: https://github.com/ashleve/rootutils 27 | # ------------------------------------------------------------------------------------ # 28 | 29 | from pflow_encodec.utils import ( 30 | RankedLogger, 31 | extras, 32 | get_metric_value, 33 | instantiate_callbacks, 34 | instantiate_loggers, 35 | log_hyperparameters, 36 | task_wrapper, 37 | ) 38 | 39 | log = RankedLogger(__name__, rank_zero_only=True) 40 | 41 | 42 | @task_wrapper 43 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 44 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. 45 | 46 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 47 | failure. Useful for multiruns, saving info about the crash, etc. 48 | 49 | :param cfg: A DictConfig configuration composed by Hydra. 50 | :return: A tuple with metrics and dict with all instantiated objects. 51 | """ 52 | # set seed for random number generators in pytorch, numpy and python.random 53 | torch.set_float32_matmul_precision("high") 54 | torch.backends.cuda.matmul.allow_tf32 = True 55 | torch.backends.cudnn.allow_tf32 = True 56 | 57 | if cfg.get("seed"): 58 | L.seed_everything(cfg.seed, workers=True) 59 | 60 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 61 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 62 | 63 | log.info(f"Instantiating model <{cfg.model._target_}>") 64 | model: LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) 65 | 66 | log.info("Instantiating callbacks...") 67 | callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 68 | 69 | log.info("Instantiating loggers...") 70 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 71 | 72 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 73 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 74 | 75 | object_dict = { 76 | "cfg": cfg, 77 | "datamodule": datamodule, 78 | "model": model, 79 | "callbacks": callbacks, 80 | "logger": logger, 81 | "trainer": trainer, 82 | } 83 | 84 | if logger: 85 | log.info("Logging hyperparameters!") 86 | log_hyperparameters(object_dict) 87 | 88 | if cfg.get("train"): 89 | log.info("Starting training!") 90 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 91 | 92 | train_metrics = trainer.callback_metrics 93 | 94 | if cfg.get("test"): 95 | log.info("Starting testing!") 96 | ckpt_path = trainer.checkpoint_callback.best_model_path 97 | if ckpt_path == "": 98 | log.warning("Best ckpt not found! Using current weights for testing...") 99 | ckpt_path = None 100 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 101 | log.info(f"Best ckpt path: {ckpt_path}") 102 | 103 | test_metrics = trainer.callback_metrics 104 | 105 | # merge train and test metrics 106 | metric_dict = {**train_metrics, **test_metrics} 107 | 108 | return metric_dict, object_dict 109 | 110 | 111 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 112 | def main(cfg: DictConfig) -> Optional[float]: 113 | """Main entry point for training. 114 | 115 | :param cfg: DictConfig configuration composed by Hydra. 116 | :return: Optional[float] with optimized metric value. 117 | """ 118 | # apply extra utilities 119 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 120 | extras(cfg) 121 | 122 | # train the model 123 | metric_dict, _ = train(cfg) 124 | 125 | # safely retrieve metric value for hydra-based hyperparameter optimization 126 | metric_value = get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 127 | 128 | # return optimized metric 129 | return metric_value 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /pflow_encodec/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from pflow_encodec.utils.pylogger import RankedLogger 7 | 8 | logger = RankedLogger(__name__) 9 | 10 | 11 | class DistributedBucketSampler(DistributedSampler): 12 | """Distributed Bucket Sampler for dynamic batching. 13 | 14 | it gathers samples with similar length into a batch. each batch comes from a single bucket. 15 | bucket[i] contains samples with length in (boundaries[i], boundaries[i+1]]. samples with length < first bucket and 16 | length > last bucket will be discarded. 17 | Args: 18 | dataset: dataset to sample from. it should have a lengths attribute 19 | batch_durations: number of frames in a batch 20 | boundaries: a list of boundaries for bucketing, samples with length in (boundaries[i], boundaries[i+1]] 21 | will be put into bucket i. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | dataset, 27 | batch_durations: float, 28 | boundaries: List[float], 29 | num_replicas=None, 30 | rank=None, 31 | shuffle: bool = True, 32 | drop_last: bool = True, 33 | ): 34 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last) 35 | 36 | self.durations = dataset.audio_durations 37 | self.batch_durations = batch_durations 38 | self.boundaries = boundaries 39 | 40 | self.buckets = self._create_bucket() 41 | logger.info(f"Created {len(self.buckets)} buckets") 42 | logger.info(f"Boundaries: {self.boundaries}") 43 | bucket_sizes = [len(bucket) for bucket in self.buckets] 44 | logger.info(f"Bucket sizes: {bucket_sizes}") 45 | self.batches = self._create_batches() 46 | 47 | def _bisect(self, x, lo=0, hi=None): 48 | if hi is None: 49 | hi = len(self.boundaries) - 1 50 | 51 | if hi > lo: 52 | mid = (hi + lo) // 2 53 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 54 | return mid 55 | elif x <= self.boundaries[mid]: 56 | return self._bisect(x, lo, mid) 57 | else: 58 | return self._bisect(x, mid + 1, hi) 59 | else: 60 | return -1 61 | 62 | def _create_bucket(self): 63 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 64 | for i in range(len(self.durations)): 65 | length = self.durations[i] 66 | idx_bucket = self._bisect(length) 67 | if idx_bucket != -1: 68 | buckets[idx_bucket].append(i) 69 | 70 | for i in range(len(buckets) - 1, 0, -1): 71 | if len(buckets[i]) == 0: 72 | buckets.pop(i) 73 | self.boundaries.pop(i + 1) 74 | 75 | return buckets 76 | 77 | def _create_batches(self): 78 | g = torch.Generator() 79 | g.manual_seed(self.epoch) 80 | buckets = self.buckets 81 | indices = [] 82 | if self.shuffle: 83 | for bucket in buckets: 84 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 85 | else: 86 | for bucket in buckets: 87 | indices.append(list(range(len(bucket)))) 88 | batches = [] 89 | for bucket_id, bucket in enumerate(buckets): 90 | bucket_indices = indices[bucket_id] 91 | bucket_batches = [] 92 | current_batch = [] 93 | durations = 0 94 | # since we can not guarantee every process has the same number of batches, iterate all samples to generate batches 95 | for bucket_idx in bucket_indices: 96 | sample_idx = bucket[bucket_idx] 97 | sample_duration = self.durations[sample_idx] 98 | if durations + sample_duration > self.batch_durations and current_batch: 99 | bucket_batches.append(current_batch) 100 | current_batch = [] 101 | durations = 0 102 | durations += sample_duration 103 | current_batch.append(sample_idx) 104 | 105 | if not bucket_batches: 106 | # there's no batch made, just append the current batch 107 | bucket_batches.append(current_batch) 108 | elif current_batch and not self.drop_last: 109 | # there's still samples left in the current batch 110 | bucket_batches.append(current_batch) 111 | if len(bucket_batches) % self.num_replicas != 0: 112 | # the number of batches should be a multiple of num_replicas, duplicate random batches to make it so 113 | remainder = self.num_replicas - (len(bucket_batches) % self.num_replicas) 114 | assert remainder > 0 115 | for _ in range(remainder): 116 | random_idx = torch.randint(0, len(bucket_batches), (1,), generator=g).item() 117 | bucket_batches.append(bucket_batches[random_idx]) 118 | batches.extend(bucket_batches) 119 | assert len(batches) % self.num_replicas == 0 120 | if self.shuffle: 121 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 122 | batches = [batches[i] for i in batch_ids] 123 | return batches 124 | 125 | def __iter__(self): 126 | """# of batches should be multiple of num_replicas""" 127 | self.batches = self._create_batches() 128 | return iter(self.batches) 129 | 130 | def __len__(self): 131 | return len(self.batches) 132 | -------------------------------------------------------------------------------- /pflow_encodec/modules/flow_matching.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Literal, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | 8 | from pflow_encodec.modules.transformer import ( 9 | AlibiPositionalBias, 10 | Transformer, 11 | Wav2Vec2StackedPositionEncoder, 12 | ) 13 | 14 | 15 | class TimestepEmbedder(nn.Module): 16 | def __init__(self, dim: int, dim_time: int, max_period: int = 10000): 17 | super().__init__() 18 | self.dim = dim 19 | self.max_period = max_period 20 | half = dim // 2 21 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) 22 | self.register_buffer("freqs", freqs) 23 | self.net = nn.Sequential( 24 | nn.Linear(dim, dim_time), 25 | nn.SiLU(), 26 | nn.Linear(dim_time, dim_time), 27 | ) 28 | 29 | def forward(self, t: torch.Tensor): 30 | dtype = self.freqs.dtype 31 | args = t[:, None].float() * self.freqs[None] 32 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 33 | if self.dim % 2: 34 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 35 | embedding = embedding.to(dtype=dtype) 36 | return self.net(embedding) 37 | 38 | 39 | class FlowMatchingTransformer(nn.Module): 40 | def __init__( 41 | self, 42 | dim_input: int, 43 | dim_ctx: int, 44 | dim_output: int, 45 | dim_time: int, 46 | conv_pos_kernel_size: int, 47 | conv_pos_depth: int, 48 | conv_pos_groups: int, 49 | depth: int, 50 | dim: int, 51 | dim_head: int, 52 | heads: int, 53 | ff_mult: float, 54 | attn_dropout: float, 55 | ff_dropout: float, 56 | attn_processor: Literal["naive", "sdpa"] = "naive", 57 | norm_type: Literal["layer", "ada_proj", "ada_embed"] = "layer", 58 | ff_type: Literal["conv", "linear"] = "linear", 59 | ff_kernel_size: Optional[int] = None, 60 | ff_groups: Optional[int] = None, 61 | layer_norm_eps: float = 1e-6, 62 | scale_type: Literal["none", "ada_proj", "ada_embed"] = "none", 63 | ): 64 | super().__init__() 65 | 66 | self.input_proj = nn.Linear(dim_input + dim_ctx, dim) 67 | 68 | self.time_embed = TimestepEmbedder(dim, dim_time) 69 | self.conv_pos = Wav2Vec2StackedPositionEncoder( 70 | depth=conv_pos_depth, 71 | dim=dim, 72 | kernel_size=conv_pos_kernel_size, 73 | groups=conv_pos_groups, 74 | ) 75 | 76 | self.norm_type = norm_type 77 | self.scale_type = scale_type 78 | 79 | if norm_type == "ada_embed": 80 | self.adaln_linear = nn.Linear(dim_time, dim * 4) 81 | if scale_type == "ada_embed": 82 | self.ada_scale_linear = nn.Linear(dim_time, dim * 2) 83 | 84 | self.transformer = Transformer( 85 | depth=depth, 86 | dim=dim, 87 | dim_head=dim_head, 88 | heads=heads, 89 | ff_mult=ff_mult, 90 | attn_dropout=attn_dropout, 91 | ff_dropout=ff_dropout, 92 | attn_processor=attn_processor, 93 | norm_type=norm_type, 94 | ff_type=ff_type, 95 | ff_kernel_size=ff_kernel_size, 96 | ff_groups=ff_groups, 97 | layer_norm_eps=layer_norm_eps, 98 | scale_type=scale_type, 99 | dim_cond=dim if norm_type == "ada_embed" else dim_time, 100 | dim_final_norm_cond=dim_time if norm_type == "ada_embed" else None, 101 | use_skip_connection=True, 102 | ) 103 | 104 | self.output_proj = nn.Linear(dim, dim_output) 105 | 106 | self.alibi = AlibiPositionalBias(heads) 107 | 108 | def reset_parameters(self): 109 | self.conv_pos.reset_parameters() 110 | nn.init.trunc_normal_(self.time_embed.net[0].weight, std=0.02) 111 | nn.init.zeros_(self.time_embed.net[0].bias) 112 | nn.init.trunc_normal_(self.time_embed.net[2].weight, std=0.02) 113 | nn.init.zeros_(self.time_embed.net[2].bias) 114 | 115 | # init adaln 116 | if self.norm_type == "ada_embed": 117 | nn.init.zeros_(self.adaln_linear.weight) 118 | nn.init.zeros_(self.adaln_linear.bias) 119 | 120 | if self.scale_type == "ada_embed": 121 | nn.init.zeros_(self.ada_scale_linear.weight) 122 | nn.init.zeros_(self.ada_scale_linear.bias) 123 | 124 | self.transformer.reset_adaln_parameters() 125 | 126 | # zero init output proj 127 | nn.init.zeros_(self.output_proj.weight) 128 | nn.init.zeros_(self.output_proj.bias) 129 | 130 | def forward( 131 | self, 132 | x: torch.Tensor, 133 | x_ctx: torch.Tensor, 134 | times: torch.Tensor, 135 | padding_mask: Optional[torch.Tensor] = None, 136 | drop_ctx: Optional[torch.Tensor] = None, 137 | ) -> torch.Tensor: 138 | # apply dropout to context 139 | x_ctx = torch.where(rearrange(drop_ctx, "b -> b 1 1"), 0, x_ctx) 140 | 141 | x = torch.cat([x, x_ctx], dim=-1) 142 | x = self.input_proj(x) 143 | 144 | x = x + self.conv_pos(x, mask=padding_mask) 145 | cond = self.time_embed(times).unsqueeze(1) 146 | 147 | cond_input = dict() 148 | if self.norm_type == "ada_proj": 149 | cond_input["attn_norm_cond"] = cond 150 | cond_input["ff_norm_cond"] = cond 151 | cond_input["final_norm_cond"] = cond 152 | elif self.norm_type == "ada_embed": 153 | attn_norm_scale, attn_norm_bias, ff_norm_scale, ff_norm_bias = self.adaln_linear(cond).chunk(4, dim=-1) 154 | cond_input["attn_norm_cond"] = torch.cat([attn_norm_scale, attn_norm_bias], dim=-1) 155 | cond_input["ff_norm_cond"] = torch.cat([ff_norm_scale, ff_norm_bias], dim=-1) 156 | cond_input["final_norm_cond"] = cond 157 | 158 | if self.scale_type == "ada_proj": 159 | cond_input["attn_scale_cond"] = cond 160 | cond_input["ff_scale_cond"] = cond 161 | elif self.scale_type == "ada_embed": 162 | attn_scale, ff_scale = self.ada_scale_linear(cond).chunk(2, dim=-1) 163 | cond_input["attn_scale_cond"] = attn_scale 164 | cond_input["ff_scale_cond"] = ff_scale 165 | 166 | seq_len = x.size(1) 167 | bias = self.alibi(seq_len) 168 | x = self.transformer(x, mask=padding_mask, cond_input=cond_input, bias=bias) 169 | 170 | return self.output_proj(x) 171 | -------------------------------------------------------------------------------- /pflow_encodec/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import lightning as L 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from pflow_encodec.data.sampler import DistributedBucketSampler 9 | from pflow_encodec.data.text_latent_dur_dataset import ( 10 | TextLatentDataset, 11 | TextLatentLangDataset, 12 | ) 13 | 14 | 15 | class TextLatentLightningDataModule(L.LightningDataModule): 16 | def __init__( 17 | self, 18 | train_tsv_path: str, 19 | val_tsv_path: str, 20 | add_trailing_silence: bool = True, 21 | batch_durations: float = 50.0, 22 | min_duration: float = 3.0, 23 | max_duration: float = 15.0, 24 | boundaries: List[float] = [3.0, 5.0, 7.0, 10.0, 15.0], 25 | num_workers: int = 4, 26 | return_upsampled: bool = False, 27 | max_frame: int = 1500, 28 | text2latent_rate: float = 1.5, 29 | mean: float = -0.5444963574409485, 30 | std: float = 5.242217063903809, 31 | use_lang_id: bool = False, 32 | languages: Optional[List[str]] = None, 33 | ): 34 | super().__init__() 35 | self.train_tsv_path = train_tsv_path 36 | self.val_tsv_path = val_tsv_path 37 | 38 | self.train_ds: TextLatentDataset = None 39 | self.val_ds: TextLatentDataset = None 40 | 41 | self.add_trailing_silence = add_trailing_silence 42 | self.pad_idx = None 43 | 44 | self.batch_durations = batch_durations 45 | self.min_duration = min_duration 46 | self.max_duration = max_duration 47 | self.boundaries = boundaries 48 | self.num_workers = num_workers 49 | self.return_upsampled = return_upsampled 50 | # do not use return_upsampled 51 | assert not self.return_upsampled, "return_upsampled is not supported" 52 | 53 | self.max_frame = max_frame 54 | self.text2latent_rate = text2latent_rate 55 | 56 | self.mean = mean 57 | self.std = std 58 | self.use_lang_id = use_lang_id 59 | if languages is not None: 60 | self.languages = languages 61 | 62 | def setup(self, stage: str): 63 | if stage != "fit": 64 | raise ValueError(f"Stage {stage} is not supported") 65 | dataset_cls = TextLatentLangDataset if self.use_lang_id else TextLatentDataset 66 | aux_dict = {} 67 | if self.use_lang_id: 68 | if self.languages is None: 69 | raise ValueError("languages must be provided when use_lang_id is True") 70 | aux_dict["languages"] = self.languages 71 | self.train_ds = dataset_cls( 72 | self.train_tsv_path, 73 | add_trailing_silence=self.add_trailing_silence, 74 | mean=self.mean, 75 | std=self.std, 76 | min_duration=self.min_duration, 77 | max_duration=self.max_duration, 78 | **aux_dict, 79 | ) 80 | self.val_ds = dataset_cls( 81 | self.val_tsv_path, 82 | add_trailing_silence=self.add_trailing_silence, 83 | mean=self.mean, 84 | std=self.std, 85 | min_duration=self.min_duration, 86 | max_duration=self.max_duration, 87 | **aux_dict, 88 | ) 89 | 90 | self.pad_idx = self.train_ds.tokenizer.pad_idx 91 | 92 | def prepare_data(self): 93 | if not os.path.exists(self.train_tsv_path): 94 | raise FileNotFoundError(f"File {self.train_tsv_path} does not exist") 95 | if not os.path.exists(self.val_tsv_path): 96 | raise FileNotFoundError(f"File {self.val_tsv_path} does not exist") 97 | 98 | def _collate(self, batch): 99 | result = {} 100 | if self.use_lang_id: 101 | text_tokens, durations, latents, languages = map(list, zip(*batch)) 102 | lang_ids = torch.stack([lang for lang in languages]) 103 | result["lang_ids"] = lang_ids 104 | else: 105 | text_tokens, durations, latents = map(list, zip(*batch)) 106 | # used for training AudioModel 107 | for t, d in zip(text_tokens, durations): 108 | if t.shape[-1] != d.shape[-1]: 109 | raise ValueError(f"Text token and duration shape mismatch: {t.shape} != {d.shape}") 110 | max_text_len = max(t.shape[-1] for t in text_tokens) 111 | text_token_lens = torch.tensor([t.shape[-1] for t in text_tokens]) 112 | text_tokens = torch.cat( 113 | [torch.nn.functional.pad(t, (0, max_text_len - t.shape[-1]), value=self.pad_idx) for t in text_tokens], 114 | dim=0, 115 | ) 116 | 117 | max_duration_len = max(d.shape[-1] for d in durations) 118 | duration_lens = torch.tensor([d.shape[-1] for d in durations]) 119 | durations = torch.cat( 120 | [torch.nn.functional.pad(d, (0, max_duration_len - d.shape[-1]), value=0) for d in durations], 121 | dim=0, 122 | ) 123 | 124 | max_latent_len = max(latent.shape[-2] for latent in latents) 125 | latent_lens = torch.tensor([latent.shape[-2] for latent in latents]) 126 | latents = torch.cat( 127 | [ 128 | torch.nn.functional.pad(latent, (0, 0, 0, max_latent_len - latent.shape[-2]), value=0) 129 | for latent in latents 130 | ], 131 | dim=0, 132 | ) 133 | result["text_tokens"] = text_tokens 134 | result["text_token_lens"] = text_token_lens 135 | result["durations"] = durations 136 | result["duration_lens"] = duration_lens 137 | result["latents"] = latents 138 | result["latent_lens"] = latent_lens 139 | return result 140 | 141 | def train_dataloader(self): 142 | world_size = 1 if not torch.distributed.is_initialized() else None 143 | rank = 0 if not torch.distributed.is_initialized() else None 144 | sampler = DistributedBucketSampler( 145 | self.train_ds, 146 | batch_durations=self.batch_durations, 147 | boundaries=self.boundaries, 148 | shuffle=True, 149 | num_replicas=world_size, 150 | rank=rank, 151 | ) 152 | return DataLoader( 153 | self.train_ds, 154 | batch_sampler=sampler, 155 | num_workers=self.num_workers, 156 | collate_fn=self._collate, 157 | pin_memory=True, 158 | ) 159 | 160 | def val_dataloader(self): 161 | world_size = 1 if not torch.distributed.is_initialized() else None 162 | rank = 0 if not torch.distributed.is_initialized() else None 163 | sampler = DistributedBucketSampler( 164 | self.val_ds, 165 | batch_durations=self.batch_durations, 166 | boundaries=self.boundaries, 167 | num_replicas=world_size, 168 | rank=rank, 169 | drop_last=False, 170 | shuffle=True, 171 | ) 172 | return DataLoader( 173 | self.val_ds, 174 | batch_sampler=sampler, 175 | num_workers=self.num_workers, 176 | collate_fn=self._collate, 177 | shuffle=False, 178 | ) 179 | -------------------------------------------------------------------------------- /notebooks/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "from tqdm.auto import tqdm\n", 22 | "\n", 23 | "from pflow_encodec.modules.spk_enc import SpeakerEncoder\n", 24 | "from pflow_encodec.modules.transformer import Transformer" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "text_encoder = Transformer(\n", 34 | " depth=6,\n", 35 | " dim=192,\n", 36 | " dim_head=96,\n", 37 | " heads=2,\n", 38 | " ff_mult=4.0,\n", 39 | " attn_dropout=0.1,\n", 40 | " ff_dropout=0.0,\n", 41 | " norm_type=\"ada_embed\",\n", 42 | " ff_type=\"conv\",\n", 43 | " ff_kernel_size=9,\n", 44 | " ff_groups=4,\n", 45 | " scale_type=\"ada_embed\",\n", 46 | " dim_cond=192,\n", 47 | ")\n", 48 | "cond_linear = nn.Linear(192, 192 * 6)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "x = torch.randn(1, 64, 192)\n", 58 | "cond = torch.randn(1, 1, 192)\n", 59 | "attn_norm_scale, attn_norm_bias, attn_scale, ff_norm_scale, ff_norm_bias, ff_scale = cond_linear(cond).chunk(6, dim=-1)\n", 60 | "cond_input = {\n", 61 | " \"attn_norm_cond\": torch.cat([attn_norm_scale, attn_norm_bias], dim=-1),\n", 62 | " \"attn_scale_cond\": attn_scale,\n", 63 | " \"ff_norm_cond\": torch.cat([ff_norm_scale, ff_norm_bias], dim=-1),\n", 64 | " \"ff_scale_cond\": ff_scale,\n", 65 | " \"final_norm_cond\": cond,\n", 66 | "}\n", 67 | "out = text_encoder(x, cond_input=cond_input)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "sum(p.numel() for p in text_encoder.parameters()) / 1e6" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "decoder = Transformer(\n", 86 | " depth=12,\n", 87 | " dim=512,\n", 88 | " dim_head=64,\n", 89 | " heads=8,\n", 90 | " ff_mult=4.0,\n", 91 | " attn_dropout=0.1,\n", 92 | " ff_dropout=0.0,\n", 93 | " norm_type=\"ada_embed\",\n", 94 | " ff_type=\"conv\",\n", 95 | " ff_kernel_size=3,\n", 96 | " ff_groups=4,\n", 97 | " scale_type=\"ada_embed\",\n", 98 | " dim_cond=512,\n", 99 | ")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "sum(p.numel() for p in decoder.parameters()) / 1e6" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "spk_encoder = SpeakerEncoder(\n", 118 | " dim_input=128,\n", 119 | " depth=2,\n", 120 | " dim=192,\n", 121 | " dim_head=96,\n", 122 | " heads=2,\n", 123 | " ff_mult=4.0,\n", 124 | " attn_dropout=0.1,\n", 125 | " ff_dropout=0.0,\n", 126 | " norm_type=\"layer\",\n", 127 | " ff_type=\"conv\",\n", 128 | " ff_kernel_size=9,\n", 129 | " ff_groups=4,\n", 130 | " scale_type=\"none\",\n", 131 | ")" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "sum(p.numel() for p in spk_encoder.parameters()) / 1e6" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "prompt = torch.randn(1, 225, 128)\n", 150 | "spk_encoder(prompt).shape" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "from pflow_encodec.data.datamodule import TextLatentLightningDataModule\n", 160 | "\n", 161 | "dm = TextLatentLightningDataModule(\n", 162 | " train_tsv_path=\"/home/seastar105/datasets/libritts_r/train_duration.tsv\",\n", 163 | " val_tsv_path=\"/home/seastar105/datasets/libritts_r/dev_duration.tsv\",\n", 164 | " num_workers=8,\n", 165 | " return_upsampled=False,\n", 166 | ")\n", 167 | "dm.setup(\"fit\")\n", 168 | "dl = dm.train_dataloader()" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "from pflow_encodec.models.pflow import PFlow\n", 178 | "\n", 179 | "model = PFlow()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "batch = next(iter(dl))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "text_tokens, text_token_lens, durations, duration_lens, latents, latent_lens = batch" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "import torch\n", 207 | "\n", 208 | "\n", 209 | "def slice_segments(x, ids_str, segment_size=4):\n", 210 | " ret = torch.zeros_like(x[:, :segment_size, :])\n", 211 | " for i in range(x.size(0)):\n", 212 | " idx_str = ids_str[i]\n", 213 | " idx_end = idx_str + segment_size\n", 214 | " ret[i] = x[i, idx_str:idx_end, :]\n", 215 | " return ret\n", 216 | "\n", 217 | "\n", 218 | "def rand_slice_segments(x, x_lengths=None, segment_size=4):\n", 219 | " b, t, d = x.size()\n", 220 | " if x_lengths is None:\n", 221 | " x_lengths = t\n", 222 | " ids_str_max = x_lengths - segment_size + 1\n", 223 | " ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)\n", 224 | " ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to(dtype=torch.long)\n", 225 | " ret = slice_segments(x, ids_str, segment_size)\n", 226 | " mask = torch.arange(t, device=x.device).expand(b, t) >= ids_str.unsqueeze(1)\n", 227 | " mask &= torch.arange(t, device=x.device).expand(b, t) < (ids_str + segment_size).unsqueeze(1)\n", 228 | " return ret, mask" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "prompts, prompt_masks = rand_slice_segments(latents, latent_lens, segment_size=225)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "prompts.shape" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "model(text_tokens, text_token_lens, durations, duration_lens, latents, latent_lens, prompts, prompt_masks)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [] 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "pflow-encodec", 269 | "language": "python", 270 | "name": "python3" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 3 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython3", 282 | "version": "3.10.13" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 2 287 | } 288 | -------------------------------------------------------------------------------- /pflow_encodec/models/lightning_modules/pflow.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Mapping, Optional 2 | 3 | import hydra 4 | import lightning as L 5 | import torch 6 | from audiotools import AudioSignal 7 | from lightning.pytorch.utilities import grad_norm 8 | from omegaconf import DictConfig 9 | 10 | from pflow_encodec.data.tokenizer import EncodecTokenizer 11 | from pflow_encodec.models.pflow import PFlow 12 | from pflow_encodec.utils.pylogger import RankedLogger 13 | 14 | logger = RankedLogger(__name__) 15 | 16 | 17 | class PFlowLightningModule(L.LightningModule): 18 | def __init__( 19 | self, 20 | net: DictConfig, 21 | optimizer: DictConfig, 22 | scheduler: DictConfig, 23 | prompt_length: int = 225, 24 | sample_freq: int = 10000, 25 | sample_idx: List[int] = [], 26 | mean: float = 0.0, 27 | std: float = 1.0, 28 | text2latent_ratio: float = 1.5, 29 | net_ckpt_path: Optional[str] = None, 30 | languages: Optional[List[str]] = None, 31 | max_lang_loss: Optional[float] = None, 32 | ): 33 | super().__init__() 34 | self.save_hyperparameters(logger=False) 35 | 36 | self.save_hyperparameters() 37 | 38 | self.net: PFlow = hydra.utils.instantiate(net, _recursive_=False) 39 | 40 | self.optimizer_cfg = optimizer 41 | self.scheduler_cfg = scheduler 42 | 43 | self.prompt_length = prompt_length 44 | 45 | self.sample_freq = sample_freq 46 | self.sample_idx = sample_idx 47 | 48 | self.first_sample = True 49 | self.codec = [EncodecTokenizer(device="cpu")] # avoid move codec to gpu for memory reduction 50 | 51 | self.mean = mean 52 | self.std = std 53 | self.text2latent_ratio = text2latent_ratio 54 | 55 | if languages is not None: 56 | self.languages = languages 57 | self.lang2idx = {lang: idx for idx, lang in enumerate(languages)} 58 | self.max_lang_loss = max_lang_loss 59 | 60 | if net_ckpt_path is not None: 61 | logger.info(f"Loading model from {net_ckpt_path}") 62 | missing, unexpected = self.net.load_state_dict( 63 | torch.load(net_ckpt_path, map_location="cpu")["state_dict"], strict=False 64 | ) 65 | if missing: 66 | logger.warning(f"Missing keys: {missing}") 67 | 68 | if unexpected: 69 | logger.warning(f"Unexpected keys: {unexpected}") 70 | 71 | def configure_optimizers(self): 72 | optimizer = hydra.utils.instantiate(self.optimizer_cfg, params=self.net.parameters()) 73 | if self.scheduler_cfg is None: 74 | return [optimizer], [] 75 | scheduler = hydra.utils.instantiate(self.scheduler_cfg, optimizer=optimizer) 76 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 77 | 78 | def on_before_optimizer_step(self, optimizer): 79 | self.log( 80 | "other/grad_norm", 81 | grad_norm(self.net, norm_type=2)["grad_2.0_norm_total"], 82 | on_step=True, 83 | on_epoch=False, 84 | prog_bar=True, 85 | logger=True, 86 | ) 87 | 88 | def get_prompt(self, latents: torch.Tensor, latent_lens: torch.Tensor): 89 | b, t, d = latents.shape 90 | max_start = latent_lens - self.prompt_length 91 | start_idx = (torch.rand((b,), device=latents.device) * max_start).long().clamp(min=0) 92 | prompts = torch.zeros((b, self.prompt_length, d), device=latents.device, dtype=latents.dtype) 93 | for i in range(latents.shape[0]): 94 | prompts[i] = latents[i, start_idx[i] : start_idx[i] + self.prompt_length] 95 | 96 | max_len = latent_lens.max() 97 | prompt_mask = torch.arange(max_len, device=latent_lens.device).expand(latent_lens.shape[0], -1) < ( 98 | start_idx.unsqueeze(1) + self.prompt_length 99 | ) 100 | prompt_mask &= torch.arange(max_len, device=latent_lens.device).expand( 101 | latent_lens.shape[0], -1 102 | ) >= start_idx.unsqueeze(1) 103 | return prompts, prompt_mask 104 | 105 | def get_input(self, batch): 106 | text_tokens = batch["text_tokens"] 107 | text_token_lens = batch["text_token_lens"] 108 | durations = batch["durations"] 109 | duration_lens = batch["duration_lens"] 110 | latents = batch["latents"] 111 | latent_lens = batch["latent_lens"] 112 | prompts, prompt_masks = self.get_prompt(latents, latent_lens) 113 | lang_ids = None 114 | if "lang_ids" in batch: 115 | lang_ids = batch["lang_ids"] 116 | return ( 117 | text_tokens, 118 | text_token_lens, 119 | durations, 120 | duration_lens, 121 | latents, 122 | latent_lens, 123 | prompts, 124 | prompt_masks, 125 | lang_ids, 126 | ) 127 | 128 | def training_step(self, batch, batch_idx): 129 | ( 130 | text_tokens, 131 | text_token_lens, 132 | durations, 133 | duration_lens, 134 | latents, 135 | latent_lens, 136 | prompts, 137 | prompt_masks, 138 | lang_ids, 139 | ) = self.get_input(batch) 140 | duration_loss, enc_loss, flow_matching_loss, lang_loss = self.net( 141 | text_tokens, 142 | text_token_lens, 143 | durations, 144 | duration_lens, 145 | latents, 146 | latent_lens, 147 | prompts, 148 | prompt_masks, 149 | lang_ids=lang_ids, 150 | ) 151 | 152 | self.log("train/enc_loss", enc_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 153 | self.log("train/duration_loss", duration_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 154 | self.log( 155 | "train/flow_matching_loss", flow_matching_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 156 | ) 157 | self.log( 158 | "train/latent_loss", enc_loss + flow_matching_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 159 | ) 160 | loss = enc_loss + duration_loss + flow_matching_loss 161 | self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 162 | 163 | if lang_loss is not None: 164 | # just for check 165 | self.log("train/lang_loss", lang_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 166 | 167 | if self.global_step % self.sample_freq == 0: 168 | self.log_audio() 169 | 170 | return loss 171 | 172 | def validation_step(self, batch, batch_idx): 173 | ( 174 | text_tokens, 175 | text_token_lens, 176 | durations, 177 | duration_lens, 178 | latents, 179 | latent_lens, 180 | prompts, 181 | prompt_masks, 182 | lang_ids, 183 | ) = self.get_input(batch) 184 | duration_loss, enc_loss, flow_matching_loss, lang_loss = self.net( 185 | text_tokens, 186 | text_token_lens, 187 | durations, 188 | duration_lens, 189 | latents, 190 | latent_lens, 191 | prompts, 192 | prompt_masks, 193 | lang_ids=lang_ids, 194 | ) 195 | 196 | self.log("val/enc_loss", enc_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 197 | self.log("val/duration_loss", duration_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 198 | self.log("val/flow_matching_loss", flow_matching_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 199 | self.log( 200 | "val/latent_loss", enc_loss + flow_matching_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True 201 | ) 202 | loss = enc_loss + duration_loss + flow_matching_loss 203 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 204 | 205 | if lang_loss is not None: 206 | # just for check 207 | self.log("train/lang_loss", lang_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 208 | 209 | return loss 210 | 211 | @torch.inference_mode() 212 | def log_audio(self): 213 | if self.global_rank != 0: 214 | return 215 | 216 | self.net.eval() 217 | codec = self.codec[0] 218 | writer = self.logger.experiment 219 | 220 | def write_to_tb(codec_latent: torch.Tensor, name: str): 221 | # denormalize 222 | codec_latent = codec_latent * self.std + self.mean 223 | with torch.amp.autocast(device_type="cuda", enabled=False): 224 | recon = codec.decode_latents(codec_latent.to(device=codec.device, dtype=codec.dtype)) 225 | signal = AudioSignal(recon, sample_rate=codec.sample_rate).float().ensure_max_of_audio() 226 | signal.write_audio_to_tb(name, writer, self.global_step) 227 | 228 | if self.first_sample: 229 | self.first_sample = False 230 | for idx, sample_idx in enumerate(self.sample_idx): 231 | sample = self.trainer.datamodule.val_ds[sample_idx] 232 | if len(sample) == 3: 233 | _, _, latent = sample 234 | else: 235 | _, _, latent, _ = sample 236 | write_to_tb(latent, f"recon/sample_{idx}.wav") 237 | 238 | # sample with gt duration 239 | for idx, sample_idx in enumerate(self.sample_idx): 240 | sample = self.trainer.datamodule.val_ds[sample_idx] 241 | if len(sample) == 3: 242 | text_token, duration, latent = sample 243 | lang_id = None 244 | else: 245 | text_token, duration, latent, lang = sample 246 | lang_id = lang.unsqueeze(0).to(self.device) 247 | start_idx = torch.randint(0, latent.shape[-2] - self.prompt_length, (1,)) 248 | prompt = latent[:, start_idx : start_idx + self.prompt_length] 249 | sampled = self.net.generate( 250 | text_token.to(self.device), 251 | prompt.to(device=self.device, dtype=self.dtype), 252 | duration.to(self.device), 253 | upscale_ratio=self.text2latent_ratio, 254 | lang_ids=lang_id, 255 | ) 256 | write_to_tb(sampled, f"sampled/gt_dur_{idx}.wav") 257 | 258 | # sample with pred duration 259 | for idx, sample_idx in enumerate(self.sample_idx): 260 | sample = self.trainer.datamodule.val_ds[sample_idx] 261 | if len(sample) == 3: 262 | text_token, duration, latent = sample 263 | lang_id = None 264 | else: 265 | text_token, duration, latent, lang = sample 266 | lang_id = lang.unsqueeze(0).to(self.device) 267 | start_idx = torch.randint(0, latent.shape[-2] - self.prompt_length, (1,)) 268 | prompt = latent[:, start_idx : start_idx + self.prompt_length] 269 | sampled = self.net.generate( 270 | text_token.to(self.device), 271 | prompt.to(device=self.device, dtype=self.dtype), 272 | upscale_ratio=self.text2latent_ratio, 273 | lang_ids=lang_id, 274 | ) 275 | write_to_tb(sampled, f"sampled/pred_dur_{idx}.wav") 276 | 277 | self.net.train() 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PFlow Encodec 2 | 3 | Implementation of TTS based on paper [P-Flow: A Fast and Data-Efficient Zero-Shot TTS through Speech Prompting](https://openreview.net/pdf?id=zNA7u7wtIN). You can check main differences between implementation and paper in [Differences](#difference-from-paper) section. 4 | 5 | # Main goal of this project 6 | 7 | I have two goals to achieve in this project. It seems work but, really poor at Japanese and numbers. 8 | 9 | - First, I want to test character-based input with [SeamlessM4T](https://arxiv.org/abs/2308.11596)'s [Aligner](https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/unity2_aligner_README.md) for English, Korean, Japanese and other languages. but, mainly for three languages mentioned above. 10 | - Second, zero-shot multilingual TTS model. since this model will be trained with sentencepiece tokenizer input, it does not need phonemizer. so, it would be easily adapted to other languages tokenizer supports. check out supported languages of tokenizer [here](https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/nar_t2u_aligner.yaml) 11 | 12 | # Samples 13 | 14 | Generated Samples from model trained on LibriTTS-R, korean and japanese corpus of AIHub 131 datasets. All samples are decoded with MultiBand-Diffusion model from [AudioCraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MBD.md). 15 | Pretrained checkpoint used here is available on [huggingface](https://huggingface.co/seastar105/pflow-encodec-ejk). 16 | 17 | you can check how to use it in [sample notebook](https://github.com/seastar105/pflow-encodec/blob/main/notebooks/generate.ipynb). 18 | 19 | Currently, speaker embedding of multi-lingual model seems to be highly entangled with language info. it shows worse zero-shot capability. I'm planning to train new model with language ID to reduce language bias in speaker embedding. 20 | 21 | Code-switch Text: There's famous japanese sentence, つきがきれいですね, which means 나는 당신을 사랑합니다. 22 | 23 | English Prompt Generation 24 | 25 | https://github.com/seastar105/pflow-encodec/assets/30820469/57a0450b-e1b2-48b6-b0ec-9433806edb10 26 | 27 | Japanese Prompt Generation 28 | 29 | https://github.com/seastar105/pflow-encodec/assets/30820469/bf5e4c29-2545-411a-adbc-b461a5c2cefa 30 | 31 | Korean Prompt Generation 32 | 33 | https://github.com/seastar105/pflow-encodec/assets/30820469/74f2ff7a-554d-4797-9841-a8b7b74d9fbf 34 | 35 | English Text: P-Flow encodec is Text-to-Speech model trained on Encodec latent space, using Flow Matching. 36 | 37 | Prompt Audio (from [LibriTTS-R](https://www.openslr.org/141/)) 38 | 39 | https://github.com/seastar105/pflow-encodec/assets/30820469/a3c1b3d8-ea94-4cb7-bd21-7226e3fd54b1 40 | 41 | Generated Audio 42 | 43 | https://github.com/seastar105/pflow-encodec/assets/30820469/1de00f81-4c87-402e-a4bc-66deb29c194d 44 | 45 | Japanese Text: こんにちは、初めまして。あなたの名前はなんですか?これは音声合成モデルから作られた音声です。 46 | 47 | Prompt Audio (from [JSUT](https://sites.google.com/site/shinnosuketakamichi/publication/jsut)) 48 | 49 | https://github.com/seastar105/pflow-encodec/assets/30820469/fb4f1a10-fb8b-413e-8bec-d1d0f58d8423 50 | 51 | Generated Audio 52 | 53 | https://github.com/seastar105/pflow-encodec/assets/30820469/137d4e34-f674-4681-a652-93c4a44f4554 54 | 55 | Korean Text: 백남준은 미디어 아트의 개척자로서 다양한 테크놀로지를 이용하여 실험적이고 창의적으로 작업했다. 56 | 57 | Prompt Audio (from [KSS](https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset)) 58 | 59 | https://github.com/seastar105/pflow-encodec/assets/30820469/db3435d0-8e8f-45ef-b3b3-a164ad316d71 60 | 61 | Generated Audio 62 | 63 | https://github.com/seastar105/pflow-encodec/assets/30820469/8dff38ec-a2d7-49a6-80fb-de6012b33a1b 64 | 65 | # Environment Setup 66 | 67 | I've developed in WSL, Windows 11. I have not tested on other platforms and torch version. I recommend using conda environment. 68 | 69 | ```bash 70 | conda create -n pflow-encodec -y python=3.10 71 | conda activate pflow-encodec 72 | conda install -y pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia 73 | conda install -y -c conda-forge libsndfile==1.0.31 74 | pip install -r requirements.txt 75 | pip install -r infer-requirements.txt 76 | ``` 77 | 78 | # Dataset Preparation 79 | 80 | ## meta tsv file 81 | 82 | First of all, you need to prepare tsv file, which contains three columns: `audio_path`, `text`, `duration`. each column is separated by tab. 83 | 84 | `audio_path` is path to audio file, `text` is transcript of audio file, and `duration` is duration of audio file in seconds. 85 | 86 | ### Example 87 | 88 | ```tsv 89 | audio_path text duration 90 | /path/to/audio1.wav Hello, World! 1.5 91 | /path/to/audio2.wav 안녕하세요, 세계! 2.0 92 | /path/to/audio3.wav こんにちは、世界! 2.5 93 | ``` 94 | 95 | ## Dump encodec latent and sentencepiece token durations 96 | 97 | Here, use encodec latent as output, and duration per token as target of duration predictor. 98 | 99 | you can dump encodec latent and sentencepiece token durations with following command. 100 | 101 | ```bash 102 | python scripts/dump_durations.py --input_tsv 103 | python scripts/dump_latents.py --input_tsv 104 | ``` 105 | 106 | this command requires GPU and `scripts/dump_durations.py` may require more than 8GB of GPU memory. 107 | 108 | `scripts/dump_durations.py` takes about 6 hours for 1000 hours of audio files. `scripts/dump_latents.py` takes about 4 hours for 1000 hours of audio files. both time was measured on RTX 4090. 109 | 110 | each script will make two files per audio file: 111 | `.latent.npy` and `.duration.npy`. 112 | 113 | **NOTE: `scripts/dump_latents.py` will print out global mean and std of dataset's latent. You should keep it since this value is used for training model.** 114 | 115 | Now, you can start training. 116 | 117 | # Train Model 118 | 119 | Repository's code is based on [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template). 120 | 121 | After preparing dataset, you can start training after setting dataset config and experiment config. Let your dataset name be `new_dataset`. first you need to set dataset config in `configs/data/new_dataset.yaml`. 122 | 123 | ```yaml 124 | _target_: pflow_encodec.data.datamodule.TextLatentLightningDataModule 125 | 126 | train_tsv_path: 127 | val_tsv_path: 128 | add_trailing_silence: True 129 | batch_durations: 50.0 # mini-batch duration in seconds 130 | min_duration: 3.5 # minimum duration of files, this value MUST be bigger than 3.0 131 | max_duration: 15.0 132 | boundaries: [3.0, 5.0, 7.0, 10.0, 15.0] 133 | num_workers: 8 134 | return_upsampled: False 135 | max_frame: 1500 # 20s 136 | text2latent_rate: 1.5 # 50Hz:75Hz 137 | mean: 138 | std: 139 | ``` 140 | 141 | fill ``, ``, ``, and `` with your dataset's meta path and mean/std values. 142 | 143 | then, create config in `configs/experiment/new_dataset.yaml` based on `configs/experiment/default.yaml`. 144 | 145 | ```yaml 146 | # @package _global_ 147 | 148 | defaults: 149 | - override /data: new_dataset.yaml # your dataset config name here!!! 150 | - override /model: pflow_base.yaml 151 | - override /callbacks: default.yaml 152 | - override /trainer: gpu.yaml 153 | - override /logger: tensorboard.yaml 154 | 155 | task_name: pflow 156 | tags: ["pflow"] 157 | seed: 998244353 158 | test: False 159 | 160 | callbacks: 161 | val_checkpoint: 162 | filename: "val_latent_loss_{val/latent_loss:.4f}-{step:06d}" 163 | monitor: val/latent_loss 164 | mode: "min" 165 | model: 166 | scheduler: 167 | total_steps: ${trainer.max_steps} 168 | pct_start: 0.02 169 | sample_freq: 5000 170 | sample_idx: [] # sample indices used for sampling while train. idx will be used to choose samples from validation dataset. so this value should not be greater than len(val_dataset) 171 | mean: ${data.mean} 172 | std: ${data.std} 173 | trainer: 174 | max_steps: 500000 175 | max_epochs: 10000 # arbitrary large number 176 | precision: bf16-mixed # you should check if your GPU supports bf16 177 | accumulate_grad_batches: 4 # effective batch size 178 | gradient_clip_val: 0.2 179 | num_nodes: 1 180 | devices: 1 181 | hydra: 182 | run: 183 | dir: 184 | ``` 185 | 186 | now you can run training with following command. 187 | 188 | ```bash 189 | python pflow_encodec/train.py experiment=new_dataset 190 | ``` 191 | 192 | **NOTE: If you want to train model with multiple GPUs, you should adjust trainer.num_nodes and trainer.devices in experiment config. Also you should set trainer.use_distributed_sampler to be False. For more detailed information, check out Pytorch Lightning's documents.** 193 | 194 | Example of single node 4 gpus 195 | 196 | ``` 197 | trainer: 198 | num_nodes: 1 199 | devices: 4 200 | use_distributed_sampler: False 201 | ``` 202 | 203 | # Pre-trained models 204 | 205 | | Language | Weights | Model Card | 206 | | ----------------- | ----------------------------------------------------------------------------- | --------------------------------------------------------------------------- | 207 | | MultiLingual(EJK) | [🤗 Hub](https://huggingface.co/seastar105/pflow-encodec-ejk) | [Link](https://github.com/seastar105/pflow-encodec/blob/main/MODEL_CARD.md) | 208 | | English | [🤗 Hub](https://huggingface.co/seastar105/pflow-encodec-libritts) | | 209 | | Japanese | [🤗 Hub](https://huggingface.co/seastar105/pflow-encodec-aihub-libri-japanese) | | 210 | | Korean | [🤗 Hub](https://huggingface.co/seastar105/pflow-encodec-aihub-libri-korean) | | 211 | 212 | # TODO 213 | 214 | - [x] Implement baseline model. 215 | - [x] Train model on libritts-r. 216 | - [ ] Simple gradio demo. 217 | - [x] Dataset preparation documentation. 218 | - [x] Train model on another language, i'm planning to train on Korean and Japanese. 219 | - [x] Multilingual model. 220 | - [ ] Test Language ID embedding in Text Encoder for Multilingual Model 221 | - [ ] Train small bert with SeamlessM4T's tokenizer then apply it to Text Encoder. 222 | 223 | # Difference from paper 224 | 225 | I did not conduct ablation studies for each changes due to lack of resources. 226 | 227 | - Use [Encodec](https://github.com/facebookresearch/audiocraft/blob/main/docs/ENCODEC.md) instead of MelSpectrogram. 228 | - Use character-base input instead of phoneme, and GT duration as a target of duration predictor instead of MAS. 229 | - Use AdaLN-Zero from [DiT](https://arxiv.org/abs/2212.09748) for speaker-conditioned text encoder instead of concat and self-attention. 230 | - Use transformer as Flow Matching decoder instead of Wavenet blocks with AdaLN-Single timestep conditioning from [PixArt-α](https://arxiv.org/abs/2310.00426) 231 | - Use attention pooling instead of mean pooling to get fixed-size speaker embedding as P-Flow used in their ablation study. 232 | - Use conv-feedforward(FFT Block from Fastspeech) and GeGLU 233 | - Use Alibi + Convolution positional encoding in transformer, from data2vec 2.0 and voicebox 234 | - Use null cond for CFG sampling instead of mean-pooled hidden vectors. 235 | 236 | # Credits 237 | 238 | - I borrowed some code from [VITS repo](https://github.com/jaywalnut310/vits), [voicebox-pytorch](https://github.com/lucidrains/voicebox-pytorch), and [fairseq2](https://github.com/facebookresearch/fairseq2). 239 | - This research used datasets from 'The Open AI Dataset Project (AI-Hub, S. Korea)'. All data information can be accessed through 'AI-Hub (www.aihub.or.kr) 240 | -------------------------------------------------------------------------------- /notebooks/generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import hydra\n", 10 | "import torch\n", 11 | "from audiocraft.models import MultiBandDiffusion\n", 12 | "from audiotools import AudioSignal\n", 13 | "from huggingface_hub import hf_hub_download\n", 14 | "\n", 15 | "from pflow_encodec.data.tokenizer import EncodecTokenizer, TextTokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def load_model(ckpt_path, device=\"cpu\"):\n", 25 | " ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n", 26 | "\n", 27 | " model = hydra.utils.instantiate(ckpt[\"model_config\"])\n", 28 | " model.load_state_dict(ckpt[\"state_dict\"])\n", 29 | " model = model.eval().to(device)\n", 30 | "\n", 31 | " return model, ckpt[\"data_config\"]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "ckpt_path = hf_hub_download(repo_id=\"seastar105/pflow-encodec-ejk\", filename=\"multilingual_base_bs100x4.ckpt\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "model, data_config = load_model(ckpt_path, \"cuda\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "e_prompt = hf_hub_download(repo_id=\"seastar105/pflow-encodec-ejk\", filename=\"samples/libritts_r_prompt.wav\")\n", 59 | "j_prompt = hf_hub_download(repo_id=\"seastar105/pflow-encodec-ejk\", filename=\"samples/jsut_prompt.wav\")\n", 60 | "k_prompt = hf_hub_download(repo_id=\"seastar105/pflow-encodec-ejk\", filename=\"samples/kss_prompt.wav\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "text_tokenizer = TextTokenizer()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "encodec_tokenizer = EncodecTokenizer()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "mbd_model = MultiBandDiffusion.get_mbd_24khz(bw=6)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "@torch.inference_mode()\n", 97 | "def pflow_inference(\n", 98 | " model, text, prompt_path, data_config, cfg_scale=1.0, n_steps=16, ode_method=\"midpoint\", return_latent=False\n", 99 | "):\n", 100 | " device = next(model.parameters()).device\n", 101 | " prompt = encodec_tokenizer.encode_file(prompt_path).to(device)\n", 102 | " mean = data_config[\"mean\"]\n", 103 | " std = data_config[\"std\"]\n", 104 | " upscale_ratio = data_config[\"text2latent_ratio\"]\n", 105 | "\n", 106 | " text_token = text_tokenizer.encode_text(text).to(device).unsqueeze(0)\n", 107 | " prompt = (prompt - mean) / std\n", 108 | " result = model.generate(\n", 109 | " text_token, prompt, cfg_scale=cfg_scale, n_steps=n_steps, ode_method=ode_method, upscale_ratio=upscale_ratio\n", 110 | " )\n", 111 | " result = result * std + mean\n", 112 | " if return_latent:\n", 113 | " return result.cpu()\n", 114 | " recon = encodec_tokenizer.decode_latents(result.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 115 | " return recon.cpu()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "@torch.inference_mode()\n", 125 | "def mbd_decode(mbd_model, latent):\n", 126 | " codes = encodec_tokenizer.quantize_latents(latent.to(device=encodec_tokenizer.device))\n", 127 | " recon = mbd_model.tokens_to_wav(codes[:, :8, :])\n", 128 | " return recon.cpu()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "e_text = \"P-Flow encodec is Text-to-Speech model trained on Encodec latent space, using Flow Matching.\"" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "latents = pflow_inference(\n", 147 | " model, e_text, e_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 148 | ")\n", 149 | "pflow_result = (\n", 150 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 151 | " .detach()\n", 152 | " .cpu()\n", 153 | ")\n", 154 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 155 | "pflow_signal.embed()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 165 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 166 | "mbd_signal.embed()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "j_text = \"こんにちは、初めまして。あなたの名前はなんですか?これは音声合成モデルから作られた音声です。\"" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "latents = pflow_inference(\n", 185 | " model, j_text, j_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 186 | ")\n", 187 | "pflow_result = (\n", 188 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 189 | " .detach()\n", 190 | " .cpu()\n", 191 | ")\n", 192 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 193 | "pflow_signal.embed()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 203 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 204 | "mbd_signal.embed()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "k_text = \"백남준은 미디어 아트의 개척자로서 다양한 테크놀로지를 이용하여 실험적이고 창의적으로 작업했다.\"" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "latents = pflow_inference(\n", 223 | " model, k_text, k_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 224 | ")\n", 225 | "pflow_result = (\n", 226 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 227 | " .detach()\n", 228 | " .cpu()\n", 229 | ")\n", 230 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 231 | "pflow_signal.embed()" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 241 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 242 | "mbd_signal.embed()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "code_text = \"There's famous japanese sentence, つきがきれいですね, which means 나는 당신을 사랑합니다.\"" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "latents = pflow_inference(\n", 261 | " model, code_text, e_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 262 | ")\n", 263 | "pflow_result = (\n", 264 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 265 | " .detach()\n", 266 | " .cpu()\n", 267 | ")\n", 268 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 269 | "pflow_signal.embed()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 279 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 280 | "mbd_signal.embed()" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "latents = pflow_inference(\n", 290 | " model, code_text, j_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 291 | ")\n", 292 | "pflow_result = (\n", 293 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 294 | " .detach()\n", 295 | " .cpu()\n", 296 | ")\n", 297 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 298 | "pflow_signal.embed()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 308 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 309 | "mbd_signal.embed()" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "latents = pflow_inference(\n", 319 | " model, code_text, k_prompt, data_config, cfg_scale=1.2, n_steps=16, ode_method=\"midpoint\", return_latent=True\n", 320 | ")\n", 321 | "pflow_result = (\n", 322 | " encodec_tokenizer.decode_latents(latents.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))\n", 323 | " .detach()\n", 324 | " .cpu()\n", 325 | ")\n", 326 | "pflow_signal = AudioSignal(pflow_result, 24000).normalize(-23).ensure_max_of_audio()\n", 327 | "pflow_signal.embed()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "mbd_recon = mbd_decode(mbd_model, latents)\n", 337 | "mbd_signal = AudioSignal(mbd_recon, 24000).normalize(-23).ensure_max_of_audio()\n", 338 | "mbd_signal.embed()" 339 | ] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "pflow-encodec", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.10.13" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | -------------------------------------------------------------------------------- /pflow_encodec/models/pflow.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchdiffeq 7 | from einops import rearrange 8 | 9 | from pflow_encodec.modules import ( 10 | DurationPredictor, 11 | FlowMatchingTransformer, 12 | GradientReversal, 13 | SpeakerEncoder, 14 | TextEncoder, 15 | ) 16 | 17 | 18 | class PFlow(nn.Module): 19 | def __init__( 20 | self, 21 | feature_dim: int = 128, 22 | text_encoder_vocab_size: int = 10094, 23 | text_encoder_embed_dim: int = 192, 24 | text_encoder_conv_pos_depth: int = 2, 25 | text_encoder_conv_pos_kernel_size: int = 15, 26 | text_encoder_conv_pos_groups: int = 16, 27 | text_encoder_depth: int = 6, 28 | text_encoder_dim: int = 192, 29 | text_encoder_dim_head: int = 96, 30 | text_encoder_heads: int = 2, 31 | text_encoder_ff_mult: float = 4.0, 32 | text_encoder_attn_dropout: float = 0.1, 33 | text_encoder_ff_dropout: float = 0.0, 34 | text_encoder_attn_processor: str = "naive", 35 | text_encoder_norm_type: str = "ada_proj", 36 | text_encoder_ff_type: str = "conv", 37 | text_encoder_ff_kernel_size: int = 3, 38 | text_encoder_ff_groups: int = 1, 39 | text_encoder_scale_type: str = "ada_proj", 40 | speaker_encoder_dim_input: int = 128, 41 | speaker_encoder_conv_pos_depth: int = 2, 42 | speaker_encoder_conv_pos_kernel_size: int = 15, 43 | speaker_encoder_conv_pos_groups: int = 16, 44 | speaker_encoder_depth: int = 2, 45 | speaker_encoder_dim: int = 192, 46 | speaker_encoder_dim_head: int = 96, 47 | speaker_encoder_heads: int = 2, 48 | speaker_encoder_ff_mult: float = 4.0, 49 | speaker_encoder_attn_dropout: float = 0.1, 50 | speaker_encoder_ff_dropout: float = 0.0, 51 | speaker_encoder_attn_processor: str = "naive", 52 | speaker_encoder_norm_type: str = "layer", 53 | speaker_encoder_ff_type: str = "conv", 54 | speaker_encoder_ff_kernel_size: int = 3, 55 | speaker_encoder_ff_groups: int = 1, 56 | speaker_encoder_scale_type: str = "none", 57 | flow_matching_dim_time: int = 2048, 58 | flow_matching_conv_pos_kernel_size: int = 31, 59 | flow_matching_conv_pos_depth: int = 2, 60 | flow_matching_conv_pos_groups: int = 16, 61 | flow_matching_depth: int = 6, 62 | flow_matching_dim: int = 512, 63 | flow_matching_dim_head: int = 128, 64 | flow_matching_heads: int = 4, 65 | flow_matching_ff_mult: float = 4.0, 66 | flow_matching_attn_dropout: float = 0.1, 67 | flow_matching_ff_dropout: float = 0.0, 68 | flow_matching_attn_processor: str = "naive", 69 | flow_matching_norm_type: str = "ada_embed", 70 | flow_matching_ff_type: str = "conv", 71 | flow_matching_ff_kernel_size: int = 3, 72 | flow_matching_ff_groups: int = 2, 73 | flow_matching_scale_type: str = "ada_embed", 74 | duration_predictor_dim: int = 256, 75 | duration_predictor_depth: int = 2, 76 | duration_predictor_kernel_size: int = 3, 77 | duration_predictor_dropout: float = 0.1, 78 | p_uncond: float = 0.1, 79 | interpolate_mode: str = "linear", 80 | sigma: float = 0.01, # from pflow paper 81 | num_languages: int = 0, 82 | p_drop_lang: float = 0.2, 83 | ): 84 | super().__init__() 85 | 86 | self.text_encoder = TextEncoder( 87 | vocab_size=text_encoder_vocab_size, 88 | dim_text=text_encoder_embed_dim, 89 | dim_spk=speaker_encoder_dim, 90 | dim_output=feature_dim, 91 | conv_pos_kernel_size=text_encoder_conv_pos_kernel_size, 92 | conv_pos_depth=text_encoder_conv_pos_depth, 93 | conv_pos_groups=text_encoder_conv_pos_groups, 94 | depth=text_encoder_depth, 95 | dim=text_encoder_dim, 96 | dim_head=text_encoder_dim_head, 97 | heads=text_encoder_heads, 98 | ff_mult=text_encoder_ff_mult, 99 | attn_dropout=text_encoder_attn_dropout, 100 | ff_dropout=text_encoder_ff_dropout, 101 | attn_processor=text_encoder_attn_processor, 102 | norm_type=text_encoder_norm_type, 103 | ff_type=text_encoder_ff_type, 104 | ff_kernel_size=text_encoder_ff_kernel_size, 105 | ff_groups=text_encoder_ff_groups, 106 | scale_type=text_encoder_scale_type, 107 | ) 108 | 109 | self.spk_encoder = SpeakerEncoder( 110 | dim_input=speaker_encoder_dim_input, 111 | conv_pos_kernel_size=speaker_encoder_conv_pos_kernel_size, 112 | conv_pos_depth=speaker_encoder_conv_pos_depth, 113 | conv_pos_groups=speaker_encoder_conv_pos_groups, 114 | depth=speaker_encoder_depth, 115 | dim=speaker_encoder_dim, 116 | dim_head=speaker_encoder_dim_head, 117 | heads=speaker_encoder_heads, 118 | ff_mult=speaker_encoder_ff_mult, 119 | attn_dropout=speaker_encoder_attn_dropout, 120 | ff_dropout=speaker_encoder_ff_dropout, 121 | attn_processor=speaker_encoder_attn_processor, 122 | norm_type=speaker_encoder_norm_type, 123 | ff_type=speaker_encoder_ff_type, 124 | ff_kernel_size=speaker_encoder_ff_kernel_size, 125 | ff_groups=speaker_encoder_ff_groups, 126 | scale_type=speaker_encoder_scale_type, 127 | ) 128 | 129 | self.flow_matching_decoder = FlowMatchingTransformer( 130 | dim_input=feature_dim, 131 | dim_ctx=feature_dim, 132 | dim_output=feature_dim, 133 | dim_time=flow_matching_dim_time, 134 | conv_pos_kernel_size=flow_matching_conv_pos_kernel_size, 135 | conv_pos_depth=flow_matching_conv_pos_depth, 136 | conv_pos_groups=flow_matching_conv_pos_groups, 137 | depth=flow_matching_depth, 138 | dim=flow_matching_dim, 139 | dim_head=flow_matching_dim_head, 140 | heads=flow_matching_heads, 141 | ff_mult=flow_matching_ff_mult, 142 | attn_dropout=flow_matching_attn_dropout, 143 | ff_dropout=flow_matching_ff_dropout, 144 | attn_processor=flow_matching_attn_processor, 145 | norm_type=flow_matching_norm_type, 146 | ff_type=flow_matching_ff_type, 147 | ff_kernel_size=flow_matching_ff_kernel_size, 148 | ff_groups=flow_matching_ff_groups, 149 | scale_type=flow_matching_scale_type, 150 | ) 151 | 152 | self.duration_predictor = DurationPredictor( 153 | dim_input=text_encoder_embed_dim, 154 | dim=duration_predictor_dim, 155 | depth=duration_predictor_depth, 156 | kernel_size=duration_predictor_kernel_size, 157 | dropout=duration_predictor_dropout, 158 | ) 159 | if num_languages > 0: 160 | self.lang_emb = nn.Embedding(num_languages + 1, speaker_encoder_dim, padding_idx=num_languages) 161 | self.lang_head = nn.Sequential(GradientReversal(1.0), nn.Linear(speaker_encoder_dim, num_languages)) 162 | 163 | self.reset_parameters() 164 | 165 | self.p_uncond = p_uncond 166 | self.interpolate_mode = interpolate_mode 167 | self.sigma = sigma 168 | 169 | if num_languages > 0: 170 | self.num_languages = num_languages 171 | self.p_drop_lang = p_drop_lang 172 | 173 | def reset_parameters(self): 174 | def default_init(m): 175 | if isinstance(m, nn.Linear): 176 | nn.init.xavier_uniform_(m.weight) 177 | if m.bias is not None: 178 | nn.init.constant_(m.bias, 0) 179 | elif isinstance(m, nn.Conv1d): 180 | nn.init.xavier_uniform_(m.weight) 181 | if m.bias is not None: 182 | nn.init.constant_(m.bias, 0) 183 | elif isinstance(m, nn.Embedding): 184 | nn.init.normal_(m.weight, mean=0, std=0.02) 185 | 186 | self.apply(default_init) 187 | 188 | # init conv pos 189 | self.text_encoder.reset_parameters() 190 | self.spk_encoder.reset_parameters() 191 | self.flow_matching_decoder.reset_parameters() 192 | 193 | def length_to_attn_mask(self, lens: Optional[torch.Tensor] = None) -> torch.Tensor: 194 | if lens is None: 195 | return None 196 | attn_mask = torch.arange(lens.max()).to(lens.device) < lens.unsqueeze(-1) 197 | return attn_mask 198 | 199 | def length_regulator(self, embs, emb_lens, durations, duration_lens): 200 | # can we do it faster? unpad then expand then pad? https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py 201 | max_len = 0 202 | expanded = [] 203 | for emb, token_len, dur, dur_len in zip(embs, emb_lens, durations, duration_lens): 204 | emb = emb[:token_len, :] 205 | dur = dur[:dur_len] 206 | expanded.append(torch.repeat_interleave(emb, dur, dim=0).unsqueeze(0)) 207 | max_len = max(max_len, expanded[-1].shape[-2]) 208 | return torch.cat([F.pad(e, (0, 0, 0, max_len - e.shape[-2]), value=0) for e in expanded], dim=0) 209 | 210 | def duration_loss(self, duration_pred, durations, duration_lens): 211 | mask = self.length_to_attn_mask(duration_lens) 212 | pred = duration_pred[mask] 213 | target = durations[mask] 214 | log_target = torch.log1p(target) 215 | loss = F.mse_loss(pred, log_target) 216 | return loss 217 | 218 | def enc_loss(self, h, latents, latent_lens, prompt_masks): 219 | assert h.shape == latents.shape, f"Shape mismatch: {h.shape} != {latents.shape}" 220 | attn_mask = self.length_to_attn_mask(latent_lens) 221 | loss_mask = ~prompt_masks & attn_mask 222 | pred = h[loss_mask] 223 | target = latents[loss_mask] 224 | loss = F.mse_loss(pred, target) 225 | return loss 226 | 227 | @staticmethod 228 | def interpolate(h: torch.Tensor, latent: torch.Tensor, mode: str = "linear") -> torch.Tensor: 229 | assert mode in ["linear", "nearest"], f"Interpolation mode {mode} is not supported" 230 | latent_len = latent.shape[-2] 231 | h_len = h.shape[-2] 232 | if latent_len == h_len: 233 | return h 234 | h = rearrange(h, "b t c -> b c t") 235 | h = F.interpolate(h, size=latent_len, mode=mode) 236 | h = rearrange(h, "b c t -> b t c") 237 | return h 238 | 239 | def forward( 240 | self, 241 | text_tokens, 242 | text_token_lens, 243 | durations, 244 | duration_lens, 245 | latents, 246 | latent_lens, 247 | prompts, 248 | prompt_masks, 249 | lang_ids=None, 250 | ): 251 | # text encoder, speaker encoder 252 | spk_emb = self.spk_encoder(prompts) 253 | text_padding_mask = self.length_to_attn_mask(text_token_lens) 254 | lang_emb = None 255 | lang_loss = None 256 | if lang_ids is not None: 257 | if self.training: 258 | lang_logits = self.lang_head(spk_emb).squeeze(1) 259 | lang_loss = F.cross_entropy(lang_logits, lang_ids) 260 | batch_size = lang_ids.shape[0] 261 | lang_drop_mask = torch.rand((batch_size,)).to(lang_ids.device) < self.p_drop_lang 262 | lang_ids = lang_ids * ~lang_drop_mask 263 | lang_emb = self.lang_emb(lang_ids).unsqueeze(1) 264 | h, text_emb = self.text_encoder( 265 | text_tokens=text_tokens, spk_emb=spk_emb, lang_emb=lang_emb, padding_mask=text_padding_mask 266 | ) 267 | 268 | # duration predictor 269 | duration_pred = self.duration_predictor(text_emb.detach(), text_padding_mask) 270 | duration_loss = self.duration_loss(duration_pred, durations, duration_lens) 271 | 272 | # encoder loss 273 | h = self.length_regulator(h, text_token_lens, durations, duration_lens) 274 | h = self.interpolate(h, latents, mode=self.interpolate_mode) 275 | enc_loss = self.enc_loss(h, latents, latent_lens, prompt_masks) 276 | 277 | # flow matching 278 | times = torch.rand((h.shape[0],)).to(h.device) 279 | times = rearrange(times, "b -> b 1 1") 280 | x0 = torch.randn_like(latents) 281 | xt = (1 - (1 - self.sigma) * times) * x0 + times * latents 282 | flow = latents - (1 - self.sigma) * x0 283 | times = rearrange(times, "b 1 1 -> b") 284 | drop_cond = torch.rand((h.shape[0],)).to(h.device) < self.p_uncond 285 | x_ctx = h 286 | latent_padding_mask = self.length_to_attn_mask(latent_lens) 287 | 288 | vt = self.flow_matching_decoder( 289 | x=xt, x_ctx=x_ctx, times=times, padding_mask=latent_padding_mask, drop_ctx=drop_cond 290 | ) 291 | loss_mask = ~prompt_masks & latent_padding_mask 292 | flow_matching_loss = F.mse_loss(vt[loss_mask], flow[loss_mask]) 293 | 294 | return duration_loss, enc_loss, flow_matching_loss, lang_loss 295 | 296 | @torch.no_grad() 297 | def generate( 298 | self, 299 | text_tokens, 300 | prompts, 301 | durations=None, 302 | n_steps: int = 16, 303 | ode_method: str = "midpoint", 304 | cfg_scale: float = 1.0, 305 | upscale_ratio: float = 1.5, 306 | lang_ids=None, 307 | ): 308 | assert text_tokens.shape[0] == 1, "generation with batch size > 1 is not supported yet" 309 | spk_emb = self.spk_encoder(prompts) 310 | lang_emb = None 311 | if lang_ids is not None: 312 | lang_emb = self.lang_emb(lang_ids).unsqueeze(1) 313 | h, text_emb = self.text_encoder(text_tokens=text_tokens, spk_emb=spk_emb, lang_emb=lang_emb) 314 | 315 | if durations is None: 316 | duration_pred = self.duration_predictor(text_emb.detach()) 317 | durations = torch.expm1(duration_pred).clamp(min=1).ceil().long() 318 | 319 | h = torch.repeat_interleave(h, durations.squeeze(), dim=1) 320 | upscale_len = round(h.shape[-2] * upscale_ratio) 321 | h = rearrange(h, "b t c -> b c t") 322 | h = F.interpolate(h, size=upscale_len, mode=self.interpolate_mode) 323 | h = rearrange(h, "b c t -> b t c") 324 | 325 | def sample_fn(t, x_t): 326 | batch_size = x_t.shape[0] 327 | t = t.expand(batch_size) 328 | drop_cond = torch.zeros_like(t).bool() 329 | v = self.flow_matching_decoder(x_t, h, t, drop_ctx=drop_cond) 330 | if cfg_scale != 0: 331 | v_null = self.flow_matching_decoder(x_t, h, t, drop_ctx=~drop_cond) 332 | v = (1 + cfg_scale) * v - v_null 333 | 334 | return v 335 | 336 | times = torch.linspace(0, 1, n_steps + 1).to(h.device) 337 | x0 = torch.randn_like(h) 338 | traj = torchdiffeq.odeint(sample_fn, x0, times, atol=1e-4, rtol=1e-4, method=ode_method) 339 | return traj[-1] 340 | -------------------------------------------------------------------------------- /pflow_encodec/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Dict, Literal, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from einops import einsum, rearrange 9 | from torch.nn.utils import remove_weight_norm, weight_norm 10 | 11 | from pflow_encodec.utils.helper import exists 12 | 13 | 14 | class AdaptiveLayerNormProj(nn.Module): 15 | def __init__(self, dim: int, dim_cond: int, eps: float = 1e-6): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) 18 | 19 | self.scale = nn.Linear(dim_cond, dim) 20 | self.bias = nn.Linear(dim_cond, dim) 21 | 22 | def reset_parameters(self): 23 | nn.init.zeros_(self.scale.weight) 24 | nn.init.zeros_(self.scale.bias) 25 | nn.init.zeros_(self.bias.weight) 26 | nn.init.zeros_(self.bias.bias) 27 | 28 | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 29 | scale = self.scale(cond) 30 | bias = self.bias(cond) 31 | return self.norm(x) * (1 + scale) + bias 32 | 33 | 34 | class AdaptiveLayerNormEmbed(nn.Module): 35 | def __init__(self, dim: int, eps: float = 1e-6): 36 | super().__init__() 37 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) 38 | 39 | self.scale = nn.Parameter(torch.randn(1, 1, dim) / dim**0.5) 40 | self.bias = nn.Parameter(torch.zeros(1, 1, dim)) 41 | 42 | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 43 | scale, bias = cond.chunk(2, dim=-1) 44 | scale = self.scale + scale 45 | bias = self.bias + bias 46 | return self.norm(x) * (1 + scale) + bias 47 | 48 | 49 | class AdaptiveScaleProj(nn.Module): 50 | def __init__(self, dim: int, dim_cond: int): 51 | super().__init__() 52 | self.scale = nn.Linear(dim_cond, dim) 53 | 54 | def reset_parameters(self): 55 | nn.init.zeros_(self.scale.weight) 56 | nn.init.zeros_(self.scale.bias) 57 | 58 | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 59 | scale = self.scale(cond) 60 | return x * scale 61 | 62 | 63 | class AdaptiveScaleEmbed(nn.Module): 64 | def __init__(self, dim: int): 65 | super().__init__() 66 | self.scale = nn.Parameter(torch.randn(1, 1, dim) / dim**0.5) 67 | 68 | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 69 | scale = self.scale + cond 70 | return x * scale 71 | 72 | 73 | class GEGLU(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def forward(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: 78 | x, gate = x.chunk(2, dim=dim) 79 | return F.gelu(gate) * x 80 | 81 | 82 | class ConvFeedForward(nn.Module): 83 | def __init__(self, dim: int, mult: float, kernel_size: int, groups: int = 1, dropout: float = 0.0): 84 | super().__init__() 85 | intermediate_dim = int(dim * mult * 3 / 4) 86 | self.conv1 = nn.Conv1d(dim, 2 * intermediate_dim, kernel_size, padding="same", groups=groups) 87 | self.act = GEGLU() 88 | self.dropout = nn.Dropout(dropout) 89 | self.conv2 = nn.Conv1d(intermediate_dim, dim, kernel_size, padding="same", groups=groups) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = rearrange(x, "b t d -> b d t") 93 | x = self.conv1(x) 94 | x = self.act(x, dim=1) 95 | x = self.dropout(x) 96 | x = self.conv2(x) 97 | x = rearrange(x, "b d t -> b t d") 98 | return x 99 | 100 | 101 | class FeedForward(nn.Module): 102 | def __init__(self, dim: int, mult: float, dropout: float = 0.0): 103 | super().__init__() 104 | intermediate_dim = int(dim * mult * 2 / 3) 105 | self.proj1 = nn.Linear(dim, 2 * intermediate_dim) 106 | self.act = GEGLU() 107 | self.dropout = nn.Dropout(dropout) 108 | self.proj2 = nn.Linear(intermediate_dim, dim) 109 | 110 | def forward(self, x: torch.Tensor) -> torch.Tensor: 111 | x = self.proj1(x) 112 | x = self.act(x) 113 | x = self.dropout(x) 114 | x = self.proj2(x) 115 | return x 116 | 117 | 118 | class MultiHeadAttention(nn.Module): 119 | def __init__( 120 | self, 121 | dim: int, 122 | dim_head: int, 123 | heads: int, 124 | dim_context: Optional[int] = None, 125 | scale: Optional[float] = None, 126 | dropout: float = 0.0, 127 | processor: Literal["naive", "sdpa", "flash"] = "naive", 128 | ): 129 | super().__init__() 130 | self.dim = dim 131 | self.dim_head = dim_head 132 | self.dim_context = dim_context if exists(dim_context) else dim 133 | self.scale = scale if exists(scale) else dim_head ** -0.5 134 | self.processor = processor 135 | self.heads = heads 136 | 137 | inner_dim = dim_head * heads 138 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 139 | self.to_k = nn.Linear(self.dim_context, inner_dim, bias=False) 140 | self.to_v = nn.Linear(self.dim_context, inner_dim, bias=False) 141 | self.dropout = nn.Dropout(dropout) # apply to attn score 142 | 143 | self.to_out = nn.Linear(inner_dim, dim) 144 | 145 | self.attn_processor_dict = { 146 | "naive": self.naive_attention, 147 | "sdpa": self.sdpa_attention, 148 | } 149 | 150 | if self.processor not in self.attn_processor_dict: 151 | raise NotImplementedError(f"processor {self.processor} is not implemented yet") 152 | 153 | def process_attn_mask_bias(self, mask, bias): 154 | if not exists(bias): 155 | return mask, False 156 | 157 | if exists(mask): 158 | bias = bias.masked_fill(~mask, -torch.finfo(bias.dtype).max) 159 | return bias, True 160 | 161 | def naive_attention(self, q, k, v, mask, bias, **attn_kwargs): 162 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)) 163 | 164 | attn_mask, is_bias = self.process_attn_mask_bias(mask, bias) 165 | dots = einsum(q, k, "b h i d, b h j d -> b h i j") * self.scale 166 | 167 | if exists(attn_mask): 168 | if is_bias: 169 | dots = dots + attn_mask 170 | else: 171 | dots.masked_fill_(~attn_mask, -torch.finfo(dots.dtype).max) 172 | 173 | attn = dots.softmax(dim=-1) 174 | attn = self.dropout(attn) 175 | out = einsum(attn, v, "b h i j, b h j d -> b h i d") 176 | out = rearrange(out, "b h n d -> b n (h d)") 177 | 178 | return out 179 | 180 | def sdpa_attention(self, q, k, v, mask, bias, **attn_kwargs): 181 | if not hasattr(F, "scaled_dot_product_attention"): 182 | raise RuntimeError( 183 | "torch.nn.functional.scaled_dot_product_attention is not available. Please upgrade to PyTorch 2.0.0 or later." 184 | ) 185 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)) 186 | 187 | attn_mask, _ = self.process_attn_mask_bias(mask, bias) 188 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p) 189 | out = rearrange(out, "b h n d -> b n (h d)") 190 | return out 191 | 192 | def get_attn_processor(self, processor): 193 | assert processor in self.attn_processor_dict, f"processor {processor} is not implemented yet" 194 | return self.attn_processor_dict[processor] 195 | 196 | def forward( 197 | self, 198 | x: torch.Tensor, 199 | context: Optional[torch.Tensor] = None, 200 | mask: Optional[torch.Tensor] = None, 201 | bias: Optional[torch.Tensor] = None, 202 | **attn_kwargs, 203 | ): 204 | if not exists(context): 205 | context = x 206 | 207 | b, t, d = x.shape 208 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 209 | 210 | attn_output = self.get_attn_processor(self.processor)(q, k, v, mask, bias, **attn_kwargs) 211 | 212 | return self.to_out(attn_output) 213 | 214 | 215 | # code from https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/models/wav2vec2/position_encoder.py 216 | class Wav2Vec2PositionEncoderLayer(nn.Module): 217 | def __init__( 218 | self, 219 | dim: int, 220 | kernel_size: int, 221 | groups: int, 222 | ) -> None: 223 | super().__init__() 224 | 225 | self.conv = nn.Conv1d( 226 | dim, 227 | dim, 228 | kernel_size, 229 | padding="same", 230 | groups=groups, 231 | ) 232 | 233 | self.layer_norm = nn.LayerNorm(dim) 234 | self.activation = nn.GELU() 235 | 236 | def forward(self, encodings: torch.Tensor) -> torch.Tensor: 237 | encodings = self.conv(encodings) 238 | 239 | encodings = encodings.transpose(1, 2) # (B, D, T) -> (B, T, D) 240 | encodings = self.layer_norm(encodings) 241 | encodings = encodings.transpose(1, 2) # (B, T, D) -> (B, D, T) 242 | 243 | encodings = self.activation(encodings) 244 | return encodings 245 | 246 | 247 | class Wav2Vec2StackedPositionEncoder(nn.Module): 248 | def __init__( 249 | self, 250 | depth: int, 251 | dim: int, 252 | kernel_size: int, 253 | groups: int, 254 | ) -> None: 255 | super().__init__() 256 | 257 | k = max(3, kernel_size // depth) 258 | 259 | self.layers = nn.Sequential() 260 | 261 | for _ in range(depth): 262 | layer = Wav2Vec2PositionEncoderLayer( 263 | dim, 264 | k, 265 | groups, 266 | ) 267 | 268 | self.layers.append(layer) 269 | 270 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 271 | if exists(mask): 272 | mask = mask[..., None] 273 | x = x.masked_fill(~mask, 0.0) 274 | 275 | x = x.transpose(1, 2) 276 | x = self.layers(x) 277 | x = x.transpose(1, 2) 278 | 279 | if exists(mask): 280 | x = x.masked_fill(~mask, 0.0) 281 | 282 | return x 283 | 284 | def reset_parameters(self): 285 | def init_(m): 286 | if isinstance(m, nn.Conv1d): 287 | model_dim, kernel_size = m.in_channels, m.kernel_size[0] 288 | try: 289 | remove_weight_norm(m) 290 | except ValueError: 291 | # Raised during the `__init__` call since we don't have the weight 292 | # norm hook registered yet. Safe to ignore. 293 | pass 294 | 295 | nn.init.normal_(m.weight, mean=0.0, std=(4.0 / (kernel_size * model_dim)) ** 0.5) 296 | 297 | weight_norm(m, dim=2) 298 | 299 | if m.bias is not None: 300 | nn.init.constant_(m.bias, 0.0) 301 | 302 | self.apply(init_) 303 | 304 | 305 | class AlibiPositionalBias(nn.Module): 306 | def __init__(self, heads: int): 307 | super().__init__() 308 | self.heads = heads 309 | 310 | slopes = torch.Tensor(self._get_slopes(heads)) 311 | slopes = rearrange(slopes, "h -> 1 h 1 1") 312 | self.register_buffer("slopes", slopes, persistent=False) 313 | self.register_buffer("bias", None, persistent=False) 314 | 315 | def get_bias(self, seq_len: int): 316 | i_arange = torch.arange(seq_len, device=self.device) 317 | j_arange = torch.arange(seq_len, device=self.device) 318 | bias = -torch.abs(rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1")) 319 | return bias.unsqueeze(0) 320 | 321 | @staticmethod 322 | def _get_slopes(heads): 323 | def get_slopes_power_of_2(n): 324 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 325 | ratio = start 326 | return [start * ratio**i for i in range(n)] 327 | 328 | if math.log2(heads).is_integer(): 329 | return get_slopes_power_of_2(heads) 330 | 331 | closest_power_of_2 = 2 ** math.floor(math.log2(heads)) 332 | return ( 333 | get_slopes_power_of_2(closest_power_of_2) 334 | + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2] 335 | ) 336 | 337 | @property 338 | def device(self): 339 | return next(self.buffers()).device 340 | 341 | def forward(self, seq_len: int): 342 | if exists(self.bias) and self.bias.shape[-1] >= seq_len: 343 | return self.bias[..., -seq_len:, -seq_len:] 344 | 345 | bias = self.get_bias(seq_len) 346 | bias = bias * self.slopes 347 | 348 | self.register_buffer("bias", bias, persistent=False) 349 | 350 | return self.bias 351 | 352 | 353 | class Transformer(nn.Module): 354 | def __init__( 355 | self, 356 | depth: int, 357 | dim: int, 358 | dim_head: int, 359 | heads: int, 360 | ff_mult: float, 361 | attn_dropout: float, 362 | ff_dropout: float, 363 | dim_cond: Optional[int] = None, 364 | attn_processor: Literal["naive", "sdpa"] = "naive", 365 | norm_type: Literal["layer", "ada_proj", "ada_embed"] = "layer", 366 | ff_type: Literal["conv", "linear"] = "linear", 367 | ff_kernel_size: Optional[int] = None, 368 | ff_groups: Optional[int] = None, 369 | layer_norm_eps: float = 1e-6, 370 | scale_type: Literal["none", "ada_proj", "ada_embed"] = "none", 371 | use_skip_connection: bool = False, 372 | dim_final_norm_cond: Optional[int] = None, 373 | ): 374 | super().__init__() 375 | self.layers = nn.ModuleList([]) 376 | 377 | self.norm_type = norm_type 378 | norm_class = self.get_norm_class(norm_type, dim_cond) 379 | 380 | self.ff_type = ff_type 381 | ff_class = self.get_ff_class(ff_type, ff_kernel_size, ff_groups) 382 | 383 | self.scale_type = scale_type 384 | if self.scale_type != "none": 385 | assert ( 386 | self.norm_type == self.scale_type 387 | ), f"norm type {self.norm_type} and scale type {self.scale_type} must be the same" 388 | scale_class = self.get_scale_class(scale_type, dim, dim_cond) 389 | 390 | self.layers = nn.ModuleList([]) 391 | for ind in range(depth): 392 | layer = ind + 1 393 | has_skip = use_skip_connection and layer > (depth // 2) 394 | self.layers.append( 395 | nn.ModuleList( 396 | [ 397 | nn.Linear(dim * 2, dim) if has_skip else None, 398 | norm_class(dim, eps=layer_norm_eps), 399 | MultiHeadAttention( 400 | dim=dim, 401 | dim_head=dim_head, 402 | heads=heads, 403 | scale=None, 404 | dropout=attn_dropout, 405 | processor=attn_processor, 406 | ), 407 | scale_class(), 408 | norm_class(dim, eps=layer_norm_eps), 409 | ff_class(dim=dim, mult=ff_mult, dropout=ff_dropout), 410 | scale_class(), 411 | ] 412 | ) 413 | ) 414 | 415 | if self.norm_type == "ada_embed": 416 | assert exists(dim_final_norm_cond), "dim_final_norm_cond must be provided when using ada_embed" 417 | 418 | self.final_norm = ( 419 | nn.LayerNorm(dim, eps=layer_norm_eps) 420 | if self.norm_type == "layer" 421 | else AdaptiveLayerNormProj( 422 | dim, dim_cond=dim_cond if self.norm_type == "ada_proj" else dim_final_norm_cond, eps=layer_norm_eps 423 | ) 424 | ) 425 | 426 | def reset_adaln_parameters(self): 427 | def init_(m): 428 | if isinstance(m, AdaptiveLayerNormProj): 429 | m.reset_parameters() 430 | 431 | self.apply(init_) 432 | 433 | @staticmethod 434 | def expand_mask(mask: Optional[torch.Tensor] = None): 435 | if exists(mask): 436 | if mask.ndim == 2: # B L 437 | mask = rearrange(mask, "b j -> b 1 1 j") 438 | elif mask.ndim == 3: # B q_len k_len 439 | mask = rearrange(mask, "b i j -> b 1 i j") 440 | return mask 441 | 442 | @staticmethod 443 | def get_norm_class(norm_type, dim_cond): 444 | if norm_type == "layer": 445 | return nn.LayerNorm 446 | elif norm_type == "ada_proj": 447 | return partial(AdaptiveLayerNormProj, dim_cond=dim_cond) 448 | elif norm_type == "ada_embed": 449 | return AdaptiveLayerNormEmbed 450 | else: 451 | raise NotImplementedError(f"norm type {norm_type} is not implemented yet") 452 | 453 | @staticmethod 454 | def get_scale_class(scale_type, dim, dim_cond): 455 | if scale_type == "none": 456 | return nn.Identity 457 | elif scale_type == "ada_proj": 458 | return partial(AdaptiveScaleProj, dim=dim, dim_cond=dim_cond) 459 | elif scale_type == "ada_embed": 460 | return partial(AdaptiveScaleEmbed, dim=dim) 461 | else: 462 | raise NotImplementedError(f"scale type {scale_type} is not implemented yet") 463 | 464 | @staticmethod 465 | def get_ff_class(ff_type, kernel_size, groups): 466 | if ff_type == "conv": 467 | return partial(ConvFeedForward, kernel_size=kernel_size, groups=groups) 468 | elif ff_type == "linear": 469 | return FeedForward 470 | else: 471 | raise NotImplementedError(f"ff type {ff_type} is not implemented yet") 472 | 473 | def forward( 474 | self, 475 | x: torch.Tensor, 476 | context: Optional[torch.Tensor] = None, 477 | mask: Optional[torch.Tensor] = None, 478 | bias: Optional[torch.Tensor] = None, 479 | cond_input: Dict[str, torch.Tensor] = dict(), 480 | ): 481 | mask = self.expand_mask(mask) 482 | if exists(bias): 483 | assert bias.ndim == 4, f"bias must have 4 dimensions in Transformer, got {bias.ndim}" 484 | 485 | skip_connects = [] 486 | for skip_combiner, attn_norm, attn, attn_scale, ff_norm, ff, ff_scale in self.layers: 487 | if not exists(skip_combiner): 488 | skip_connects.append(x) 489 | else: 490 | skip_connect = skip_connects.pop() 491 | x = torch.cat([x, skip_connect], dim=-1) 492 | x = skip_combiner(x) 493 | residual = x 494 | if self.norm_type == "layer": 495 | x = attn_norm(x) 496 | else: 497 | x = attn_norm(x, cond=cond_input.get("attn_norm_cond", None)) 498 | x = attn(x, context=context, mask=mask, bias=bias) 499 | if self.scale_type != "none": 500 | x = attn_scale(x, cond=cond_input.get("attn_scale_cond", None)) 501 | x = x + residual 502 | 503 | residual = x 504 | if self.norm_type == "layer": 505 | x = ff_norm(x) 506 | else: 507 | x = ff_norm(x, cond=cond_input.get("ff_norm_cond", None)) 508 | x = ff(x) 509 | if self.scale_type != "none": 510 | x = ff_scale(x, cond=cond_input.get("ff_scale_cond", None)) 511 | x = x + residual 512 | 513 | final_output = ( 514 | self.final_norm(x) 515 | if self.norm_type == "layer" 516 | else self.final_norm(x, cond=cond_input.get("final_norm_cond", None)) 517 | ) 518 | return final_output 519 | --------------------------------------------------------------------------------