├── .dockerignore ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── LICENSE ├── README.md ├── dockerfile ├── fish_vocoder ├── configs │ ├── callbacks │ │ ├── default.yaml │ │ ├── early_stopping.yaml │ │ ├── learning_rate_monitor.yaml │ │ ├── model_checkpoint.yaml │ │ ├── model_summary.yaml │ │ ├── none.yaml │ │ └── rich_progress_bar.yaml │ ├── data │ │ ├── dataset │ │ │ └── vocoder-train.yaml │ │ └── vocoder.yaml │ ├── extras │ │ └── default.yaml │ ├── hydra │ │ └── default.yaml │ ├── logger │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── gan.yaml │ │ ├── generator │ │ │ ├── firefly-gan-base.yaml │ │ │ ├── hifigan-vae.yaml │ │ │ ├── hifigan.yaml │ │ │ ├── vocos-huge.yaml │ │ │ ├── vocos-small-vae.yaml │ │ │ ├── vocos-small.yaml │ │ │ └── vocos.yaml │ │ ├── resolution │ │ │ ├── 24000_2048_3072.yaml │ │ │ ├── 24000_256_1024.yaml │ │ │ └── 44100_512_2048.yaml │ │ ├── spectrogram │ │ │ ├── linear.yaml │ │ │ └── mel.yaml │ │ ├── vae.yaml │ │ └── vqvae.yaml │ ├── paths │ │ └── default.yaml │ ├── train.yaml │ └── trainer │ │ └── default.yaml ├── data │ ├── datamodules │ │ └── naive.py │ ├── datasets │ │ ├── mix.py │ │ └── vocoder.py │ └── transforms │ │ ├── crop.py │ │ ├── discontinuous.py │ │ ├── hq_pitch_shift.py │ │ ├── load.py │ │ ├── loudness.py │ │ ├── pad.py │ │ └── spectrogram.py ├── eval.py ├── models │ ├── gan.py │ ├── vae.py │ └── vocoder.py ├── modules │ ├── discriminators │ │ ├── mpd.py │ │ └── mrd.py │ ├── encoders │ │ ├── convnext.py │ │ ├── hubert.py │ │ ├── mms.py │ │ └── posterior_encoder.py │ ├── generators │ │ ├── bigvgan.py │ │ ├── hifigan.py │ │ ├── refinegan.py │ │ ├── unify.py │ │ └── vocos.py │ └── losses │ │ └── stft.py ├── schedulers │ └── warmup_cosine.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── file.py │ ├── grad_norm.py │ ├── instantiators.py │ ├── logger.py │ ├── logging_utils.py │ ├── mask.py │ ├── rich_utils.py │ ├── utils.py │ └── viz.py ├── pdm.lock ├── pyproject.toml └── scripts ├── convert_diffsinger_mel.py ├── random_copy.py ├── test_firefly_gan.sh ├── test_hifigan.sh ├── test_vocos_huge.sh ├── train_convnext_bigvgan_base.sh ├── train_convnext_hifigan_base.sh ├── train_convnext_hifigan_vae.sh ├── train_vocos.sh ├── train_vocos_huge.sh ├── train_vocos_huge_full.sh └── vocos_gen.py /.dockerignore: -------------------------------------------------------------------------------- 1 | logs 2 | checkpoints 3 | .venv 4 | dataset 5 | build 6 | tests 7 | .git 8 | data 9 | build 10 | *.log 11 | other 12 | results 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Text tool 55 | tools/text/create_symbol_dict.py 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | /data 135 | /dataset 136 | .vscode 137 | 138 | *.pt 139 | *.pth 140 | hifigan/model 141 | output 142 | lightning_logs 143 | logs 144 | wandb 145 | *.ckpt 146 | checkpoints 147 | filelists 148 | raw 149 | results 150 | 151 | configs/exp_*.py 152 | exp_*.sh 153 | .DS_Store 154 | .vscode 155 | exported 156 | pitches_editor 157 | .pdm-python 158 | .hydra 159 | .pgx.* 160 | other 161 | filelist.* 162 | run.sh 163 | /generated_* 164 | /generated 165 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: monthly 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.6.0 7 | hooks: 8 | - id: check-yaml 9 | - id: end-of-file-fixer 10 | - id: trailing-whitespace 11 | 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.5.6 14 | hooks: 15 | - id: ruff 16 | args: [ --fix ] 17 | 18 | - repo: https://github.com/psf/black 19 | rev: 24.8.0 20 | hooks: 21 | - id: black 22 | 23 | - repo: https://github.com/codespell-project/codespell 24 | rev: v2.3.0 25 | hooks: 26 | - id: codespell 27 | files: ^.*\.(py|md|rst|yml)$ 28 | args: [-L=fro] 29 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/vocoder/4488f62ca4fefcdcc26998d836605a3d6f5baae5/.project-root -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fish Audio. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fish Vocoder 2 | 3 | This repo is designed as an uniform interface for developing various vocoders. 4 | 5 | Configs: 6 | - [x] hifigan (baseline): HiFiGAN generator with UnivNet discriminators. 7 | - [x] bigvgan: BigVGAN generator. 8 | - [x] vocos: Vocos (ConvNext) generator. 9 | - [x] refinegan: RefineGAN generator. 10 | - [ ] firefly-gan: convnext encoder + hifigan generator. 11 | 12 | ## References 13 | - TIMM: https://github.com/huggingface/pytorch-image-models 14 | - BigVGAN: https://github.com/NVIDIA/BigVGAN 15 | - Vocos: https://github.com/charactr-platform/vocos 16 | - UnivNet: https://github.com/mindslab-ai/univnet 17 | - ConvNext: https://github.com/facebookresearch/ConvNeXt 18 | - HiFiGAN: https://github.com/jik876/hifi-gan 19 | - Fish Diffusion: https://github.com/fishaudio/fish-diffusion 20 | - RefineGAN: https://arxiv.org/abs/2111.00962 21 | - Encodec: https://github.com/facebookresearch/encodec 22 | - EVA-GAN: https://arxiv.org/abs/2402.00892 23 | - ConvNext: https://github.com/facebookresearch/ConvNeXt 24 | -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.06-py3 2 | 3 | # Install system dependencies 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | RUN apt-get update && apt-get install -y git curl build-essential ffmpeg libsm6 libxext6 libjpeg-dev \ 6 | zlib1g-dev aria2 zsh openssh-server sudo && \ 7 | apt-get clean && rm -rf /var/lib/apt/lists/* 8 | 9 | # Install code server and zsh 10 | RUN wget -c https://github.com/coder/code-server/releases/download/v4.5.1/code-server_4.5.1_amd64.deb && \ 11 | dpkg -i ./code-server_4.5.1_amd64.deb && \ 12 | code-server --install-extension ms-python.python && \ 13 | rm ./code-server_4.5.1_amd64.deb && \ 14 | sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended 15 | 16 | # Install s5cmd 17 | RUN curl -L https://github.com/peak/s5cmd/releases/download/v2.1.0-beta.1/s5cmd_2.1.0-beta.1_Linux-64bit.tar.gz | tar xvz -C /tmp && \ 18 | mv /tmp/s5cmd /usr/local/bin/s5cmd && s5cmd --help 19 | 20 | # Install dependencies 21 | WORKDIR /root/exp 22 | COPY pyproject.toml . 23 | RUN pip3 install . 24 | 25 | CMD ["code-server", "--auth", "password", "--bind-addr", "0.0.0.0:8888", "/root/exp"] 26 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - model_summary 4 | - rich_progress_bar 5 | - learning_rate_monitor 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "step_{step:09d}" 11 | save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: -1 # save 5 latest checkpoints 13 | monitor: step # use step to monitor checkpoints 14 | mode: max # save the latest checkpoint with the highest global_step 15 | every_n_epochs: null # don't save checkpoints by epoch end 16 | every_n_train_steps: 20000 # save checkpoints every 5000 steps 17 | auto_insert_metric_name: False 18 | 19 | model_summary: 20 | max_depth: -1 21 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/learning_rate_monitor.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html 2 | 3 | learning_rate_monitor: 4 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 5 | logging_interval: step 6 | log_momentum: False 7 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: null # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: False # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/vocoder/4488f62ca4fefcdcc26998d836605a3d6f5baae5/fish_vocoder/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /fish_vocoder/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /fish_vocoder/configs/data/dataset/vocoder-train.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.data.datasets.vocoder.VocoderDataset 2 | transform: 3 | _target_: torch.nn.Sequential 4 | _args_: 5 | - _target_: fish_vocoder.data.transforms.load.LoadAudio 6 | sampling_rate: ${model.sampling_rate} 7 | - _target_: fish_vocoder.data.transforms.hq_pitch_shift.RandomHQPitchShift 8 | probability: 0.5 9 | sampling_rate: ${model.sampling_rate} 10 | pitch_range: 12 11 | - _target_: fish_vocoder.data.transforms.loudness.RandomLoudness 12 | probability: 0.5 13 | loudness_range: [0.1, 0.9] 14 | - _target_: fish_vocoder.data.transforms.crop.RandomCrop 15 | probability: 1 16 | crop_length: "${eval: '${model.hop_length} * ${model.num_frames}'}" 17 | - _target_: fish_vocoder.data.transforms.pad.Pad 18 | multiple_of: ${model.hop_length} 19 | -------------------------------------------------------------------------------- /fish_vocoder/configs/data/vocoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset@datasets.train.datasets.hifi-8000h.dataset: vocoder-train 3 | - dataset@datasets.train.datasets.vocoder-data-441.dataset: vocoder-train 4 | - dataset@datasets.train.datasets.libritts-train.dataset: vocoder-train 5 | - _self_ 6 | 7 | _target_: fish_vocoder.data.datamodules.naive.NaiveDataModule 8 | 9 | batch_size: 16 10 | val_batch_size: 2 11 | num_workers: 8 12 | 13 | collate_fn: 14 | _target_: fish_vocoder.data.datasets.vocoder.collate_fn 15 | _partial_: true 16 | 17 | datasets: 18 | train: 19 | _target_: fish_vocoder.data.datasets.mix.MixDatast 20 | datasets: 21 | hifi-8000h: 22 | dataset: 23 | root: filelist.hifi-8000h.train 24 | prob: 0.8 25 | vocoder-data-441: 26 | dataset: 27 | root: filelist.vocoder_data_441.train 28 | prob: 0.1 29 | libritts-train: 30 | dataset: 31 | root: filelist.libritts.train 32 | prob: 0.1 33 | 34 | val: 35 | _target_: fish_vocoder.data.datasets.vocoder.VocoderDataset 36 | root: dataset/valid 37 | transform: 38 | _target_: torch.nn.Sequential 39 | _args_: 40 | - _target_: fish_vocoder.data.transforms.load.LoadAudio 41 | sampling_rate: ${model.sampling_rate} 42 | - _target_: fish_vocoder.data.transforms.crop.RandomCrop 43 | probability: 1 44 | crop_length: "${eval: '${model.hop_length} * 1000'}" 45 | - _target_: fish_vocoder.data.transforms.pad.Pad 46 | multiple_of: ${model.hop_length} 47 | -------------------------------------------------------------------------------- /fish_vocoder/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /fish_vocoder/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.run_dir} 11 | -------------------------------------------------------------------------------- /fish_vocoder/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /fish_vocoder/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "fish-vocoder" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/gan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - resolution@_here_: "44100_512_2048" 3 | - generator: hifigan 4 | - spectrogram@mel_transforms.modules.input: mel 5 | - spectrogram@mel_transforms.modules.loss: mel 6 | - _self_ 7 | 8 | _target_: fish_vocoder.models.gan.GANModel 9 | 10 | # While generator is running in large, we need to crop 32 11 | # frames for discriminator to save memory. 12 | num_frames: 128 13 | crop_length: "${eval: '${model.hop_length} * 32'}" 14 | 15 | optimizer: 16 | _target_: torch.optim.AdamW 17 | _partial_: true 18 | lr: 1 19 | betas: [0.8, 0.99] 20 | eps: 1e-6 21 | 22 | lr_scheduler: 23 | _target_: torch.optim.lr_scheduler.LambdaLR 24 | _partial_: true 25 | lr_lambda: 26 | _target_: fish_vocoder.schedulers.warmup_cosine.LambdaWarmUpCosineScheduler 27 | val_base: 1e-4 28 | val_final: 0 29 | max_decay_steps: "${eval: ${trainer.max_steps} // 2}" 30 | 31 | mel_transforms: 32 | _target_: torch.nn.ModuleDict 33 | modules: 34 | input: {} 35 | loss: {} 36 | 37 | generator: {} 38 | 39 | discriminators: 40 | _target_: torch.nn.ModuleDict 41 | modules: 42 | mpd: 43 | _target_: fish_vocoder.modules.discriminators.mpd.MultiPeriodDiscriminator 44 | periods: [3, 5, 7, 11, 17, 23, 37] 45 | 46 | mrd: 47 | _target_: fish_vocoder.modules.discriminators.mrd.MultiResolutionDiscriminator 48 | resolutions: 49 | - ["${model.n_fft}", "${model.hop_length}", "${model.win_length}"] 50 | - [1024, 120, 600] 51 | - [2048, 240, 1200] 52 | - [4096, 480, 2400] 53 | - [512, 50, 240] 54 | 55 | multi_resolution_stft_loss: 56 | _target_: fish_vocoder.modules.losses.stft.MultiResolutionSTFTLoss 57 | resolutions: ${model.discriminators.modules.mrd.resolutions} 58 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/firefly-gan-base.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.modules.generators.unify.UnifyGenerator 2 | backbone: 3 | _target_: fish_vocoder.modules.encoders.convnext.ConvNeXtEncoder 4 | input_channels: ${model.num_mels} 5 | depths: [3, 3, 9, 3] 6 | dims: [128, 256, 384, 512] 7 | drop_path_rate: 0.2 8 | kernel_size: 7 9 | head: 10 | _target_: fish_vocoder.modules.generators.hifigan.HiFiGANGenerator 11 | hop_length: ${model.hop_length} 12 | upsample_rates: [8, 8, 2, 2, 2] 13 | upsample_kernel_sizes: [16, 16, 4, 4, 4] 14 | resblock_kernel_sizes: [3, 7, 11] 15 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 16 | num_mels: 512 # consistent with the output of the backbone 17 | upsample_initial_channel: 512 18 | use_template: false 19 | pre_conv_kernel_size: 13 20 | post_conv_kernel_size: 13 21 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/hifigan-vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.ModuleDict 2 | modules: 3 | encoder: 4 | _target_: fish_vocoder.modules.encoders.hubert.HubertEncoder 5 | model_name: "facebook/hubert-base-ls960" 6 | freeze_backbone: true 7 | output_size: 512 8 | decoder: 9 | _target_: fish_vocoder.modules.generators.hifigan.HiFiGANGenerator 10 | hop_length: 640 # 2*320 at 16kHz, which is 40ms 11 | upsample_rates: [8, 5, 4, 2, 2] # aka. strides 12 | upsample_kernel_sizes: [16, 10, 8, 4, 4] 13 | resblock_kernel_sizes: [3, 7, 11] 14 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 15 | num_mels: 512 16 | upsample_initial_channel: 512 17 | use_template: false 18 | pre_conv_kernel_size: 7 19 | post_conv_kernel_size: 7 20 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/hifigan.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.modules.generators.hifigan.HiFiGANGenerator 2 | hop_length: ${model.hop_length} 3 | upsample_rates: [8, 8, 2, 2, 2] # aka. strides 4 | upsample_kernel_sizes: [16, 16, 8, 2, 2] 5 | resblock_kernel_sizes: [3, 7, 11] 6 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 7 | num_mels: ${model.num_mels} 8 | upsample_initial_channel: 512 9 | use_template: false 10 | pre_conv_kernel_size: 7 11 | post_conv_kernel_size: 7 12 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/vocos-huge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vocos 3 | - _self_ 4 | 5 | backbone: 6 | depths: [3, 3, 27, 3] 7 | dims: [352, 704, 1408, 2816] 8 | drop_path_rate: 0.4 9 | kernel_sizes: [7] 10 | head: 11 | dim: 2816 12 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/vocos-small-vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.ModuleDict 2 | modules: 3 | encoder: 4 | _target_: fish_vocoder.modules.encoders.convnext.ConvNeXtEncoder 5 | input_channels: "${eval: '${model.n_fft} // 2 + 1'}" 6 | depths: [6] 7 | dims: [512] 8 | drop_path_rate: 0.1 9 | kernel_sizes: [7] 10 | decoder: 11 | _target_: fish_vocoder.modules.generators.vocos.VocosGenerator 12 | backbone: 13 | _target_: fish_vocoder.modules.encoders.convnext.ConvNeXtEncoder 14 | input_channels: 512 15 | depths: [6, 3] 16 | dims: [512, 1024] 17 | drop_path_rate: 0.1 18 | kernel_sizes: [7] 19 | head: 20 | _target_: fish_vocoder.modules.generators.vocos.ISTFTHead 21 | dim: "${eval: '${..backbone.dims} [-1]'}" 22 | n_fft: ${model.n_fft} 23 | hop_length: ${model.hop_length} 24 | win_length: ${model.win_length} 25 | padding: same 26 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/vocos-small.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.modules.generators.vocos.VocosGenerator 2 | backbone: 3 | _target_: fish_vocoder.modules.encoders.convnext.ConvNeXtEncoder 4 | input_channels: ${model.num_mels} 5 | depths: [8] 6 | dims: [512] 7 | drop_path_rate: 0.1 8 | kernel_sizes: [7] 9 | head: 10 | _target_: fish_vocoder.modules.generators.vocos.ISTFTHead 11 | dim: 512 12 | n_fft: ${model.n_fft} 13 | hop_length: ${model.hop_length} 14 | win_length: ${model.win_length} 15 | padding: same 16 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/generator/vocos.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.modules.generators.unify.UnifyGenerator 2 | backbone: 3 | _target_: fish_vocoder.modules.encoders.convnext.ConvNeXtEncoder 4 | input_channels: ${model.num_mels} 5 | depths: [3, 3, 27, 3] 6 | dims: [128, 256, 512, 1024] 7 | drop_path_rate: 0.4 8 | kernel_size: 7 9 | head: 10 | _target_: fish_vocoder.modules.generators.vocos.ISTFTHead 11 | dim: 1024 12 | n_fft: ${model.n_fft} 13 | hop_length: ${model.hop_length} 14 | win_length: ${model.win_length} 15 | padding: same 16 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/resolution/24000_2048_3072.yaml: -------------------------------------------------------------------------------- 1 | sampling_rate: 24000 2 | num_mels: 100 3 | n_fft: 3072 4 | hop_length: 2048 5 | win_length: 3072 6 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/resolution/24000_256_1024.yaml: -------------------------------------------------------------------------------- 1 | sampling_rate: 24000 2 | num_mels: 100 3 | n_fft: 1024 4 | hop_length: 256 5 | win_length: 1024 6 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/resolution/44100_512_2048.yaml: -------------------------------------------------------------------------------- 1 | sampling_rate: 44100 2 | num_mels: 128 3 | n_fft: 2048 4 | hop_length: 512 5 | win_length: 2048 6 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/spectrogram/linear.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.data.transforms.spectrogram.LinearSpectrogram 2 | n_fft: ${model.n_fft} 3 | hop_length: ${model.hop_length} 4 | win_length: ${model.win_length} 5 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/spectrogram/mel.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_vocoder.data.transforms.spectrogram.LogMelSpectrogram 2 | n_fft: ${model.n_fft} 3 | hop_length: ${model.hop_length} 4 | win_length: ${model.win_length} 5 | sample_rate: ${model.sampling_rate} 6 | n_mels: ${model.num_mels} 7 | f_min: 0 8 | f_max: "${eval: '${model.sampling_rate} // 2'}" 9 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/vae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gan 3 | - override spectrogram@mel_transforms.modules.input: linear 4 | - _self_ 5 | 6 | _target_: fish_vocoder.models.vae.VAEModel 7 | latent_size: 256 8 | -------------------------------------------------------------------------------- /fish_vocoder/configs/model/vqvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gan 3 | - override resolution@_here_: "24000_2048_3072" 4 | - override generator: vocos-small-vae 5 | - override spectrogram@mel_transforms.modules.input: null 6 | - _self_ 7 | 8 | _target_: fish_vocoder.models.vae.VQVAEModel 9 | codebook_size: 4096 10 | num_quantizers: 1 11 | latent_size: 512 12 | num_frames: 32 13 | # crop_length: "${eval: '${model.hop_length} * 8'}" 14 | 15 | # Reduce discriminator periods to save memory. 16 | discriminators: 17 | modules: 18 | mpd: 19 | periods: [2, 3, 5, 7, 11] 20 | 21 | mrd: 22 | resolutions: 23 | - ["${model.n_fft}", "${model.hop_length}", "${model.win_length}"] 24 | - [1024, 120, 600] 25 | - [2048, 240, 1200] 26 | - [4096, 480, 2400] 27 | -------------------------------------------------------------------------------- /fish_vocoder/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to runs directory 13 | run_dir: ${paths.log_dir}/${task_name} 14 | 15 | # path to ckpt directory 16 | ckpt_dir: ${paths.run_dir}/checkpoints/ 17 | 18 | # path to output directory, created dynamically by hydra 19 | # path generation pattern is specified in `configs/hydra/default.yaml` 20 | # use it to store all files generated during the run, like ckpts and metrics 21 | output_dir: ${hydra:runtime.output_dir} 22 | 23 | # path to working directory 24 | work_dir: ${hydra:runtime.cwd} 25 | -------------------------------------------------------------------------------- /fish_vocoder/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: vocoder 8 | - model: gan 9 | - callbacks: default 10 | - logger: wandb 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | tags: ["dev"] 37 | 38 | # set False to skip model training 39 | train: True 40 | 41 | # evaluate on test set, using best model weights achieved during training 42 | # lightning chooses best weights based on the metric specified in checkpoint callback 43 | test: False 44 | 45 | # compile model for faster training with pytorch 2.0 46 | compile: False 47 | 48 | # simply provide checkpoint path to resume training 49 | ckpt_path: null 50 | resume_weights_only: False 51 | 52 | # seed for random number generators in pytorch, numpy and python.random 53 | seed: 594461 54 | 55 | # Inference Arguments 56 | input_path: null 57 | output_path: null 58 | pitch_shift: 0 59 | -------------------------------------------------------------------------------- /fish_vocoder/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | max_steps: 10_000_000 # 5M steps 5 | 6 | accelerator: cuda 7 | devices: auto 8 | num_nodes: 1 9 | strategy: ddp_find_unused_parameters_true 10 | 11 | # 32-bit precision (tf32 on compatible GPUs) 12 | precision: "32" 13 | 14 | # disable validation by epoch end 15 | check_val_every_n_epoch: null 16 | val_check_interval: 5000 17 | 18 | # set True to to ensure deterministic results 19 | # makes training slower but gives more reproducibility than just setting seeds 20 | deterministic: False 21 | 22 | # Use torch.backends.cudnn.benchmark to speed up training 23 | benchmark: True 24 | -------------------------------------------------------------------------------- /fish_vocoder/data/datamodules/naive.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import lightning as L 4 | from torch.utils.data import DataLoader, Dataset, IterableDataset 5 | 6 | 7 | class NaiveDataModule(L.LightningDataModule): 8 | def __init__( 9 | self, 10 | datasets: Dict[str, Dataset], 11 | batch_size: int, 12 | num_workers: int, 13 | pin_memory: bool = False, 14 | drop_last: bool = False, 15 | persistent_workers: bool = True, 16 | collate_fn: Optional[callable] = None, 17 | train_batch_size: Optional[int] = None, 18 | val_batch_size: Optional[int] = None, 19 | test_batch_size: Optional[int] = None, 20 | ): 21 | super().__init__() 22 | 23 | self.splits: dict[str, Dataset] = datasets 24 | self.batch_size = batch_size 25 | self.train_batch_size = train_batch_size or batch_size 26 | self.val_batch_size = val_batch_size or batch_size 27 | self.test_batch_size = test_batch_size or batch_size 28 | self.num_workers = num_workers 29 | self.pin_memory = pin_memory 30 | self.drop_last = drop_last 31 | self.persistent_workers = persistent_workers 32 | self.collate_fn = collate_fn 33 | 34 | def train_dataloader(self) -> DataLoader: 35 | return DataLoader( 36 | self.splits["train"], 37 | batch_size=self.train_batch_size, 38 | num_workers=self.num_workers, 39 | pin_memory=self.pin_memory, 40 | drop_last=self.drop_last, 41 | persistent_workers=self.persistent_workers, 42 | collate_fn=self.collate_fn, 43 | shuffle=not isinstance(self.splits["train"], IterableDataset), 44 | ) 45 | 46 | def val_dataloader(self) -> Optional[DataLoader]: 47 | return DataLoader( 48 | self.splits["val"], 49 | batch_size=self.val_batch_size, 50 | num_workers=self.num_workers, 51 | pin_memory=self.pin_memory, 52 | drop_last=self.drop_last, 53 | persistent_workers=self.persistent_workers, 54 | collate_fn=self.collate_fn, 55 | shuffle=False, 56 | ) 57 | 58 | def test_dataloader(self) -> Optional[DataLoader]: 59 | # This probably won't be used 60 | 61 | assert self.batch_size == 1, "Batch size must be 1 for test set" 62 | 63 | return DataLoader( 64 | self.splits["test"], 65 | batch_size=self.test_batch_size, 66 | num_workers=self.num_workers, 67 | pin_memory=self.pin_memory, 68 | drop_last=False, 69 | persistent_workers=self.persistent_workers, 70 | collate_fn=self.collate_fn, 71 | shuffle=False, 72 | ) 73 | -------------------------------------------------------------------------------- /fish_vocoder/data/datasets/mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torch.distributed import get_rank, is_initialized 7 | from torch.utils.data import IterableDataset 8 | 9 | 10 | class MixDatast(IterableDataset): 11 | def __init__(self, datasets: dict[str, dict]): 12 | values = list(datasets.values()) 13 | probs = [v["prob"] for v in values] 14 | self.datasets = [v["dataset"] for v in values] 15 | 16 | total_probs = sum(probs) 17 | self.probs = [p / total_probs for p in probs] 18 | 19 | def __iter__(self): 20 | rank = get_rank() if is_initialized() else 0 21 | seed = (42 + rank * 114 + os.getpid() * 514) % 2**32 22 | 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | while True: 29 | # Randomly select a dataset 30 | dataset = random.choices(self.datasets, weights=self.probs)[0] 31 | data = random.choice(dataset) 32 | 33 | yield data 34 | -------------------------------------------------------------------------------- /fish_vocoder/data/datasets/vocoder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | 8 | from fish_vocoder.utils.file import AUDIO_EXTENSIONS, list_files 9 | 10 | 11 | class VocoderDataset(Dataset): 12 | def __init__( 13 | self, 14 | root: str | Path, 15 | transform: Optional[Callable[[Tensor], Tensor]] = None, 16 | ) -> None: 17 | super().__init__() 18 | 19 | assert Path(root).exists(), f"Path {root} does not exist." 20 | assert transform is not None, "transform must be provided." 21 | 22 | root = Path(root) 23 | 24 | if root.is_dir(): 25 | self.audio_paths = list_files(root, AUDIO_EXTENSIONS, recursive=True) 26 | else: 27 | self.audio_paths = root.read_text().splitlines() 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.audio_paths) 33 | 34 | def __getitem__(self, idx): 35 | audio = self.audio_paths[idx] 36 | audio = self.transform(audio) 37 | 38 | # Do normalization to avoid clipping 39 | if audio.abs().max() >= 1.0: 40 | audio /= audio.abs().max() / 0.99 41 | 42 | return { 43 | "audio": audio, 44 | } 45 | 46 | 47 | def collate_fn(batch): 48 | lengths = [b["audio"].shape[-1] for b in batch] 49 | max_len = max(lengths) 50 | 51 | for i, b in enumerate(batch): 52 | pad = max_len - b["audio"].shape[-1] 53 | batch[i]["audio"] = torch.nn.functional.pad(b["audio"], (0, pad)) 54 | 55 | return { 56 | "audio": torch.stack([b["audio"] for b in batch]), 57 | "lengths": torch.tensor(lengths, dtype=torch.long), 58 | } 59 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/crop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | 5 | class RandomCrop(nn.Module): 6 | def __init__( 7 | self, 8 | probability: float = 1.0, 9 | crop_length: int = 44100 * 3, 10 | ) -> None: 11 | super().__init__() 12 | 13 | self.probability = probability 14 | self.crop_length = crop_length 15 | 16 | def forward(self, waveform: Tensor) -> Tensor: 17 | if torch.rand(1) > self.probability: 18 | return waveform 19 | 20 | if waveform.shape[-1] <= self.crop_length: 21 | return waveform 22 | 23 | start_idx = torch.randint(0, waveform.shape[-1] - self.crop_length, (1,)).item() 24 | end_idx = start_idx + self.crop_length 25 | 26 | return waveform[..., start_idx:end_idx] 27 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/discontinuous.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | 5 | class RandomDiscontinuous(nn.Module): 6 | def __init__( 7 | self, 8 | probability: float = 1.0, 9 | silent_range: tuple[float, float] = (0.01, 0.1), 10 | silent_ratio_range: tuple[float, float] = (0.1, 0.2), 11 | sampling_rate: int = 44100, 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.probability = probability 16 | self.sampling_rate = sampling_rate 17 | self.silent_range = ( 18 | int(silent_range[0] * sampling_rate), 19 | int(silent_range[1] * sampling_rate), 20 | ) 21 | self.silent_ratio_range = silent_ratio_range 22 | 23 | def forward(self, waveform: Tensor) -> Tensor: 24 | if torch.rand(1) > self.probability: 25 | return waveform 26 | 27 | current_silent_length = 0 28 | total_silent_length = torch.randint( 29 | int(self.silent_ratio_range[0] * waveform.shape[-1]), 30 | int(self.silent_ratio_range[1] * waveform.shape[-1]), 31 | (1,), 32 | ).item() 33 | 34 | while current_silent_length < total_silent_length: 35 | silent_length = torch.randint(*self.silent_range, (1,)).item() 36 | start_idx = torch.randint( 37 | 0, waveform.shape[-1] - silent_length, (1,) 38 | ).item() 39 | end_idx = start_idx + silent_length 40 | current_silent_length += silent_length 41 | 42 | # 0: all silent, 1: linear fade in and out 43 | silent_mode = torch.randint(0, 2, (1,)).item() 44 | 45 | if silent_mode == 0: 46 | waveform[..., start_idx:end_idx] = 0 47 | elif silent_mode == 1: 48 | waveform[..., start_idx:end_idx] *= torch.cat( 49 | ( 50 | torch.linspace(0, 1, silent_length // 2), 51 | torch.linspace(1, 0, silent_length - silent_length // 2), 52 | ) 53 | ) 54 | 55 | return waveform 56 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/hq_pitch_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio.functional as AF 3 | from torch import Tensor, nn 4 | 5 | 6 | class RandomHQPitchShift(nn.Module): 7 | def __init__( 8 | self, 9 | probability: float = 1.0, 10 | pitch_range: int | tuple[int, int] = 12, 11 | sampling_rate: int = 44100, 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.probability = probability 16 | 17 | if isinstance(pitch_range, int): 18 | pitch_range = (-pitch_range, pitch_range) 19 | 20 | self.pitch_range = pitch_range 21 | self.sampling_rate = sampling_rate 22 | 23 | def forward(self, waveform: Tensor) -> Tensor: 24 | if torch.rand(1) > self.probability: 25 | return waveform 26 | 27 | pitch_shift = torch.randint(*self.pitch_range, (1,)).item() 28 | duration_shift = 2 ** (pitch_shift / 12) 29 | 30 | orig_freq = round(self.sampling_rate * duration_shift) 31 | orig_freq = orig_freq - (orig_freq % 100) # avoid creating lots of windows 32 | 33 | y = AF.resample(waveform, orig_freq=orig_freq, new_freq=self.sampling_rate) 34 | 35 | return y 36 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torch import Tensor, nn 4 | from torchaudio import functional as AF 5 | 6 | 7 | class LoadAudio(nn.Module): 8 | def __init__(self, sampling_rate: int = 44100, to_mono: bool = True): 9 | super().__init__() 10 | 11 | self.sampling_rate = sampling_rate 12 | self.to_mono = to_mono 13 | 14 | def forward(self, audio_path: str) -> Tensor: 15 | try: 16 | audio, sr = torchaudio.load(audio_path) 17 | except Exception: 18 | audio, sr = ( 19 | torch.zeros((self.sample_rate * 10,), dtype=torch.float32), 20 | 44100, 21 | ) 22 | 23 | audio = AF.resample(audio, orig_freq=sr, new_freq=self.sampling_rate) 24 | 25 | # If audio is not mono, convert it to mono 26 | if self.to_mono and audio.shape[0] > 1: 27 | audio = audio.mean(dim=0, keepdim=True) 28 | 29 | return audio 30 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/loudness.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | 5 | class RandomLoudness(nn.Module): 6 | def __init__( 7 | self, probability: float = 1.0, loudness_range: tuple[int, int] = (0.1, 0.9) 8 | ) -> None: 9 | super().__init__() 10 | 11 | self.probability = probability 12 | self.loudness_range = loudness_range 13 | 14 | def forward(self, waveform: Tensor) -> Tensor: 15 | if torch.rand(1) > self.probability: 16 | return waveform 17 | 18 | new_loudness = ( 19 | torch.rand(1).item() * (self.loudness_range[1] - self.loudness_range[0]) 20 | + self.loudness_range[0] 21 | ) 22 | max_loudness = torch.max(torch.abs(waveform)) 23 | waveform = waveform * (new_loudness / (max_loudness + 1e-5)) 24 | 25 | return waveform 26 | 27 | 28 | class LoudnessNorm(nn.Module): 29 | def __init__(self, probability: float = 1.0) -> None: 30 | super().__init__() 31 | 32 | self.probability = probability 33 | 34 | def forward(self, waveform: Tensor) -> Tensor: 35 | if torch.rand(1) > self.probability: 36 | return waveform 37 | 38 | max_loudness = torch.max(torch.abs(waveform)) 39 | waveform = waveform / (max_loudness + 1e-5) 40 | 41 | return waveform 42 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/pad.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor, nn 4 | 5 | 6 | class Pad(nn.Module): 7 | def __init__( 8 | self, 9 | multiple_of: Optional[int] = None, 10 | target_length: Optional[int] = None, 11 | ) -> None: 12 | super().__init__() 13 | 14 | assert ( 15 | multiple_of is not None or target_length is not None 16 | ), "Either multiple_of or target_length must be specified." 17 | assert ( 18 | multiple_of is None or target_length is None 19 | ), "Only one of multiple_of or target_length must be specified." 20 | 21 | self.multiple_of = multiple_of 22 | self.target_length = target_length 23 | 24 | def forward(self, waveform: Tensor) -> Tensor: 25 | if self.multiple_of is not None: 26 | pad = self.multiple_of - (waveform.shape[-1] % self.multiple_of) 27 | 28 | if pad == self.multiple_of: 29 | return waveform 30 | else: 31 | pad = self.target_length - waveform.shape[-1] 32 | 33 | return nn.functional.pad(waveform, (pad // 2, pad - (pad // 2)), "constant") 34 | -------------------------------------------------------------------------------- /fish_vocoder/data/transforms/spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from torchaudio.transforms import MelScale 4 | 5 | 6 | class LinearSpectrogram(nn.Module): 7 | def __init__( 8 | self, 9 | n_fft=2048, 10 | win_length=2048, 11 | hop_length=512, 12 | center=False, 13 | mode="pow2_sqrt", 14 | ): 15 | super().__init__() 16 | 17 | self.n_fft = n_fft 18 | self.win_length = win_length 19 | self.hop_length = hop_length 20 | self.center = center 21 | self.mode = mode 22 | 23 | self.register_buffer("window", torch.hann_window(win_length)) 24 | 25 | def forward(self, y: Tensor) -> Tensor: 26 | if y.ndim == 3: 27 | y = y.squeeze(1) 28 | 29 | y = torch.nn.functional.pad( 30 | y.unsqueeze(1), 31 | ( 32 | (self.win_length - self.hop_length) // 2, 33 | (self.win_length - self.hop_length + 1) // 2, 34 | ), 35 | mode="reflect", 36 | ).squeeze(1) 37 | 38 | spec = torch.stft( 39 | y, 40 | self.n_fft, 41 | hop_length=self.hop_length, 42 | win_length=self.win_length, 43 | window=self.window, 44 | center=self.center, 45 | pad_mode="reflect", 46 | normalized=False, 47 | onesided=True, 48 | return_complex=True, 49 | ) 50 | 51 | spec = torch.view_as_real(spec) 52 | 53 | if self.mode == "pow2_sqrt": 54 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 55 | 56 | return spec 57 | 58 | 59 | class LogMelSpectrogram(nn.Module): 60 | def __init__( 61 | self, 62 | sample_rate=44100, 63 | n_fft=2048, 64 | win_length=2048, 65 | hop_length=512, 66 | n_mels=128, 67 | center=False, 68 | f_min=0.0, 69 | f_max=None, 70 | ): 71 | super().__init__() 72 | 73 | self.sample_rate = sample_rate 74 | self.n_fft = n_fft 75 | self.win_length = win_length 76 | self.hop_length = hop_length 77 | self.center = center 78 | self.n_mels = n_mels 79 | self.f_min = f_min 80 | self.f_max = f_max or sample_rate // 2 81 | 82 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 83 | self.mel_scale = MelScale( 84 | self.n_mels, 85 | self.sample_rate, 86 | self.f_min, 87 | self.f_max, 88 | self.n_fft // 2 + 1, 89 | "slaney", 90 | "slaney", 91 | ) 92 | 93 | def compress(self, x: Tensor) -> Tensor: 94 | return torch.log(torch.clamp(x, min=1e-5)) 95 | 96 | def decompress(self, x: Tensor) -> Tensor: 97 | return torch.exp(x) 98 | 99 | def forward(self, x: Tensor) -> Tensor: 100 | x = self.spectrogram(x) 101 | x = self.mel_scale(x) 102 | x = self.compress(x) 103 | 104 | return x 105 | -------------------------------------------------------------------------------- /fish_vocoder/eval.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | 4 | import click 5 | import librosa 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality 10 | from tqdm import tqdm 11 | 12 | from fish_vocoder.data.transforms.spectrogram import LogMelSpectrogram 13 | 14 | 15 | def pesq_nb(target, preds, sr): 16 | target = torchaudio.functional.resample(target, orig_freq=sr, new_freq=8000) 17 | preds = torchaudio.functional.resample(preds, orig_freq=sr, new_freq=8000) 18 | 19 | return perceptual_evaluation_speech_quality(preds, target, 8000, "nb").item() 20 | 21 | 22 | def pesq_wb(target, preds, sr): 23 | target = torchaudio.functional.resample(target, orig_freq=sr, new_freq=16000) 24 | preds = torchaudio.functional.resample(preds, orig_freq=sr, new_freq=16000) 25 | 26 | return perceptual_evaluation_speech_quality(preds, target, 16000, "wb").item() 27 | 28 | 29 | def spec_difference(spec, target, preds): 30 | target = spec(target[None]) 31 | preds = spec(preds[None]) 32 | 33 | return torch.mean(torch.abs(target - preds)).item() 34 | 35 | 36 | @click.command() 37 | @click.argument("source", type=click.Path(exists=True, dir_okay=True, file_okay=False)) 38 | @click.argument( 39 | "generated", type=click.Path(exists=True, dir_okay=True, file_okay=False) 40 | ) 41 | @click.option("--sr", default=24000) 42 | @click.option("--glob-pattern", default="*.wav") 43 | @click.option("--is-vocal/--is-instrumental", default=True) 44 | def main(source, generated, sr, glob_pattern, is_vocal): 45 | source = Path(source) 46 | generated = Path(generated) 47 | 48 | assert source.is_dir() 49 | assert generated.is_dir() 50 | 51 | source_files = sorted(list(source.rglob(glob_pattern))) 52 | scores = defaultdict(list) 53 | bar = tqdm(source_files) 54 | 55 | mel_spec = LogMelSpectrogram(sr, 1024, 1024, 256, 128, center=False) 56 | 57 | for idx, source_file in enumerate(tqdm(source_files)): 58 | generated_file = generated / source_file.relative_to(source) 59 | 60 | if not generated_file.exists(): 61 | generated_file = generated_file.with_suffix(".flac") 62 | 63 | if not generated_file.exists(): 64 | print(f"{generated_file} does not exist") 65 | continue 66 | 67 | source_audio, _ = librosa.load(source_file, sr=sr) 68 | generated_audio, _ = librosa.load(generated_file, sr=sr) 69 | 70 | min_len = min(len(source_audio), len(generated_audio)) 71 | assert max(len(source_audio) - min_len, len(generated_audio) - min_len) < 1000 72 | 73 | source_audio = source_audio[:min_len] 74 | generated_audio = generated_audio[:min_len] 75 | 76 | source_audio = torch.from_numpy(source_audio) 77 | generated_audio = torch.from_numpy(generated_audio) 78 | 79 | try: 80 | if is_vocal: 81 | scores["pesq_nb"].append(pesq_nb(source_audio, generated_audio, sr)) 82 | scores["pesq_wb"].append(pesq_wb(source_audio, generated_audio, sr)) 83 | 84 | scores["spec_diff"].append( 85 | spec_difference(mel_spec, source_audio, generated_audio) 86 | ) 87 | except Exception: 88 | print(f"Error processing {source_file}") 89 | continue 90 | 91 | if idx % 10 == 0: 92 | all_metrics = [f"{k}: {np.mean(v):.2f}" for k, v in scores.items()] 93 | bar.set_description(", ".join(all_metrics)) 94 | 95 | print("Average scores:") 96 | for k, v in scores.items(): 97 | print(f" {k}: {np.mean(v):.2f}") 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /fish_vocoder/models/gan.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from fish_vocoder.models.vocoder import VocoderModel 8 | from fish_vocoder.modules.losses.stft import MultiResolutionSTFTLoss 9 | from fish_vocoder.utils.grad_norm import grad_norm 10 | from fish_vocoder.utils.mask import sequence_mask 11 | 12 | 13 | class GANModel(VocoderModel): 14 | def __init__( 15 | self, 16 | sampling_rate: int, 17 | n_fft: int, 18 | hop_length: int, 19 | win_length: int, 20 | num_mels: int, 21 | optimizer: Callable, 22 | lr_scheduler: Callable, 23 | mel_transforms: nn.ModuleDict, 24 | generator: nn.Module, 25 | discriminators: nn.ModuleDict, 26 | multi_resolution_stft_loss: MultiResolutionSTFTLoss, 27 | num_frames: int, 28 | crop_length: int | None = None, 29 | ): 30 | super().__init__( 31 | sampling_rate=sampling_rate, 32 | n_fft=n_fft, 33 | hop_length=hop_length, 34 | win_length=win_length, 35 | num_mels=num_mels, 36 | ) 37 | 38 | # Model parameters 39 | self.optimizer_builder = optimizer 40 | self.lr_scheduler_builder = lr_scheduler 41 | 42 | # Spectrogram transforms 43 | self.mel_transforms = mel_transforms 44 | 45 | # Generator and discriminators 46 | self.generator = generator 47 | self.discriminators = discriminators 48 | 49 | # Loss 50 | self.multi_resolution_stft_loss = multi_resolution_stft_loss 51 | 52 | # Crop length for saving memory 53 | self.num_frames = num_frames 54 | self.crop_length = crop_length 55 | 56 | # Disable automatic optimization 57 | self.automatic_optimization = False 58 | 59 | def configure_optimizers(self): 60 | # Need two optimizers and two schedulers 61 | optimizer_generator = self.optimizer_builder(self.generator.parameters()) 62 | optimizer_discriminator = self.optimizer_builder( 63 | self.discriminators.parameters() 64 | ) 65 | 66 | lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator) 67 | lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator) 68 | 69 | return ( 70 | { 71 | "optimizer": optimizer_generator, 72 | "lr_scheduler": { 73 | "scheduler": lr_scheduler_generator, 74 | "interval": "step", 75 | "name": "optimizer/generator", 76 | }, 77 | }, 78 | { 79 | "optimizer": optimizer_discriminator, 80 | "lr_scheduler": { 81 | "scheduler": lr_scheduler_discriminator, 82 | "interval": "step", 83 | "name": "optimizer/discriminator", 84 | }, 85 | }, 86 | ) 87 | 88 | def training_generator(self, audio, audio_mask): 89 | fake_audio, base_loss = self.forward(audio, audio_mask) 90 | 91 | assert fake_audio.shape == audio.shape 92 | 93 | # Apply mask 94 | audio = audio * audio_mask 95 | fake_audio = fake_audio * audio_mask 96 | 97 | # Multi-Resolution STFT Loss 98 | sc_loss, mag_loss = self.multi_resolution_stft_loss( 99 | fake_audio.squeeze(1), audio.squeeze(1) 100 | ) 101 | loss_stft = sc_loss + mag_loss 102 | 103 | self.log( 104 | "train/generator/stft", 105 | loss_stft, 106 | on_step=True, 107 | on_epoch=False, 108 | prog_bar=True, 109 | logger=True, 110 | sync_dist=True, 111 | ) 112 | 113 | # L1 Mel-Spectrogram Loss 114 | # This is not used in backpropagation currently 115 | audio_mel = self.mel_transforms.loss(audio.squeeze(1)) 116 | fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1)) 117 | loss_mel = F.l1_loss(audio_mel, fake_audio_mel) 118 | 119 | self.log( 120 | "train/generator/mel", 121 | loss_mel, 122 | on_step=True, 123 | on_epoch=False, 124 | prog_bar=True, 125 | logger=True, 126 | sync_dist=True, 127 | ) 128 | 129 | # Now, we need to reduce the length of the audio to save memory 130 | if self.crop_length is not None and audio.shape[2] > self.crop_length: 131 | slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,)) 132 | 133 | audio = audio[..., slice_idx : slice_idx + self.crop_length] 134 | fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length] 135 | audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length] 136 | 137 | assert audio.shape == fake_audio.shape == audio_mask.shape 138 | 139 | # Adv Loss 140 | loss_adv_all = 0 141 | 142 | for key, disc in self.discriminators.items(): 143 | score_fakes, feat_fake = disc(fake_audio) 144 | _, feat_real = disc(audio) 145 | 146 | # Adversarial Loss 147 | loss_fake = 0 148 | for score_fake in score_fakes: 149 | loss_fake += torch.mean((1 - score_fake) ** 2) 150 | 151 | # Feature Matching Loss 152 | loss_fm = 0 153 | for dr, dg in zip(feat_real, feat_fake): 154 | for rl, gl in zip(dr, dg): 155 | loss_fm += F.l1_loss(rl, gl) 156 | 157 | self.log( 158 | f"train/generator/adv_{key}", 159 | loss_fake, 160 | on_step=True, 161 | on_epoch=False, 162 | prog_bar=False, 163 | logger=True, 164 | sync_dist=True, 165 | ) 166 | 167 | self.log( 168 | f"train/generator/adv_fm_{key}", 169 | loss_fm, 170 | on_step=True, 171 | on_epoch=False, 172 | prog_bar=False, 173 | logger=True, 174 | sync_dist=True, 175 | ) 176 | 177 | loss_adv_all += loss_fake + loss_fm 178 | 179 | loss_adv_all /= len(self.discriminators) 180 | loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all 181 | 182 | self.log( 183 | "train/generator/all", 184 | loss_gen_all, 185 | on_step=True, 186 | on_epoch=False, 187 | prog_bar=True, 188 | logger=True, 189 | sync_dist=True, 190 | ) 191 | 192 | return loss_gen_all, audio, fake_audio 193 | 194 | def training_discriminator(self, audio, fake_audio): 195 | loss_disc_all = 0 196 | 197 | for key, disc in self.discriminators.items(): 198 | scores, _ = disc(audio) 199 | score_fakes, _ = disc(fake_audio.detach()) 200 | 201 | loss_disc = 0 202 | 203 | for score, score_fake in zip(scores, score_fakes): 204 | loss_disc += torch.mean((score - 1) ** 2) + torch.mean( 205 | (score_fake) ** 2 206 | ) 207 | 208 | self.log( 209 | f"train/discriminator/{key}", 210 | loss_disc, 211 | on_step=True, 212 | on_epoch=False, 213 | prog_bar=False, 214 | logger=True, 215 | sync_dist=True, 216 | ) 217 | 218 | loss_disc_all += loss_disc 219 | 220 | loss_disc_all /= len(self.discriminators) 221 | 222 | self.log( 223 | "train/discriminator/all", 224 | loss_disc_all, 225 | on_step=True, 226 | on_epoch=False, 227 | prog_bar=True, 228 | logger=True, 229 | sync_dist=True, 230 | ) 231 | 232 | return loss_disc_all 233 | 234 | def training_step(self, batch, batch_idx): 235 | optim_g, optim_d = self.optimizers() 236 | 237 | audio, lengths = batch["audio"], batch["lengths"] 238 | audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32) 239 | 240 | # Generator 241 | optim_g.zero_grad() 242 | loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask) 243 | self.manual_backward(loss_gen_all) 244 | 245 | self.log( 246 | "train/generator/grad_norm", 247 | grad_norm(self.generator.parameters()), 248 | on_step=True, 249 | on_epoch=False, 250 | prog_bar=False, 251 | logger=True, 252 | sync_dist=True, 253 | ) 254 | 255 | optim_g.step() 256 | 257 | # Discriminator 258 | assert fake_audio.shape == audio.shape 259 | 260 | optim_d.zero_grad() 261 | loss_disc_all = self.training_discriminator(audio, fake_audio) 262 | self.manual_backward(loss_disc_all) 263 | 264 | for key, disc in self.discriminators.items(): 265 | self.log( 266 | f"train/discriminator/grad_norm_{key}", 267 | grad_norm(disc.parameters()), 268 | on_step=True, 269 | on_epoch=False, 270 | prog_bar=False, 271 | logger=True, 272 | sync_dist=True, 273 | ) 274 | 275 | optim_d.step() 276 | 277 | # Manual LR Scheduler 278 | scheduler_g, scheduler_d = self.lr_schedulers() 279 | scheduler_g.step() 280 | scheduler_d.step() 281 | 282 | def forward(self, audio, mask=None, input_spec=None): 283 | if input_spec is None: 284 | input_spec = self.mel_transforms.input(audio.squeeze(1)) 285 | 286 | fake_audio = self.generator(input_spec) 287 | 288 | return fake_audio, 0 289 | 290 | def validation_step(self, batch: Any, batch_idx: int): 291 | audio, lengths = batch["audio"], batch["lengths"] 292 | audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32) 293 | 294 | # Generator 295 | fake_audio, _ = self.forward(audio, audio_mask) 296 | assert fake_audio.shape == audio.shape 297 | 298 | # Apply mask 299 | audio = audio * audio_mask 300 | fake_audio = fake_audio * audio_mask 301 | 302 | # L1 Mel-Spectrogram Loss 303 | audio_mel = self.mel_transforms.loss(audio.squeeze(1)) 304 | fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1)) 305 | loss_mel = F.l1_loss(audio_mel, fake_audio_mel) 306 | 307 | self.log( 308 | "val/metrics/mel", 309 | loss_mel, 310 | on_step=False, 311 | on_epoch=True, 312 | prog_bar=True, 313 | logger=True, 314 | sync_dist=True, 315 | ) 316 | 317 | # Report other metrics 318 | self.report_val_metrics(fake_audio, audio, lengths) 319 | -------------------------------------------------------------------------------- /fish_vocoder/models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from encodec.quantization.core_vq import ResidualVectorQuantization, VectorQuantization 3 | 4 | from fish_vocoder.models.gan import GANModel 5 | 6 | 7 | class VAEModel(GANModel): 8 | def __init__(self, latent_size: int, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | self.latent_size = latent_size 12 | 13 | def forward(self, audio, mask): 14 | input_spec = self.mel_transforms.input(audio.squeeze(1)) 15 | 16 | latent = self.generator.encoder(input_spec) 17 | mean, logvar = torch.chunk(latent, 2, dim=1) 18 | z = self.reparameterize(mean, logvar) 19 | fake_audio = self.generator.decoder(z) 20 | 21 | kl_loss = self.kl_loss(mean, logvar) 22 | 23 | self.log( 24 | "train/generator/kl", 25 | kl_loss, 26 | on_step=True, 27 | on_epoch=False, 28 | prog_bar=True, 29 | logger=True, 30 | sync_dist=True, 31 | ) 32 | 33 | return fake_audio, kl_loss 34 | 35 | def reparameterize(self, mean, logvar): 36 | if self.training: 37 | std = torch.exp(0.5 * logvar) 38 | eps = torch.randn_like(std) 39 | 40 | return mean + eps * std 41 | 42 | return mean 43 | 44 | @staticmethod 45 | def kl_loss(mean, logvar): 46 | # B, D, T -> B, 1, T 47 | losses = 0.5 * (mean**2 + torch.exp(logvar) - logvar - 1) 48 | return losses.mean() 49 | 50 | 51 | class VQVAEModel(GANModel): 52 | def __init__( 53 | self, 54 | latent_size: int, 55 | codebook_size: int, 56 | num_quantizers: int = 1, 57 | *args, 58 | **kwargs, 59 | ): 60 | super().__init__(*args, **kwargs) 61 | 62 | self.latent_size = latent_size 63 | self.codebook_size = codebook_size 64 | self.num_quantizers = num_quantizers 65 | 66 | if num_quantizers > 1: 67 | self.vq = ResidualVectorQuantization( 68 | dim=latent_size, 69 | codebook_size=codebook_size, 70 | num_quantizers=num_quantizers, 71 | kmeans_init=False, 72 | ) 73 | else: 74 | self.vq = VectorQuantization( 75 | dim=latent_size, 76 | codebook_size=codebook_size, 77 | kmeans_init=False, 78 | ) 79 | 80 | def forward(self, audio, mask, input_spec=None): 81 | latent = self.generator.encoder(audio, mask) 82 | quantize, _, vq_loss = self.vq(latent) 83 | 84 | if self.num_quantizers > 1: 85 | vq_loss = vq_loss.mean() 86 | 87 | fake_audio = self.generator.decoder(quantize) 88 | 89 | assert abs(fake_audio.size(2) - audio.size(2)) <= self.hop_length 90 | 91 | if fake_audio.size(2) > audio.size(2): 92 | fake_audio = fake_audio[:, :, : audio.size(2)] 93 | else: 94 | fake_audio = torch.nn.functional.pad( 95 | fake_audio, (0, audio.size(2) - fake_audio.size(2)) 96 | ) 97 | 98 | stage = "train" if self.training else "val" 99 | self.log( 100 | f"{stage}/generator/vq", 101 | vq_loss, 102 | on_step=True, 103 | on_epoch=False, 104 | prog_bar=True, 105 | logger=True, 106 | sync_dist=True, 107 | ) 108 | 109 | return fake_audio, 0 # vq_loss * 5 110 | -------------------------------------------------------------------------------- /fish_vocoder/models/vocoder.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | import wandb 4 | from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger 5 | from matplotlib import pyplot as plt 6 | from torchaudio.functional import resample 7 | from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality 8 | 9 | from fish_vocoder.data.transforms.spectrogram import LogMelSpectrogram 10 | from fish_vocoder.utils.viz import plot_mel 11 | 12 | 13 | class VocoderModel(L.LightningModule): 14 | def __init__( 15 | self, 16 | sampling_rate: int, 17 | n_fft: int, 18 | hop_length: int, 19 | win_length: int, 20 | num_mels: int, 21 | ): 22 | super().__init__() 23 | 24 | # Base parameters 25 | self.sampling_rate = sampling_rate 26 | self.n_fft = n_fft 27 | self.hop_length = hop_length 28 | self.win_length = win_length 29 | self.num_mels = num_mels 30 | 31 | # Mel-Spectrogram for visualization 32 | self.viz_mel_transform = LogMelSpectrogram( 33 | sample_rate=self.sampling_rate, 34 | n_fft=self.n_fft, 35 | win_length=self.win_length, 36 | hop_length=self.hop_length, 37 | n_mels=self.num_mels, 38 | ) 39 | 40 | @torch.no_grad() 41 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 42 | def pesq(self, y_hat, y, sr=16000): 43 | y_hat = resample(y_hat, self.sampling_rate, sr) 44 | y = resample(y, self.sampling_rate, sr) 45 | 46 | return perceptual_evaluation_speech_quality(y_hat, y, sr, "wb").mean() 47 | 48 | @torch.no_grad() 49 | def report_val_metrics(self, y_g_hat, y, lengths): 50 | # PESQ 51 | pesq = self.pesq(y_g_hat, y) 52 | 53 | self.log( 54 | "val/metrics/pesq", 55 | pesq, 56 | on_step=False, 57 | on_epoch=True, 58 | prog_bar=False, 59 | logger=True, 60 | sync_dist=True, 61 | ) 62 | 63 | # Mel-Spectrogram 64 | y_mel = self.viz_mel_transform(y.squeeze(1)) 65 | y_g_hat_mel = self.viz_mel_transform(y_g_hat.squeeze(1)) 66 | 67 | for idx, (mel, gen_mel, audio, gen_audio, audio_len) in enumerate( 68 | zip(y_mel, y_g_hat_mel, y.detach().cpu(), y_g_hat.detach().cpu(), lengths) 69 | ): 70 | mel_len = audio_len // self.hop_length 71 | 72 | image_mels = plot_mel( 73 | [ 74 | gen_mel[:, :mel_len], 75 | mel[:, :mel_len], 76 | ], 77 | ["Sampled Spectrogram", "Ground-Truth Spectrogram"], 78 | ) 79 | 80 | if isinstance(self.logger, WandbLogger): 81 | self.logger.experiment.log( 82 | { 83 | "reconstruction_mel": wandb.Image(image_mels, caption="mels"), 84 | "wavs": [ 85 | wandb.Audio( 86 | audio[0, :audio_len], 87 | sample_rate=self.sampling_rate, 88 | caption="gt", 89 | ), 90 | wandb.Audio( 91 | gen_audio[0, :audio_len], 92 | sample_rate=self.sampling_rate, 93 | caption="prediction", 94 | ), 95 | ], 96 | }, 97 | ) 98 | 99 | if isinstance(self.logger, TensorBoardLogger): 100 | self.logger.experiment.add_figure( 101 | f"sample-{idx}/mels", 102 | image_mels, 103 | global_step=self.global_step, 104 | ) 105 | self.logger.experiment.add_audio( 106 | f"sample-{idx}/wavs/gt", 107 | audio[0, :audio_len], 108 | self.global_step, 109 | sample_rate=self.sampling_rate, 110 | ) 111 | self.logger.experiment.add_audio( 112 | f"sample-{idx}/wavs/prediction", 113 | gen_audio[0, :audio_len], 114 | self.global_step, 115 | sample_rate=self.sampling_rate, 116 | ) 117 | 118 | plt.close(image_mels) 119 | -------------------------------------------------------------------------------- /fish_vocoder/modules/discriminators/mpd.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.parametrizations import spectral_norm, weight_norm 7 | 8 | 9 | class DiscriminatorP(nn.Module): 10 | def __init__( 11 | self, 12 | *, 13 | period: int, 14 | kernel_size: int = 5, 15 | stride: int = 3, 16 | use_spectral_norm: bool = False, 17 | channels: Optional[list[int]] = None, 18 | ) -> None: 19 | super(DiscriminatorP, self).__init__() 20 | 21 | self.period = period 22 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 23 | 24 | if channels is None: 25 | channels = [1, 64, 128, 256, 512, 1024] 26 | 27 | self.convs = nn.ModuleList( 28 | [ 29 | norm_f( 30 | nn.Conv2d( 31 | in_channels, 32 | out_channels, 33 | (kernel_size, 1), 34 | (stride, 1), 35 | padding=(kernel_size // 2, 0), 36 | ) 37 | ) 38 | for in_channels, out_channels in zip(channels[:-1], channels[1:]) 39 | ] 40 | ) 41 | self.conv_post = norm_f(nn.Conv2d(channels[-1], 1, (3, 1), 1, padding=(1, 0))) 42 | 43 | def forward(self, x): 44 | fmap = [] 45 | 46 | # 1d to 2d 47 | b, c, t = x.shape 48 | if t % self.period != 0: # pad first 49 | n_pad = self.period - (t % self.period) 50 | x = F.pad(x, (0, n_pad), "constant") 51 | t = t + n_pad 52 | x = x.view(b, c, t // self.period, self.period) 53 | 54 | for conv in self.convs: 55 | x = conv(x) 56 | x = F.silu(x, inplace=True) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | x = torch.flatten(x, 1, -1) 62 | 63 | return x, fmap 64 | 65 | 66 | class MultiPeriodDiscriminator(nn.Module): 67 | def __init__(self, periods: Optional[list[int]] = None): 68 | super().__init__() 69 | 70 | if periods is None: 71 | periods = [2, 3, 5, 7, 11] 72 | 73 | self.discriminators = nn.ModuleList( 74 | [DiscriminatorP(period=period) for period in periods] 75 | ) 76 | 77 | def forward( 78 | self, x: torch.Tensor 79 | ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: 80 | scores, feature_map = [], [] 81 | 82 | for disc in self.discriminators: 83 | res, fmap = disc(x) 84 | 85 | scores.append(res) 86 | feature_map.append(fmap) 87 | 88 | return scores, feature_map 89 | -------------------------------------------------------------------------------- /fish_vocoder/modules/discriminators/mrd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.parametrizations import spectral_norm, weight_norm 5 | 6 | 7 | class DiscriminatorR(torch.nn.Module): 8 | def __init__( 9 | self, 10 | *, 11 | n_fft: int = 1024, 12 | hop_length: int = 120, 13 | win_length: int = 600, 14 | use_spectral_norm: bool = False, 15 | ): 16 | super(DiscriminatorR, self).__init__() 17 | 18 | self.n_fft = n_fft 19 | self.hop_length = hop_length 20 | self.win_length = win_length 21 | 22 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 23 | 24 | self.convs = nn.ModuleList( 25 | [ 26 | norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), 27 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 28 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 29 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 30 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 31 | ] 32 | ) 33 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 34 | 35 | def forward(self, x): 36 | fmap = [] 37 | 38 | x = self.spectrogram(x) 39 | x = x.unsqueeze(1) 40 | 41 | for conv in self.convs: 42 | x = conv(x) 43 | x = F.silu(x, inplace=True) 44 | fmap.append(x) 45 | 46 | x = self.conv_post(x) 47 | fmap.append(x) 48 | x = torch.flatten(x, 1, -1) 49 | 50 | return x, fmap 51 | 52 | def spectrogram(self, x): 53 | x = F.pad( 54 | x, 55 | ( 56 | (self.n_fft - self.hop_length) // 2, 57 | (self.n_fft - self.hop_length + 1) // 2, 58 | ), 59 | mode="reflect", 60 | ) 61 | x = x.squeeze(1) 62 | x = torch.stft( 63 | x, 64 | n_fft=self.n_fft, 65 | hop_length=self.hop_length, 66 | win_length=self.win_length, 67 | center=False, 68 | return_complex=True, 69 | ) 70 | x = torch.view_as_real(x) # [B, F, TT, 2] 71 | mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] 72 | 73 | return mag 74 | 75 | 76 | class MultiResolutionDiscriminator(torch.nn.Module): 77 | def __init__(self, resolutions: list[tuple[int]]): 78 | super().__init__() 79 | 80 | self.discriminators = nn.ModuleList( 81 | [ 82 | DiscriminatorR( 83 | n_fft=n_fft, hop_length=hop_length, win_length=win_length 84 | ) 85 | for n_fft, hop_length, win_length in resolutions 86 | ] 87 | ) 88 | 89 | def forward( 90 | self, x: torch.Tensor 91 | ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: 92 | scores, feature_map = [], [] 93 | 94 | for disc in self.discriminators: 95 | res, fmap = disc(x) 96 | 97 | scores.append(res) 98 | feature_map.append(fmap) 99 | 100 | scores = torch.cat(scores, dim=1) 101 | 102 | return scores, feature_map 103 | -------------------------------------------------------------------------------- /fish_vocoder/modules/encoders/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | # DropPath copied from timm library 7 | def drop_path( 8 | x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True 9 | ): 10 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 11 | 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | 18 | """ # noqa: E501 19 | 20 | if drop_prob == 0.0 or not training: 21 | return x 22 | keep_prob = 1 - drop_prob 23 | shape = (x.shape[0],) + (1,) * ( 24 | x.ndim - 1 25 | ) # work with diff dim tensors, not just 2D ConvNets 26 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 27 | if keep_prob > 0.0 and scale_by_keep: 28 | random_tensor.div_(keep_prob) 29 | return x * random_tensor 30 | 31 | 32 | class DropPath(nn.Module): 33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 34 | 35 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): 36 | super(DropPath, self).__init__() 37 | self.drop_prob = drop_prob 38 | self.scale_by_keep = scale_by_keep 39 | 40 | def forward(self, x): 41 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 42 | 43 | def extra_repr(self): 44 | return f"drop_prob={round(self.drop_prob,3):0.3f}" 45 | 46 | 47 | class LayerNorm(nn.Module): 48 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 49 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 50 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 51 | with shape (batch_size, channels, height, width). 52 | """ # noqa: E501 53 | 54 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 55 | super().__init__() 56 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 57 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 58 | self.eps = eps 59 | self.data_format = data_format 60 | if self.data_format not in ["channels_last", "channels_first"]: 61 | raise NotImplementedError 62 | self.normalized_shape = (normalized_shape,) 63 | 64 | def forward(self, x): 65 | if self.data_format == "channels_last": 66 | return F.layer_norm( 67 | x, self.normalized_shape, self.weight, self.bias, self.eps 68 | ) 69 | elif self.data_format == "channels_first": 70 | u = x.mean(1, keepdim=True) 71 | s = (x - u).pow(2).mean(1, keepdim=True) 72 | x = (x - u) / torch.sqrt(s + self.eps) 73 | x = self.weight[:, None] * x + self.bias[:, None] 74 | return x 75 | 76 | 77 | # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py 78 | class ConvNeXtBlock(nn.Module): 79 | r"""ConvNeXt Block. There are two equivalent implementations: 80 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 81 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 82 | We use (2) as we find it slightly faster in PyTorch 83 | 84 | Args: 85 | dim (int): Number of input channels. 86 | drop_path (float): Stochastic depth rate. Default: 0.0 87 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 88 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. 89 | kernel_size (int): Kernel size for depthwise conv. Default: 7. 90 | dilation (int): Dilation for depthwise conv. Default: 1. 91 | """ # noqa: E501 92 | 93 | def __init__( 94 | self, 95 | dim: int, 96 | drop_path: float = 0.0, 97 | layer_scale_init_value: float = 1e-6, 98 | mlp_ratio: float = 4.0, 99 | kernel_size: int = 7, 100 | dilation: int = 1, 101 | ): 102 | super().__init__() 103 | 104 | self.dwconv = nn.Conv1d( 105 | dim, 106 | dim, 107 | kernel_size=kernel_size, 108 | padding=int(dilation * (kernel_size - 1) / 2), 109 | groups=dim, 110 | ) # depthwise conv 111 | self.norm = LayerNorm(dim, eps=1e-6) 112 | self.pwconv1 = nn.Linear( 113 | dim, int(mlp_ratio * dim) 114 | ) # pointwise/1x1 convs, implemented with linear layers 115 | self.act = nn.GELU() 116 | self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) 117 | self.gamma = ( 118 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 119 | if layer_scale_init_value > 0 120 | else None 121 | ) 122 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 123 | 124 | def forward(self, x, apply_residual: bool = True): 125 | input = x 126 | 127 | x = self.dwconv(x) 128 | x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) 129 | x = self.norm(x) 130 | x = self.pwconv1(x) 131 | x = self.act(x) 132 | x = self.pwconv2(x) 133 | 134 | if self.gamma is not None: 135 | x = self.gamma * x 136 | 137 | x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) 138 | x = self.drop_path(x) 139 | 140 | if apply_residual: 141 | x = input + x 142 | 143 | return x 144 | 145 | 146 | class ConvNeXtEncoder(nn.Module): 147 | def __init__( 148 | self, 149 | input_channels: int = 3, 150 | depths: list[int] = [3, 3, 9, 3], 151 | dims: list[int] = [96, 192, 384, 768], 152 | drop_path_rate: float = 0.0, 153 | layer_scale_init_value: float = 1e-6, 154 | kernel_size: int = 7, 155 | ): 156 | super().__init__() 157 | assert len(depths) == len(dims) 158 | 159 | self.downsample_layers = nn.ModuleList() 160 | stem = nn.Sequential( 161 | nn.Conv1d( 162 | input_channels, 163 | dims[0], 164 | kernel_size=kernel_size, 165 | padding=kernel_size // 2, 166 | padding_mode="zeros", 167 | ), 168 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), 169 | ) 170 | self.downsample_layers.append(stem) 171 | 172 | for i in range(len(depths) - 1): 173 | mid_layer = nn.Sequential( 174 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 175 | nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), 176 | ) 177 | self.downsample_layers.append(mid_layer) 178 | 179 | self.stages = nn.ModuleList() 180 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 181 | 182 | cur = 0 183 | for i in range(len(depths)): 184 | stage = nn.Sequential( 185 | *[ 186 | ConvNeXtBlock( 187 | dim=dims[i], 188 | drop_path=dp_rates[cur + j], 189 | layer_scale_init_value=layer_scale_init_value, 190 | kernel_size=kernel_size, 191 | ) 192 | for j in range(depths[i]) 193 | ] 194 | ) 195 | self.stages.append(stage) 196 | cur += depths[i] 197 | 198 | self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, m): 202 | if isinstance(m, (nn.Conv1d, nn.Linear)): 203 | nn.init.trunc_normal_(m.weight, std=0.02) 204 | nn.init.constant_(m.bias, 0) 205 | 206 | def forward( 207 | self, 208 | x: torch.Tensor, 209 | ) -> torch.Tensor: 210 | for i in range(len(self.downsample_layers)): 211 | x = self.downsample_layers[i](x) 212 | x = self.stages[i](x) 213 | 214 | return self.norm(x) 215 | -------------------------------------------------------------------------------- /fish_vocoder/modules/encoders/hubert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import HubertModel 4 | 5 | 6 | class HubertEncoder(nn.Module): 7 | def __init__( 8 | self, 9 | model_name: str = "facebook/hubert-large-ll60k", 10 | freeze_backbone: bool = True, 11 | output_size: int = 1024, 12 | ): 13 | super().__init__() 14 | 15 | self.model = HubertModel.from_pretrained(model_name) 16 | self.freeze_backbone = freeze_backbone 17 | 18 | if self.freeze_backbone: 19 | for param in self.model.parameters(): 20 | param.requires_grad = False 21 | 22 | self.post = nn.Sequential( 23 | nn.Conv1d( 24 | self.model.config.hidden_size, output_size, kernel_size=3, padding=1 25 | ), 26 | nn.SiLU(), 27 | nn.Conv1d(output_size, output_size, stride=2, kernel_size=3, padding=1), 28 | nn.SiLU(), 29 | nn.Conv1d(output_size, output_size, kernel_size=1), 30 | ) 31 | 32 | def forward( 33 | self, 34 | x: torch.Tensor, 35 | mask: torch.Tensor = None, 36 | ) -> torch.Tensor: 37 | if x.ndim == 3: 38 | assert x.shape[1] == 1 and mask.shape[1] == 1 39 | x = x.squeeze(1) 40 | mask = mask.squeeze(1) 41 | 42 | if self.freeze_backbone: 43 | with torch.no_grad(): 44 | x = self.model(x, attention_mask=mask) 45 | else: 46 | x = self.model(x, attention_mask=mask) 47 | 48 | x = x.last_hidden_state.transpose(1, 2) 49 | x = self.post(x) 50 | 51 | return x 52 | 53 | 54 | if __name__ == "__main__": 55 | model = HubertEncoder() 56 | x = torch.randn(1, 16000) 57 | y = model(x) 58 | print(y.shape) 59 | -------------------------------------------------------------------------------- /fish_vocoder/modules/encoders/mms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torch import nn 4 | from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model 5 | 6 | 7 | class MMSEncoder(nn.Module): 8 | def __init__(self, sampling_rate: int = 44100, hop_length: int = 512) -> None: 9 | super().__init__() 10 | 11 | self.sample_rate = sampling_rate 12 | self.hop_length = hop_length 13 | self.processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/mms-300m") 14 | self.model = Wav2Vec2Model.from_pretrained("facebook/mms-300m") 15 | 16 | @torch.no_grad() 17 | def forward(self, x): 18 | num_frames = x.shape[-1] // self.hop_length 19 | 20 | x = torchaudio.functional.resample( 21 | x, orig_freq=self.sample_rate, new_freq=16000 22 | ) 23 | x = [i.cpu().numpy() for i in x] 24 | input_values = self.processor( 25 | x, return_tensors="pt", padding=True, sampling_rate=16000 26 | ).input_values 27 | input_values = input_values.to(self.model.device) 28 | 29 | x = self.model(input_values).last_hidden_state 30 | x = x.transpose(1, 2) 31 | x = torch.functional.F.interpolate(x, size=num_frames, mode="nearest") 32 | 33 | return x 34 | -------------------------------------------------------------------------------- /fish_vocoder/modules/encoders/posterior_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from fish_vocoder.utils.mask import sequence_mask 9 | 10 | 11 | class WaveNet(nn.Module): 12 | def __init__( 13 | self, 14 | hidden_channels, 15 | kernel_size, 16 | dilation_rate, 17 | dilation_cycle, 18 | n_layers, 19 | p_dropout=0, 20 | ): 21 | super().__init__() 22 | 23 | assert kernel_size % 2 == 1 24 | self.hidden_channels = hidden_channels 25 | self.kernel_size = (kernel_size,) 26 | self.dilation_rate = dilation_rate 27 | self.n_layers = n_layers 28 | self.p_dropout = p_dropout 29 | 30 | self.in_layers = nn.ModuleList() 31 | self.res_skip_layers = nn.ModuleList() 32 | self.drop = nn.Dropout(p_dropout) 33 | 34 | for i in range(n_layers): 35 | dilation = dilation_rate ** (i % dilation_cycle) 36 | padding = int((kernel_size * dilation - dilation) / 2) 37 | in_layer = nn.Conv1d( 38 | hidden_channels, 39 | 2 * hidden_channels, 40 | kernel_size, 41 | dilation=dilation, 42 | padding=padding, 43 | ) 44 | in_layer = nn.utils.weight_norm(in_layer, name="weight") 45 | self.in_layers.append(in_layer) 46 | 47 | # last one is not necessary 48 | if i < n_layers - 1: 49 | res_skip_channels = 2 * hidden_channels 50 | else: 51 | res_skip_channels = hidden_channels 52 | 53 | res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) 54 | res_skip_layer = nn.utils.weight_norm(res_skip_layer, name="weight") 55 | self.res_skip_layers.append(res_skip_layer) 56 | 57 | def forward(self, x, x_mask): 58 | output = torch.zeros_like(x) 59 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 60 | 61 | for i in range(self.n_layers): 62 | x_in = self.in_layers[i](x) 63 | 64 | # Gate 65 | in_act = x_in 66 | t_act = torch.tanh(in_act[:, : n_channels_tensor[0], :]) 67 | s_act = torch.sigmoid(in_act[:, n_channels_tensor[0] :, :]) 68 | acts = t_act * s_act 69 | 70 | acts = self.drop(acts) 71 | 72 | res_skip_acts = self.res_skip_layers[i](acts) 73 | if i < self.n_layers - 1: 74 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 75 | x = (x + res_acts) * x_mask 76 | output = output + res_skip_acts[:, self.hidden_channels :, :] 77 | else: 78 | output = output + res_skip_acts 79 | 80 | return output * x_mask 81 | 82 | def remove_weight_norm(self): 83 | if self.gin_channels != 0: 84 | nn.utils.remove_weight_norm(self.cond_layer) 85 | for layer in self.in_layers: 86 | nn.utils.remove_weight_norm(layer) 87 | for layer in self.res_skip_layers: 88 | nn.utils.remove_weight_norm(layer) 89 | 90 | 91 | class PosteriorEncoder(nn.Module): 92 | def __init__( 93 | self, 94 | in_channels, 95 | out_channels, 96 | hidden_channels, 97 | kernel_size=5, 98 | dilation_rate=1, 99 | dilation_cycle=1, 100 | n_layers=16, 101 | mode="vqvae", 102 | ): 103 | super().__init__() 104 | 105 | assert mode in ["vae", "bnvae", "vqvae"] 106 | 107 | self.mode = mode 108 | self.in_channels = in_channels 109 | self.out_channels = out_channels 110 | self.hidden_channels = hidden_channels 111 | self.kernel_size = kernel_size 112 | self.dilation_rate = dilation_rate 113 | self.dilation_cycle = dilation_cycle 114 | self.n_layers = n_layers 115 | 116 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 117 | self.enc = WaveNet( 118 | hidden_channels, 119 | kernel_size=kernel_size, 120 | dilation_rate=dilation_rate, 121 | dilation_cycle=dilation_cycle, 122 | n_layers=n_layers, 123 | ) 124 | self.proj = nn.Conv1d( 125 | hidden_channels, out_channels * 2 if mode != "vqvae" else out_channels, 1 126 | ) 127 | 128 | if mode == "bnvae": 129 | self.mu_bn = nn.BatchNorm1d(out_channels, affine=True) 130 | self.mu_bn.weight.requires_grad = False 131 | self.mu_bn.weight.fill_(0.5) 132 | 133 | def forward(self, x, x_lengths=None): 134 | if x_lengths is not None: 135 | x_mask = sequence_mask(x_lengths, x.size(2))[:, None].to( 136 | device=x.device, dtype=x.dtype 137 | ) 138 | else: 139 | x_mask = torch.ones(x.size(0), 1, x.size(2), dtype=x.dtype, device=x.device) 140 | 141 | x = self.pre(x) * x_mask 142 | x = self.enc(x, x_mask) 143 | x = self.proj(x) * x_mask 144 | 145 | if self.mode in ["bnvae", "vae"]: 146 | mean, logvar = torch.split(x, self.out_channels, dim=1) 147 | logvar = torch.clamp(logvar, min=-30, max=20) 148 | 149 | if self.mode == "bnvae": 150 | mean = self.mu_bn(mean) 151 | 152 | if self.training: 153 | z = (mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)) * x_mask 154 | else: 155 | z = mean * x_mask 156 | 157 | return z, mean, logvar, x_mask 158 | 159 | elif self.mode == "vqvae": 160 | return x 161 | 162 | def remove_weight_norm(self): 163 | self.enc.remove_weight_norm() 164 | -------------------------------------------------------------------------------- /fish_vocoder/modules/generators/bigvgan.py: -------------------------------------------------------------------------------- 1 | # BigVGAN Adapted from https://github.com/NVIDIA/BigVGAN under the MIT license. 2 | # SNAKE Adapted from https://github.com/EdwardDixon/snake under the MIT license. 3 | 4 | from math import prod 5 | from typing import Callable 6 | 7 | import numpy as np 8 | import torch 9 | from alias_free_torch import Activation1d 10 | from torch import nn, pow, sin 11 | from torch.nn import Conv1d, Parameter 12 | from torch.nn.utils.parametrizations import weight_norm 13 | from torch.nn.utils.parametrize import remove_parametrizations 14 | 15 | from .hifigan import get_padding, init_weights 16 | 17 | 18 | class Snake(nn.Module): 19 | """ 20 | Implementation of a sine-based periodic activation function 21 | Shape: 22 | - Input: (B, C, T) 23 | - Output: (B, C, T), same shape as the input 24 | Parameters: 25 | - alpha - trainable parameter 26 | References: 27 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 28 | https://arxiv.org/abs/2006.08195 29 | Examples: 30 | >>> a1 = snake(256) 31 | >>> x = torch.randn(256) 32 | >>> x = a1(x) 33 | """ # noqa: E501 34 | 35 | def __init__( 36 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False 37 | ): 38 | """ 39 | Initialization. 40 | INPUT: 41 | - in_features: shape of the input 42 | - alpha: trainable parameter 43 | alpha is initialized to 1 by default, higher values = higher-frequency. 44 | alpha will be trained along with the rest of your model. 45 | """ 46 | super(Snake, self).__init__() 47 | self.in_features = in_features 48 | 49 | # initialize alpha 50 | self.alpha_logscale = alpha_logscale 51 | if self.alpha_logscale: # log scale alphas initialized to zeros 52 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 53 | else: # linear scale alphas initialized to ones 54 | self.alpha = Parameter(torch.ones(in_features) * alpha) 55 | 56 | self.alpha.requires_grad = alpha_trainable 57 | 58 | self.no_div_by_zero = 0.000000001 59 | 60 | def forward(self, x): 61 | """ 62 | Forward pass of the function. 63 | Applies the function to the input elementwise. 64 | Snake ∶= x + 1/a * sin^2 (xa) 65 | """ 66 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 67 | if self.alpha_logscale: 68 | alpha = torch.exp(alpha) 69 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 70 | 71 | return x 72 | 73 | 74 | class SnakeBeta(nn.Module): 75 | """ 76 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 77 | Shape: 78 | - Input: (B, C, T) 79 | - Output: (B, C, T), same shape as the input 80 | Parameters: 81 | - alpha - trainable parameter that controls frequency 82 | - beta - trainable parameter that controls magnitude 83 | References: 84 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 85 | https://arxiv.org/abs/2006.08195 86 | Examples: 87 | >>> a1 = snakebeta(256) 88 | >>> x = torch.randn(256) 89 | >>> x = a1(x) 90 | """ # noqa: E501 91 | 92 | def __init__( 93 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False 94 | ): 95 | """ 96 | Initialization. 97 | INPUT: 98 | - in_features: shape of the input 99 | - alpha - trainable parameter that controls frequency 100 | - beta - trainable parameter that controls magnitude 101 | alpha is initialized to 1 by default, higher values = higher-frequency. 102 | beta is initialized to 1 by default, higher values = higher-magnitude. 103 | alpha will be trained along with the rest of your model. 104 | """ 105 | super(SnakeBeta, self).__init__() 106 | self.in_features = in_features 107 | 108 | # initialize alpha 109 | self.alpha_logscale = alpha_logscale 110 | if self.alpha_logscale: # log scale alphas initialized to zeros 111 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 112 | self.beta = Parameter(torch.zeros(in_features) * alpha) 113 | else: # linear scale alphas initialized to ones 114 | self.alpha = Parameter(torch.ones(in_features) * alpha) 115 | self.beta = Parameter(torch.ones(in_features) * alpha) 116 | 117 | self.alpha.requires_grad = alpha_trainable 118 | self.beta.requires_grad = alpha_trainable 119 | 120 | self.no_div_by_zero = 0.000000001 121 | 122 | def forward(self, x): 123 | """ 124 | Forward pass of the function. 125 | Applies the function to the input elementwise. 126 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 127 | """ 128 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 129 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 130 | if self.alpha_logscale: 131 | alpha = torch.exp(alpha) 132 | beta = torch.exp(beta) 133 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 134 | 135 | return x 136 | 137 | 138 | class AMPBlock(torch.nn.Module): 139 | def __init__( 140 | self, 141 | channels, 142 | kernel_size=3, 143 | dilation=(1, 3, 5), 144 | activation=SnakeBeta, 145 | snake_logscale=True, 146 | ): 147 | super().__init__() 148 | 149 | self.convs1 = nn.ModuleList( 150 | [ 151 | weight_norm( 152 | Conv1d( 153 | channels, 154 | channels, 155 | kernel_size, 156 | 1, 157 | dilation=dilation[0], 158 | padding=get_padding(kernel_size, dilation[0]), 159 | ) 160 | ), 161 | weight_norm( 162 | Conv1d( 163 | channels, 164 | channels, 165 | kernel_size, 166 | 1, 167 | dilation=dilation[1], 168 | padding=get_padding(kernel_size, dilation[1]), 169 | ) 170 | ), 171 | weight_norm( 172 | Conv1d( 173 | channels, 174 | channels, 175 | kernel_size, 176 | 1, 177 | dilation=dilation[2], 178 | padding=get_padding(kernel_size, dilation[2]), 179 | ) 180 | ), 181 | ] 182 | ) 183 | self.convs1.apply(init_weights) 184 | 185 | self.convs2 = nn.ModuleList( 186 | [ 187 | weight_norm( 188 | Conv1d( 189 | channels, 190 | channels, 191 | kernel_size, 192 | 1, 193 | dilation=1, 194 | padding=get_padding(kernel_size, 1), 195 | ) 196 | ), 197 | weight_norm( 198 | Conv1d( 199 | channels, 200 | channels, 201 | kernel_size, 202 | 1, 203 | dilation=1, 204 | padding=get_padding(kernel_size, 1), 205 | ) 206 | ), 207 | weight_norm( 208 | Conv1d( 209 | channels, 210 | channels, 211 | kernel_size, 212 | 1, 213 | dilation=1, 214 | padding=get_padding(kernel_size, 1), 215 | ) 216 | ), 217 | ] 218 | ) 219 | self.convs2.apply(init_weights) 220 | 221 | self.num_layers = len(self.convs1) + len( 222 | self.convs2 223 | ) # total number of conv layers 224 | 225 | # BigVGAN 226 | self.activations = nn.ModuleList( 227 | [ 228 | Activation1d( 229 | activation=activation(channels, alpha_logscale=snake_logscale) 230 | ) 231 | for _ in range(self.num_layers) 232 | ] 233 | ) 234 | 235 | def forward(self, x): 236 | acts1, acts2 = self.activations[::2], self.activations[1::2] 237 | 238 | for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): 239 | xt = a1(x) 240 | xt = c1(xt) 241 | xt = a2(xt) 242 | xt = c2(xt) 243 | x = xt + x 244 | 245 | return x 246 | 247 | def remove_parametrizations(self): 248 | for conv in self.convs1: 249 | remove_parametrizations(conv) 250 | 251 | for conv in self.convs2: 252 | remove_parametrizations(conv) 253 | 254 | 255 | class BigVGANGenerator(torch.nn.Module): 256 | def __init__( 257 | self, 258 | *, 259 | hop_length: int = 512, 260 | upsample_rates: tuple[int] = (8, 8, 2, 2, 2), 261 | upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), 262 | resblock_kernel_sizes: tuple[int] = (3, 7, 11), 263 | resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), 264 | num_mels: int = 128, 265 | upsample_initial_channel: int = 512, 266 | activation: Callable = SnakeBeta, 267 | use_template: bool = True, 268 | pre_conv_kernel_size: int = 7, 269 | post_conv_kernel_size: int = 7, 270 | ): 271 | super().__init__() 272 | 273 | assert ( 274 | prod(upsample_rates) == hop_length 275 | ), f"hop_length must be {prod(upsample_rates)}" 276 | 277 | self.conv_pre = weight_norm( 278 | nn.Conv1d( 279 | num_mels, 280 | upsample_initial_channel, 281 | pre_conv_kernel_size, 282 | 1, 283 | padding=get_padding(pre_conv_kernel_size), 284 | ) 285 | ) 286 | 287 | self.num_upsamples = len(upsample_rates) 288 | self.num_kernels = len(resblock_kernel_sizes) 289 | 290 | # transposed conv-based upsamplers. does not apply anti-aliasing 291 | self.noise_convs = nn.ModuleList() 292 | self.use_template = use_template 293 | self.ups = nn.ModuleList() 294 | 295 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 296 | c_cur = upsample_initial_channel // (2 ** (i + 1)) 297 | self.ups.append( 298 | weight_norm( 299 | nn.ConvTranspose1d( 300 | upsample_initial_channel // (2**i), 301 | upsample_initial_channel // (2 ** (i + 1)), 302 | k, 303 | u, 304 | padding=(k - u) // 2, 305 | ) 306 | ) 307 | ) 308 | 309 | if not use_template: 310 | continue 311 | 312 | if i + 1 < len(upsample_rates): 313 | stride_f0 = np.prod(upsample_rates[i + 1 :]) 314 | self.noise_convs.append( 315 | Conv1d( 316 | 1, 317 | c_cur, 318 | kernel_size=stride_f0 * 2, 319 | stride=stride_f0, 320 | padding=stride_f0 // 2, 321 | ) 322 | ) 323 | else: 324 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) 325 | 326 | # residual blocks using anti-aliased multi-periodicity composition modules (AMP) 327 | self.resblocks = nn.ModuleList() 328 | for i in range(len(self.ups)): 329 | ch = upsample_initial_channel // (2 ** (i + 1)) 330 | 331 | for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): 332 | self.resblocks.append(AMPBlock(ch, k, d)) 333 | 334 | # post conv 335 | self.activation_post = Activation1d( 336 | activation=activation(ch, alpha_logscale=True) 337 | ) 338 | self.conv_post = weight_norm( 339 | nn.Conv1d( 340 | ch, 341 | 1, 342 | post_conv_kernel_size, 343 | 1, 344 | padding=get_padding(post_conv_kernel_size), 345 | ) 346 | ) 347 | 348 | # weight initialization 349 | self.ups.apply(init_weights) 350 | self.conv_post.apply(init_weights) 351 | 352 | def forward(self, x, template=None): 353 | x = self.conv_pre(x) 354 | 355 | for i in range(self.num_upsamples): 356 | x = self.ups[i](x) 357 | 358 | if self.use_template: 359 | x = x + self.noise_convs[i](template) 360 | 361 | xs = [] 362 | for j in range(self.num_kernels): 363 | xs.append(self.resblocks[i * self.num_kernels + j](x)) 364 | 365 | x = torch.stack(xs, dim=0).mean(dim=0) 366 | 367 | x = self.activation_post(x) 368 | x = self.conv_post(x) 369 | x = torch.tanh(x) 370 | 371 | return x 372 | 373 | def remove_parametrizations(self): 374 | for up in self.ups: 375 | remove_parametrizations(up) 376 | for block in self.resblocks: 377 | block.remove_parametrizations() 378 | remove_parametrizations(self.conv_pre) 379 | remove_parametrizations(self.conv_post) 380 | -------------------------------------------------------------------------------- /fish_vocoder/modules/generators/hifigan.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from math import prod 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import Conv1d 10 | from torch.nn.utils.parametrizations import weight_norm 11 | from torch.nn.utils.parametrize import remove_parametrizations 12 | from torch.utils.checkpoint import checkpoint 13 | 14 | 15 | def init_weights(m, mean=0.0, std=0.01): 16 | classname = m.__class__.__name__ 17 | if classname.find("Conv") != -1: 18 | m.weight.data.normal_(mean, std) 19 | 20 | 21 | def get_padding(kernel_size, dilation=1): 22 | return (kernel_size * dilation - dilation) // 2 23 | 24 | 25 | class ResBlock1(torch.nn.Module): 26 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 27 | super().__init__() 28 | 29 | self.convs1 = nn.ModuleList( 30 | [ 31 | weight_norm( 32 | Conv1d( 33 | channels, 34 | channels, 35 | kernel_size, 36 | 1, 37 | dilation=dilation[0], 38 | padding=get_padding(kernel_size, dilation[0]), 39 | ) 40 | ), 41 | weight_norm( 42 | Conv1d( 43 | channels, 44 | channels, 45 | kernel_size, 46 | 1, 47 | dilation=dilation[1], 48 | padding=get_padding(kernel_size, dilation[1]), 49 | ) 50 | ), 51 | weight_norm( 52 | Conv1d( 53 | channels, 54 | channels, 55 | kernel_size, 56 | 1, 57 | dilation=dilation[2], 58 | padding=get_padding(kernel_size, dilation[2]), 59 | ) 60 | ), 61 | ] 62 | ) 63 | self.convs1.apply(init_weights) 64 | 65 | self.convs2 = nn.ModuleList( 66 | [ 67 | weight_norm( 68 | Conv1d( 69 | channels, 70 | channels, 71 | kernel_size, 72 | 1, 73 | dilation=1, 74 | padding=get_padding(kernel_size, 1), 75 | ) 76 | ), 77 | weight_norm( 78 | Conv1d( 79 | channels, 80 | channels, 81 | kernel_size, 82 | 1, 83 | dilation=1, 84 | padding=get_padding(kernel_size, 1), 85 | ) 86 | ), 87 | weight_norm( 88 | Conv1d( 89 | channels, 90 | channels, 91 | kernel_size, 92 | 1, 93 | dilation=1, 94 | padding=get_padding(kernel_size, 1), 95 | ) 96 | ), 97 | ] 98 | ) 99 | self.convs2.apply(init_weights) 100 | 101 | def forward(self, x): 102 | for c1, c2 in zip(self.convs1, self.convs2): 103 | xt = F.silu(x) 104 | xt = c1(xt) 105 | xt = F.silu(xt) 106 | xt = c2(xt) 107 | x = xt + x 108 | return x 109 | 110 | def remove_parametrizations(self): 111 | for conv in self.convs1: 112 | remove_parametrizations(conv) 113 | for conv in self.convs2: 114 | remove_parametrizations(conv) 115 | 116 | 117 | class ParralelBlock(nn.Module): 118 | def __init__( 119 | self, 120 | channels: int, 121 | kernel_sizes: tuple[int] = (3, 7, 11), 122 | dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), 123 | ): 124 | super().__init__() 125 | 126 | assert len(kernel_sizes) == len(dilation_sizes) 127 | 128 | self.blocks = nn.ModuleList() 129 | for k, d in zip(kernel_sizes, dilation_sizes): 130 | self.blocks.append(ResBlock1(channels, k, d)) 131 | 132 | def forward(self, x): 133 | return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) 134 | 135 | 136 | class HiFiGANGenerator(nn.Module): 137 | def __init__( 138 | self, 139 | *, 140 | hop_length: int = 512, 141 | upsample_rates: tuple[int] = (8, 8, 2, 2, 2), 142 | upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), 143 | resblock_kernel_sizes: tuple[int] = (3, 7, 11), 144 | resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), 145 | num_mels: int = 128, 146 | upsample_initial_channel: int = 512, 147 | use_template: bool = True, 148 | pre_conv_kernel_size: int = 7, 149 | post_conv_kernel_size: int = 7, 150 | post_activation: Callable = partial(nn.SiLU, inplace=True), 151 | ): 152 | super().__init__() 153 | 154 | assert ( 155 | prod(upsample_rates) == hop_length 156 | ), f"hop_length must be {prod(upsample_rates)}" 157 | 158 | self.conv_pre = weight_norm( 159 | nn.Conv1d( 160 | num_mels, 161 | upsample_initial_channel, 162 | pre_conv_kernel_size, 163 | 1, 164 | padding=get_padding(pre_conv_kernel_size), 165 | ) 166 | ) 167 | 168 | self.num_upsamples = len(upsample_rates) 169 | self.num_kernels = len(resblock_kernel_sizes) 170 | 171 | self.noise_convs = nn.ModuleList() 172 | self.use_template = use_template 173 | self.ups = nn.ModuleList() 174 | 175 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 176 | c_cur = upsample_initial_channel // (2 ** (i + 1)) 177 | self.ups.append( 178 | weight_norm( 179 | nn.ConvTranspose1d( 180 | upsample_initial_channel // (2**i), 181 | upsample_initial_channel // (2 ** (i + 1)), 182 | k, 183 | u, 184 | padding=(k - u) // 2, 185 | ) 186 | ) 187 | ) 188 | 189 | if not use_template: 190 | continue 191 | 192 | if i + 1 < len(upsample_rates): 193 | stride_f0 = np.prod(upsample_rates[i + 1 :]) 194 | self.noise_convs.append( 195 | Conv1d( 196 | 1, 197 | c_cur, 198 | kernel_size=stride_f0 * 2, 199 | stride=stride_f0, 200 | padding=stride_f0 // 2, 201 | ) 202 | ) 203 | else: 204 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) 205 | 206 | self.resblocks = nn.ModuleList() 207 | for i in range(len(self.ups)): 208 | ch = upsample_initial_channel // (2 ** (i + 1)) 209 | self.resblocks.append( 210 | ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) 211 | ) 212 | 213 | self.activation_post = post_activation() 214 | self.conv_post = weight_norm( 215 | nn.Conv1d( 216 | ch, 217 | 1, 218 | post_conv_kernel_size, 219 | 1, 220 | padding=get_padding(post_conv_kernel_size), 221 | ) 222 | ) 223 | self.ups.apply(init_weights) 224 | self.conv_post.apply(init_weights) 225 | 226 | def forward(self, x, template=None): 227 | x = self.conv_pre(x) 228 | 229 | for i in range(self.num_upsamples): 230 | x = F.silu(x, inplace=True) 231 | x = self.ups[i](x) 232 | 233 | if self.use_template: 234 | x = x + self.noise_convs[i](template) 235 | 236 | if self.training and self.checkpointing: 237 | x = checkpoint( 238 | self.resblocks[i], 239 | x, 240 | use_reentrant=False, 241 | ) 242 | else: 243 | x = self.resblocks[i](x) 244 | 245 | x = self.activation_post(x) 246 | x = self.conv_post(x) 247 | x = torch.tanh(x) 248 | 249 | return x 250 | 251 | def remove_parametrizations(self): 252 | for up in self.ups: 253 | remove_parametrizations(up) 254 | for block in self.resblocks: 255 | block.remove_parametrizations() 256 | remove_parametrizations(self.conv_pre) 257 | remove_parametrizations(self.conv_post) 258 | -------------------------------------------------------------------------------- /fish_vocoder/modules/generators/refinegan.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.utils.parametrizations import weight_norm 8 | from torch.nn.utils.parametrize import remove_parametrizations 9 | 10 | 11 | def named_apply( 12 | fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False 13 | ) -> nn.Module: 14 | if not depth_first and include_root: 15 | fn(module=module, name=name) 16 | 17 | for child_name, child_module in module.named_children(): 18 | child_name = ".".join((name, child_name)) if name else child_name 19 | named_apply( 20 | fn=fn, 21 | module=child_module, 22 | name=child_name, 23 | depth_first=depth_first, 24 | include_root=True, 25 | ) 26 | 27 | if depth_first and include_root: 28 | fn(module=module, name=name) 29 | 30 | return module 31 | 32 | 33 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 34 | return int((kernel_size * dilation - dilation) / 2) 35 | 36 | 37 | class ResBlock(torch.nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | in_channels: int, 42 | out_channels: int, 43 | kernel_size: int = 7, 44 | dilation: tuple[int] = (1, 3, 5), 45 | leaky_relu_slope: float = 0.2, 46 | ): 47 | super(ResBlock, self).__init__() 48 | 49 | self.leaky_relu_slope = leaky_relu_slope 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | 53 | self.convs1 = nn.ModuleList( 54 | [ 55 | weight_norm( 56 | nn.Conv1d( 57 | in_channels=in_channels if idx == 0 else out_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=1, 61 | dilation=d, 62 | padding=get_padding(kernel_size, d), 63 | ) 64 | ) 65 | for idx, d in enumerate(dilation) 66 | ] 67 | ) 68 | self.convs1.apply(self.init_weights) 69 | 70 | self.convs2 = nn.ModuleList( 71 | [ 72 | weight_norm( 73 | nn.Conv1d( 74 | in_channels=out_channels, 75 | out_channels=out_channels, 76 | kernel_size=kernel_size, 77 | stride=1, 78 | dilation=d, 79 | padding=get_padding(kernel_size, d), 80 | ) 81 | ) 82 | for idx, d in enumerate(dilation) 83 | ] 84 | ) 85 | self.convs2.apply(self.init_weights) 86 | 87 | def forward(self, x): 88 | for idx, (c1, c2) in enumerate(zip(self.convs1, self.convs2)): 89 | xt = F.leaky_relu(x, self.leaky_relu_slope) 90 | xt = c1(xt) 91 | xt = F.leaky_relu(xt, self.leaky_relu_slope) 92 | xt = c2(xt) 93 | 94 | if idx != 0 or self.in_channels == self.out_channels: 95 | x = xt + x 96 | else: 97 | x = xt 98 | 99 | return x 100 | 101 | def remove_parametrizations(self): 102 | for c1, c2 in zip(self.convs1, self.convs2): 103 | remove_parametrizations(c1) 104 | remove_parametrizations(c2) 105 | 106 | def init_weights(self, m): 107 | if type(m) == nn.Conv1d: 108 | m.weight.data.normal_(0, 0.01) 109 | m.bias.data.fill_(0.0) 110 | 111 | 112 | class AdaIN(nn.Module): 113 | def __init__( 114 | self, 115 | *, 116 | channels: int, 117 | leaky_relu_slope: float = 0.2, 118 | ) -> None: 119 | super().__init__() 120 | 121 | self.weight = nn.Parameter(torch.ones(channels)) 122 | self.activation = nn.LeakyReLU(leaky_relu_slope) 123 | 124 | def forward(self, x: torch.Tensor) -> torch.Tensor: 125 | gaussian = torch.randn_like(x) * self.weight[None, :, None] 126 | 127 | return self.activation(x + gaussian) 128 | 129 | 130 | class ParallelResBlock(nn.Module): 131 | def __init__( 132 | self, 133 | *, 134 | in_channels: int, 135 | out_channels: int, 136 | kernel_sizes: int = (3, 7, 11), 137 | dilation: tuple[int] = (1, 3, 5), 138 | leaky_relu_slope: float = 0.2, 139 | ) -> None: 140 | super().__init__() 141 | 142 | self.in_channels = in_channels 143 | self.out_channels = out_channels 144 | 145 | self.input_conv = nn.Conv1d( 146 | in_channels=in_channels, 147 | out_channels=out_channels, 148 | kernel_size=7, 149 | stride=1, 150 | padding=3, 151 | ) 152 | 153 | self.blocks = nn.ModuleList( 154 | [ 155 | nn.Sequential( 156 | AdaIN(channels=out_channels), 157 | ResBlock( 158 | in_channels=out_channels, 159 | out_channels=out_channels, 160 | kernel_size=kernel_size, 161 | dilation=dilation, 162 | leaky_relu_slope=leaky_relu_slope, 163 | ), 164 | AdaIN(channels=out_channels), 165 | ) 166 | for kernel_size in kernel_sizes 167 | ] 168 | ) 169 | 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | x = self.input_conv(x) 172 | 173 | results = [block(x) for block in self.blocks] 174 | 175 | return torch.mean(torch.stack(results), dim=0) 176 | 177 | def remove_parametrizations(self): 178 | for block in self.blocks: 179 | block[1].remove_parametrizations() 180 | 181 | 182 | class RefineGANGenerator(nn.Module): 183 | def __init__( 184 | self, 185 | *, 186 | sampling_rate: int = 44100, 187 | hop_length: int = 256, 188 | downsample_rates: tuple[int] = (2, 2, 8, 8), 189 | upsample_rates: tuple[int] = (8, 8, 2, 2), 190 | leaky_relu_slope: float = 0.2, 191 | num_mels: int = 128, 192 | start_channels: int = 16, 193 | ) -> None: 194 | super().__init__() 195 | 196 | self.sampling_rate = sampling_rate 197 | self.hop_length = hop_length 198 | self.downsample_rates = downsample_rates 199 | self.upsample_rates = upsample_rates 200 | self.leaky_relu_slope = leaky_relu_slope 201 | 202 | assert np.prod(downsample_rates) == np.prod(upsample_rates) == hop_length 203 | 204 | self.template_conv = weight_norm( 205 | nn.Conv1d( 206 | in_channels=1, 207 | out_channels=start_channels, 208 | kernel_size=7, 209 | stride=1, 210 | padding=3, 211 | ) 212 | ) 213 | 214 | channels = start_channels 215 | 216 | self.downsample_blocks = nn.ModuleList([]) 217 | for rate in downsample_rates: 218 | new_channels = channels * 2 219 | 220 | self.downsample_blocks.append( 221 | nn.Sequential( 222 | nn.Upsample(scale_factor=1 / rate, mode="linear"), 223 | ResBlock( 224 | in_channels=channels, 225 | out_channels=new_channels, 226 | kernel_size=7, 227 | dilation=(1, 3, 5), 228 | leaky_relu_slope=leaky_relu_slope, 229 | ), 230 | ) 231 | ) 232 | 233 | channels = new_channels 234 | 235 | self.mel_conv = weight_norm( 236 | nn.Conv1d( 237 | in_channels=num_mels, 238 | out_channels=channels, 239 | kernel_size=7, 240 | stride=1, 241 | padding=3, 242 | ) 243 | ) 244 | channels *= 2 245 | 246 | self.upsample_blocks = nn.ModuleList([]) 247 | self.upsample_conv_blocks = nn.ModuleList([]) 248 | 249 | for rate in upsample_rates: 250 | new_channels = channels // 2 251 | 252 | self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear")) 253 | 254 | self.upsample_conv_blocks.append( 255 | ParallelResBlock( 256 | in_channels=channels + channels // 4, 257 | out_channels=new_channels, 258 | kernel_sizes=(3, 7, 11), 259 | dilation=(1, 3, 5), 260 | leaky_relu_slope=leaky_relu_slope, 261 | ) 262 | ) 263 | 264 | channels = new_channels 265 | 266 | self.output_conv = weight_norm( 267 | nn.Conv1d( 268 | in_channels=channels, 269 | out_channels=1, 270 | kernel_size=7, 271 | stride=1, 272 | padding=3, 273 | ) 274 | ) 275 | 276 | def remove_parametrizations(self) -> None: 277 | remove_parametrizations(self.template_conv) 278 | remove_parametrizations(self.mel_conv) 279 | remove_parametrizations(self.output_conv) 280 | 281 | for block in self.downsample_blocks: 282 | block[1].remove_parametrizations() 283 | 284 | for block in self.upsample_conv_blocks: 285 | block.remove_parametrizations() 286 | 287 | def forward(self, mel: torch.Tensor, template: torch.Tensor) -> torch.Tensor: 288 | """ 289 | Args: 290 | mel (torch.Tensor): [B, mel_bin, T] 291 | template (torch.Tensor): [B, 1, T] 292 | 293 | Returns: 294 | torch.Tensor: [B, 1, T] 295 | """ 296 | 297 | x = self.template_conv(template) 298 | 299 | downs = [] 300 | 301 | for block in self.downsample_blocks: 302 | x = F.leaky_relu(x, self.leaky_relu_slope, inplace=True) 303 | downs.append(x) 304 | x = block(x) 305 | 306 | x = torch.cat([x, self.mel_conv(mel)], dim=1) 307 | 308 | for upsample_block, conv_block, down in zip( 309 | self.upsample_blocks, 310 | self.upsample_conv_blocks, 311 | reversed(downs), 312 | ): 313 | x = F.leaky_relu(x, self.leaky_relu_slope, inplace=True) 314 | x = upsample_block(x) 315 | 316 | x = torch.cat([x, down], dim=1) 317 | x = conv_block(x) 318 | 319 | x = F.leaky_relu(x, self.leaky_relu_slope, inplace=True) 320 | x = self.output_conv(x) 321 | x = torch.tanh(x) 322 | 323 | return x 324 | -------------------------------------------------------------------------------- /fish_vocoder/modules/generators/unify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class UnifyGenerator(nn.Module): 6 | def __init__( 7 | self, 8 | backbone: nn.Module, 9 | head: nn.Module, 10 | vq: nn.Module | None = None, 11 | ): 12 | super().__init__() 13 | 14 | self.backbone = backbone 15 | self.head = head 16 | self.vq = vq 17 | 18 | def forward(self, x: torch.Tensor, template=None) -> torch.Tensor: 19 | x = self.backbone(x) 20 | 21 | if self.vq is not None: 22 | vq_result = self.vq(x) 23 | x = vq_result.z 24 | 25 | x = self.head(x, template=template) 26 | 27 | if x.ndim == 2: 28 | x = x[:, None, :] 29 | 30 | if self.vq is not None: 31 | return x, vq_result 32 | 33 | return x 34 | 35 | def encode(self, x: torch.Tensor) -> torch.Tensor: 36 | if self.vq is None: 37 | raise ValueError("VQ module is not present in the model.") 38 | 39 | x = self.backbone(x) 40 | vq_result = self.vq(x) 41 | return vq_result.codes 42 | 43 | def decode(self, codes: torch.Tensor, template=None) -> torch.Tensor: 44 | if self.vq is None: 45 | raise ValueError("VQ module is not present in the model.") 46 | 47 | x = self.vq.from_codes(codes)[0] 48 | x = self.head(x, template=template) 49 | 50 | if x.ndim == 2: 51 | x = x[:, None, :] 52 | 53 | return x 54 | 55 | def remove_parametrizations(self): 56 | if hasattr(self.backbone, "remove_parametrizations"): 57 | self.backbone.remove_parametrizations() 58 | 59 | if hasattr(self.head, "remove_parametrizations"): 60 | self.head.remove_parametrizations() 61 | -------------------------------------------------------------------------------- /fish_vocoder/modules/generators/vocos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from vocos.spectral_ops import ISTFT 4 | 5 | 6 | class ISTFTHead(nn.Module): 7 | """ 8 | ISTFT Head module for predicting STFT complex coefficients. 9 | 10 | Args: 11 | dim (int): Hidden dimension of the model. 12 | n_fft (int): Size of Fourier transform. 13 | hop_length (int): The distance between neighboring sliding window frames, which should align with 14 | the resolution of the input features. 15 | win_length (int): The size of window frame and STFT filter. 16 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 17 | """ # noqa: E501 18 | 19 | def __init__( 20 | self, 21 | dim: int, 22 | n_fft: int, 23 | hop_length: int, 24 | win_length: int, 25 | padding: str = "same", 26 | ): 27 | super().__init__() 28 | 29 | self.n_fft = n_fft 30 | self.hop_length = hop_length 31 | self.win_length = win_length 32 | 33 | self.istft = ISTFT( 34 | n_fft=n_fft, 35 | hop_length=hop_length, 36 | win_length=win_length, 37 | padding=padding, 38 | ) 39 | 40 | out_dim = n_fft * 2 41 | self.out = nn.Conv1d(dim, out_dim, 1) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Forward pass of the ISTFTHead module. 46 | 47 | Args: 48 | x (Tensor): Input tensor of shape (B, H, L), where B is the batch size, 49 | L is the sequence length, and H denotes the model dimension. 50 | 51 | Returns: 52 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 53 | """ # noqa: E501 54 | 55 | x = self.out(x) 56 | 57 | mag, p = x.chunk(2, dim=1) 58 | mag = torch.exp(mag) 59 | mag = torch.clip( 60 | mag, max=1e2 61 | ) # safeguard to prevent excessively large magnitudes 62 | 63 | # wrapping happens here. These two lines produce real and imaginary value 64 | x = torch.cos(p) 65 | y = torch.sin(p) 66 | 67 | S = mag * (x + 1j * y) 68 | 69 | return self.istft(S) 70 | -------------------------------------------------------------------------------- /fish_vocoder/modules/losses/stft.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | def stft(x, fft_size, hop_size, win_length, window): 14 | """Perform STFT and convert to magnitude spectrogram. 15 | Args: 16 | x (Tensor): Input signal tensor (B, T). 17 | fft_size (int): FFT size. 18 | hop_size (int): Hop size. 19 | win_length (int): Window length. 20 | window (str): Window function type. 21 | Returns: 22 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 23 | """ 24 | spec = torch.stft( 25 | x, 26 | fft_size, 27 | hop_size, 28 | win_length, 29 | window, 30 | return_complex=True, 31 | pad_mode="reflect", 32 | ) 33 | spec = torch.view_as_real(spec) 34 | 35 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 36 | return torch.sqrt(torch.clamp(spec.pow(2).sum(-1), min=1e-6)).transpose(2, 1) 37 | 38 | 39 | class SpectralConvergengeLoss(nn.Module): 40 | """Spectral convergence loss module.""" 41 | 42 | def __init__(self): 43 | """Initialize spectral convergence loss module.""" 44 | super(SpectralConvergengeLoss, self).__init__() 45 | 46 | def forward(self, x_mag, y_mag): 47 | """Calculate forward propagation. 48 | Args: 49 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 50 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 51 | Returns: 52 | Tensor: Spectral convergence loss value. 53 | """ # noqa: E501 54 | 55 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 56 | 57 | 58 | class LogSTFTMagnitudeLoss(nn.Module): 59 | """Log STFT magnitude loss module.""" 60 | 61 | def __init__(self): 62 | """Initialize los STFT magnitude loss module.""" 63 | super(LogSTFTMagnitudeLoss, self).__init__() 64 | 65 | def forward(self, x_mag, y_mag): 66 | """Calculate forward propagation. 67 | Args: 68 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 69 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 70 | Returns: 71 | Tensor: Log STFT magnitude loss value. 72 | """ # noqa: E501 73 | 74 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 75 | 76 | 77 | class STFTLoss(nn.Module): 78 | """STFT loss module.""" 79 | 80 | def __init__( 81 | self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window 82 | ): 83 | """Initialize STFT loss module.""" 84 | super(STFTLoss, self).__init__() 85 | 86 | self.fft_size = fft_size 87 | self.shift_size = shift_size 88 | self.win_length = win_length 89 | self.register_buffer("window", window(win_length)) 90 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 91 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 92 | 93 | def forward(self, x, y): 94 | """Calculate forward propagation. 95 | Args: 96 | x (Tensor): Predicted signal (B, T). 97 | y (Tensor): Groundtruth signal (B, T). 98 | Returns: 99 | Tensor: Spectral convergence loss value. 100 | Tensor: Log STFT magnitude loss value. 101 | """ 102 | 103 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 104 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 105 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 106 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 107 | 108 | return sc_loss, mag_loss 109 | 110 | 111 | class MultiResolutionSTFTLoss(nn.Module): 112 | """Multi resolution STFT loss module.""" 113 | 114 | def __init__(self, resolutions, window=torch.hann_window): 115 | super(MultiResolutionSTFTLoss, self).__init__() 116 | 117 | self.stft_losses = nn.ModuleList() 118 | for fs, ss, wl in resolutions: 119 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 120 | 121 | def forward(self, x, y): 122 | """Calculate forward propagation. 123 | Args: 124 | x (Tensor): Predicted signal (B, T). 125 | y (Tensor): Groundtruth signal (B, T). 126 | Returns: 127 | Tensor: Multi resolution spectral convergence loss value. 128 | Tensor: Multi resolution log STFT magnitude loss value. 129 | """ 130 | sc_loss = 0.0 131 | mag_loss = 0.0 132 | for f in self.stft_losses: 133 | sc_l, mag_l = f(x, y) 134 | sc_loss += sc_l 135 | mag_loss += mag_l 136 | 137 | sc_loss /= len(self.stft_losses) 138 | mag_loss /= len(self.stft_losses) 139 | 140 | return sc_loss, mag_loss 141 | -------------------------------------------------------------------------------- /fish_vocoder/schedulers/warmup_cosine.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class LambdaWarmUpCosineScheduler: 7 | """ 8 | note: use with a base_lr of 1.0 9 | """ 10 | 11 | def __init__( 12 | self, 13 | *, 14 | val_base, 15 | val_final, 16 | max_decay_steps, 17 | val_start=0, 18 | warm_up_steps=0, 19 | ): 20 | """Warmup cosine scheduler 21 | 22 | Args: 23 | val_base (float): the val after warmup 24 | val_final (float): the val at the end of the schedule 25 | max_decay_steps (int): number of steps to decay from val_base to val_final (after warmup) 26 | val_start (float, optional): learning rate at the start of the schedule. Defaults to 0. 27 | warm_up_steps (int, optional): number of steps for the warmup phase. Defaults to 0. 28 | """ # noqa: E501 29 | 30 | self.val_final = val_final 31 | self.val_base = val_base 32 | self.warm_up_steps = warm_up_steps 33 | self.val_start = val_start 34 | self.val_base_decay_steps = max_decay_steps 35 | self.last_lr = 0.0 36 | 37 | def schedule(self, n): 38 | if n < self.warm_up_steps: 39 | lr = ( 40 | self.val_base - self.val_start 41 | ) / self.warm_up_steps * n + self.val_start 42 | self.last_lr = lr 43 | 44 | return lr 45 | 46 | t = (n - self.warm_up_steps) / (self.val_base_decay_steps - self.warm_up_steps) 47 | t = min(t, 1.0) 48 | lr = self.val_final + 0.5 * (self.val_base - self.val_final) * ( 49 | 1 + math.cos(t * torch.pi) 50 | ) 51 | self.last_lr = lr 52 | 53 | return lr 54 | 55 | def __call__(self, n): 56 | return self.schedule(n) 57 | -------------------------------------------------------------------------------- /fish_vocoder/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import hydra 5 | import librosa 6 | import pyrootutils 7 | import soundfile as sf 8 | import torch 9 | import torch.nn.functional as F 10 | from hydra.utils import instantiate 11 | from lightning import LightningModule 12 | from omegaconf import DictConfig, OmegaConf 13 | 14 | # Allow TF32 on Ampere GPUs 15 | torch.set_float32_matmul_precision("high") 16 | 17 | # register eval resolver and root 18 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 19 | OmegaConf.register_new_resolver("eval", eval) 20 | 21 | # flake8: noqa: E402 22 | from fish_vocoder.utils.logger import logger 23 | 24 | 25 | @hydra.main(config_path="configs", version_base="1.3", config_name="train") 26 | @torch.no_grad() 27 | def main(cfg: DictConfig): 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | logger.info(f"Using device: {device}") 30 | 31 | model: LightningModule = instantiate(cfg.model) 32 | ckpt = torch.load(cfg.ckpt_path, map_location=device) 33 | 34 | if "state_dict" in ckpt: 35 | ckpt = ckpt["state_dict"] 36 | 37 | model.load_state_dict(ckpt) 38 | model.eval() 39 | model.to(device) 40 | 41 | if hasattr(model.generator, "remove_weight_norm"): 42 | model.generator.remove_weight_norm() 43 | 44 | input_path = Path(cfg.input_path) 45 | 46 | if input_path.is_file(): 47 | audios = [input_path] 48 | input_path = input_path.parent 49 | elif input_path.is_dir(): 50 | audios = list(input_path.rglob("*")) 51 | 52 | for audio_path in audios: 53 | if audio_path.suffix in [".wav", ".flac", ".mp3"]: 54 | gt_y, sr = librosa.load(audio_path, sr=cfg.model.sampling_rate, mono=False) 55 | 56 | # If mono, add a channel dimension 57 | if len(gt_y.shape) == 1: 58 | gt_y = gt_y[None, :] 59 | 60 | # If we have more than one channel, switch to batched mode 61 | if cfg.pitch_shift != 0: 62 | gt_y = librosa.effects.pitch_shift(gt_y, sr=sr, n_steps=cfg.pitch_shift) 63 | 64 | gt_y = torch.from_numpy(gt_y)[:, None].to(model.device, torch.float32) 65 | lengths = torch.IntTensor([gt_y.shape[-1]]) 66 | gt_y = F.pad( 67 | gt_y, 68 | (0, cfg.model.hop_length - (cfg.model.hop_length % gt_y.shape[-1])), 69 | ) 70 | logger.info(f"gt_y shape: {gt_y.shape}, lengths: {lengths}") 71 | inputs = model.mel_transforms.input(gt_y.squeeze(1)) 72 | 73 | elif audio_path.suffix in [".pt", ".pth"]: 74 | input_mels = torch.load(audio_path, map_location=model.device).to( 75 | torch.float32 76 | ) 77 | 78 | if len(input_mels.shape) == 2: 79 | input_mels = input_mels[None, ...] 80 | 81 | if input_mels.shape[-1] == cfg.model.num_mels: 82 | input_mels = input_mels.transpose(1, 2) 83 | 84 | inputs = input_mels 85 | else: 86 | continue 87 | 88 | start = time.time() 89 | fake_audio = model(None, None, input_spec=inputs)[0] 90 | logger.info(f"Time taken: {time.time() - start:.2f}s") 91 | 92 | output_path = ( 93 | Path(cfg.output_path) 94 | / f"{Path(audio_path).relative_to(input_path).with_suffix('.wav')}" 95 | ) 96 | output_path.parent.mkdir(parents=True, exist_ok=True) 97 | 98 | fake_audio = fake_audio.squeeze(1) 99 | sf.write(output_path, fake_audio.cpu().numpy().T, cfg.model.sampling_rate) 100 | logger.info(f"Saved generated audio to {output_path}") 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /fish_vocoder/train.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import pyrootutils 6 | import torch 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning.pytorch.loggers import Logger 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | # Allow TF32 on Ampere GPUs 12 | torch.set_float32_matmul_precision("high") 13 | torch.backends.cudnn.allow_tf32 = True 14 | 15 | # register eval resolver and root 16 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 17 | OmegaConf.register_new_resolver("eval", eval) 18 | 19 | # flake8: noqa: E402 20 | from fish_vocoder import utils 21 | from fish_vocoder.utils.file import get_latest_checkpoint 22 | from fish_vocoder.utils.logger import logger as log 23 | 24 | 25 | @utils.task_wrapper 26 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 27 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 28 | training. 29 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 30 | failure. Useful for multiruns, saving info about the crash, etc. 31 | Args: 32 | cfg (DictConfig): Configuration composed by Hydra. 33 | Returns: 34 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 35 | """ # noqa: E501 36 | 37 | # set seed for random number generators in pytorch, numpy and python.random 38 | if cfg.get("seed"): 39 | L.seed_everything(cfg.seed, workers=True) 40 | 41 | if cfg.get("deterministic"): 42 | torch.use_deterministic_algorithms(True) 43 | 44 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 45 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 46 | 47 | log.info(f"Instantiating model <{cfg.model._target_}>") 48 | model: LightningModule = hydra.utils.instantiate(cfg.model) 49 | 50 | log.info("Instantiating callbacks...") 51 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 52 | 53 | log.info("Instantiating loggers...") 54 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 55 | 56 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 57 | trainer: Trainer = hydra.utils.instantiate( 58 | cfg.trainer, callbacks=callbacks, logger=logger 59 | ) 60 | 61 | object_dict = { 62 | "cfg": cfg, 63 | "datamodule": datamodule, 64 | "model": model, 65 | "callbacks": callbacks, 66 | "logger": logger, 67 | "trainer": trainer, 68 | } 69 | 70 | if logger: 71 | log.info("Logging hyperparameters!") 72 | utils.log_hyperparameters(object_dict) 73 | 74 | if cfg.get("compile"): 75 | log.info("Compiling model!") 76 | model = torch.compile(model) 77 | 78 | if cfg.get("train"): 79 | log.info("Starting training!") 80 | 81 | ckpt_path = cfg.get("ckpt_path") 82 | 83 | if ckpt_path is None: 84 | ckpt_path = get_latest_checkpoint(cfg.paths.ckpt_dir) 85 | 86 | if ckpt_path is not None: 87 | log.info(f"Resuming from checkpoint: {ckpt_path}") 88 | 89 | if cfg.get("resume_weights_only"): 90 | log.info("Resuming weights only!") 91 | ckpt = torch.load(ckpt_path, map_location=model.device) 92 | model.load_state_dict( 93 | ckpt["state_dict"] if "state_dict" in ckpt else ckpt, strict=False 94 | ) 95 | ckpt_path = None 96 | 97 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 98 | 99 | train_metrics = trainer.callback_metrics 100 | 101 | if cfg.get("test"): 102 | log.info("Starting testing!") 103 | ckpt_path = trainer.checkpoint_callback.best_model_path 104 | if ckpt_path == "": 105 | log.warning("Best ckpt not found! Using current weights for testing...") 106 | ckpt_path = cfg.get("ckpt_path") 107 | 108 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 109 | log.info(f"Best ckpt path: {ckpt_path}") 110 | 111 | test_metrics = trainer.callback_metrics 112 | 113 | # merge train and test metrics 114 | metric_dict = {**train_metrics, **test_metrics} 115 | 116 | return metric_dict, object_dict 117 | 118 | 119 | @hydra.main(version_base="1.3", config_path="./configs", config_name="train.yaml") 120 | def main(cfg: DictConfig) -> Optional[float]: 121 | # apply extra utilities 122 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 123 | utils.extras(cfg) 124 | 125 | # train the model 126 | train(cfg) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /fish_vocoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from fish_vocoder.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from fish_vocoder.utils.logger import logger 3 | from fish_vocoder.utils.logging_utils import log_hyperparameters 4 | from fish_vocoder.utils.rich_utils import enforce_tags, print_config_tree 5 | from fish_vocoder.utils.utils import extras, get_metric_value, task_wrapper 6 | 7 | __all__ = [ 8 | "enforce_tags", 9 | "extras", 10 | "get_metric_value", 11 | "logger", 12 | "instantiate_callbacks", 13 | "instantiate_loggers", 14 | "log_hyperparameters", 15 | "print_config_tree", 16 | "task_wrapper", 17 | ] 18 | -------------------------------------------------------------------------------- /fish_vocoder/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | AUDIO_EXTENSIONS = { 6 | ".mp3", 7 | ".wav", 8 | ".flac", 9 | ".ogg", 10 | ".m4a", 11 | ".wma", 12 | ".aac", 13 | ".aiff", 14 | ".aif", 15 | ".aifc", 16 | } 17 | 18 | 19 | def list_files( 20 | path: Union[Path, str], 21 | extensions: set[str] = None, 22 | recursive: bool = False, 23 | sort: bool = True, 24 | ) -> list[Path]: 25 | """List files in a directory. 26 | 27 | Args: 28 | path (Path): Path to the directory. 29 | extensions (set, optional): Extensions to filter. Defaults to None. 30 | recursive (bool, optional): Whether to search recursively. Defaults to False. 31 | sort (bool, optional): Whether to sort the files. Defaults to True. 32 | 33 | Returns: 34 | list: List of files. 35 | """ 36 | 37 | if isinstance(path, str): 38 | path = Path(path) 39 | 40 | if not path.exists(): 41 | raise FileNotFoundError(f"Directory {path} does not exist.") 42 | 43 | files = ( 44 | [ 45 | Path(os.path.join(root, filename)) 46 | for root, _, filenames in os.walk(path, followlinks=True) 47 | for filename in filenames 48 | if Path(os.path.join(root, filename)).is_file() 49 | ] 50 | if recursive 51 | else [f for f in path.glob("*") if f.is_file()] 52 | ) 53 | 54 | if extensions is not None: 55 | files = [f for f in files if f.suffix in extensions] 56 | 57 | if sort: 58 | files = sorted(files) 59 | 60 | return files 61 | 62 | 63 | def get_latest_checkpoint(path: Path | str) -> Path | None: 64 | # Find the latest checkpoint 65 | ckpt_dir = Path(path) 66 | 67 | if ckpt_dir.exists() is False: 68 | return None 69 | 70 | ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) 71 | if len(ckpts) == 0: 72 | return None 73 | 74 | return ckpts[-1] 75 | -------------------------------------------------------------------------------- /fish_vocoder/utils/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.utils._foreach_utils import ( 6 | _group_tensors_by_device_and_dtype, 7 | _has_foreach_support, 8 | ) 9 | 10 | 11 | @torch.no_grad() 12 | def grad_norm( 13 | parameters: Union[Tensor, list[Tensor]], 14 | norm_type: float = 2.0, 15 | ) -> float: 16 | """ 17 | Returns the norm of the gradients of the given parameters. 18 | 19 | Args: 20 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 21 | single Tensor that will have gradients normalized 22 | norm_type (float): type of the used p-norm. 23 | 24 | Returns: 25 | Total norm of the parameter gradients (viewed as a single vector). 26 | """ # noqa: E501 27 | 28 | if isinstance(parameters, Tensor): 29 | parameters = [parameters] 30 | 31 | grads = [p.grad for p in parameters if p.grad is not None] 32 | first_device = grads[0].device 33 | grouped_grads: dict[ 34 | tuple[torch.device, torch.dtype], list[list[Tensor]] 35 | ] = _group_tensors_by_device_and_dtype( 36 | [[g.detach() for g in grads]] 37 | ) # type: ignore[assignment] 38 | 39 | norms = [] 40 | for (device, _), [grads] in grouped_grads.items(): 41 | if _has_foreach_support(grads, device=device): 42 | norms.extend(torch._foreach_norm(grads, norm_type)) 43 | else: 44 | norms.extend([torch.norm(g, norm_type) for g in grads]) 45 | 46 | return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) 47 | -------------------------------------------------------------------------------- /fish_vocoder/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.loggers import Logger 7 | 8 | from fish_vocoder.utils.logger import logger as log 9 | 10 | 11 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 12 | """Instantiates callbacks from config.""" 13 | 14 | callbacks: List[Callback] = [] 15 | 16 | if not callbacks_cfg: 17 | log.warning("No callback configs found! Skipping..") 18 | return callbacks 19 | 20 | if not isinstance(callbacks_cfg, DictConfig): 21 | raise TypeError("Callbacks config must be a DictConfig!") 22 | 23 | for _, cb_conf in callbacks_cfg.items(): 24 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 25 | log.info(f"Instantiating callback <{cb_conf._target_}>") 26 | callbacks.append(hydra.utils.instantiate(cb_conf)) 27 | 28 | return callbacks 29 | 30 | 31 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 32 | """Instantiates loggers from config.""" 33 | 34 | logger: List[Logger] = [] 35 | 36 | if not logger_cfg: 37 | log.warning("No logger configs found! Skipping...") 38 | return logger 39 | 40 | if not isinstance(logger_cfg, DictConfig): 41 | raise TypeError("Logger config must be a DictConfig!") 42 | 43 | for _, lg_conf in logger_cfg.items(): 44 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 45 | log.info(f"Instantiating logger <{lg_conf._target_}>") 46 | logger.append(hydra.utils.instantiate(lg_conf)) 47 | 48 | return logger 49 | -------------------------------------------------------------------------------- /fish_vocoder/utils/logger.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | from loguru import logger 3 | 4 | # this ensures all logging levels get marked with the rank zero decorator 5 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 6 | logging_levels = ( 7 | "debug", 8 | "info", 9 | "warning", 10 | "error", 11 | "exception", 12 | "critical", 13 | ) 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | -------------------------------------------------------------------------------- /fish_vocoder/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from fish_vocoder.utils import logger as log 4 | 5 | 6 | @rank_zero_only 7 | def log_hyperparameters(object_dict: dict) -> None: 8 | """Controls which config parts are saved by lightning loggers. 9 | 10 | Additionally saves: 11 | - Number of model parameters 12 | """ 13 | 14 | hparams = {} 15 | 16 | cfg = object_dict["cfg"] 17 | model = object_dict["model"] 18 | trainer = object_dict["trainer"] 19 | 20 | if not trainer.logger: 21 | log.warning("Logger not found! Skipping hyperparameter logging...") 22 | return 23 | 24 | hparams["model"] = cfg["model"] 25 | 26 | # save number of model parameters 27 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 28 | hparams["model/params/trainable"] = sum( 29 | p.numel() for p in model.parameters() if p.requires_grad 30 | ) 31 | hparams["model/params/non_trainable"] = sum( 32 | p.numel() for p in model.parameters() if not p.requires_grad 33 | ) 34 | 35 | hparams["data"] = cfg["data"] 36 | hparams["trainer"] = cfg["trainer"] 37 | 38 | hparams["callbacks"] = cfg.get("callbacks") 39 | hparams["extras"] = cfg.get("extras") 40 | 41 | hparams["task_name"] = cfg.get("task_name") 42 | hparams["tags"] = cfg.get("tags") 43 | hparams["ckpt_path"] = cfg.get("ckpt_path") 44 | hparams["seed"] = cfg.get("seed") 45 | 46 | # send hparams to all loggers 47 | for logger in trainer.loggers: 48 | logger.log_hyperparams(hparams) 49 | -------------------------------------------------------------------------------- /fish_vocoder/utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sequence_mask(lengths, max_length=None) -> torch.Tensor: 5 | if max_length is None: 6 | max_length = lengths.max() 7 | 8 | x = torch.arange(max_length, dtype=lengths.dtype, device=lengths.device) 9 | 10 | return x.unsqueeze(0) < lengths.unsqueeze(1) 11 | -------------------------------------------------------------------------------- /fish_vocoder/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from fish_vocoder.utils import logger as log 13 | 14 | 15 | @rank_zero_only 16 | def print_config_tree( 17 | cfg: DictConfig, 18 | print_order: Sequence[str] = ( 19 | "data", 20 | "model", 21 | "callbacks", 22 | "logger", 23 | "trainer", 24 | "paths", 25 | "extras", 26 | ), 27 | resolve: bool = False, 28 | save_to_file: bool = False, 29 | ) -> None: 30 | """Prints content of DictConfig using Rich library and its tree structure. 31 | 32 | Args: 33 | cfg (DictConfig): Configuration composed by Hydra. 34 | print_order (Sequence[str], optional): Determines in what order config components are printed. 35 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 36 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 37 | """ # noqa: E501 38 | 39 | style = "dim" 40 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 41 | 42 | queue = [] 43 | 44 | # add fields from `print_order` to queue 45 | for field in print_order: 46 | ( 47 | queue.append(field) 48 | if field in cfg 49 | else log.warning( 50 | f"Field '{field}' not found in config. " 51 | + f"Skipping '{field}' config printing..." 52 | ) 53 | ) 54 | 55 | # add all the other fields to queue (not specified in `print_order`) 56 | for field in cfg: 57 | if field not in queue: 58 | queue.append(field) 59 | 60 | # generate config tree from queue 61 | for field in queue: 62 | branch = tree.add(field, style=style, guide_style=style) 63 | 64 | config_group = cfg[field] 65 | if isinstance(config_group, DictConfig): 66 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 67 | else: 68 | branch_content = str(config_group) 69 | 70 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 71 | 72 | # print config tree 73 | rich.print(tree) 74 | 75 | # save config tree to file 76 | if save_to_file: 77 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 78 | rich.print(tree, file=file) 79 | 80 | 81 | @rank_zero_only 82 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 83 | """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 84 | 85 | if not cfg.get("tags"): 86 | if "id" in HydraConfig().cfg.hydra.job: 87 | raise ValueError("Specify tags before launching a multirun!") 88 | 89 | log.warning("No tags provided in config. Prompting user to input tags...") 90 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 91 | tags = [t.strip() for t in tags.split(",") if t != ""] 92 | 93 | with open_dict(cfg): 94 | cfg.tags = tags 95 | 96 | log.info(f"Tags: {cfg.tags}") 97 | 98 | if save_to_file: 99 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 100 | rich.print(cfg.tags, file=file) 101 | -------------------------------------------------------------------------------- /fish_vocoder/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from typing import Callable 4 | 5 | from omegaconf import DictConfig 6 | 7 | from fish_vocoder.utils import logger as log 8 | from fish_vocoder.utils import rich_utils 9 | 10 | 11 | def extras(cfg: DictConfig) -> None: 12 | """Applies optional utilities before the task is started. 13 | 14 | Utilities: 15 | - Ignoring python warnings 16 | - Setting tags from command line 17 | - Rich config printing 18 | """ 19 | 20 | # return if no `extras` config 21 | if not cfg.get("extras"): 22 | log.warning("Extras config not found! ") 23 | return 24 | 25 | # disable python warnings 26 | if cfg.extras.get("ignore_warnings"): 27 | log.info("Disabling python warnings! ") 28 | warnings.filterwarnings("ignore") 29 | 30 | # prompt user to input tags from command line if none are provided in the config 31 | if cfg.extras.get("enforce_tags"): 32 | log.info("Enforcing tags! ") 33 | rich_utils.enforce_tags(cfg, save_to_file=True) 34 | 35 | # pretty print config tree using Rich library 36 | if cfg.extras.get("print_config"): 37 | log.info("Printing config tree with Rich! ") 38 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 39 | 40 | 41 | def task_wrapper(task_func: Callable) -> Callable: 42 | """Optional decorator that controls the failure behavior when executing the task function. 43 | 44 | This wrapper can be used to: 45 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 46 | - save the exception to a `.log` file 47 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 48 | - etc. (adjust depending on your needs) 49 | 50 | Example: 51 | ``` 52 | @utils.task_wrapper 53 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 54 | 55 | ... 56 | 57 | return metric_dict, object_dict 58 | ``` 59 | """ # noqa: E501 60 | 61 | def wrap(cfg: DictConfig): 62 | # execute the task 63 | try: 64 | metric_dict, object_dict = task_func(cfg=cfg) 65 | 66 | # things to do if exception occurs 67 | except Exception as ex: 68 | # save exception to `.log` file 69 | log.exception("") 70 | 71 | # some hyperparameter combinations might be invalid or 72 | # cause out-of-memory errors so when using hparam search 73 | # plugins like Optuna, you might want to disable 74 | # raising the below exception to avoid multirun failure 75 | raise ex 76 | 77 | # things to always do after either success or exception 78 | finally: 79 | # display output dir path in terminal 80 | log.info(f"Output dir: {cfg.paths.output_dir}") 81 | 82 | # always close wandb run (even if exception occurs so multirun won't fail) 83 | if find_spec("wandb"): # check if wandb is installed 84 | import wandb 85 | 86 | if wandb.run: 87 | log.info("Closing wandb!") 88 | wandb.finish() 89 | 90 | return metric_dict, object_dict 91 | 92 | return wrap 93 | 94 | 95 | def get_metric_value(metric_dict: dict, metric_name: str) -> float: 96 | """Safely retrieves value of the metric logged in LightningModule.""" 97 | 98 | if not metric_name: 99 | log.info("Metric name is None! Skipping metric value retrieval...") 100 | return None 101 | 102 | if metric_name not in metric_dict: 103 | raise Exception( 104 | f"Metric value not found! \n" 105 | "Make sure metric name logged in LightningModule is correct!\n" 106 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 107 | ) 108 | 109 | metric_value = metric_dict[metric_name].item() 110 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 111 | 112 | return metric_value 113 | -------------------------------------------------------------------------------- /fish_vocoder/utils/viz.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from matplotlib import pyplot as plt 3 | from torch import Tensor 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def plot_mel(data, titles=None): 9 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 10 | 11 | if titles is None: 12 | titles = [None for i in range(len(data))] 13 | 14 | plt.tight_layout() 15 | 16 | for i in range(len(data)): 17 | mel = data[i] 18 | 19 | if isinstance(mel, Tensor): 20 | mel = mel.detach().cpu().numpy() 21 | 22 | axes[i][0].imshow(mel, origin="lower") 23 | axes[i][0].set_aspect(2.5, adjustable="box") 24 | axes[i][0].set_ylim(0, mel.shape[0]) 25 | axes[i][0].set_title(titles[i], fontsize="medium") 26 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 27 | axes[i][0].set_anchor("W") 28 | 29 | return fig 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fish-vocoder" 3 | version = "1.0.0" 4 | description = "" 5 | license = {text = "MIT"} 6 | authors = [ 7 | {name = "Lengyue", email = "lengyue@lengyue.me"}, 8 | ] 9 | requires-python = ">=3.10,<4.0" 10 | dependencies = [ 11 | "librosa<1.0.0,>=0.9.1", 12 | "numba<1.0.0,>=0.56.4", 13 | "wandb>=0.15.4", 14 | "loguru>=0.7.0", 15 | "tensorboard<3.0.0,>=2.11.2", 16 | "natsort<9.0.0,>=8.3.1", 17 | "torch<3.0.0,>=2.0.0", 18 | "torchaudio<3.0.0,>=2.0.0", 19 | "lightning>=2.0.3", 20 | "hydra-core>=1.3.2", 21 | "pyrootutils>=1.0.4", 22 | "hydra-colorlog>=1.2.0", 23 | "torch-summary>=1.4.5", 24 | "matplotlib>=3.7.1", 25 | "encodec>=0.1.1", 26 | "vocos>=0.0.2", 27 | "transformers>=4.31.0", 28 | "pesq>=0.0.4", 29 | "alias-free-torch>=0.0.6", 30 | ] 31 | 32 | [tool.pdm] 33 | [tool.pdm.build] 34 | includes = ["fish_vocoder"] 35 | 36 | [tool.pdm.dev-dependencies] 37 | dev = [ 38 | "black>=22.12.0", 39 | "pytest>=7.3.1", 40 | "pre-commit>=3.3.3", 41 | "ruff>=0.0.280", 42 | ] 43 | 44 | [build-system] 45 | requires = ["pdm-backend"] 46 | build-backend = "pdm.backend" 47 | 48 | [tool.pdm.scripts] 49 | lint = { shell = "black . && ruff check --fix ." } 50 | check = { shell = "black --check . && ruff check ." } 51 | test = { shell = "PYTHONPATH=. pytest -n=auto -q tests" } 52 | 53 | [[tool.pdm.source]] 54 | type = "find_links" 55 | name = "torch-cu118" 56 | url = "https://download.pytorch.org/whl/cu118" 57 | verify_ssl = true 58 | 59 | [tool.isort] 60 | profile = "black" 61 | extend_skip = ["dataset", "logs"] 62 | 63 | [tool.ruff] 64 | line-length = 88 65 | select = [ 66 | # Pyflakes 67 | "F", 68 | # Pycodestyle 69 | "E", 70 | "W", 71 | # isort 72 | "I001" 73 | ] 74 | -------------------------------------------------------------------------------- /scripts/convert_diffsinger_mel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | data = torch.load("能解答一切的答案.mel.pt", map_location="cpu") 4 | all_mels = [i["mel"] for i in data] 5 | # all_mel = torch.cat(all_mels, dim=1) / 0.434294 6 | 7 | all_mel = ( 8 | torch.zeros( 9 | (1, int(data[-1]["offset"] * 44100 / 512) + data[-1]["mel"].shape[1], 128) 10 | ) 11 | - 11.512925 12 | ) 13 | 14 | for seg in data: 15 | offset = int(seg["offset"] * 44100 / 512) 16 | mel = seg["mel"] / 0.434294 17 | all_mel[:, offset : offset + mel.shape[1], :] = mel 18 | 19 | torch.save(all_mel, "other/mels/能解答一切的答案.mel.pt") 20 | -------------------------------------------------------------------------------- /scripts/random_copy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import click 5 | from loguru import logger 6 | 7 | 8 | @click.command() 9 | @click.argument("src", type=click.Path(exists=True, file_okay=False, dir_okay=True)) 10 | @click.argument("dst", type=click.Path(exists=False, file_okay=False, dir_okay=True)) 11 | @click.argument("num", type=int) 12 | @click.option("--seed", help="Random seed", type=int, default=42) 13 | def random_copy(src: Path, dst: Path, num: int, seed: int): 14 | """Copy random files from SRC to DST.""" 15 | 16 | src, dst = Path(src), Path(dst) 17 | 18 | files = [f for f in src.rglob("*") if f.is_file() and f.suffix in [".wav", ".flac"]] 19 | logger.info(f"Found {len(files)} files in {src}") 20 | 21 | generator = random.Random(seed) 22 | selected_files = generator.sample(files, num) 23 | 24 | logger.info(f"Copying {len(selected_files)} files to {dst}") 25 | 26 | for f in selected_files: 27 | f_dst = dst / f.relative_to(src) 28 | f_dst.parent.mkdir(parents=True, exist_ok=True) 29 | f_dst.write_bytes(f.read_bytes()) 30 | 31 | logger.info("Done") 32 | 33 | 34 | if __name__ == "__main__": 35 | random_copy() 36 | -------------------------------------------------------------------------------- /scripts/test_firefly_gan.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python fish_vocoder/test.py task_name=firefly-gan-test \ 4 | model/generator=firefly-gan-base \ 5 | ckpt_path=checkpoints/firefly-gan-base.ckpt \ 6 | 'input_path="other"' \ 7 | 'output_path="results/test-audios"' 8 | -------------------------------------------------------------------------------- /scripts/test_hifigan.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/test.py task_name=hifigan \ 2 | model/generator=hifigan \ 3 | ckpt_path=other/step_001235000.ckpt \ 4 | 'input_path="other/Track 10_014_1.wav"' 5 | -------------------------------------------------------------------------------- /scripts/test_vocos_huge.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/test.py task_name=vocos-huge \ 2 | model/generator=vocos-huge \ 3 | ckpt_path=other/vocos_huge/step_000475000.ckpt \ 4 | 'input_path="other"' \ 5 | 'output_path="results/test-audios"' 6 | -------------------------------------------------------------------------------- /scripts/train_convnext_bigvgan_base.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=convnext-bigvgan-base \ 2 | model/generator=convnext-bigvgan-base \ 3 | model.num_frames=16 \ 4 | data.batch_size=32 \ 5 | logger=tensorboard 6 | -------------------------------------------------------------------------------- /scripts/train_convnext_hifigan_base.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=convnext-hifigan-base \ 2 | model/generator=convnext-hifigan-base \ 3 | model.num_frames=16 \ 4 | data.batch_size=32 \ 5 | logger=tensorboard 6 | -------------------------------------------------------------------------------- /scripts/train_convnext_hifigan_vae.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=convnext-hifigan-vae \ 2 | model=vae \ 3 | model/generator=convnext-hifigan-vae \ 4 | data.datasets.train.datasets.hifi-8000h.dataset.root=filelist.hifi-8000h.train 5 | -------------------------------------------------------------------------------- /scripts/train_vocos.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=vocos \ 2 | model/generator=vocos \ 3 | data.datasets.train.datasets.hifi-8000h.dataset.root=filelist.hifi-8000h.train \ 4 | logger=tensorboard 5 | -------------------------------------------------------------------------------- /scripts/train_vocos_huge.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=vocos-huge \ 2 | model/generator=vocos-huge \ 3 | data.batch_size=4 \ 4 | data.datasets.train.datasets.hifi-8000h.dataset.root=filelist.hifi-8000h.train \ 5 | logger=tensorboard 6 | -------------------------------------------------------------------------------- /scripts/train_vocos_huge_full.sh: -------------------------------------------------------------------------------- 1 | python fish_vocoder/train.py task_name=vocos-huge-full \ 2 | model/generator=vocos-huge \ 3 | model.num_mels=160 \ 4 | model.mel_transforms.modules.input.f_min=0 \ 5 | model.mel_transforms.modules.input.f_max=22050 \ 6 | data.batch_size=4 \ 7 | trainer.precision=16-mixed \ 8 | logger=tensorboard 9 | -------------------------------------------------------------------------------- /scripts/vocos_gen.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torchaudio 4 | from tqdm import tqdm 5 | from vocos import Vocos 6 | 7 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 8 | source = Path("dataset/LibriTTS/test-other") 9 | target = Path("results/LibriTTS/test-other/vocos-official") 10 | 11 | for i in tqdm(list(source.rglob("*.wav"))): 12 | y, sr = torchaudio.load(i) 13 | if y.size(0) > 1: # mix to mono 14 | y = y.mean(dim=0, keepdim=True) 15 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) 16 | y_hat = vocos(y) 17 | 18 | target_file = target / i.relative_to(source) 19 | target_file.parent.mkdir(parents=True, exist_ok=True) 20 | torchaudio.save(target_file, y_hat, sample_rate=24000) 21 | --------------------------------------------------------------------------------