├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── fish_speech ├── callbacks │ ├── __init__.py │ └── grad_norm.py ├── configs │ ├── base.yaml │ ├── lora │ │ └── r_8_alpha_16.yaml │ ├── model │ │ ├── dual_ar_2_codebook_large.yaml │ │ ├── dual_ar_2_codebook_medium.yaml │ │ ├── dual_ar_2_codebook_small.yaml │ │ └── naive_2_codebook_small.yaml │ ├── text2semantic_finetune.yaml │ ├── text2semantic_pretrain.yaml │ ├── text2semantic_sft.yaml │ ├── vits_decoder_finetune.yaml │ ├── vits_decoder_pretrain.yaml │ ├── vqgan_finetune.yaml │ └── vqgan_pretrain.yaml ├── datasets │ ├── concat_repeat.py │ ├── protos │ │ ├── text-data.proto │ │ ├── text_data_pb2.py │ │ └── text_data_stream.py │ ├── text.py │ ├── vits.py │ └── vqgan.py ├── i18n │ ├── __init__.py │ ├── core.py │ ├── locale │ │ ├── en_US.json │ │ ├── es_ES.json │ │ ├── ja_JP.json │ │ └── zh_CN.json │ └── scan.py ├── models │ ├── text2semantic │ │ ├── __init__.py │ │ ├── lit_module.py │ │ ├── llama.py │ │ └── lora_utils.py │ ├── vits_decoder │ │ ├── __init__.py │ │ ├── lit_module.py │ │ ├── losses.py │ │ └── modules │ │ │ ├── attentions.py │ │ │ ├── commons.py │ │ │ ├── models.py │ │ │ ├── modules.py │ │ │ ├── mrte.py │ │ │ └── vq_encoder.py │ └── vqgan │ │ ├── __init__.py │ │ ├── lit_module.py │ │ ├── modules │ │ ├── discriminator.py │ │ ├── firefly.py │ │ ├── fsq.py │ │ ├── reference.py │ │ └── wavenet.py │ │ └── utils.py ├── scheduler.py ├── text │ ├── __init__.py │ └── clean.py ├── train.py ├── utils │ ├── __init__.py │ ├── braceexpand.py │ ├── file.py │ ├── instantiators.py │ ├── logger.py │ ├── logging_utils.py │ ├── rich_utils.py │ ├── spectrogram.py │ └── utils.py └── webui │ ├── css │ └── style.css │ ├── html │ └── footer.html │ ├── js │ └── animate.js │ ├── launch_utils.py │ └── manage.py ├── nodes.py ├── requirements.txt ├── tools ├── api.py ├── extract_model.py ├── llama │ ├── build_dataset.py │ ├── generate.py │ ├── merge_lora.py │ ├── quantize.py │ └── rebuild_tokenizer.py ├── merge_asr_files.py ├── vqgan │ ├── create_train_split.py │ ├── extract_vq.py │ └── inference.py ├── webui.py └── whisper_asr.py └── web └── js ├── previewAudio.js ├── uploadAudio.js └── uploadSRT.js /.gitignore: -------------------------------------------------------------------------------- 1 | .pgx.* 2 | .pdm-python 3 | __pycache__ 4 | /results 5 | /data 6 | /*.test.sh 7 | *.filelist 8 | filelists 9 | /fish_speech/text/cmudict_cache.pickle 10 | /checkpoints 11 | /.vscode 12 | /data_server/target 13 | /*.npy 14 | /*.wav 15 | /*.mp3 16 | /results 17 | /data 18 | /.idea 19 | ffmpeg.exe 20 | ffprobe.exe 21 | asr-label-win-x64.exe 22 | /.cache 23 | /fishenv 24 | /.locale 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, AIFSH 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-FishSpeech 2 | a custom comfyui node for [fish-speech](https://github.com/fishaudio/fish-speech.git) 3 | 4 | ## Disclaimer / 免责声明 5 | We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws. 6 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. 7 | 8 | ## How to use 9 | make sure `ffmpeg` is worked in your commandline 10 | for Linux 11 | ``` 12 | apt update 13 | apt install ffmpeg 14 | ``` 15 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically 16 | 17 | then! 18 | ``` 19 | git clone https://github.com/AIFSH/ComfyUI-FishSpeech.git 20 | cd ComfyUI-FishSpeech 21 | pip install -r requirements.txt 22 | ``` 23 | `weights` will be downloaded from huggingface automatically! if you in china,make sure your internet attach the huggingface 24 | or if you still struggle with huggingface, you may try follow [hf-mirror](https://hf-mirror.com/) to config your env. 25 | 26 | [This repository is publicly accessible, but you have to accept the conditions to access its files and content.](https://huggingface.co/fishaudio/fish-speech-1) 27 | 28 | you may meet 29 | ``` 30 | ubprocess.CalledProcessError: Command '['cmake', '/tmp/pip-install-y_5f2bue/samplerate_578130ccaceb41abb26587e96f64988e', '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=/tmp/pip-install-y_5f2bue/samplerate_578130ccaceb41abb26587e96f64988e/build/lib.linux-x86_64-cpython-310/', '-DPYTHON_EXECUTABLE=/usr/local/miniconda3/bin/python', '-DCMAKE_BUILD_TYPE=Release', '-DPACKAGE_VERSION_INFO=0.2.1']' returned non-zero exit status 1. 31 | ``` 32 | when install `samplerate` 33 | 34 | try 35 | 36 | ``` 37 | pip -q install git+https://github.com/tuxu/python-samplerate.git@fix_cmake_dep 38 | ``` 39 | 40 | if 41 | ``` 42 | "cannot import name 'weight_norm' from 'torch.nn.utils.parametrizations' 43 | ``` 44 | please update your `torch` 45 | 46 | ## Tutorial 47 | [Demo](https://www.bilibili.com/video/BV1Tx4y1B7zE/) 48 | 49 | ## Thanks 50 | - [fish-speech](https://github.com/fishaudio/fish-speech.git) 51 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os,site,sys 2 | now_dir = os.path.dirname(os.path.abspath(__file__)) 3 | site_packages_roots = [] 4 | for path in site.getsitepackages(): 5 | if "packages" in path: 6 | site_packages_roots.append(path) 7 | if(site_packages_roots==[]):site_packages_roots=["%s/runtime/Lib/site-packages" % now_dir] 8 | #os.environ["OPENBLAS_NUM_THREADS"] = "4" 9 | for site_packages_root in site_packages_roots: 10 | if os.path.exists(site_packages_root): 11 | try: 12 | with open("%s/fish_speech.pth" % (site_packages_root), "w") as f: 13 | f.write( 14 | "%s\n%s/fish_speech\n" 15 | % (now_dir,now_dir) 16 | ) 17 | break 18 | except PermissionError: 19 | raise PermissionError 20 | 21 | if os.path.isfile("%s/fish_speech.pth" % (site_packages_root)): 22 | print("!!!fish_speech path was added to " + "%s/fish_speech.pth" % (site_packages_root) 23 | + "\n if meet `No module` error,try `python main.py` again") 24 | 25 | 26 | WEB_DIRECTORY = "./web" 27 | from .nodes import LoadAudio,PreViewAudio,LoadSRT,FishSpeech_INFER,FishSpeech_INFER_SRT 28 | 29 | NODE_CLASS_MAPPINGS = { 30 | "LoadAudio": LoadAudio, 31 | "PreViewAudio": PreViewAudio, 32 | "LoadSRT": LoadSRT, 33 | "FishSpeech_INFER": FishSpeech_INFER, 34 | "FishSpeech_INFER_SRT": FishSpeech_INFER_SRT 35 | } 36 | 37 | # A dictionary that contains the friendly/humanly readable titles for the nodes 38 | NODE_DISPLAY_NAME_MAPPINGS = { 39 | "LoadAudio": "AudioLoader", 40 | "PreViewAudio": "PreView Audio", 41 | "LoadSRT": "SRT FILE Loader", 42 | "FishSpeech_INFER": "FishSpeech Inference", 43 | "FishSpeech_INFER_SRT": "FishSpeech Voice Clone" 44 | } -------------------------------------------------------------------------------- /fish_speech/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .grad_norm import GradNormMonitor 2 | 3 | __all__ = ["GradNormMonitor"] 4 | -------------------------------------------------------------------------------- /fish_speech/callbacks/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import lightning.pytorch as pl 4 | import torch 5 | from lightning import LightningModule, Trainer 6 | from lightning.pytorch.callbacks import Callback 7 | from torch import Tensor, nn 8 | from torch.utils._foreach_utils import ( 9 | _group_tensors_by_device_and_dtype, 10 | _has_foreach_support, 11 | ) 12 | 13 | 14 | @torch.no_grad() 15 | def grad_norm( 16 | parameters: Union[Tensor, list[Tensor]], 17 | norm_type: float = 2.0, 18 | ) -> float: 19 | """ 20 | Returns the norm of the gradients of the given parameters. 21 | 22 | Args: 23 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 24 | single Tensor that will have gradients normalized 25 | norm_type (float): type of the used p-norm. 26 | 27 | Returns: 28 | Total norm of the parameter gradients (viewed as a single vector). 29 | """ # noqa: E501 30 | 31 | if isinstance(parameters, Tensor): 32 | parameters = [parameters] 33 | 34 | grads = [p.grad for p in parameters if p.grad is not None] 35 | if len(grads) == 0: 36 | return None 37 | 38 | first_device = grads[0].device 39 | grouped_grads: dict[ 40 | tuple[torch.device, torch.dtype], list[list[Tensor]] 41 | ] = _group_tensors_by_device_and_dtype( 42 | [[g.detach() for g in grads]] 43 | ) # type: ignore[assignment] 44 | 45 | norms = [] 46 | for (device, _), ([grads], _) in grouped_grads.items(): 47 | if _has_foreach_support(grads, device=device): 48 | norms.extend(torch._foreach_norm(grads, norm_type)) 49 | else: 50 | norms.extend([torch.norm(g, norm_type) for g in grads]) 51 | 52 | return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) 53 | 54 | 55 | class GradNormMonitor(Callback): 56 | """ 57 | Callback that computes the gradient norm of the model parameters. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | norm_type: float = 2.0, 63 | logging_interval: str = "step", 64 | sub_module: Optional[Union[str, list[str]]] = None, 65 | ) -> None: 66 | """ 67 | Args: 68 | norm_type (float): type of the used p-norm. 69 | logging_interval (str): "step" or "epoch". 70 | """ 71 | super().__init__() 72 | 73 | self.norm_type = norm_type 74 | self.logging_interval = logging_interval 75 | self.sub_module = sub_module 76 | 77 | def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: 78 | """ 79 | Computes the gradient norm of the model parameters and logs it to the logger. 80 | 81 | Args: 82 | trainer (Trainer): The trainer object 83 | model (LightningModule): The current lightningModule 84 | """ 85 | 86 | lightning_model = model 87 | 88 | if self.sub_module is None: 89 | return self.log_sub_module_grad_norm(lightning_model, model, "") 90 | 91 | sub_modules = self.sub_module 92 | if isinstance(sub_modules, str): 93 | sub_modules = [sub_modules] 94 | 95 | for sub_module in sub_modules: 96 | self.log_sub_module_grad_norm( 97 | lightning_model, getattr(model, sub_module), f"/{sub_module}" 98 | ) 99 | 100 | def log_sub_module_grad_norm( 101 | self, lightning_model: LightningModule, model: nn.Module, path: str 102 | ) -> None: 103 | grad_norm_val = grad_norm(model.parameters(), self.norm_type) 104 | if grad_norm_val is None: 105 | return 106 | 107 | on_step = self.logging_interval == "step" 108 | lightning_model.log( 109 | f"train{path}/grad_norm", 110 | grad_norm_val, 111 | on_step=on_step, 112 | on_epoch=not on_step, 113 | ) 114 | -------------------------------------------------------------------------------- /fish_speech/configs/base.yaml: -------------------------------------------------------------------------------- 1 | # Base configuration for training a model 2 | paths: 3 | run_dir: results/${project} 4 | ckpt_dir: ${paths.run_dir}/checkpoints 5 | 6 | hydra: 7 | run: 8 | dir: ${paths.run_dir} 9 | 10 | # Lightning Trainer 11 | trainer: 12 | _target_: lightning.pytorch.trainer.Trainer 13 | 14 | default_root_dir: ${paths.run_dir} 15 | accelerator: gpu 16 | num_nodes: 1 17 | devices: auto 18 | strategy: 19 | _target_: lightning.pytorch.strategies.DDPStrategy 20 | process_group_backend: nccl # This should be override when training on windows 21 | 22 | precision: bf16-mixed 23 | 24 | # disable validation by epoch end 25 | check_val_every_n_epoch: null 26 | val_check_interval: 5000 27 | max_steps: 100_000 28 | 29 | # Use torch.backends.cudnn.benchmark to speed up training 30 | benchmark: true 31 | 32 | # Callbacks 33 | callbacks: 34 | model_checkpoint: 35 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 36 | dirpath: ${paths.ckpt_dir} 37 | filename: "step_{step:09d}" 38 | save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt 39 | save_top_k: 5 # save 5 latest checkpoints 40 | monitor: step # use step to monitor checkpoints 41 | mode: max # save the latest checkpoint with the highest global_step 42 | every_n_epochs: null # don't save checkpoints by epoch end 43 | every_n_train_steps: 5000 # save checkpoints every 5000 steps 44 | auto_insert_metric_name: false 45 | 46 | model_summary: 47 | _target_: lightning.pytorch.callbacks.ModelSummary 48 | max_depth: 2 # the maximum depth of layer nesting that the summary will include 49 | 50 | learning_rate_monitor: 51 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 52 | logging_interval: step 53 | log_momentum: false 54 | 55 | grad_norm_monitor: 56 | _target_: fish_speech.callbacks.GradNormMonitor 57 | norm_type: 2 58 | logging_interval: step 59 | 60 | # Logger 61 | logger: 62 | tensorboard: 63 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 64 | save_dir: "${paths.run_dir}/tensorboard/" 65 | name: null 66 | log_graph: false 67 | default_hp_metric: true 68 | prefix: "" 69 | 70 | # wandb: 71 | # _target_: lightning.pytorch.loggers.wandb.WandbLogger 72 | # # name: "" # name of the run (normally generated by wandb) 73 | # save_dir: "${paths.run_dir}" 74 | # offline: False 75 | # id: null # pass correct id to resume experiment! 76 | # anonymous: null # enable anonymous logging 77 | # project: "fish-speech" 78 | # log_model: False # upload lightning ckpts 79 | # prefix: "" # a string to put at the beginning of metric keys 80 | # # entity: "" # set to name of your wandb team 81 | # group: "" 82 | # tags: ["vq", "hq", "finetune"] 83 | # job_type: "" 84 | 85 | # Loop 86 | train: true 87 | test: false 88 | -------------------------------------------------------------------------------- /fish_speech/configs/lora/r_8_alpha_16.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.text2semantic.lora_utils.LoraConfig 2 | r: 8 3 | lora_alpha: 16 4 | -------------------------------------------------------------------------------- /fish_speech/configs/model/dual_ar_2_codebook_large.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dual_ar_2_codebook_small 3 | - _self_ 4 | 5 | config: 6 | n_layer: 30 7 | n_fast_layer: 6 8 | n_head: 24 9 | dim: 1536 10 | -------------------------------------------------------------------------------- /fish_speech/configs/model/dual_ar_2_codebook_medium.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dual_ar_2_codebook_small 3 | - _self_ 4 | 5 | config: 6 | n_layer: 24 7 | n_fast_layer: 6 8 | n_head: 16 9 | dim: 1024 10 | -------------------------------------------------------------------------------- /fish_speech/configs/model/dual_ar_2_codebook_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.text2semantic.llama.DualARTransformer 2 | config: 3 | _target_: fish_speech.models.text2semantic.llama.DualARModelArgs 4 | max_seq_len: ${max_length} 5 | vocab_size: 264 # pad 262 to 8x 6 | n_layer: 12 7 | n_fast_layer: 4 8 | n_head: 12 9 | dim: 768 10 | rope_base: 10000 11 | norm_eps: 1e-5 12 | num_codebooks: 2 # input/output codebook size 13 | codebook_size: 1032 # codebook size 1024 + 2 special tokens 14 | -------------------------------------------------------------------------------- /fish_speech/configs/model/naive_2_codebook_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.text2semantic.llama.NaiveTransformer 2 | config: 3 | _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs 4 | max_seq_len: ${max_length} 5 | vocab_size: 36408 6 | n_layer: 12 7 | n_head: 12 8 | dim: 768 9 | rope_base: 10000 10 | norm_eps: 1e-5 11 | num_codebooks: 2 # input/output codebook size 12 | codebook_size: 1032 # codebook size 1024 + 2 special tokens 13 | -------------------------------------------------------------------------------- /fish_speech/configs/text2semantic_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - model@model.model: dual_ar_2_codebook_small 4 | - _self_ 5 | 6 | project: text2semantic_finetune_dual_ar 7 | max_length: 2048 8 | ckpt_path: checkpoints/text2semantic-sft-medium-v1.1-4k.pth 9 | resume_weights_only: true 10 | 11 | # Lightning Trainer 12 | trainer: 13 | accumulate_grad_batches: 1 14 | gradient_clip_val: 1.0 15 | gradient_clip_algorithm: "norm" 16 | max_steps: 1000 17 | precision: bf16-true 18 | limit_val_batches: 10 19 | val_check_interval: 100 20 | 21 | # Dataset Configuration 22 | tokenizer: 23 | _target_: transformers.AutoTokenizer.from_pretrained 24 | pretrained_model_name_or_path: fishaudio/fish-speech-1 25 | 26 | # Dataset Configuration 27 | train_dataset: 28 | _target_: fish_speech.datasets.text.AutoAugTextDataset 29 | proto_files: 30 | - data/protos 31 | tokenizer: ${tokenizer} 32 | max_length: ${max_length} 33 | num_codebooks: ${model.model.config.num_codebooks} 34 | use_speaker: 0.5 35 | interactive_prob: 0.7 36 | 37 | val_dataset: 38 | _target_: fish_speech.datasets.text.AutoAugTextDataset 39 | proto_files: 40 | - data/protos 41 | tokenizer: ${tokenizer} 42 | max_length: ${max_length} 43 | num_codebooks: ${model.model.config.num_codebooks} 44 | use_speaker: 0.5 45 | interactive_prob: 0.7 46 | 47 | data: 48 | _target_: fish_speech.datasets.text.TextDataModule 49 | train_dataset: ${train_dataset} 50 | val_dataset: ${val_dataset} 51 | num_workers: 4 52 | batch_size: 8 53 | tokenizer: ${tokenizer} 54 | max_length: ${max_length} 55 | 56 | # Model Configuration 57 | model: 58 | _target_: fish_speech.models.text2semantic.TextToSemantic 59 | model: {} 60 | 61 | optimizer: 62 | _target_: torch.optim.AdamW 63 | _partial_: true 64 | lr: 1e-5 65 | weight_decay: 0 66 | betas: [0.9, 0.95] 67 | eps: 1e-5 68 | 69 | lr_scheduler: 70 | _target_: torch.optim.lr_scheduler.LambdaLR 71 | _partial_: true 72 | lr_lambda: 73 | _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda 74 | _partial_: true 75 | num_warmup_steps: 0.1 76 | num_training_steps: ${trainer.max_steps} 77 | 78 | # Callbacks 79 | callbacks: 80 | model_checkpoint: 81 | every_n_train_steps: 100 82 | -------------------------------------------------------------------------------- /fish_speech/configs/text2semantic_pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - model@model.model: dual_ar_2_codebook_small 4 | - _self_ 5 | 6 | project: text2semantic_pretrain_dual_ar_debug 7 | max_length: 2048 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accumulate_grad_batches: 1 12 | gradient_clip_val: 1.0 13 | gradient_clip_algorithm: 'norm' 14 | max_steps: 1_000_000 15 | precision: bf16-true 16 | limit_val_batches: 10 17 | 18 | # Dataset Configuration 19 | tokenizer: 20 | _target_: transformers.AutoTokenizer.from_pretrained 21 | pretrained_model_name_or_path: fishaudio/fish-speech-1 22 | 23 | # Dataset Configuration 24 | train_dataset: 25 | _target_: fish_speech.datasets.text.AutoAugTextDataset 26 | proto_files: 27 | - data/protos/train 28 | tokenizer: ${tokenizer} 29 | max_length: ${max_length} 30 | num_codebooks: ${model.model.config.num_codebooks} 31 | use_speaker: false 32 | interactive_prob: 0.5 33 | 34 | val_dataset: 35 | _target_: fish_speech.datasets.text.AutoAugTextDataset 36 | proto_files: 37 | - data/protos/test 38 | tokenizer: ${tokenizer} 39 | max_length: ${max_length} 40 | num_codebooks: ${model.model.config.num_codebooks} 41 | use_speaker: false 42 | interactive_prob: 0.5 43 | 44 | data: 45 | _target_: fish_speech.datasets.text.TextDataModule 46 | train_dataset: ${train_dataset} 47 | val_dataset: ${val_dataset} 48 | num_workers: 4 49 | batch_size: 8 50 | tokenizer: ${tokenizer} 51 | max_length: ${max_length} 52 | 53 | # Model Configuration 54 | model: 55 | _target_: fish_speech.models.text2semantic.TextToSemantic 56 | model: {} 57 | 58 | optimizer: 59 | _target_: torch.optim.AdamW 60 | _partial_: true 61 | lr: 3e-4 62 | weight_decay: 0.01 63 | betas: [0.9, 0.95] 64 | eps: 1e-5 65 | 66 | lr_scheduler: 67 | _target_: torch.optim.lr_scheduler.LambdaLR 68 | _partial_: true 69 | lr_lambda: 70 | _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda 71 | _partial_: true 72 | num_warmup_steps: 2000 73 | num_training_steps: ${trainer.max_steps} 74 | final_lr_ratio: 0.1 75 | -------------------------------------------------------------------------------- /fish_speech/configs/text2semantic_sft.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - model@model.model: dual_ar_2_codebook_small 4 | - _self_ 5 | 6 | project: text2semantic_sft_dual_ar 7 | max_length: 4096 8 | ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth 9 | resume_weights_only: true 10 | 11 | # Lightning Trainer 12 | trainer: 13 | accumulate_grad_batches: 1 14 | gradient_clip_val: 1.0 15 | gradient_clip_algorithm: 'norm' 16 | max_steps: 10_000 17 | precision: bf16-true 18 | limit_val_batches: 10 19 | val_check_interval: 500 20 | 21 | # Dataset Configuration 22 | tokenizer: 23 | _target_: transformers.AutoTokenizer.from_pretrained 24 | pretrained_model_name_or_path: fishaudio/fish-speech-1 25 | 26 | # Dataset Configuration 27 | train_dataset: 28 | _target_: fish_speech.datasets.text.AutoAugTextDataset 29 | proto_files: 30 | - data/protos/sft 31 | tokenizer: ${tokenizer} 32 | max_length: ${max_length} 33 | num_codebooks: ${model.model.config.num_codebooks} 34 | use_speaker: 0.5 35 | interactive_prob: 0.7 36 | 37 | val_dataset: 38 | _target_: fish_speech.datasets.text.AutoAugTextDataset 39 | proto_files: 40 | - data/protos/sft 41 | tokenizer: ${tokenizer} 42 | max_length: ${max_length} 43 | num_codebooks: ${model.model.config.num_codebooks} 44 | use_speaker: 0.5 45 | interactive_prob: 0.7 46 | 47 | data: 48 | _target_: fish_speech.datasets.text.TextDataModule 49 | train_dataset: ${train_dataset} 50 | val_dataset: ${val_dataset} 51 | num_workers: 4 52 | batch_size: 8 53 | tokenizer: ${tokenizer} 54 | max_length: ${max_length} 55 | 56 | # Model Configuration 57 | model: 58 | _target_: fish_speech.models.text2semantic.TextToSemantic 59 | model: {} 60 | 61 | optimizer: 62 | _target_: torch.optim.AdamW 63 | _partial_: true 64 | lr: 4e-5 65 | weight_decay: 0 66 | betas: [0.9, 0.95] 67 | eps: 1e-5 68 | 69 | lr_scheduler: 70 | _target_: torch.optim.lr_scheduler.LambdaLR 71 | _partial_: true 72 | lr_lambda: 73 | _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda 74 | _partial_: true 75 | num_warmup_steps: 100 76 | num_training_steps: ${trainer.max_steps} 77 | final_lr_ratio: 0 78 | 79 | callbacks: 80 | model_checkpoint: 81 | every_n_train_steps: 1000 82 | save_top_k: 10 83 | -------------------------------------------------------------------------------- /fish_speech/configs/vits_decoder_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: vits_decoder 6 | ckpt_path: checkpoints/vits_decoder_v1.1.ckpt 7 | resume_weights_only: true 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accelerator: gpu 12 | devices: auto 13 | strategy: 14 | find_unused_parameters: true 15 | precision: 32 16 | max_steps: 100_000 17 | val_check_interval: 100 18 | benchmark: false 19 | 20 | sample_rate: 44100 21 | hop_length: 512 22 | num_mels: 128 23 | n_fft: 2048 24 | win_length: 2048 25 | 26 | # Dataset Configuration 27 | tokenizer: 28 | _target_: transformers.AutoTokenizer.from_pretrained 29 | pretrained_model_name_or_path: fishaudio/fish-speech-1 30 | 31 | # Dataset Configuration 32 | train_dataset: 33 | _target_: fish_speech.datasets.vits.VITSDataset 34 | filelist: data/source/Genshin/filelist.train.txt 35 | sample_rate: ${sample_rate} 36 | hop_length: ${hop_length} 37 | suffix: ".lab" 38 | tokenizer: ${tokenizer} 39 | sentence_mask_ratio: 0.2 40 | 41 | val_dataset: 42 | _target_: fish_speech.datasets.vits.VITSDataset 43 | filelist: data/source/Genshin/filelist.test.txt 44 | sample_rate: ${sample_rate} 45 | hop_length: ${hop_length} 46 | suffix: ".lab" 47 | tokenizer: ${tokenizer} 48 | 49 | data: 50 | _target_: fish_speech.datasets.vits.VITSDataModule 51 | train_dataset: ${train_dataset} 52 | val_dataset: ${val_dataset} 53 | num_workers: 4 54 | batch_size: 8 55 | val_batch_size: 4 56 | tokenizer: ${tokenizer} 57 | 58 | # Model Configuration 59 | model: 60 | _target_: fish_speech.models.vits_decoder.VITSDecoder 61 | sample_rate: ${sample_rate} 62 | hop_length: ${hop_length} 63 | freeze_discriminator: false 64 | 65 | weight_mel: 45.0 66 | weight_kl: 1.0 67 | 68 | generator: 69 | _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn 70 | spec_channels: 1025 71 | segment_size: 32 72 | inter_channels: 192 73 | hidden_channels: 192 74 | filter_channels: 768 75 | n_heads: 2 76 | n_layers: 6 77 | kernel_size: 3 78 | p_dropout: 0.1 79 | resblock: "1" 80 | resblock_kernel_sizes: [3, 7, 11] 81 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 82 | upsample_rates: [8, 8, 2, 2, 2] 83 | upsample_initial_channel: 512 84 | upsample_kernel_sizes: [16, 16, 8, 2, 2] 85 | gin_channels: 512 86 | vq_mask_ratio: 0.2 87 | ref_mask_ratio: 0.2 88 | 89 | discriminator: 90 | _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator 91 | periods: [2, 3, 5, 7, 11] 92 | 93 | mel_transform: 94 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 95 | sample_rate: ${sample_rate} 96 | n_fft: ${n_fft} 97 | hop_length: ${hop_length} 98 | win_length: ${win_length} 99 | n_mels: ${num_mels} 100 | 101 | spec_transform: 102 | _target_: fish_speech.utils.spectrogram.LinearSpectrogram 103 | n_fft: ${n_fft} 104 | hop_length: ${hop_length} 105 | win_length: ${win_length} 106 | mode: pow2_sqrt 107 | 108 | optimizer: 109 | _target_: torch.optim.AdamW 110 | _partial_: true 111 | lr: 1e-4 112 | betas: [0.8, 0.99] 113 | eps: 1e-6 114 | 115 | lr_scheduler: 116 | _target_: torch.optim.lr_scheduler.ExponentialLR 117 | _partial_: true 118 | gamma: 0.999999 119 | 120 | callbacks: 121 | grad_norm_monitor: 122 | sub_module: 123 | - generator 124 | - discriminator 125 | 126 | model_checkpoint: 127 | every_n_train_steps: ${trainer.val_check_interval} 128 | save_top_k: 10 129 | -------------------------------------------------------------------------------- /fish_speech/configs/vits_decoder_pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: vits_decoder 6 | ckpt_path: checkpoints/Bert-VITS2/ensemble.pth 7 | resume_weights_only: true 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accelerator: gpu 12 | devices: auto 13 | strategy: ddp_find_unused_parameters_true 14 | precision: 32 15 | max_steps: 1_000_000 16 | val_check_interval: 1000 17 | benchmark: false 18 | 19 | sample_rate: 44100 20 | hop_length: 512 21 | num_mels: 128 22 | n_fft: 2048 23 | win_length: 2048 24 | 25 | # Dataset Configuration 26 | tokenizer: 27 | _target_: transformers.AutoTokenizer.from_pretrained 28 | pretrained_model_name_or_path: fishaudio/fish-speech-1 29 | 30 | # Dataset Configuration 31 | train_dataset: 32 | _target_: fish_speech.datasets.vits.VITSDataset 33 | filelist: data/source/Genshin/filelist.train.txt 34 | sample_rate: ${sample_rate} 35 | hop_length: ${hop_length} 36 | suffix: ".lab" 37 | tokenizer: ${tokenizer} 38 | sentence_mask_ratio: 0.2 39 | 40 | val_dataset: 41 | _target_: fish_speech.datasets.vits.VITSDataset 42 | filelist: data/source/Genshin/filelist.test.txt 43 | sample_rate: ${sample_rate} 44 | hop_length: ${hop_length} 45 | suffix: ".lab" 46 | tokenizer: ${tokenizer} 47 | 48 | data: 49 | _target_: fish_speech.datasets.vits.VITSDataModule 50 | train_dataset: ${train_dataset} 51 | val_dataset: ${val_dataset} 52 | num_workers: 4 53 | batch_size: 8 54 | val_batch_size: 4 55 | tokenizer: ${tokenizer} 56 | 57 | # Model Configuration 58 | model: 59 | _target_: fish_speech.models.vits_decoder.VITSDecoder 60 | sample_rate: ${sample_rate} 61 | hop_length: ${hop_length} 62 | freeze_discriminator: false 63 | 64 | weight_mel: 45.0 65 | weight_kl: 1.0 66 | 67 | generator: 68 | _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn 69 | spec_channels: 1025 70 | segment_size: 32 71 | inter_channels: 192 72 | hidden_channels: 192 73 | filter_channels: 768 74 | n_heads: 2 75 | n_layers: 6 76 | kernel_size: 3 77 | p_dropout: 0.1 78 | resblock: "1" 79 | resblock_kernel_sizes: [3, 7, 11] 80 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 81 | upsample_rates: [8, 8, 2, 2, 2] 82 | upsample_initial_channel: 512 83 | upsample_kernel_sizes: [16, 16, 8, 2, 2] 84 | gin_channels: 512 85 | vq_mask_ratio: 0.2 86 | ref_mask_ratio: 0.2 87 | 88 | discriminator: 89 | _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator 90 | periods: [2, 3, 5, 7, 11] 91 | 92 | mel_transform: 93 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 94 | sample_rate: ${sample_rate} 95 | n_fft: ${n_fft} 96 | hop_length: ${hop_length} 97 | win_length: ${win_length} 98 | n_mels: ${num_mels} 99 | 100 | spec_transform: 101 | _target_: fish_speech.utils.spectrogram.LinearSpectrogram 102 | n_fft: ${n_fft} 103 | hop_length: ${hop_length} 104 | win_length: ${win_length} 105 | mode: pow2_sqrt 106 | 107 | optimizer: 108 | _target_: torch.optim.AdamW 109 | _partial_: true 110 | lr: 1e-4 111 | betas: [0.8, 0.99] 112 | eps: 1e-6 113 | 114 | lr_scheduler: 115 | _target_: torch.optim.lr_scheduler.ExponentialLR 116 | _partial_: true 117 | gamma: 0.999999 118 | 119 | callbacks: 120 | grad_norm_monitor: 121 | sub_module: 122 | - generator 123 | - discriminator 124 | 125 | model_checkpoint: 126 | every_n_train_steps: 1000 127 | save_top_k: 10 128 | -------------------------------------------------------------------------------- /fish_speech/configs/vqgan_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: vq-gan-finetune 6 | ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth 7 | resume_weights_only: true 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accelerator: gpu 12 | devices: auto 13 | precision: bf16-mixed 14 | max_steps: 100_000 15 | val_check_interval: 5000 16 | strategy: 17 | find_unused_parameters: true 18 | 19 | sample_rate: 44100 20 | hop_length: 512 21 | num_mels: 128 22 | n_fft: 2048 23 | win_length: 2048 24 | 25 | # Dataset Configuration 26 | train_dataset: 27 | _target_: fish_speech.datasets.vqgan.VQGANDataset 28 | filelist: data/vq_train_filelist.txt 29 | sample_rate: ${sample_rate} 30 | hop_length: ${hop_length} 31 | slice_frames: 512 32 | 33 | val_dataset: 34 | _target_: fish_speech.datasets.vqgan.VQGANDataset 35 | filelist: data/vq_val_filelist.txt 36 | sample_rate: ${sample_rate} 37 | hop_length: ${hop_length} 38 | 39 | data: 40 | _target_: fish_speech.datasets.vqgan.VQGANDataModule 41 | train_dataset: ${train_dataset} 42 | val_dataset: ${val_dataset} 43 | num_workers: 4 44 | batch_size: 16 45 | val_batch_size: 16 46 | 47 | # Model Configuration 48 | model: 49 | _target_: fish_speech.models.vqgan.VQGAN 50 | 51 | sampling_rate: ${sample_rate} 52 | weight_adv: 0.2 53 | weight_vq: 1.0 54 | weight_mel: 1.0 55 | 56 | # Important: Set the freeze_encoder to true to only train the decoder 57 | freeze_encoder: true 58 | 59 | encoder: 60 | _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet 61 | input_channels: ${num_mels} 62 | residual_channels: 768 63 | residual_layers: 20 64 | dilation_cycle: 4 65 | 66 | quantizer: 67 | _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize 68 | input_dim: 768 69 | n_codebooks: 1 70 | n_groups: 2 71 | levels: [8, 5, 5, 5] 72 | 73 | decoder: 74 | _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet 75 | output_channels: ${num_mels} 76 | residual_channels: 768 77 | residual_layers: 20 78 | dilation_cycle: 4 79 | condition_channels: 768 80 | 81 | discriminator: 82 | _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator 83 | 84 | vocoder: 85 | _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase 86 | ckpt_path: null # You may download the pretrained vocoder and set the path here 87 | 88 | encode_mel_transform: 89 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 90 | sample_rate: ${sample_rate} 91 | n_fft: ${n_fft} 92 | hop_length: ${hop_length} 93 | win_length: ${win_length} 94 | n_mels: ${num_mels} 95 | f_min: 0.0 96 | f_max: 8000.0 97 | 98 | gt_mel_transform: 99 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 100 | sample_rate: ${sample_rate} 101 | n_fft: ${n_fft} 102 | hop_length: ${hop_length} 103 | win_length: ${win_length} 104 | n_mels: ${num_mels} 105 | 106 | optimizer: 107 | _target_: torch.optim.AdamW 108 | _partial_: true 109 | lr: 4e-5 110 | betas: [0.8, 0.99] 111 | eps: 1e-5 112 | weight_decay: 0.01 113 | 114 | lr_scheduler: 115 | _target_: torch.optim.lr_scheduler.LambdaLR 116 | _partial_: true 117 | lr_lambda: 118 | _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda 119 | _partial_: true 120 | num_warmup_steps: 0 121 | num_training_steps: ${trainer.max_steps} 122 | final_lr_ratio: 0 123 | 124 | callbacks: 125 | model_summary: 126 | _target_: lightning.pytorch.callbacks.ModelSummary 127 | max_depth: 1 128 | 129 | model_checkpoint: 130 | every_n_train_steps: ${trainer.val_check_interval} 131 | 132 | grad_norm_monitor: 133 | sub_module: 134 | - encoder 135 | - decoder 136 | - quantizer 137 | - discriminator 138 | -------------------------------------------------------------------------------- /fish_speech/configs/vqgan_pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: vq-gan-pretrain 6 | 7 | # Lightning Trainer 8 | trainer: 9 | accelerator: gpu 10 | devices: auto 11 | precision: bf16-mixed 12 | max_steps: 1_000_000 13 | val_check_interval: 5000 14 | strategy: 15 | find_unused_parameters: true 16 | 17 | sample_rate: 44100 18 | hop_length: 512 19 | num_mels: 128 20 | n_fft: 2048 21 | win_length: 2048 22 | 23 | # Dataset Configuration 24 | train_dataset: 25 | _target_: torch.utils.data.ConcatDataset 26 | datasets: 27 | - _target_: fish_speech.datasets.vqgan.VQGANDataset 28 | filelist: data/gigaspeech/vq_train_filelist.txt 29 | sample_rate: ${sample_rate} 30 | hop_length: ${hop_length} 31 | slice_frames: 512 32 | - _target_: fish_speech.datasets.vqgan.VQGANDataset 33 | filelist: data/sft/vq_train_filelist.txt 34 | sample_rate: ${sample_rate} 35 | hop_length: ${hop_length} 36 | slice_frames: 512 37 | 38 | val_dataset: 39 | _target_: fish_speech.datasets.vqgan.VQGANDataset 40 | filelist: data/sft/vq_val_filelist.txt 41 | sample_rate: ${sample_rate} 42 | hop_length: ${hop_length} 43 | 44 | data: 45 | _target_: fish_speech.datasets.vqgan.VQGANDataModule 46 | train_dataset: ${train_dataset} 47 | val_dataset: ${val_dataset} 48 | num_workers: 4 49 | batch_size: 32 50 | val_batch_size: 32 51 | 52 | # Model Configuration 53 | model: 54 | _target_: fish_speech.models.vqgan.VQGAN 55 | 56 | sampling_rate: ${sample_rate} 57 | weight_adv: 0.2 58 | weight_vq: 1.0 59 | weight_mel: 1.0 60 | freeze_encoder: false 61 | 62 | encoder: 63 | _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet 64 | input_channels: ${num_mels} 65 | residual_channels: 768 66 | residual_layers: 20 67 | dilation_cycle: 4 68 | 69 | quantizer: 70 | _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize 71 | input_dim: 768 72 | n_codebooks: 1 73 | n_groups: 2 74 | levels: [8, 5, 5, 5] 75 | 76 | decoder: 77 | _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet 78 | output_channels: ${num_mels} 79 | residual_channels: 768 80 | residual_layers: 20 81 | dilation_cycle: 4 82 | condition_channels: 768 83 | 84 | discriminator: 85 | _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator 86 | 87 | vocoder: 88 | _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase 89 | ckpt_path: null # You may download the pretrained vocoder and set the path here 90 | 91 | encode_mel_transform: 92 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 93 | sample_rate: ${sample_rate} 94 | n_fft: ${n_fft} 95 | hop_length: ${hop_length} 96 | win_length: ${win_length} 97 | n_mels: ${num_mels} 98 | f_min: 0.0 99 | f_max: 8000.0 100 | 101 | gt_mel_transform: 102 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 103 | sample_rate: ${sample_rate} 104 | n_fft: ${n_fft} 105 | hop_length: ${hop_length} 106 | win_length: ${win_length} 107 | n_mels: ${num_mels} 108 | 109 | optimizer: 110 | _target_: torch.optim.AdamW 111 | _partial_: true 112 | lr: 1e-4 113 | betas: [0.8, 0.99] 114 | eps: 1e-5 115 | weight_decay: 0.01 116 | 117 | lr_scheduler: 118 | _target_: torch.optim.lr_scheduler.LambdaLR 119 | _partial_: true 120 | lr_lambda: 121 | _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda 122 | _partial_: true 123 | num_warmup_steps: 100 124 | num_training_steps: ${trainer.max_steps} 125 | final_lr_ratio: 0 126 | 127 | callbacks: 128 | model_summary: 129 | _target_: lightning.pytorch.callbacks.ModelSummary 130 | max_depth: 1 131 | 132 | model_checkpoint: 133 | every_n_train_steps: ${trainer.val_check_interval} 134 | 135 | grad_norm_monitor: 136 | sub_module: 137 | - encoder 138 | - decoder 139 | - quantizer 140 | - discriminator 141 | -------------------------------------------------------------------------------- /fish_speech/datasets/concat_repeat.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from typing import Iterable 3 | 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ConcatRepeatDataset(Dataset): 8 | datasets: list[Dataset] 9 | cumulative_sizes: list[int] 10 | repeats: list[int] 11 | 12 | @staticmethod 13 | def cumsum(sequence, repeats): 14 | r, s = [], 0 15 | for dataset, repeat in zip(sequence, repeats): 16 | l = len(dataset) * repeat 17 | r.append(l + s) 18 | s += l 19 | return r 20 | 21 | def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): 22 | super().__init__() 23 | 24 | self.datasets = list(datasets) 25 | self.repeats = repeats 26 | 27 | assert len(self.datasets) > 0, "datasets should not be an empty iterable" 28 | assert len(self.datasets) == len( 29 | repeats 30 | ), "datasets and repeats should have the same length" 31 | 32 | for d in self.datasets: 33 | assert not isinstance( 34 | d, IterableDataset 35 | ), "ConcatDataset does not support IterableDataset" 36 | 37 | self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) 38 | 39 | def __len__(self): 40 | return self.cumulative_sizes[-1] 41 | 42 | def __getitem__(self, idx): 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | 45 | if dataset_idx == 0: 46 | sample_idx = idx 47 | else: 48 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 49 | 50 | dataset = self.datasets[dataset_idx] 51 | 52 | return dataset[sample_idx % len(dataset)] 53 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text-data.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package text_data; 4 | 5 | message Semantics { 6 | repeated uint32 values = 1; 7 | } 8 | 9 | message Sentence { 10 | repeated string texts = 1; 11 | repeated Semantics semantics = 3; 12 | } 13 | 14 | message TextData { 15 | string source = 1; 16 | string name = 2; 17 | repeated Sentence sentences = 4; 18 | } 19 | 20 | message SampledData { 21 | string source = 1; 22 | string name = 2; 23 | repeated Sentence samples = 3; 24 | } 25 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text_data_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: text-data.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' 18 | ) 19 | 20 | _globals = globals() 21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 22 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) 23 | if _descriptor._USE_C_DESCRIPTORS == False: 24 | DESCRIPTOR._options = None 25 | _globals["_SEMANTICS"]._serialized_start = 30 26 | _globals["_SEMANTICS"]._serialized_end = 57 27 | _globals["_SENTENCE"]._serialized_start = 59 28 | _globals["_SENTENCE"]._serialized_end = 125 29 | _globals["_TEXTDATA"]._serialized_start = 127 30 | _globals["_TEXTDATA"]._serialized_end = 207 31 | _globals["_SAMPLEDDATA"]._serialized_start = 209 32 | _globals["_SAMPLEDDATA"]._serialized_end = 290 33 | # @@protoc_insertion_point(module_scope) 34 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text_data_stream.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | from .text_data_pb2 import TextData 4 | 5 | 6 | def read_pb_stream(f): 7 | while True: 8 | buf = f.read(4) 9 | if len(buf) == 0: 10 | break 11 | size = struct.unpack("I", buf)[0] 12 | buf = f.read(size) 13 | text_data = TextData() 14 | text_data.ParseFromString(buf) 15 | yield text_data 16 | 17 | 18 | def write_pb_stream(f, text_data): 19 | buf = text_data.SerializeToString() 20 | f.write(struct.pack("I", len(buf))) 21 | f.write(buf) 22 | 23 | 24 | def pack_pb_stream(text_data): 25 | buf = text_data.SerializeToString() 26 | return struct.pack("I", len(buf)) + buf 27 | 28 | 29 | def split_pb_stream(f): 30 | while True: 31 | head = f.read(4) 32 | if len(head) == 0: 33 | break 34 | size = struct.unpack("I", head)[0] 35 | buf = f.read(size) 36 | yield head + buf 37 | -------------------------------------------------------------------------------- /fish_speech/datasets/vits.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import librosa 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | from lightning import LightningDataModule 11 | from torch.utils.data import DataLoader, Dataset 12 | from torch.utils.data.distributed import DistributedSampler 13 | from transformers import AutoTokenizer 14 | 15 | from fish_speech.utils import RankedLogger 16 | 17 | logger = RankedLogger(__name__, rank_zero_only=False) 18 | 19 | 20 | class VITSDataset(Dataset): 21 | def __init__( 22 | self, 23 | filelist: str, 24 | tokenizer: AutoTokenizer, 25 | sample_rate: int = 44100, 26 | hop_length: int = 512, 27 | min_duration: float = 1.5, 28 | max_duration: float = 30.0, 29 | suffix: str = ".lab", 30 | sentence_mask_ratio: float = 0.0, 31 | ): 32 | super().__init__() 33 | 34 | filelist = Path(filelist) 35 | root = filelist.parent 36 | 37 | self.files = [] 38 | for line in filelist.read_text(encoding="utf-8").splitlines(): 39 | path = root / line 40 | self.files.append(path) 41 | 42 | self.sample_rate = sample_rate 43 | self.hop_length = hop_length 44 | self.min_duration = min_duration 45 | self.max_duration = max_duration 46 | self.tokenizer = tokenizer 47 | self.suffix = suffix 48 | self.sentence_mask_ratio = sentence_mask_ratio 49 | 50 | def __len__(self): 51 | return len(self.files) 52 | 53 | def get_item(self, idx): 54 | audio_file = self.files[idx] 55 | text_file = audio_file.with_suffix(self.suffix) 56 | 57 | if text_file.exists() is False or audio_file.exists() is False: 58 | return None 59 | 60 | audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True) 61 | duration = len(audio) / self.sample_rate 62 | 63 | # Pad to minimum duration 64 | if duration < self.min_duration: 65 | pad_duration = self.min_duration - duration 66 | pad_samples = int(pad_duration * self.sample_rate) 67 | audio = np.pad(audio, (0, pad_samples)) 68 | 69 | # Truncate to maximum duration 70 | if duration > self.max_duration: 71 | random_start = random.randint( 72 | 0, len(audio) - int(self.max_duration * self.sample_rate) - 1 73 | ) 74 | audio = audio[ 75 | random_start : random_start + int(self.max_duration * self.sample_rate) 76 | ] 77 | 78 | max_value = np.abs(audio).max() 79 | if max_value > 1.0: 80 | audio = audio / max_value 81 | 82 | if random.random() < self.sentence_mask_ratio: 83 | text = "-" 84 | else: 85 | text = text_file.read_text(encoding="utf-8") 86 | 87 | input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0) 88 | 89 | return { 90 | "audio": torch.from_numpy(audio), 91 | "text": input_ids, 92 | } 93 | 94 | def __getitem__(self, idx): 95 | try: 96 | return self.get_item(idx) 97 | except Exception as e: 98 | import traceback 99 | 100 | traceback.print_exc() 101 | logger.error(f"Error loading {self.files[idx]}: {e}") 102 | return None 103 | 104 | 105 | @dataclass 106 | class VITSCollator: 107 | tokenizer: AutoTokenizer 108 | 109 | def __call__(self, batch): 110 | batch = [x for x in batch if x is not None] 111 | 112 | audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) 113 | audio_maxlen = audio_lengths.max() 114 | 115 | text_lengths = torch.tensor([len(x["text"]) for x in batch]) 116 | text_maxlen = text_lengths.max() 117 | 118 | # Rounds up to nearest multiple of 2 (audio_lengths) 119 | audios = [] 120 | texts = [] 121 | for x in batch: 122 | audios.append( 123 | torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) 124 | ) 125 | 126 | texts.append( 127 | torch.nn.functional.pad( 128 | x["text"], 129 | (0, text_maxlen - len(x["text"])), 130 | value=self.tokenizer.eos_token_id, 131 | ) 132 | ) 133 | 134 | return { 135 | "audios": torch.stack(audios), 136 | "audio_lengths": audio_lengths, 137 | "texts": torch.stack(texts), 138 | "text_lengths": text_lengths, 139 | } 140 | 141 | 142 | class VITSDataModule(LightningDataModule): 143 | def __init__( 144 | self, 145 | train_dataset: VITSDataset, 146 | val_dataset: VITSDataset, 147 | tokenizer: AutoTokenizer, 148 | batch_size: int = 32, 149 | num_workers: int = 4, 150 | val_batch_size: Optional[int] = None, 151 | ): 152 | super().__init__() 153 | 154 | self.train_dataset = train_dataset 155 | self.val_dataset = val_dataset 156 | self.batch_size = batch_size 157 | self.val_batch_size = val_batch_size or batch_size 158 | self.num_workers = num_workers 159 | self.tokenizer = tokenizer 160 | 161 | def train_dataloader(self): 162 | return DataLoader( 163 | self.train_dataset, 164 | batch_size=self.batch_size, 165 | collate_fn=VITSCollator(self.tokenizer), 166 | num_workers=self.num_workers, 167 | shuffle=False, 168 | persistent_workers=True, 169 | ) 170 | 171 | def val_dataloader(self): 172 | return DataLoader( 173 | self.val_dataset, 174 | batch_size=self.val_batch_size, 175 | collate_fn=VITSCollator(self.tokenizer), 176 | num_workers=self.num_workers, 177 | persistent_workers=True, 178 | ) 179 | 180 | 181 | if __name__ == "__main__": 182 | tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") 183 | dataset = VITSDataset( 184 | "data/source/Genshin/filelist.train.txt", tokenizer=tokenizer, suffix=".lab" 185 | ) 186 | dataloader = DataLoader( 187 | dataset, batch_size=4, shuffle=False, collate_fn=VITSCollator(tokenizer) 188 | ) 189 | 190 | for batch in dataloader: 191 | print(batch["audios"].shape) 192 | print(batch["audio_lengths"]) 193 | print(batch["texts"].shape) 194 | print(batch["text_lengths"]) 195 | break 196 | -------------------------------------------------------------------------------- /fish_speech/datasets/vqgan.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | from lightning import LightningDataModule 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from fish_speech.utils import RankedLogger 12 | 13 | logger = RankedLogger(__name__, rank_zero_only=False) 14 | 15 | 16 | class VQGANDataset(Dataset): 17 | def __init__( 18 | self, 19 | filelist: str, 20 | sample_rate: int = 32000, 21 | hop_length: int = 640, 22 | slice_frames: Optional[int] = None, 23 | ): 24 | super().__init__() 25 | 26 | filelist = Path(filelist) 27 | root = filelist.parent 28 | 29 | self.files = [ 30 | root / line.strip() 31 | for line in filelist.read_text(encoding="utf-8").splitlines() 32 | if line.strip() 33 | ] 34 | self.sample_rate = sample_rate 35 | self.hop_length = hop_length 36 | self.slice_frames = slice_frames 37 | 38 | def __len__(self): 39 | return len(self.files) 40 | 41 | def get_item(self, idx): 42 | file = self.files[idx] 43 | 44 | audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) 45 | 46 | # Slice audio and features 47 | if ( 48 | self.slice_frames is not None 49 | and audio.shape[0] > self.slice_frames * self.hop_length 50 | ): 51 | start = np.random.randint( 52 | 0, audio.shape[0] - self.slice_frames * self.hop_length 53 | ) 54 | audio = audio[start : start + self.slice_frames * self.hop_length] 55 | 56 | if len(audio) == 0: 57 | return None 58 | 59 | max_value = np.abs(audio).max() 60 | if max_value > 1.0: 61 | audio = audio / max_value 62 | 63 | return { 64 | "audio": torch.from_numpy(audio), 65 | } 66 | 67 | def __getitem__(self, idx): 68 | try: 69 | return self.get_item(idx) 70 | except Exception as e: 71 | import traceback 72 | 73 | traceback.print_exc() 74 | logger.error(f"Error loading {self.files[idx]}: {e}") 75 | return None 76 | 77 | 78 | @dataclass 79 | class VQGANCollator: 80 | def __call__(self, batch): 81 | batch = [x for x in batch if x is not None] 82 | 83 | audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) 84 | audio_maxlen = audio_lengths.max() 85 | 86 | # Rounds up to nearest multiple of 2 (audio_lengths) 87 | audios = [] 88 | for x in batch: 89 | audios.append( 90 | torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) 91 | ) 92 | 93 | return { 94 | "audios": torch.stack(audios), 95 | "audio_lengths": audio_lengths, 96 | } 97 | 98 | 99 | class VQGANDataModule(LightningDataModule): 100 | def __init__( 101 | self, 102 | train_dataset: VQGANDataset, 103 | val_dataset: VQGANDataset, 104 | batch_size: int = 32, 105 | num_workers: int = 4, 106 | val_batch_size: Optional[int] = None, 107 | ): 108 | super().__init__() 109 | 110 | self.train_dataset = train_dataset 111 | self.val_dataset = val_dataset 112 | self.batch_size = batch_size 113 | self.val_batch_size = val_batch_size or batch_size 114 | self.num_workers = num_workers 115 | 116 | def train_dataloader(self): 117 | return DataLoader( 118 | self.train_dataset, 119 | batch_size=self.batch_size, 120 | collate_fn=VQGANCollator(), 121 | num_workers=self.num_workers, 122 | shuffle=True, 123 | persistent_workers=True, 124 | ) 125 | 126 | def val_dataloader(self): 127 | return DataLoader( 128 | self.val_dataset, 129 | batch_size=self.val_batch_size, 130 | collate_fn=VQGANCollator(), 131 | num_workers=self.num_workers, 132 | persistent_workers=True, 133 | ) 134 | 135 | 136 | if __name__ == "__main__": 137 | dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") 138 | dataloader = DataLoader( 139 | dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() 140 | ) 141 | 142 | for batch in dataloader: 143 | print(batch["audios"].shape) 144 | print(batch["features"].shape) 145 | print(batch["audio_lengths"]) 146 | print(batch["feature_lengths"]) 147 | break 148 | -------------------------------------------------------------------------------- /fish_speech/i18n/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import i18n 2 | 3 | __all__ = ["i18n"] 4 | -------------------------------------------------------------------------------- /fish_speech/i18n/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import locale 3 | from pathlib import Path 4 | 5 | I18N_FILE_PATH = Path(__file__).parent / "locale" 6 | DEFAULT_LANGUAGE = "en_US" 7 | 8 | 9 | def load_language_list(language): 10 | with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: 11 | language_list = json.load(f) 12 | 13 | return language_list 14 | 15 | 16 | class I18nAuto: 17 | def __init__(self): 18 | i18n_file = Path(".locale") 19 | 20 | if i18n_file.exists(): 21 | with open(i18n_file, "r", encoding="utf-8") as f: 22 | language = f.read().strip() 23 | else: 24 | # getlocale can't identify the system's language ((None, None)) 25 | language = locale.getdefaultlocale()[0] 26 | 27 | if (I18N_FILE_PATH / f"{language}.json").exists() is False: 28 | language = DEFAULT_LANGUAGE 29 | 30 | self.language = language 31 | self.language_map = load_language_list(language) 32 | 33 | def __call__(self, key): 34 | return self.language_map.get(key, key) 35 | 36 | def __repr__(self): 37 | return "Use Language: " + self.language 38 | 39 | 40 | i18n = I18nAuto() 41 | -------------------------------------------------------------------------------- /fish_speech/i18n/locale/en_US.json: -------------------------------------------------------------------------------- 1 | { 2 | "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", 3 | "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", 4 | "Accumulate Gradient Batches": "Accumulate Gradient Batches", 5 | "Add to Processing Area": "Add to Processing Area", 6 | "Added path successfully!": "Added path successfully!", 7 | "Advanced Config": "Advanced Config", 8 | "Base LLAMA Model": "Base LLAMA Model", 9 | "Batch Size": "Batch Size", 10 | "Chinese": "Chinese", 11 | "Compile Model": "Compile Model", 12 | "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", 13 | "Copy": "Copy", 14 | "Data Preprocessing": "Data Preprocessing", 15 | "Data Preprocessing Path": "Data Preprocessing Path", 16 | "Data Source": "Data Source", 17 | "Decoder Model Config": "Decoder Model Config", 18 | "Decoder Model Path": "Decoder Model Path", 19 | "Disabled": "Disabled", 20 | "Enable Reference Audio": "Enable Reference Audio", 21 | "English": "English", 22 | "Error Message": "Error Message", 23 | "File Preprocessing": "File Preprocessing", 24 | "Generate": "Generate", 25 | "Generated Audio": "Generated Audio", 26 | "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", 27 | "Infer interface is closed": "Infer interface is closed", 28 | "Inference Configuration": "Inference Configuration", 29 | "Inference Server Configuration": "Inference Server Configuration", 30 | "Inference Server Error": "Inference Server Error", 31 | "Inferring interface is launched at {}": "Inferring interface is launched at {}", 32 | "Initial Learning Rate": "Initial Learning Rate", 33 | "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", 34 | "Input Text": "Input Text", 35 | "Invalid path: {}": "Invalid path: {}", 36 | "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", 37 | "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", 38 | "Japanese": "Japanese", 39 | "LLAMA Configuration": "LLAMA Configuration", 40 | "LLAMA Model Config": "LLAMA Model Config", 41 | "LLAMA Model Path": "LLAMA Model Path", 42 | "Labeling Device": "Labeling Device", 43 | "LoRA Model to be merged": "LoRA Model to be merged", 44 | "Maximum Audio Duration": "Maximum Audio Duration", 45 | "Maximum Length per Sample": "Maximum Length per Sample", 46 | "Maximum Training Steps": "Maximum Training Steps", 47 | "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", 48 | "Merge": "Merge", 49 | "Merge LoRA": "Merge LoRA", 50 | "Merge successfully": "Merge successfully", 51 | "Minimum Audio Duration": "Minimum Audio Duration", 52 | "Model Output Path": "Model Output Path", 53 | "Model Size": "Model Size", 54 | "Move": "Move", 55 | "Move files successfully": "Move files successfully", 56 | "No audio generated, please check the input text.": "No audio generated, please check the input text.", 57 | "No selected options": "No selected options", 58 | "Number of Workers": "Number of Workers", 59 | "Open Inference Server": "Open Inference Server", 60 | "Open Labeler WebUI": "Open Labeler WebUI", 61 | "Open Tensorboard": "Open Tensorboard", 62 | "Opened labeler in browser": "Opened labeler in browser", 63 | "Optional Label Language": "Optional Label Language", 64 | "Optional online ver": "Optional online ver", 65 | "Output Path": "Output Path", 66 | "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", 67 | "Precision": "Precision", 68 | "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", 69 | "Put your text here.": "Put your text here.", 70 | "Reference Audio": "Reference Audio", 71 | "Reference Text": "Reference Text", 72 | "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.", 73 | "Remove Selected Data": "Remove Selected Data", 74 | "Removed path successfully!": "Removed path successfully!", 75 | "Repetition Penalty": "Repetition Penalty", 76 | "Save model every n steps": "Save model every n steps", 77 | "Select LLAMA ckpt": "Select LLAMA ckpt", 78 | "Select VITS ckpt": "Select VITS ckpt", 79 | "Select VQGAN ckpt": "Select VQGAN ckpt", 80 | "Select source file processing method": "Select source file processing method", 81 | "Select the model to be trained": "Select the model to be trained", 82 | "Selected: {}": "Selected: {}", 83 | "Speaker": "Speaker", 84 | "Speaker is identified by the folder name": "Speaker is identified by the folder name", 85 | "Start Training": "Start Training", 86 | "Streaming Audio": "Streaming Audio", 87 | "Streaming Generate": "Streaming Generate", 88 | "Tensorboard Host": "Tensorboard Host", 89 | "Tensorboard Log Path": "Tensorboard Log Path", 90 | "Tensorboard Port": "Tensorboard Port", 91 | "Tensorboard interface is closed": "Tensorboard interface is closed", 92 | "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", 93 | "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", 94 | "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", 95 | "Training Configuration": "Training Configuration", 96 | "Training Error": "Training Error", 97 | "Training stopped": "Training stopped", 98 | "Type name of the speaker": "Type name of the speaker", 99 | "Type the path or select from the dropdown": "Type the path or select from the dropdown", 100 | "Use LoRA": "Use LoRA", 101 | "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", 102 | "Use filelist": "Use filelist", 103 | "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", 104 | "VITS Configuration": "VITS Configuration", 105 | "VQGAN Configuration": "VQGAN Configuration", 106 | "Validation Batch Size": "Validation Batch Size", 107 | "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", 108 | "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", 109 | "WebUI Host": "WebUI Host", 110 | "WebUI Port": "WebUI Port", 111 | "Whisper Model": "Whisper Model", 112 | "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", 113 | "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", 114 | "latest": "latest", 115 | "new": "new" 116 | } 117 | -------------------------------------------------------------------------------- /fish_speech/i18n/locale/es_ES.json: -------------------------------------------------------------------------------- 1 | { 2 | "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", 3 | "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", 4 | "Accumulate Gradient Batches": "Acumular lotes de gradientes", 5 | "Add to Processing Area": "Agregar al Área de Procesamiento", 6 | "Added path successfully!": "¡Ruta agregada exitosamente!", 7 | "Advanced Config": "Configuración Avanzada", 8 | "Base LLAMA Model": "Modelo Base LLAMA", 9 | "Batch Size": "Tamaño del Lote", 10 | "Chinese": "Chino", 11 | "Compile Model": "Compilar Modelo", 12 | "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", 13 | "Copy": "Copiar", 14 | "Data Preprocessing": "Preprocesamiento de Datos", 15 | "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", 16 | "Data Source": "Fuente de Datos", 17 | "Decoder Model Config": "Configuración del modelo decodificador", 18 | "Decoder Model Path": "Ruta del modelo decodificador", 19 | "Disabled": "Desactivado", 20 | "Enable Reference Audio": "Habilitar Audio de Referencia", 21 | "English": "Inglés", 22 | "Error Message": "Mensaje de Error", 23 | "File Preprocessing": "Preprocesamiento de Archivos", 24 | "Generate": "Generar", 25 | "Generated Audio": "Audio Generado", 26 | "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", 27 | "Infer interface is closed": "La interfaz de inferencia está cerrada", 28 | "Inference Configuration": "Configuración de Inferencia", 29 | "Inference Server Configuration": "Configuración del Servidor de Inferencia", 30 | "Inference Server Error": "Error del Servidor de Inferencia", 31 | "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", 32 | "Initial Learning Rate": "Tasa de Aprendizaje Inicial", 33 | "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", 34 | "Input Text": "Texto de Entrada", 35 | "Invalid path: {}": "Ruta inválida: {}", 36 | "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", 37 | "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", 38 | "Japanese": "Japonés", 39 | "LLAMA Configuration": "Configuración de LLAMA", 40 | "LLAMA Model Config": "Configuración del Modelo LLAMA", 41 | "LLAMA Model Path": "Ruta del Modelo LLAMA", 42 | "Labeling Device": "Dispositivo de Etiquetado", 43 | "LoRA Model to be merged": "Modelo LoRA a fusionar", 44 | "Maximum Audio Duration": "Duración máxima de audio", 45 | "Maximum Length per Sample": "Longitud Máxima por Muestra", 46 | "Maximum Training Steps": "Pasos Máximos de Entrenamiento", 47 | "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", 48 | "Merge": "Fusionar", 49 | "Merge LoRA": "Fusionar LoRA", 50 | "Merge successfully": "Fusionado exitosamente", 51 | "Minimum Audio Duration": "Duración mínima de audio", 52 | "Model Output Path": "Ruta de Salida del Modelo", 53 | "Model Size": "Tamaño del Modelo", 54 | "Move": "Mover", 55 | "Move files successfully": "Archivos movidos exitosamente", 56 | "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", 57 | "No selected options": "No hay opciones seleccionadas", 58 | "Number of Workers": "Número de Trabajadores", 59 | "Open Inference Server": "Abrir Servidor de Inferencia", 60 | "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", 61 | "Open Tensorboard": "Abrir Tensorboard", 62 | "Opened labeler in browser": "Se abrió el etiquetador en el navegador", 63 | "Optional Label Language": "Idioma de Etiquetado Opcional", 64 | "Optional online ver": "Ver en línea opcional", 65 | "Output Path": "Ruta de Salida", 66 | "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", 67 | "Precision": "Precisión", 68 | "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", 69 | "Put your text here.": "Ponga su texto aquí.", 70 | "Reference Audio": "Audio de Referencia", 71 | "Reference Text": "Texto de Referencia", 72 | "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado se publica bajo la Licencia BSD-3-Clause, y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", 73 | "Remove Selected Data": "Eliminar Datos Seleccionados", 74 | "Removed path successfully!": "¡Ruta eliminada exitosamente!", 75 | "Repetition Penalty": "Penalización por Repetición", 76 | "Save model every n steps": "Guardar modelo cada n pasos", 77 | "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", 78 | "Select VITS ckpt": "Seleccionar punto de control VITS", 79 | "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", 80 | "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", 81 | "Select the model to be trained": "Seleccione el modelo a ser entrenado", 82 | "Selected: {}": "Seleccionado: {}", 83 | "Speaker": "Hablante", 84 | "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", 85 | "Start Training": "Iniciar Entrenamiento", 86 | "Streaming Audio": "transmisión de audio", 87 | "Streaming Generate": "síntesis en flujo", 88 | "Tensorboard Host": "Host de Tensorboard", 89 | "Tensorboard Log Path": "Ruta de Registro de Tensorboard", 90 | "Tensorboard Port": "Puerto de Tensorboard", 91 | "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", 92 | "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", 93 | "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", 94 | "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", 95 | "Training Configuration": "Configuración de Entrenamiento", 96 | "Training Error": "Error de Entrenamiento", 97 | "Training stopped": "Entrenamiento detenido", 98 | "Type name of the speaker": "Escriba el nombre del hablante", 99 | "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", 100 | "Use LoRA": "Usar LoRA", 101 | "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", 102 | "Use filelist": "Usar lista de archivos", 103 | "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", 104 | "VITS Configuration": "Configuración de VITS", 105 | "VQGAN Configuration": "Configuración de VQGAN", 106 | "Validation Batch Size": "Tamaño del Lote de Validación", 107 | "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", 108 | "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", 109 | "WebUI Host": "Host de WebUI", 110 | "WebUI Port": "Puerto de WebUI", 111 | "Whisper Model": "Modelo Whisper", 112 | "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", 113 | "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", 114 | "latest": "más reciente", 115 | "new": "nuevo" 116 | } 117 | -------------------------------------------------------------------------------- /fish_speech/i18n/locale/ja_JP.json: -------------------------------------------------------------------------------- 1 | { 2 | "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", 3 | "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", 4 | "Accumulate Gradient Batches": "勾配バッチの累積", 5 | "Add to Processing Area": "処理エリアに追加", 6 | "Added path successfully!": "パスの追加に成功しました!", 7 | "Advanced Config": "詳細設定", 8 | "Base LLAMA Model": "基本LLAMAモデル", 9 | "Batch Size": "バッチサイズ", 10 | "Chinese": "中国語", 11 | "Compile Model": "モデルのコンパイル", 12 | "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", 13 | "Copy": "コピー", 14 | "Data Preprocessing": "データ前処理", 15 | "Data Preprocessing Path": "データ前処理パス", 16 | "Data Source": "データソース", 17 | "Decoder Model Config": "デコーダーモデルの構成", 18 | "Decoder Model Path": "デコーダーモデルのパス", 19 | "Disabled": "無効", 20 | "Enable Reference Audio": "リファレンスオーディオを有効にする", 21 | "English": "英語", 22 | "Error Message": "エラーメッセージ", 23 | "File Preprocessing": "文書前处理", 24 | "Generate": "生成", 25 | "Generated Audio": "生成されたオーディオ", 26 | "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", 27 | "Infer interface is closed": "推論インターフェースが閉じられています", 28 | "Inference Configuration": "推論設定", 29 | "Inference Server Configuration": "推論サーバー設定", 30 | "Inference Server Error": "推論サーバーエラー", 31 | "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", 32 | "Initial Learning Rate": "初期学習率", 33 | "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", 34 | "Input Text": "入力テキスト", 35 | "Invalid path: {}": "無効なパス: {}", 36 | "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", 37 | "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", 38 | "Japanese": "日本語", 39 | "LLAMA Configuration": "LLAMA設定", 40 | "LLAMA Model Config": "LLAMAモデル設定", 41 | "LLAMA Model Path": "LLAMAモデルパス", 42 | "Labeling Device": "ラベリングデバイス", 43 | "LoRA Model to be merged": "マージするLoRAモデル", 44 | "Maximum Audio Duration": "最大オーディオの長さ", 45 | "Maximum Length per Sample": "サンプルあたりの最大長", 46 | "Maximum Training Steps": "最大トレーニングステップ数", 47 | "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", 48 | "Merge": "マージ", 49 | "Merge LoRA": "LoRAのマージ", 50 | "Merge successfully": "マージに成功しました", 51 | "Minimum Audio Duration": "最小オーディオの長さ", 52 | "Model Output Path": "モデル出力パス", 53 | "Model Size": "モデルサイズ", 54 | "Move": "移動", 55 | "Move files successfully": "ファイルの移動に成功しました", 56 | "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", 57 | "No selected options": "選択されたオプションはありません", 58 | "Number of Workers": "ワーカー数", 59 | "Open Inference Server": "推論サーバーを開く", 60 | "Open Labeler WebUI": "ラベラーWebUIを開く", 61 | "Open Tensorboard": "Tensorboardを開く", 62 | "Opened labeler in browser": "ブラウザでラベラーを開きました", 63 | "Optional Label Language": "オプションのラベル言語", 64 | "Optional online ver": "オプションのオンラインバージョン", 65 | "Output Path": "出力パス", 66 | "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", 67 | "Precision": "精度", 68 | "Probability of applying Speaker Condition": "話者条件を適用する確率", 69 | "Put your text here.": "ここにテキストを入力してください。", 70 | "Reference Audio": "リファレンスオーディオ", 71 | "Reference Text": "リファレンステキスト", 72 | "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "関連コードはBSD-3-Clauseライセンスの下でリリースされ、重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", 73 | "Remove Selected Data": "選択したデータを削除", 74 | "Removed path successfully!": "パスの削除に成功しました!", 75 | "Repetition Penalty": "反復ペナルティ", 76 | "Save model every n steps": "nステップごとにモデルを保存", 77 | "Select LLAMA ckpt": " LLAMA チェックポイントを選択", 78 | "Select VITS ckpt": "VITS チェックポイントを選択", 79 | "Select VQGAN ckpt": "VQGAN チェックポイントを選択", 80 | "Select source file processing method": "ソースファイルの処理方法を選択", 81 | "Select the model to be trained": "トレーニングするモデルを選択", 82 | "Selected: {}": "選択済み: {}", 83 | "Speaker": "話者", 84 | "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", 85 | "Start Training": "トレーニング開始", 86 | "Streaming Audio": "ストリーミングオーディオ", 87 | "Streaming Generate": "ストリーミング合成", 88 | "Tensorboard Host": "Tensorboardホスト", 89 | "Tensorboard Log Path": "Tensorboardログパス", 90 | "Tensorboard Port": "Tensorboardポート", 91 | "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", 92 | "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", 93 | "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", 94 | "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", 95 | "Training Configuration": "トレーニング設定", 96 | "Training Error": "トレーニングエラー", 97 | "Training stopped": "トレーニングが停止しました", 98 | "Type name of the speaker": "話者の名前を入力", 99 | "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", 100 | "Use LoRA": "LoRAを使用", 101 | "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", 102 | "Use filelist": "ファイルリストを使用", 103 | "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", 104 | "VITS Configuration": "VITS の構成", 105 | "VQGAN Configuration": "VQGAN の構成", 106 | "Validation Batch Size": "検証バッチサイズ", 107 | "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", 108 | "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", 109 | "WebUI Host": "WebUIホスト", 110 | "WebUI Port": "WebUIポート", 111 | "Whisper Model": "Whisperモデル", 112 | "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", 113 | "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", 114 | "latest": "最新", 115 | "new": "新規" 116 | } 117 | -------------------------------------------------------------------------------- /fish_speech/i18n/locale/zh_CN.json: -------------------------------------------------------------------------------- 1 | { 2 | "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", 3 | "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", 4 | "Accumulate Gradient Batches": "梯度累积批次", 5 | "Add to Processing Area": "加入处理区", 6 | "Added path successfully!": "添加路径成功!", 7 | "Advanced Config": "高级参数", 8 | "Base LLAMA Model": "基础 LLAMA 模型", 9 | "Batch Size": "批次大小", 10 | "Chinese": "中文", 11 | "Compile Model": "编译模型", 12 | "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", 13 | "Copy": "复制", 14 | "Data Preprocessing": "数据预处理", 15 | "Data Preprocessing Path": "数据预处理路径", 16 | "Data Source": "数据源", 17 | "Decoder Model Config": "解码器模型配置", 18 | "Decoder Model Path": "解码器模型路径", 19 | "Disabled": "禁用", 20 | "Enable Reference Audio": "启用参考音频", 21 | "English": "英文", 22 | "Error Message": "错误信息", 23 | "File Preprocessing": "文件预处理", 24 | "Generate": "生成", 25 | "Generated Audio": "音频", 26 | "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", 27 | "Infer interface is closed": "推理界面已关闭", 28 | "Inference Configuration": "推理配置", 29 | "Inference Server Configuration": "推理服务器配置", 30 | "Inference Server Error": "推理服务器错误", 31 | "Inferring interface is launched at {}": "推理界面已在 {} 上启动", 32 | "Initial Learning Rate": "初始学习率", 33 | "Input Audio & Source Path for Transcription": "输入音频和转录源路径", 34 | "Input Text": "输入文本", 35 | "Invalid path: {}": "无效路径: {}", 36 | "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", 37 | "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", 38 | "Japanese": "日文", 39 | "LLAMA Configuration": "LLAMA 配置", 40 | "LLAMA Model Config": "LLAMA 模型配置", 41 | "LLAMA Model Path": "LLAMA 模型路径", 42 | "Labeling Device": "标注加速设备", 43 | "LoRA Model to be merged": "要合并的 LoRA 模型", 44 | "Maximum Audio Duration": "最大音频时长", 45 | "Maximum Length per Sample": "每个样本的最大长度", 46 | "Maximum Training Steps": "最大训练步数", 47 | "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", 48 | "Merge": "合并", 49 | "Merge LoRA": "合并 LoRA", 50 | "Merge successfully": "合并成功", 51 | "Minimum Audio Duration": "最小音频时长", 52 | "Model Output Path": "模型输出路径", 53 | "Model Size": "模型规模", 54 | "Move": "移动", 55 | "Move files successfully": "移动文件成功", 56 | "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", 57 | "No selected options": "没有选择的选项", 58 | "Number of Workers": "数据加载进程数", 59 | "Open Inference Server": "打开推理服务器", 60 | "Open Labeler WebUI": "打开标注工具", 61 | "Open Tensorboard": "打开 Tensorboard", 62 | "Opened labeler in browser": "在浏览器中打开标注工具", 63 | "Optional Label Language": "[可选] 标注语言", 64 | "Optional online ver": "[可选] 使用在线版", 65 | "Output Path": "输出路径", 66 | "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", 67 | "Precision": "精度", 68 | "Probability of applying Speaker Condition": "应用说话人条件的概率", 69 | "Put your text here.": "在此处输入文本.", 70 | "Reference Audio": "参考音频", 71 | "Reference Text": "参考文本", 72 | "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.", 73 | "Remove Selected Data": "移除选中数据", 74 | "Removed path successfully!": "移除路径成功!", 75 | "Repetition Penalty": "重复惩罚", 76 | "Save model every n steps": "每 n 步保存模型", 77 | "Select LLAMA ckpt": "选择 LLAMA 检查点", 78 | "Select VITS ckpt": "选择 VITS 检查点", 79 | "Select VQGAN ckpt": "选择 VQGAN 检查点", 80 | "Select source file processing method": "选择源文件处理方法", 81 | "Select the model to be trained": "选择要训练的模型", 82 | "Selected: {}": "已选择: {}", 83 | "Speaker": "说话人", 84 | "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", 85 | "Start Training": "开始训练", 86 | "Streaming Audio": "流式音频", 87 | "Streaming Generate": "流式合成", 88 | "Tensorboard Host": "Tensorboard 监听地址", 89 | "Tensorboard Log Path": "Tensorboard 日志路径", 90 | "Tensorboard Port": "Tensorboard 端口", 91 | "Tensorboard interface is closed": "Tensorboard 界面已关闭", 92 | "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", 93 | "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", 94 | "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", 95 | "Training Configuration": "训练配置", 96 | "Training Error": "训练错误", 97 | "Training stopped": "训练已停止", 98 | "Type name of the speaker": "输入说话人的名称", 99 | "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", 100 | "Use LoRA": "使用 LoRA", 101 | "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", 102 | "Use filelist": "使用文件列表", 103 | "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", 104 | "VITS Configuration": "VITS 配置", 105 | "VQGAN Configuration": "VQGAN 配置", 106 | "Validation Batch Size": "验证批次大小", 107 | "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", 108 | "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", 109 | "WebUI Host": "WebUI 监听地址", 110 | "WebUI Port": "WebUI 端口", 111 | "Whisper Model": "Whisper 模型", 112 | "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", 113 | "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", 114 | "latest": "最近的检查点", 115 | "new": "创建新的检查点" 116 | } 117 | -------------------------------------------------------------------------------- /fish_speech/i18n/scan.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import glob 3 | import json 4 | from collections import OrderedDict 5 | from pathlib import Path 6 | 7 | from loguru import logger 8 | 9 | from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH 10 | 11 | 12 | def extract_i18n_strings(node): 13 | i18n_strings = [] 14 | 15 | if ( 16 | isinstance(node, ast.Call) 17 | and isinstance(node.func, ast.Name) 18 | and node.func.id == "i18n" 19 | ): 20 | for arg in node.args: 21 | if isinstance(arg, ast.Str): 22 | i18n_strings.append(arg.s) 23 | 24 | for child_node in ast.iter_child_nodes(node): 25 | i18n_strings.extend(extract_i18n_strings(child_node)) 26 | 27 | return i18n_strings 28 | 29 | 30 | # scan the directory for all .py files (recursively) 31 | # for each file, parse the code into an AST 32 | # for each AST, extract the i18n strings 33 | 34 | strings = [] 35 | folders = ["fish_speech", "tools"] 36 | # for filename in glob.iglob("**/*.py", recursive=True): 37 | for folder in folders: 38 | for f in Path(folder).rglob("*.py"): 39 | code = f.read_text(encoding="utf-8") 40 | if "i18n(" in code: 41 | tree = ast.parse(code) 42 | i18n_strings = extract_i18n_strings(tree) 43 | logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") 44 | strings.extend(i18n_strings) 45 | 46 | code_keys = set(strings) 47 | logger.info(f"Total unique: {len(code_keys)}") 48 | 49 | 50 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" 51 | with open(standard_file, "r", encoding="utf-8") as f: 52 | standard_data = json.load(f, object_pairs_hook=OrderedDict) 53 | standard_keys = set(standard_data.keys()) 54 | 55 | # Define the standard file name 56 | unused_keys = standard_keys - code_keys 57 | logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") 58 | for unused_key in unused_keys: 59 | logger.info(f"\t{unused_key}") 60 | 61 | missing_keys = code_keys - standard_keys 62 | logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") 63 | for missing_key in missing_keys: 64 | logger.info(f"\t{missing_key}") 65 | 66 | code_keys_dict = OrderedDict() 67 | for s in strings: 68 | code_keys_dict[s] = s 69 | 70 | # write back 71 | with open(standard_file, "w", encoding="utf-8") as f: 72 | json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) 73 | f.write("\n") 74 | 75 | logger.info(f"Updated {standard_file}") 76 | 77 | 78 | # Define the standard file name 79 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" 80 | 81 | # Find all JSON files in the directory 82 | dir_path = I18N_FILE_PATH 83 | languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] 84 | 85 | # Load the standard file 86 | with open(standard_file, "r", encoding="utf-8") as f: 87 | standard_data = json.load(f, object_pairs_hook=OrderedDict) 88 | 89 | # Loop through each language file 90 | for lang_file in languages: 91 | # Load the language file 92 | with open(lang_file, "r", encoding="utf-8") as f: 93 | lang_data = json.load(f, object_pairs_hook=OrderedDict) 94 | 95 | # Find the difference between the language file and the standard file 96 | diff = set(standard_data.keys()) - set(lang_data.keys()) 97 | 98 | miss = set(lang_data.keys()) - set(standard_data.keys()) 99 | 100 | # Add any missing keys to the language file 101 | for key in diff: 102 | lang_data[key] = "#!" + key 103 | logger.info(f"Added missing key: {key} to {lang_file}") 104 | 105 | # Del any extra keys to the language file 106 | for key in miss: 107 | del lang_data[key] 108 | logger.info(f"Del extra key: {key} from {lang_file}") 109 | 110 | # Sort the keys of the language file to match the order of the standard file 111 | lang_data = OrderedDict( 112 | sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) 113 | ) 114 | 115 | # Save the updated language file 116 | with open(lang_file, "w", encoding="utf-8") as f: 117 | json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) 118 | f.write("\n") 119 | 120 | logger.info(f"Updated {lang_file}") 121 | 122 | logger.info("Done") 123 | -------------------------------------------------------------------------------- /fish_speech/models/text2semantic/__init__.py: -------------------------------------------------------------------------------- 1 | from .lit_module import TextToSemantic 2 | 3 | __all__ = ["TextToSemantic"] 4 | -------------------------------------------------------------------------------- /fish_speech/models/text2semantic/lora_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import loralib as lora 4 | 5 | 6 | @dataclass 7 | class LoraConfig: 8 | r: int 9 | lora_alpha: float 10 | lora_dropout: float = 0.0 11 | 12 | 13 | def setup_lora(model, lora_config): 14 | # Replace the embedding layer with a LoRA layer 15 | model.embeddings = lora.Embedding( 16 | num_embeddings=model.embeddings.num_embeddings, 17 | embedding_dim=model.embeddings.embedding_dim, 18 | padding_idx=model.embeddings.padding_idx, 19 | r=lora_config.r, 20 | lora_alpha=lora_config.lora_alpha, 21 | ) 22 | 23 | # Replace output layer with a LoRA layer 24 | linears = [(model, "output")] 25 | 26 | # Replace all linear layers with LoRA layers 27 | for layer in model.layers: 28 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 29 | linears.extend( 30 | [ 31 | (layer.feed_forward, "w1"), 32 | (layer.feed_forward, "w2"), 33 | (layer.feed_forward, "w3"), 34 | ] 35 | ) 36 | 37 | if hasattr(model, "fast_layers"): 38 | model.fast_embeddings = lora.Embedding( 39 | num_embeddings=model.fast_embeddings.num_embeddings, 40 | embedding_dim=model.fast_embeddings.embedding_dim, 41 | padding_idx=model.fast_embeddings.padding_idx, 42 | r=lora_config.r, 43 | lora_alpha=lora_config.lora_alpha, 44 | ) 45 | 46 | # Dual-AR model 47 | linears.append((model, "fast_output")) 48 | 49 | for layer in model.fast_layers: 50 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 51 | linears.extend( 52 | [ 53 | (layer.feed_forward, "w1"), 54 | (layer.feed_forward, "w2"), 55 | (layer.feed_forward, "w3"), 56 | ] 57 | ) 58 | 59 | for module, layer in linears: 60 | updated_linear = lora.Linear( 61 | in_features=getattr(module, layer).in_features, 62 | out_features=getattr(module, layer).out_features, 63 | bias=getattr(module, layer).bias, 64 | r=lora_config.r, 65 | lora_alpha=lora_config.lora_alpha, 66 | lora_dropout=lora_config.lora_dropout, 67 | ) 68 | setattr(module, layer, updated_linear) 69 | 70 | # Mark only the LoRA layers as trainable 71 | lora.mark_only_lora_as_trainable(model, bias="none") 72 | 73 | 74 | def get_merged_state_dict(model): 75 | # This line will merge the state dict of the model and the LoRA parameters 76 | model.eval() 77 | 78 | # Then we need to remove the LoRA parameters from the state dict 79 | state_dict = model.state_dict() 80 | for name in list(state_dict.keys()): 81 | if "lora" in name: 82 | state_dict.pop(name) 83 | 84 | return state_dict 85 | -------------------------------------------------------------------------------- /fish_speech/models/vits_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .lit_module import VITSDecoder 2 | 3 | __all__ = ["VITSDecoder"] 4 | -------------------------------------------------------------------------------- /fish_speech/models/vits_decoder/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]): 7 | loss = 0 8 | for dr, dg in zip(fmap_r, fmap_g): 9 | dr = dr.float().detach() 10 | dg = dg.float() 11 | loss += torch.mean(torch.abs(dr - dg)) 12 | 13 | return loss * 2 14 | 15 | 16 | def discriminator_loss( 17 | disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor] 18 | ): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs: list[torch.Tensor]): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss( 47 | z_p: torch.Tensor, 48 | logs_q: torch.Tensor, 49 | m_p: torch.Tensor, 50 | logs_p: torch.Tensor, 51 | z_mask: torch.Tensor, 52 | ): 53 | """ 54 | z_p, logs_q: [b, h, t_t] 55 | m_p, logs_p: [b, h, t_t] 56 | """ 57 | z_p = z_p.float() 58 | logs_q = logs_q.float() 59 | m_p = m_p.float() 60 | logs_p = logs_p.float() 61 | z_mask = z_mask.float() 62 | 63 | kl = logs_p - logs_q - 0.5 64 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 65 | kl = torch.sum(kl * z_mask) 66 | l = kl / torch.sum(z_mask) 67 | return l 68 | -------------------------------------------------------------------------------- /fish_speech/models/vits_decoder/modules/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def init_weights(m, mean=0.0, std=0.01): 8 | classname = m.__class__.__name__ 9 | if classname.find("Conv") != -1: 10 | m.weight.data.normal_(mean, std) 11 | 12 | 13 | def get_padding(kernel_size, dilation=1): 14 | return int((kernel_size * dilation - dilation) / 2) 15 | 16 | 17 | def convert_pad_shape(pad_shape): 18 | l = pad_shape[::-1] 19 | pad_shape = [item for sublist in l for item in sublist] 20 | return pad_shape 21 | 22 | 23 | def intersperse(lst, item): 24 | result = [item] * (len(lst) * 2 + 1) 25 | result[1::2] = lst 26 | return result 27 | 28 | 29 | def kl_divergence(m_p, logs_p, m_q, logs_q): 30 | """KL(P||Q)""" 31 | kl = (logs_q - logs_p) - 0.5 32 | kl += ( 33 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 34 | ) 35 | return kl 36 | 37 | 38 | def rand_gumbel(shape): 39 | """Sample from the Gumbel distribution, protect from overflows.""" 40 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 41 | return -torch.log(-torch.log(uniform_samples)) 42 | 43 | 44 | def rand_gumbel_like(x): 45 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 46 | return g 47 | 48 | 49 | def slice_segments(x, ids_str, segment_size=4): 50 | ret = torch.zeros_like(x[:, :, :segment_size]) 51 | for i in range(x.size(0)): 52 | idx_str = ids_str[i] 53 | idx_end = idx_str + segment_size 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | return ret 56 | 57 | 58 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 59 | b, d, t = x.size() 60 | if x_lengths is None: 61 | x_lengths = t 62 | ids_str_max = x_lengths - segment_size + 1 63 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 64 | ret = slice_segments(x, ids_str, segment_size) 65 | return ret, ids_str 66 | 67 | 68 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 69 | position = torch.arange(length, dtype=torch.float) 70 | num_timescales = channels // 2 71 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 72 | num_timescales - 1 73 | ) 74 | inv_timescales = min_timescale * torch.exp( 75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 76 | ) 77 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 78 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 79 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 80 | signal = signal.view(1, channels, length) 81 | return signal 82 | 83 | 84 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 85 | b, channels, length = x.size() 86 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 87 | return x + signal.to(dtype=x.dtype, device=x.device) 88 | 89 | 90 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 91 | b, channels, length = x.size() 92 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 93 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 94 | 95 | 96 | def subsequent_mask(length): 97 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 98 | return mask 99 | 100 | 101 | @torch.jit.script 102 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 103 | n_channels_int = n_channels[0] 104 | in_act = input_a + input_b 105 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 106 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 107 | acts = t_act * s_act 108 | return acts 109 | 110 | 111 | def convert_pad_shape(pad_shape): 112 | l = pad_shape[::-1] 113 | pad_shape = [item for sublist in l for item in sublist] 114 | return pad_shape 115 | 116 | 117 | def shift_1d(x): 118 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 119 | return x 120 | 121 | 122 | def sequence_mask(length, max_length=None): 123 | if max_length is None: 124 | max_length = length.max() 125 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 126 | return x.unsqueeze(0) < length.unsqueeze(1) 127 | 128 | 129 | def generate_path(duration, mask): 130 | """ 131 | duration: [b, 1, t_x] 132 | mask: [b, 1, t_y, t_x] 133 | """ 134 | device = duration.device 135 | 136 | b, _, t_y, t_x = mask.shape 137 | cum_duration = torch.cumsum(duration, -1) 138 | 139 | cum_duration_flat = cum_duration.view(b * t_x) 140 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 141 | path = path.view(b, t_x, t_y) 142 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 143 | path = path.unsqueeze(1).transpose(2, 3) * mask 144 | return path 145 | 146 | 147 | def clip_grad_value_(parameters, clip_value, norm_type=2): 148 | if isinstance(parameters, torch.Tensor): 149 | parameters = [parameters] 150 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 151 | norm_type = float(norm_type) 152 | if clip_value is not None: 153 | clip_value = float(clip_value) 154 | 155 | total_norm = 0 156 | for p in parameters: 157 | param_norm = p.grad.data.norm(norm_type) 158 | total_norm += param_norm.item() ** norm_type 159 | if clip_value is not None: 160 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 161 | total_norm = total_norm ** (1.0 / norm_type) 162 | return total_norm 163 | 164 | 165 | def squeeze(x, x_mask=None, n_sqz=2): 166 | b, c, t = x.size() 167 | 168 | t = (t // n_sqz) * n_sqz 169 | x = x[:, :, :t] 170 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 171 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 172 | 173 | if x_mask is not None: 174 | x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz] 175 | else: 176 | x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) 177 | return x_sqz * x_mask, x_mask 178 | 179 | 180 | def unsqueeze(x, x_mask=None, n_sqz=2): 181 | b, c, t = x.size() 182 | 183 | x_unsqz = x.view(b, n_sqz, c // n_sqz, t) 184 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) 185 | 186 | if x_mask is not None: 187 | x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) 188 | else: 189 | x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) 190 | return x_unsqz * x_mask, x_mask 191 | -------------------------------------------------------------------------------- /fish_speech/models/vits_decoder/modules/mrte.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import remove_weight_norm, weight_norm 4 | 5 | from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention 6 | 7 | 8 | class MRTE(nn.Module): 9 | def __init__( 10 | self, 11 | content_enc_channels=192, 12 | hidden_size=512, 13 | out_channels=192, 14 | n_heads=4, 15 | ): 16 | super(MRTE, self).__init__() 17 | self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) 18 | self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) 19 | self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) 20 | self.c_post = nn.Conv1d(hidden_size, out_channels, 1) 21 | 22 | def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): 23 | if ge == None: 24 | ge = 0 25 | attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) 26 | 27 | ssl_enc = self.c_pre(ssl_enc * ssl_mask) 28 | text_enc = self.text_pre(text * text_mask) 29 | if test != None: 30 | if test == 0: 31 | x = ( 32 | self.cross_attention( 33 | ssl_enc * ssl_mask, text_enc * text_mask, attn_mask 34 | ) 35 | + ssl_enc 36 | + ge 37 | ) 38 | elif test == 1: 39 | x = ssl_enc + ge 40 | elif test == 2: 41 | x = ( 42 | self.cross_attention( 43 | ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask 44 | ) 45 | + ge 46 | ) 47 | else: 48 | raise ValueError("test should be 0,1,2") 49 | else: 50 | x = ( 51 | self.cross_attention( 52 | ssl_enc * ssl_mask, text_enc * text_mask, attn_mask 53 | ) 54 | + ssl_enc 55 | + ge 56 | ) 57 | x = self.c_post(x * ssl_mask) 58 | return x 59 | -------------------------------------------------------------------------------- /fish_speech/models/vits_decoder/modules/vq_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize 7 | from fish_speech.models.vqgan.modules.wavenet import WaveNet 8 | from fish_speech.models.vqgan.utils import sequence_mask 9 | from fish_speech.utils.spectrogram import LogMelSpectrogram 10 | 11 | 12 | class VQEncoder(nn.Module): 13 | def __init__( 14 | self, 15 | ): 16 | super().__init__() 17 | 18 | self.encoder = WaveNet( 19 | input_channels=128, 20 | residual_channels=768, 21 | residual_layers=20, 22 | dilation_cycle=4, 23 | ) 24 | 25 | self.quantizer = DownsampleFiniteScalarQuantize( 26 | input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5] 27 | ) 28 | 29 | self.spec = LogMelSpectrogram( 30 | sample_rate=44100, 31 | n_fft=2048, 32 | win_length=2048, 33 | hop_length=512, 34 | n_mels=128, 35 | f_min=0.0, 36 | f_max=8000.0, 37 | ) 38 | 39 | self.eval() 40 | e = self.load_state_dict( 41 | torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"), 42 | strict=False, 43 | ) 44 | 45 | assert len(e.missing_keys) == 0, e.missing_keys 46 | assert all( 47 | k.startswith("decoder.") 48 | or k.startswith("quality_projection.") 49 | or k.startswith("discriminator.") 50 | for k in e.unexpected_keys 51 | ), e.unexpected_keys 52 | 53 | @torch.no_grad() 54 | def forward(self, audios, audio_lengths, sr=None): 55 | mel_spec = self.spec(audios, sample_rate=sr) 56 | 57 | if sr is not None: 58 | audio_lengths = audio_lengths * 44100 // sr 59 | 60 | mel_lengths = audio_lengths // self.spec.hop_length 61 | mel_masks = ( 62 | torch.arange(mel_spec.shape[2], device=mel_spec.device) 63 | < mel_lengths[:, None] 64 | ) 65 | mel_masks_float_conv = mel_masks[:, None, :].float() 66 | mels = mel_spec * mel_masks_float_conv 67 | 68 | # Encode 69 | encoded_features = self.encoder(mels) * mel_masks_float_conv 70 | encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv 71 | 72 | return encoded_features 73 | 74 | @torch.no_grad() 75 | def indicies_to_vq_features( 76 | self, 77 | indices, 78 | feature_lengths, 79 | ): 80 | factor = math.prod(self.quantizer.downsample_factor) 81 | mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor) 82 | mel_masks_float_conv = mel_masks[:, None, :].float() 83 | z = self.quantizer.decode(indices) * mel_masks_float_conv 84 | 85 | return z 86 | 87 | @torch.no_grad() 88 | def encode(self, audios, audio_lengths, sr=None): 89 | audios = audios.float() 90 | 91 | mels = self.spec(audios, sample_rate=sr) 92 | mel_lengths = audio_lengths // self.spec.hop_length 93 | mel_masks = sequence_mask(mel_lengths, mels.shape[2]) 94 | mel_masks_float_conv = mel_masks[:, None, :].float() 95 | mels = mels * mel_masks_float_conv 96 | 97 | # Encode 98 | encoded_features = self.encoder(mels) * mel_masks_float_conv 99 | feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor) 100 | 101 | return self.quantizer.encode(encoded_features), feature_lengths 102 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/__init__.py: -------------------------------------------------------------------------------- 1 | from .lit_module import VQGAN 2 | 3 | __all__ = ["VQGAN"] 4 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.parametrizations import weight_norm 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | blocks = [] 11 | convs = [ 12 | (1, 64, (3, 9), 1, (1, 4)), 13 | (64, 128, (3, 9), (1, 2), (1, 4)), 14 | (128, 256, (3, 9), (1, 2), (1, 4)), 15 | (256, 512, (3, 9), (1, 2), (1, 4)), 16 | (512, 1024, (3, 3), 1, (1, 1)), 17 | (1024, 1, (3, 3), 1, (1, 1)), 18 | ] 19 | 20 | for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate( 21 | convs 22 | ): 23 | blocks.append( 24 | weight_norm( 25 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 26 | ) 27 | ) 28 | 29 | if idx != len(convs) - 1: 30 | blocks.append(nn.SiLU(inplace=True)) 31 | 32 | self.blocks = nn.Sequential(*blocks) 33 | 34 | def forward(self, x): 35 | return self.blocks(x[:, None])[:, 0] 36 | 37 | 38 | if __name__ == "__main__": 39 | model = Discriminator() 40 | print(sum(p.numel() for p in model.parameters()) / 1_000_000) 41 | x = torch.randn(1, 128, 1024) 42 | y = model(x) 43 | print(y.shape) 44 | print(y) 45 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/modules/fsq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from vector_quantize_pytorch import GroupedResidualFSQ 8 | 9 | from .firefly import ConvNeXtBlock 10 | 11 | 12 | @dataclass 13 | class FSQResult: 14 | z: torch.Tensor 15 | codes: torch.Tensor 16 | latents: torch.Tensor 17 | 18 | 19 | class DownsampleFiniteScalarQuantize(nn.Module): 20 | def __init__( 21 | self, 22 | input_dim: int = 512, 23 | n_codebooks: int = 9, 24 | n_groups: int = 1, 25 | levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 26 | downsample_factor: tuple[int] = (2, 2), 27 | downsample_dims: tuple[int] | None = None, 28 | ): 29 | super().__init__() 30 | 31 | if downsample_dims is None: 32 | downsample_dims = [input_dim for _ in range(len(downsample_factor))] 33 | 34 | all_dims = (input_dim,) + tuple(downsample_dims) 35 | 36 | self.residual_fsq = GroupedResidualFSQ( 37 | dim=all_dims[-1], 38 | levels=levels, 39 | num_quantizers=n_codebooks, 40 | groups=n_groups, 41 | ) 42 | 43 | self.downsample_factor = downsample_factor 44 | self.downsample_dims = downsample_dims 45 | 46 | self.downsample = nn.Sequential( 47 | *[ 48 | nn.Sequential( 49 | nn.Conv1d( 50 | all_dims[idx], 51 | all_dims[idx + 1], 52 | kernel_size=factor, 53 | stride=factor, 54 | ), 55 | ConvNeXtBlock(dim=all_dims[idx + 1]), 56 | ) 57 | for idx, factor in enumerate(downsample_factor) 58 | ] 59 | ) 60 | 61 | self.upsample = nn.Sequential( 62 | *[ 63 | nn.Sequential( 64 | nn.ConvTranspose1d( 65 | all_dims[idx + 1], 66 | all_dims[idx], 67 | kernel_size=factor, 68 | stride=factor, 69 | ), 70 | ConvNeXtBlock(dim=all_dims[idx]), 71 | ) 72 | for idx, factor in reversed(list(enumerate(downsample_factor))) 73 | ] 74 | ) 75 | 76 | self.apply(self._init_weights) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, (nn.Conv1d, nn.Linear)): 80 | nn.init.trunc_normal_(m.weight, std=0.02) 81 | nn.init.constant_(m.bias, 0) 82 | 83 | def forward(self, z) -> FSQResult: 84 | original_shape = z.shape 85 | z = self.downsample(z) 86 | quantized, indices = self.residual_fsq(z.mT) 87 | result = FSQResult( 88 | z=quantized.mT, 89 | codes=indices.mT, 90 | latents=z, 91 | ) 92 | result.z = self.upsample(result.z) 93 | 94 | # Pad or crop z to match original shape 95 | diff = original_shape[-1] - result.z.shape[-1] 96 | left = diff // 2 97 | right = diff - left 98 | 99 | if diff > 0: 100 | result.z = F.pad(result.z, (left, right)) 101 | elif diff < 0: 102 | result.z = result.z[..., left:-right] 103 | 104 | return result 105 | 106 | def encode(self, z): 107 | z = self.downsample(z) 108 | _, indices = self.residual_fsq(z.mT) 109 | indices = rearrange(indices, "g b l r -> b (g r) l") 110 | return indices 111 | 112 | def decode(self, indices: torch.Tensor): 113 | indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) 114 | z_q = self.residual_fsq.get_output_from_indices(indices) 115 | z_q = self.upsample(z_q.mT) 116 | return z_q 117 | 118 | # def from_latents(self, latents: torch.Tensor): 119 | # z_q, z_p, codes = super().from_latents(latents) 120 | # z_q = self.upsample(z_q) 121 | # return z_q, z_p, codes 122 | 123 | 124 | if __name__ == "__main__": 125 | rvq = DownsampleFiniteScalarQuantize( 126 | n_codebooks=1, 127 | downsample_factor=(2, 2), 128 | ) 129 | x = torch.randn(16, 512, 80) 130 | 131 | result = rvq(x) 132 | print(rvq) 133 | print(result.latents.shape, result.codes.shape, result.z.shape) 134 | 135 | # y = rvq.from_codes(result.codes) 136 | # print(y[0].shape) 137 | 138 | # y = rvq.from_latents(result.latents) 139 | # print(y[0].shape) 140 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/modules/reference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from .wavenet import WaveNet 8 | 9 | 10 | class ReferenceEncoder(WaveNet): 11 | def __init__( 12 | self, 13 | input_channels: Optional[int] = None, 14 | output_channels: Optional[int] = None, 15 | residual_channels: int = 512, 16 | residual_layers: int = 20, 17 | dilation_cycle: Optional[int] = 4, 18 | num_heads: int = 8, 19 | latent_len: int = 4, 20 | ): 21 | super().__init__( 22 | input_channels=input_channels, 23 | residual_channels=residual_channels, 24 | residual_layers=residual_layers, 25 | dilation_cycle=dilation_cycle, 26 | ) 27 | 28 | self.head_dim = residual_channels // num_heads 29 | self.num_heads = num_heads 30 | 31 | self.latent_len = latent_len 32 | self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels)) 33 | 34 | self.q = nn.Linear(residual_channels, residual_channels, bias=True) 35 | self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True) 36 | self.q_norm = nn.LayerNorm(self.head_dim) 37 | self.k_norm = nn.LayerNorm(self.head_dim) 38 | self.proj = nn.Linear(residual_channels, residual_channels) 39 | self.proj_drop = nn.Dropout(0.1) 40 | 41 | self.norm = nn.LayerNorm(residual_channels) 42 | self.mlp = nn.Sequential( 43 | nn.Linear(residual_channels, residual_channels * 4), 44 | nn.SiLU(), 45 | nn.Linear(residual_channels * 4, residual_channels), 46 | ) 47 | self.output_projection_attn = nn.Linear(residual_channels, output_channels) 48 | 49 | torch.nn.init.trunc_normal_(self.latent, std=0.02) 50 | self.apply(self.init_weights) 51 | 52 | def init_weights(self, m): 53 | if isinstance(m, nn.Linear): 54 | torch.nn.init.trunc_normal_(m.weight, std=0.02) 55 | if m.bias is not None: 56 | torch.nn.init.constant_(m.bias, 0) 57 | 58 | def forward(self, x, attn_mask=None): 59 | x = super().forward(x).mT 60 | B, N, C = x.shape 61 | 62 | # Calculate mask 63 | if attn_mask is not None: 64 | assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool 65 | 66 | attn_mask = attn_mask[:, None, None, :].expand( 67 | B, self.num_heads, self.latent_len, N 68 | ) 69 | 70 | q_latent = self.latent.expand(B, -1, -1) 71 | q = ( 72 | self.q(q_latent) 73 | .reshape(B, self.latent_len, self.num_heads, self.head_dim) 74 | .transpose(1, 2) 75 | ) 76 | 77 | kv = ( 78 | self.kv(x) 79 | .reshape(B, N, 2, self.num_heads, self.head_dim) 80 | .permute(2, 0, 3, 1, 4) 81 | ) 82 | k, v = kv.unbind(0) 83 | 84 | q, k = self.q_norm(q), self.k_norm(k) 85 | x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) 86 | 87 | x = x.transpose(1, 2).reshape(B, self.latent_len, C) 88 | x = self.proj(x) 89 | x = self.proj_drop(x) 90 | 91 | x = x + self.mlp(self.norm(x)) 92 | x = self.output_projection_attn(x) 93 | x = x.mean(1) 94 | 95 | return x 96 | 97 | 98 | if __name__ == "__main__": 99 | with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 100 | model = ReferenceEncoder( 101 | input_channels=128, 102 | output_channels=64, 103 | residual_channels=384, 104 | residual_layers=20, 105 | dilation_cycle=4, 106 | num_heads=8, 107 | ) 108 | x = torch.randn(4, 128, 64) 109 | mask = torch.ones(4, 64, dtype=torch.bool) 110 | y = model(x, mask) 111 | print(y.shape) 112 | loss = F.mse_loss(y, torch.randn(4, 64)) 113 | loss.backward() 114 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/modules/wavenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class Mish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.tanh(F.softplus(x)) 12 | 13 | 14 | class DiffusionEmbedding(nn.Module): 15 | """Diffusion Step Embedding""" 16 | 17 | def __init__(self, d_denoiser): 18 | super(DiffusionEmbedding, self).__init__() 19 | self.dim = d_denoiser 20 | 21 | def forward(self, x): 22 | device = x.device 23 | half_dim = self.dim // 2 24 | emb = math.log(10000) / (half_dim - 1) 25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 26 | emb = x[:, None] * emb[None, :] 27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 28 | return emb 29 | 30 | 31 | class LinearNorm(nn.Module): 32 | """LinearNorm Projection""" 33 | 34 | def __init__(self, in_features, out_features, bias=False): 35 | super(LinearNorm, self).__init__() 36 | self.linear = nn.Linear(in_features, out_features, bias) 37 | 38 | nn.init.xavier_uniform_(self.linear.weight) 39 | if bias: 40 | nn.init.constant_(self.linear.bias, 0.0) 41 | 42 | def forward(self, x): 43 | x = self.linear(x) 44 | return x 45 | 46 | 47 | class ConvNorm(nn.Module): 48 | """1D Convolution""" 49 | 50 | def __init__( 51 | self, 52 | in_channels, 53 | out_channels, 54 | kernel_size=1, 55 | stride=1, 56 | padding=None, 57 | dilation=1, 58 | bias=True, 59 | w_init_gain="linear", 60 | ): 61 | super(ConvNorm, self).__init__() 62 | 63 | if padding is None: 64 | assert kernel_size % 2 == 1 65 | padding = int(dilation * (kernel_size - 1) / 2) 66 | 67 | self.conv = nn.Conv1d( 68 | in_channels, 69 | out_channels, 70 | kernel_size=kernel_size, 71 | stride=stride, 72 | padding=padding, 73 | dilation=dilation, 74 | bias=bias, 75 | ) 76 | nn.init.kaiming_normal_(self.conv.weight) 77 | 78 | def forward(self, signal): 79 | conv_signal = self.conv(signal) 80 | 81 | return conv_signal 82 | 83 | 84 | class ResidualBlock(nn.Module): 85 | """Residual Block""" 86 | 87 | def __init__( 88 | self, 89 | residual_channels, 90 | use_linear_bias=False, 91 | dilation=1, 92 | condition_channels=None, 93 | ): 94 | super(ResidualBlock, self).__init__() 95 | self.conv_layer = ConvNorm( 96 | residual_channels, 97 | 2 * residual_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=dilation, 101 | dilation=dilation, 102 | ) 103 | 104 | if condition_channels is not None: 105 | self.diffusion_projection = LinearNorm( 106 | residual_channels, residual_channels, use_linear_bias 107 | ) 108 | self.condition_projection = ConvNorm( 109 | condition_channels, 2 * residual_channels, kernel_size=1 110 | ) 111 | 112 | self.output_projection = ConvNorm( 113 | residual_channels, 2 * residual_channels, kernel_size=1 114 | ) 115 | 116 | def forward(self, x, condition=None, diffusion_step=None): 117 | y = x 118 | 119 | if diffusion_step is not None: 120 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 121 | y = y + diffusion_step 122 | 123 | y = self.conv_layer(y) 124 | 125 | if condition is not None: 126 | condition = self.condition_projection(condition) 127 | y = y + condition 128 | 129 | gate, filter = torch.chunk(y, 2, dim=1) 130 | y = torch.sigmoid(gate) * torch.tanh(filter) 131 | 132 | y = self.output_projection(y) 133 | residual, skip = torch.chunk(y, 2, dim=1) 134 | 135 | return (x + residual) / math.sqrt(2.0), skip 136 | 137 | 138 | class WaveNet(nn.Module): 139 | def __init__( 140 | self, 141 | input_channels: Optional[int] = None, 142 | output_channels: Optional[int] = None, 143 | residual_channels: int = 512, 144 | residual_layers: int = 20, 145 | dilation_cycle: Optional[int] = 4, 146 | is_diffusion: bool = False, 147 | condition_channels: Optional[int] = None, 148 | ): 149 | super().__init__() 150 | 151 | # Input projection 152 | self.input_projection = None 153 | if input_channels is not None and input_channels != residual_channels: 154 | self.input_projection = ConvNorm( 155 | input_channels, residual_channels, kernel_size=1 156 | ) 157 | 158 | if input_channels is None: 159 | input_channels = residual_channels 160 | 161 | self.input_channels = input_channels 162 | 163 | # Residual layers 164 | self.residual_layers = nn.ModuleList( 165 | [ 166 | ResidualBlock( 167 | residual_channels=residual_channels, 168 | use_linear_bias=False, 169 | dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1, 170 | condition_channels=condition_channels, 171 | ) 172 | for i in range(residual_layers) 173 | ] 174 | ) 175 | 176 | # Skip projection 177 | self.skip_projection = ConvNorm( 178 | residual_channels, residual_channels, kernel_size=1 179 | ) 180 | 181 | # Output projection 182 | self.output_projection = None 183 | if output_channels is not None and output_channels != residual_channels: 184 | self.output_projection = ConvNorm( 185 | residual_channels, output_channels, kernel_size=1 186 | ) 187 | 188 | if is_diffusion: 189 | self.diffusion_embedding = DiffusionEmbedding(residual_channels) 190 | self.mlp = nn.Sequential( 191 | LinearNorm(residual_channels, residual_channels * 4, False), 192 | Mish(), 193 | LinearNorm(residual_channels * 4, residual_channels, False), 194 | ) 195 | 196 | self.apply(self._init_weights) 197 | 198 | def _init_weights(self, m): 199 | if isinstance(m, (nn.Conv1d, nn.Linear)): 200 | nn.init.trunc_normal_(m.weight, std=0.02) 201 | if getattr(m, "bias", None) is not None: 202 | nn.init.constant_(m.bias, 0) 203 | 204 | def forward(self, x, t=None, condition=None): 205 | if self.input_projection is not None: 206 | x = self.input_projection(x) 207 | x = F.silu(x) 208 | 209 | if t is not None: 210 | t = self.diffusion_embedding(t) 211 | t = self.mlp(t) 212 | 213 | skip = [] 214 | for layer in self.residual_layers: 215 | x, skip_connection = layer(x, condition, t) 216 | skip.append(skip_connection) 217 | 218 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 219 | x = self.skip_projection(x) 220 | 221 | if self.output_projection is not None: 222 | x = F.silu(x) 223 | x = self.output_projection(x) 224 | 225 | return x 226 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import torch 3 | from matplotlib import pyplot as plt 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def convert_pad_shape(pad_shape): 9 | l = pad_shape[::-1] 10 | pad_shape = [item for sublist in l for item in sublist] 11 | return pad_shape 12 | 13 | 14 | def sequence_mask(length, max_length=None): 15 | if max_length is None: 16 | max_length = length.max() 17 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 18 | return x.unsqueeze(0) < length.unsqueeze(1) 19 | 20 | 21 | def init_weights(m, mean=0.0, std=0.01): 22 | classname = m.__class__.__name__ 23 | if classname.find("Conv") != -1: 24 | m.weight.data.normal_(mean, std) 25 | 26 | 27 | def get_padding(kernel_size, dilation=1): 28 | return int((kernel_size * dilation - dilation) / 2) 29 | 30 | 31 | def plot_mel(data, titles=None): 32 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 33 | 34 | if titles is None: 35 | titles = [None for i in range(len(data))] 36 | 37 | plt.tight_layout() 38 | 39 | for i in range(len(data)): 40 | mel = data[i] 41 | 42 | if isinstance(mel, torch.Tensor): 43 | mel = mel.float().detach().cpu().numpy() 44 | 45 | axes[i][0].imshow(mel, origin="lower") 46 | axes[i][0].set_aspect(2.5, adjustable="box") 47 | axes[i][0].set_ylim(0, mel.shape[0]) 48 | axes[i][0].set_title(titles[i], fontsize="medium") 49 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 50 | axes[i][0].set_anchor("W") 51 | 52 | return fig 53 | 54 | 55 | def slice_segments(x, ids_str, segment_size=4): 56 | ret = torch.zeros_like(x[:, :, :segment_size]) 57 | for i in range(x.size(0)): 58 | idx_str = ids_str[i] 59 | idx_end = idx_str + segment_size 60 | ret[i] = x[i, :, idx_str:idx_end] 61 | 62 | return ret 63 | 64 | 65 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 66 | b, d, t = x.size() 67 | if x_lengths is None: 68 | x_lengths = t 69 | ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) 70 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) 71 | ret = slice_segments(x, ids_str, segment_size) 72 | return ret, ids_str 73 | 74 | 75 | @torch.jit.script 76 | def fused_add_tanh_sigmoid_multiply(in_act, n_channels): 77 | n_channels_int = n_channels[0] 78 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 79 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 80 | acts = t_act * s_act 81 | 82 | return acts 83 | 84 | 85 | def avg_with_mask(x, mask): 86 | assert mask.dtype == torch.float, "Mask should be float" 87 | 88 | if mask.ndim == 2: 89 | mask = mask.unsqueeze(1) 90 | 91 | if mask.shape[1] == 1: 92 | mask = mask.expand_as(x) 93 | 94 | return (x * mask).sum() / mask.sum() 95 | -------------------------------------------------------------------------------- /fish_speech/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def get_cosine_schedule_with_warmup_lr_lambda( 5 | current_step: int, 6 | *, 7 | num_warmup_steps: int | float, 8 | num_training_steps: int, 9 | num_cycles: float = 0.5, 10 | final_lr_ratio: float = 0.0, 11 | ): 12 | if 0 < num_warmup_steps < 1: # float mode 13 | num_warmup_steps = int(num_warmup_steps * num_training_steps) 14 | 15 | if current_step < num_warmup_steps: 16 | return float(current_step) / float(max(1, num_warmup_steps)) 17 | 18 | progress = float(current_step - num_warmup_steps) / float( 19 | max(1, num_training_steps - num_warmup_steps) 20 | ) 21 | 22 | return max( 23 | final_lr_ratio, 24 | 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), 25 | ) 26 | -------------------------------------------------------------------------------- /fish_speech/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .clean import clean_text 2 | 3 | __all__ = ["clean_text"] 4 | -------------------------------------------------------------------------------- /fish_speech/text/clean.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | import string 4 | 5 | LANGUAGE_UNICODE_RANGE_MAP = { 6 | "ZH": [(0x4E00, 0x9FFF)], 7 | "JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)], 8 | "EN": [(0x0000, 0x007F)], 9 | } 10 | 11 | SYMBOLS_MAPPING = { 12 | ":": ",", 13 | ";": ",", 14 | ",": ",", 15 | "。": ".", 16 | "!": "!", 17 | "?": "?", 18 | "\n": ".", 19 | "·": ",", 20 | "、": ",", 21 | "...": "…", 22 | "“": "'", 23 | "”": "'", 24 | "‘": "'", 25 | "’": "'", 26 | "(": "'", 27 | ")": "'", 28 | "(": "'", 29 | ")": "'", 30 | "《": "'", 31 | "》": "'", 32 | "【": "'", 33 | "】": "'", 34 | "[": "'", 35 | "]": "'", 36 | "—": "-", 37 | "~": "-", 38 | "~": "-", 39 | "・": "-", 40 | "「": "'", 41 | "」": "'", 42 | ";": ",", 43 | ":": ",", 44 | } 45 | 46 | REPLACE_SYMBOL_REGEX = re.compile( 47 | "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) 48 | ) 49 | ALL_KNOWN_UTF8_RANGE = list( 50 | itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values()) 51 | ) 52 | REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile( 53 | "[^" 54 | + "".join( 55 | f"{re.escape(chr(start))}-{re.escape(chr(end))}" 56 | for start, end in ALL_KNOWN_UTF8_RANGE 57 | ) 58 | + "]" 59 | ) 60 | 61 | 62 | def clean_text(text): 63 | # Clean the text 64 | text = text.strip() 65 | 66 | # Replace all chinese symbols with their english counterparts 67 | text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) 68 | text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text) 69 | 70 | return text 71 | -------------------------------------------------------------------------------- /fish_speech/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Optional 4 | 5 | import hydra 6 | import lightning as L 7 | import pyrootutils 8 | import torch 9 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 10 | from lightning.pytorch.loggers import Logger 11 | from lightning.pytorch.strategies import DDPStrategy 12 | from omegaconf import DictConfig, OmegaConf 13 | 14 | os.environ.pop("SLURM_NTASKS", None) 15 | os.environ.pop("SLURM_JOB_NAME", None) 16 | os.environ.pop("SLURM_NTASKS_PER_NODE", None) 17 | 18 | # register eval resolver and root 19 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 20 | 21 | # Allow TF32 on Ampere GPUs 22 | torch.set_float32_matmul_precision("high") 23 | torch.backends.cudnn.allow_tf32 = True 24 | 25 | # register eval resolver 26 | OmegaConf.register_new_resolver("eval", eval) 27 | 28 | import fish_speech.utils as utils 29 | 30 | log = utils.RankedLogger(__name__, rank_zero_only=True) 31 | 32 | 33 | @utils.task_wrapper 34 | def train(cfg: DictConfig) -> tuple[dict, dict]: 35 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 36 | training. 37 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 38 | failure. Useful for multiruns, saving info about the crash, etc. 39 | Args: 40 | cfg (DictConfig): Configuration composed by Hydra. 41 | Returns: 42 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 43 | """ # noqa: E501 44 | 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=False) 48 | 49 | if cfg.get("deterministic"): 50 | torch.use_deterministic_algorithms(True) 51 | 52 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 53 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 54 | 55 | log.info(f"Instantiating model <{cfg.model._target_}>") 56 | model: LightningModule = hydra.utils.instantiate(cfg.model) 57 | 58 | log.info("Instantiating callbacks...") 59 | callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 60 | 61 | log.info("Instantiating loggers...") 62 | logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) 63 | 64 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 65 | trainer: Trainer = hydra.utils.instantiate( 66 | cfg.trainer, 67 | callbacks=callbacks, 68 | logger=logger, 69 | ) 70 | 71 | object_dict = { 72 | "cfg": cfg, 73 | "datamodule": datamodule, 74 | "model": model, 75 | "callbacks": callbacks, 76 | "logger": logger, 77 | "trainer": trainer, 78 | } 79 | 80 | if logger: 81 | log.info("Logging hyperparameters!") 82 | utils.log_hyperparameters(object_dict) 83 | 84 | if cfg.get("train"): 85 | log.info("Starting training!") 86 | 87 | ckpt_path = cfg.get("ckpt_path") 88 | auto_resume = False 89 | 90 | resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) 91 | if resume_ckpt_path is not None: 92 | ckpt_path = resume_ckpt_path 93 | auto_resume = True 94 | 95 | if ckpt_path is not None: 96 | log.info(f"Resuming from checkpoint: {ckpt_path}") 97 | 98 | # resume weights only is disabled for auto-resume 99 | if cfg.get("resume_weights_only") and auto_resume is False: 100 | log.info("Resuming weights only!") 101 | ckpt = torch.load(ckpt_path, map_location=model.device) 102 | if "state_dict" in ckpt: 103 | ckpt = ckpt["state_dict"] 104 | err = model.load_state_dict(ckpt, strict=False) 105 | log.info(f"Error loading state dict: {err}") 106 | ckpt_path = None 107 | 108 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 109 | 110 | train_metrics = trainer.callback_metrics 111 | 112 | if cfg.get("test"): 113 | log.info("Starting testing!") 114 | ckpt_path = trainer.checkpoint_callback.best_model_path 115 | if ckpt_path == "": 116 | log.warning("Best ckpt not found! Using current weights for testing...") 117 | ckpt_path = cfg.get("ckpt_path") 118 | 119 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 120 | log.info(f"Best ckpt path: {ckpt_path}") 121 | 122 | test_metrics = trainer.callback_metrics 123 | 124 | # merge train and test metrics 125 | metric_dict = {**train_metrics, **test_metrics} 126 | 127 | return metric_dict, object_dict 128 | 129 | 130 | @hydra.main( 131 | version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" 132 | ) 133 | def main(cfg: DictConfig) -> Optional[float]: 134 | # train the model 135 | train(cfg) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /fish_speech/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .braceexpand import braceexpand 2 | from .file import get_latest_checkpoint 3 | from .instantiators import instantiate_callbacks, instantiate_loggers 4 | from .logger import RankedLogger 5 | from .logging_utils import log_hyperparameters 6 | from .rich_utils import enforce_tags, print_config_tree 7 | from .utils import extras, get_metric_value, task_wrapper 8 | 9 | __all__ = [ 10 | "enforce_tags", 11 | "extras", 12 | "get_metric_value", 13 | "RankedLogger", 14 | "instantiate_callbacks", 15 | "instantiate_loggers", 16 | "log_hyperparameters", 17 | "print_config_tree", 18 | "task_wrapper", 19 | "braceexpand", 20 | "get_latest_checkpoint", 21 | ] 22 | -------------------------------------------------------------------------------- /fish_speech/utils/braceexpand.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bash-style brace expansion 3 | Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py 4 | License: MIT 5 | """ 6 | 7 | import re 8 | import string 9 | from itertools import chain, product 10 | from typing import Iterable, Iterator, Optional 11 | 12 | __all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] 13 | 14 | 15 | class UnbalancedBracesError(ValueError): 16 | pass 17 | 18 | 19 | alphabet = string.ascii_uppercase + string.ascii_lowercase 20 | 21 | int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") 22 | char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") 23 | escape_re = re.compile(r"\\(.)") 24 | 25 | 26 | def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: 27 | """braceexpand(pattern) -> iterator over generated strings 28 | 29 | Returns an iterator over the strings resulting from brace expansion 30 | of pattern. This function implements Brace Expansion as described in 31 | bash(1), with the following limitations: 32 | 33 | * A pattern containing unbalanced braces will raise an 34 | UnbalancedBracesError exception. In bash, unbalanced braces will either 35 | be partly expanded or ignored. 36 | 37 | * A mixed-case character range like '{Z..a}' or '{a..Z}' will not 38 | include the characters '[]^_`' between 'Z' and 'a'. 39 | 40 | When escape is True (the default), characters in pattern can be 41 | prefixed with a backslash to cause them not to be interpreted as 42 | special characters for brace expansion (such as '{', '}', ','). 43 | To pass through a a literal backslash, double it ('\\\\'). 44 | 45 | When escape is False, backslashes in pattern have no special 46 | meaning and will be preserved in the output. 47 | 48 | Examples: 49 | 50 | >>> from braceexpand import braceexpand 51 | 52 | # Integer range 53 | >>> list(braceexpand('item{1..3}')) 54 | ['item1', 'item2', 'item3'] 55 | 56 | # Character range 57 | >>> list(braceexpand('{a..c}')) 58 | ['a', 'b', 'c'] 59 | 60 | # Sequence 61 | >>> list(braceexpand('index.html{,.backup}')) 62 | ['index.html', 'index.html.backup'] 63 | 64 | # Nested patterns 65 | >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) 66 | ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] 67 | 68 | # Prefixing an integer with zero causes all numbers to be padded to 69 | # the same width. 70 | >>> list(braceexpand('{07..10}')) 71 | ['07', '08', '09', '10'] 72 | 73 | # An optional increment can be specified for ranges. 74 | >>> list(braceexpand('{a..g..2}')) 75 | ['a', 'c', 'e', 'g'] 76 | 77 | # Ranges can go in both directions. 78 | >>> list(braceexpand('{4..1}')) 79 | ['4', '3', '2', '1'] 80 | 81 | # Numbers can be negative 82 | >>> list(braceexpand('{2..-1}')) 83 | ['2', '1', '0', '-1'] 84 | 85 | # Unbalanced braces raise an exception. 86 | >>> list(braceexpand('{1{2,3}')) 87 | Traceback (most recent call last): 88 | ... 89 | UnbalancedBracesError: Unbalanced braces: '{1{2,3}' 90 | 91 | # By default, the backslash is the escape character. 92 | >>> list(braceexpand(r'{1\\{2,3}')) 93 | ['1{2', '3'] 94 | 95 | # Setting 'escape' to False disables backslash escaping. 96 | >>> list(braceexpand(r'\\{1,2}', escape=False)) 97 | ['\\\\1', '\\\\2'] 98 | 99 | """ 100 | return ( 101 | escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) 102 | ) 103 | 104 | 105 | def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: 106 | start = 0 107 | pos = 0 108 | bracketdepth = 0 109 | items: list[Iterable[str]] = [] 110 | 111 | # print 'pattern:', pattern 112 | while pos < len(pattern): 113 | if escape and pattern[pos] == "\\": 114 | pos += 2 115 | continue 116 | elif pattern[pos] == "{": 117 | if bracketdepth == 0 and pos > start: 118 | # print 'literal:', pattern[start:pos] 119 | items.append([pattern[start:pos]]) 120 | start = pos 121 | bracketdepth += 1 122 | elif pattern[pos] == "}": 123 | bracketdepth -= 1 124 | if bracketdepth == 0: 125 | # print 'expression:', pattern[start+1:pos] 126 | expr = pattern[start + 1 : pos] 127 | item = parse_expression(expr, escape) 128 | if item is None: # not a range or sequence 129 | items.extend([["{"], parse_pattern(expr, escape), ["}"]]) 130 | else: 131 | items.append(item) 132 | start = pos + 1 # skip the closing brace 133 | pos += 1 134 | 135 | if bracketdepth != 0: # unbalanced braces 136 | raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) 137 | 138 | if start < pos: 139 | items.append([pattern[start:]]) 140 | 141 | return ("".join(item) for item in product(*items)) 142 | 143 | 144 | def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: 145 | int_range_match = int_range_re.match(expr) 146 | if int_range_match: 147 | return make_int_range(*int_range_match.groups()) 148 | 149 | char_range_match = char_range_re.match(expr) 150 | if char_range_match: 151 | return make_char_range(*char_range_match.groups()) 152 | 153 | return parse_sequence(expr, escape) 154 | 155 | 156 | def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: 157 | # sequence -> chain(*sequence_items) 158 | start = 0 159 | pos = 0 160 | bracketdepth = 0 161 | items: list[Iterable[str]] = [] 162 | 163 | # print 'sequence:', seq 164 | while pos < len(seq): 165 | if escape and seq[pos] == "\\": 166 | pos += 2 167 | continue 168 | elif seq[pos] == "{": 169 | bracketdepth += 1 170 | elif seq[pos] == "}": 171 | bracketdepth -= 1 172 | elif seq[pos] == "," and bracketdepth == 0: 173 | items.append(parse_pattern(seq[start:pos], escape)) 174 | start = pos + 1 # skip the comma 175 | pos += 1 176 | 177 | if bracketdepth != 0: 178 | raise UnbalancedBracesError 179 | if not items: 180 | return None 181 | 182 | # part after the last comma (may be the empty string) 183 | items.append(parse_pattern(seq[start:], escape)) 184 | return chain(*items) 185 | 186 | 187 | def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: 188 | if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): 189 | padding = max(len(left), len(right)) 190 | else: 191 | padding = 0 192 | step = (int(incr) or 1) if incr else 1 193 | start = int(left) 194 | end = int(right) 195 | r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) 196 | fmt = "%0{}d".format(padding) 197 | return (fmt % i for i in r) 198 | 199 | 200 | def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: 201 | step = (int(incr) or 1) if incr else 1 202 | start = alphabet.index(left) 203 | end = alphabet.index(right) 204 | if start < end: 205 | return alphabet[start : end + 1 : step] 206 | else: 207 | end = end or -len(alphabet) 208 | return alphabet[start : end - 1 : -step] 209 | 210 | 211 | if __name__ == "__main__": 212 | import doctest 213 | import sys 214 | 215 | failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) 216 | if failed: 217 | sys.exit(1) 218 | -------------------------------------------------------------------------------- /fish_speech/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from loguru import logger 7 | from natsort import natsorted 8 | 9 | AUDIO_EXTENSIONS = { 10 | ".mp3", 11 | ".wav", 12 | ".flac", 13 | ".ogg", 14 | ".m4a", 15 | ".wma", 16 | ".aac", 17 | ".aiff", 18 | ".aif", 19 | ".aifc", 20 | } 21 | 22 | 23 | def list_files( 24 | path: Union[Path, str], 25 | extensions: set[str] = None, 26 | recursive: bool = False, 27 | sort: bool = True, 28 | ) -> list[Path]: 29 | """List files in a directory. 30 | 31 | Args: 32 | path (Path): Path to the directory. 33 | extensions (set, optional): Extensions to filter. Defaults to None. 34 | recursive (bool, optional): Whether to search recursively. Defaults to False. 35 | sort (bool, optional): Whether to sort the files. Defaults to True. 36 | 37 | Returns: 38 | list: List of files. 39 | """ 40 | 41 | if isinstance(path, str): 42 | path = Path(path) 43 | 44 | if not path.exists(): 45 | raise FileNotFoundError(f"Directory {path} does not exist.") 46 | 47 | files = [file for ext in extensions for file in path.rglob(f"*{ext}")] 48 | 49 | if sort: 50 | files = natsorted(files) 51 | 52 | return files 53 | 54 | 55 | def get_latest_checkpoint(path: Path | str) -> Path | None: 56 | # Find the latest checkpoint 57 | ckpt_dir = Path(path) 58 | 59 | if ckpt_dir.exists() is False: 60 | return None 61 | 62 | ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) 63 | if len(ckpts) == 0: 64 | return None 65 | 66 | return ckpts[-1] 67 | 68 | 69 | def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: 70 | """ 71 | Load a Bert-VITS2 style filelist. 72 | """ 73 | 74 | files = set() 75 | results = [] 76 | count_duplicated, count_not_found = 0, 0 77 | 78 | LANGUAGE_TO_LANGUAGES = { 79 | "zh": ["zh", "en"], 80 | "jp": ["jp", "en"], 81 | "en": ["en"], 82 | } 83 | 84 | with open(path, "r", encoding="utf-8") as f: 85 | for line in f.readlines(): 86 | splits = line.strip().split("|", maxsplit=3) 87 | if len(splits) != 4: 88 | logger.warning(f"Invalid line: {line}") 89 | continue 90 | 91 | filename, speaker, language, text = splits 92 | file = Path(filename) 93 | language = language.strip().lower() 94 | 95 | if language == "ja": 96 | language = "jp" 97 | 98 | assert language in ["zh", "jp", "en"], f"Invalid language {language}" 99 | languages = LANGUAGE_TO_LANGUAGES[language] 100 | 101 | if file in files: 102 | logger.warning(f"Duplicated file: {file}") 103 | count_duplicated += 1 104 | continue 105 | 106 | if not file.exists(): 107 | logger.warning(f"File not found: {file}") 108 | count_not_found += 1 109 | continue 110 | 111 | results.append((file, speaker, languages, text)) 112 | 113 | if count_duplicated > 0: 114 | logger.warning(f"Total duplicated files: {count_duplicated}") 115 | 116 | if count_not_found > 0: 117 | logger.warning(f"Total files not found: {count_not_found}") 118 | 119 | return results 120 | -------------------------------------------------------------------------------- /fish_speech/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.loggers import Logger 7 | 8 | from .logger import RankedLogger 9 | 10 | log = RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config.""" 15 | 16 | callbacks: List[Callback] = [] 17 | 18 | if not callbacks_cfg: 19 | log.warning("No callback configs found! Skipping..") 20 | return callbacks 21 | 22 | if not isinstance(callbacks_cfg, DictConfig): 23 | raise TypeError("Callbacks config must be a DictConfig!") 24 | 25 | for _, cb_conf in callbacks_cfg.items(): 26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 27 | log.info(f"Instantiating callback <{cb_conf._target_}>") 28 | callbacks.append(hydra.utils.instantiate(cb_conf)) 29 | 30 | return callbacks 31 | 32 | 33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 34 | """Instantiates loggers from config.""" 35 | 36 | logger: List[Logger] = [] 37 | 38 | if not logger_cfg: 39 | log.warning("No logger configs found! Skipping...") 40 | return logger 41 | 42 | if not isinstance(logger_cfg, DictConfig): 43 | raise TypeError("Logger config must be a DictConfig!") 44 | 45 | for _, lg_conf in logger_cfg.items(): 46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 47 | log.info(f"Instantiating logger <{lg_conf._target_}>") 48 | logger.append(hydra.utils.instantiate(lg_conf)) 49 | 50 | return logger 51 | -------------------------------------------------------------------------------- /fish_speech/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = True, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log( 28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 29 | ) -> None: 30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 31 | of the process it's being logged from. If `'rank'` is provided, then the log will only 32 | occur on that rank/process. 33 | 34 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 35 | :param msg: The message to log. 36 | :param rank: The rank to log at. 37 | :param args: Additional args to pass to the underlying logging function. 38 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 39 | """ 40 | if self.isEnabledFor(level): 41 | msg, kwargs = self.process(msg, kwargs) 42 | current_rank = getattr(rank_zero_only, "rank", None) 43 | if current_rank is None: 44 | raise RuntimeError( 45 | "The `rank_zero_only.rank` needs to be set before use" 46 | ) 47 | msg = rank_prefixed_message(msg, current_rank) 48 | if self.rank_zero_only: 49 | if current_rank == 0: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | else: 52 | if rank is None: 53 | self.logger.log(level, msg, *args, **kwargs) 54 | elif current_rank == rank: 55 | self.logger.log(level, msg, *args, **kwargs) 56 | -------------------------------------------------------------------------------- /fish_speech/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from fish_speech.utils import logger as log 4 | 5 | 6 | @rank_zero_only 7 | def log_hyperparameters(object_dict: dict) -> None: 8 | """Controls which config parts are saved by lightning loggers. 9 | 10 | Additionally saves: 11 | - Number of model parameters 12 | """ 13 | 14 | hparams = {} 15 | 16 | cfg = object_dict["cfg"] 17 | model = object_dict["model"] 18 | trainer = object_dict["trainer"] 19 | 20 | if not trainer.logger: 21 | log.warning("Logger not found! Skipping hyperparameter logging...") 22 | return 23 | 24 | hparams["model"] = cfg["model"] 25 | 26 | # save number of model parameters 27 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 28 | hparams["model/params/trainable"] = sum( 29 | p.numel() for p in model.parameters() if p.requires_grad 30 | ) 31 | hparams["model/params/non_trainable"] = sum( 32 | p.numel() for p in model.parameters() if not p.requires_grad 33 | ) 34 | 35 | hparams["data"] = cfg["data"] 36 | hparams["trainer"] = cfg["trainer"] 37 | 38 | hparams["callbacks"] = cfg.get("callbacks") 39 | hparams["extras"] = cfg.get("extras") 40 | 41 | hparams["task_name"] = cfg.get("task_name") 42 | hparams["tags"] = cfg.get("tags") 43 | hparams["ckpt_path"] = cfg.get("ckpt_path") 44 | hparams["seed"] = cfg.get("seed") 45 | 46 | # send hparams to all loggers 47 | for logger in trainer.loggers: 48 | logger.log_hyperparams(hparams) 49 | -------------------------------------------------------------------------------- /fish_speech/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from fish_speech.utils import logger as log 13 | 14 | 15 | @rank_zero_only 16 | def print_config_tree( 17 | cfg: DictConfig, 18 | print_order: Sequence[str] = ( 19 | "data", 20 | "model", 21 | "callbacks", 22 | "logger", 23 | "trainer", 24 | "paths", 25 | "extras", 26 | ), 27 | resolve: bool = False, 28 | save_to_file: bool = False, 29 | ) -> None: 30 | """Prints content of DictConfig using Rich library and its tree structure. 31 | 32 | Args: 33 | cfg (DictConfig): Configuration composed by Hydra. 34 | print_order (Sequence[str], optional): Determines in what order config components are printed. 35 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 36 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 37 | """ # noqa: E501 38 | 39 | style = "dim" 40 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 41 | 42 | queue = [] 43 | 44 | # add fields from `print_order` to queue 45 | for field in print_order: 46 | ( 47 | queue.append(field) 48 | if field in cfg 49 | else log.warning( 50 | f"Field '{field}' not found in config. " 51 | + f"Skipping '{field}' config printing..." 52 | ) 53 | ) 54 | 55 | # add all the other fields to queue (not specified in `print_order`) 56 | for field in cfg: 57 | if field not in queue: 58 | queue.append(field) 59 | 60 | # generate config tree from queue 61 | for field in queue: 62 | branch = tree.add(field, style=style, guide_style=style) 63 | 64 | config_group = cfg[field] 65 | if isinstance(config_group, DictConfig): 66 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 67 | else: 68 | branch_content = str(config_group) 69 | 70 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 71 | 72 | # print config tree 73 | rich.print(tree) 74 | 75 | # save config tree to file 76 | if save_to_file: 77 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 78 | rich.print(tree, file=file) 79 | 80 | 81 | @rank_zero_only 82 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 83 | """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 84 | 85 | if not cfg.get("tags"): 86 | if "id" in HydraConfig().cfg.hydra.job: 87 | raise ValueError("Specify tags before launching a multirun!") 88 | 89 | log.warning("No tags provided in config. Prompting user to input tags...") 90 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 91 | tags = [t.strip() for t in tags.split(",") if t != ""] 92 | 93 | with open_dict(cfg): 94 | cfg.tags = tags 95 | 96 | log.info(f"Tags: {cfg.tags}") 97 | 98 | if save_to_file: 99 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 100 | rich.print(cfg.tags, file=file) 101 | -------------------------------------------------------------------------------- /fish_speech/utils/spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio.functional as F 3 | from torch import Tensor, nn 4 | from torchaudio.transforms import MelScale 5 | 6 | 7 | class LinearSpectrogram(nn.Module): 8 | def __init__( 9 | self, 10 | n_fft=2048, 11 | win_length=2048, 12 | hop_length=512, 13 | center=False, 14 | mode="pow2_sqrt", 15 | ): 16 | super().__init__() 17 | 18 | self.n_fft = n_fft 19 | self.win_length = win_length 20 | self.hop_length = hop_length 21 | self.center = center 22 | self.mode = mode 23 | 24 | self.register_buffer("window", torch.hann_window(win_length), persistent=False) 25 | 26 | def forward(self, y: Tensor) -> Tensor: 27 | if y.ndim == 3: 28 | y = y.squeeze(1) 29 | 30 | y = torch.nn.functional.pad( 31 | y.unsqueeze(1), 32 | ( 33 | (self.win_length - self.hop_length) // 2, 34 | (self.win_length - self.hop_length + 1) // 2, 35 | ), 36 | mode="reflect", 37 | ).squeeze(1) 38 | 39 | spec = torch.stft( 40 | y, 41 | self.n_fft, 42 | hop_length=self.hop_length, 43 | win_length=self.win_length, 44 | window=self.window, 45 | center=self.center, 46 | pad_mode="reflect", 47 | normalized=False, 48 | onesided=True, 49 | return_complex=True, 50 | ) 51 | 52 | spec = torch.view_as_real(spec) 53 | 54 | if self.mode == "pow2_sqrt": 55 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 56 | 57 | return spec 58 | 59 | 60 | class LogMelSpectrogram(nn.Module): 61 | def __init__( 62 | self, 63 | sample_rate=44100, 64 | n_fft=2048, 65 | win_length=2048, 66 | hop_length=512, 67 | n_mels=128, 68 | center=False, 69 | f_min=0.0, 70 | f_max=None, 71 | ): 72 | super().__init__() 73 | 74 | self.sample_rate = sample_rate 75 | self.n_fft = n_fft 76 | self.win_length = win_length 77 | self.hop_length = hop_length 78 | self.center = center 79 | self.n_mels = n_mels 80 | self.f_min = f_min 81 | self.f_max = f_max or float(sample_rate // 2) 82 | 83 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 84 | 85 | fb = F.melscale_fbanks( 86 | n_freqs=self.n_fft // 2 + 1, 87 | f_min=self.f_min, 88 | f_max=self.f_max, 89 | n_mels=self.n_mels, 90 | sample_rate=self.sample_rate, 91 | norm="slaney", 92 | mel_scale="slaney", 93 | ) 94 | self.register_buffer( 95 | "fb", 96 | fb, 97 | persistent=False, 98 | ) 99 | 100 | def compress(self, x: Tensor) -> Tensor: 101 | return torch.log(torch.clamp(x, min=1e-5)) 102 | 103 | def decompress(self, x: Tensor) -> Tensor: 104 | return torch.exp(x) 105 | 106 | def apply_mel_scale(self, x: Tensor) -> Tensor: 107 | return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) 108 | 109 | def forward( 110 | self, x: Tensor, return_linear: bool = False, sample_rate: int = None 111 | ) -> Tensor: 112 | if sample_rate is not None and sample_rate != self.sample_rate: 113 | x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) 114 | 115 | linear = self.spectrogram(x) 116 | x = self.apply_mel_scale(linear) 117 | x = self.compress(x) 118 | 119 | if return_linear: 120 | return x, self.compress(linear) 121 | 122 | return x 123 | -------------------------------------------------------------------------------- /fish_speech/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from typing import Callable 4 | 5 | from omegaconf import DictConfig 6 | 7 | from .logger import RankedLogger 8 | from .rich_utils import enforce_tags, print_config_tree 9 | 10 | log = RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def extras(cfg: DictConfig) -> None: 14 | """Applies optional utilities before the task is started. 15 | 16 | Utilities: 17 | - Ignoring python warnings 18 | - Setting tags from command line 19 | - Rich config printing 20 | """ 21 | 22 | # return if no `extras` config 23 | if not cfg.get("extras"): 24 | log.warning("Extras config not found! ") 25 | return 26 | 27 | # disable python warnings 28 | if cfg.extras.get("ignore_warnings"): 29 | log.info("Disabling python warnings! ") 30 | warnings.filterwarnings("ignore") 31 | 32 | # prompt user to input tags from command line if none are provided in the config 33 | if cfg.extras.get("enforce_tags"): 34 | log.info("Enforcing tags! ") 35 | enforce_tags(cfg, save_to_file=True) 36 | 37 | # pretty print config tree using Rich library 38 | if cfg.extras.get("print_config"): 39 | log.info("Printing config tree with Rich! ") 40 | print_config_tree(cfg, resolve=True, save_to_file=True) 41 | 42 | 43 | def task_wrapper(task_func: Callable) -> Callable: 44 | """Optional decorator that controls the failure behavior when executing the task function. 45 | 46 | This wrapper can be used to: 47 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 48 | - save the exception to a `.log` file 49 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 50 | - etc. (adjust depending on your needs) 51 | 52 | Example: 53 | ``` 54 | @utils.task_wrapper 55 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 56 | 57 | ... 58 | 59 | return metric_dict, object_dict 60 | ``` 61 | """ # noqa: E501 62 | 63 | def wrap(cfg: DictConfig): 64 | # execute the task 65 | try: 66 | metric_dict, object_dict = task_func(cfg=cfg) 67 | 68 | # things to do if exception occurs 69 | except Exception as ex: 70 | # save exception to `.log` file 71 | log.exception("") 72 | 73 | # some hyperparameter combinations might be invalid or 74 | # cause out-of-memory errors so when using hparam search 75 | # plugins like Optuna, you might want to disable 76 | # raising the below exception to avoid multirun failure 77 | raise ex 78 | 79 | # things to always do after either success or exception 80 | finally: 81 | # display output dir path in terminal 82 | log.info(f"Output dir: {cfg.paths.run_dir}") 83 | 84 | # always close wandb run (even if exception occurs so multirun won't fail) 85 | if find_spec("wandb"): # check if wandb is installed 86 | import wandb 87 | 88 | if wandb.run: 89 | log.info("Closing wandb!") 90 | wandb.finish() 91 | 92 | return metric_dict, object_dict 93 | 94 | return wrap 95 | 96 | 97 | def get_metric_value(metric_dict: dict, metric_name: str) -> float: 98 | """Safely retrieves value of the metric logged in LightningModule.""" 99 | 100 | if not metric_name: 101 | log.info("Metric name is None! Skipping metric value retrieval...") 102 | return None 103 | 104 | if metric_name not in metric_dict: 105 | raise Exception( 106 | f"Metric value not found! \n" 107 | "Make sure metric name logged in LightningModule is correct!\n" 108 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 109 | ) 110 | 111 | metric_value = metric_dict[metric_name].item() 112 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 113 | 114 | return metric_value 115 | -------------------------------------------------------------------------------- /fish_speech/webui/css/style.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --my-200: #80eeee; 3 | --my-50: #ecfdf5; 4 | --water-width: 300px; 5 | --water-heigh: 300px; 6 | } 7 | 8 | 9 | /* general styled components */ 10 | .tools { 11 | align-items: center; 12 | justify-content: center; 13 | } 14 | 15 | .gradio-button { 16 | max-width: 2.2em; 17 | min-width: 2.2em !important; 18 | height: 2.4em; 19 | align-self: end; 20 | line-height: 1em; 21 | border-radius: 0.5em; 22 | 23 | } 24 | 25 | .gradio-button.secondary-down, .gradio-button.secondary-down:hover{ 26 | box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; 27 | } 28 | 29 | /* replace original footer with ours */ 30 | a{ 31 | font-weight: bold; 32 | cursor: pointer; 33 | color: #030C14 !important; 34 | } 35 | 36 | footer { 37 | display: none !important; 38 | } 39 | 40 | #footer{ 41 | text-align: center; 42 | } 43 | 44 | #footer div{ 45 | display: inline-block; 46 | } 47 | 48 | #footer .versions{ 49 | font-size: 85%; 50 | opacity: 0.85; 51 | } 52 | 53 | /*@keyframes moveBackground {*/ 54 | /* 0% {*/ 55 | /* background-position: 0 0;*/ 56 | /* }*/ 57 | /* 100% {*/ 58 | /* background-position: -100px 100px;*/ 59 | /* }*/ 60 | /*}*/ 61 | @keyframes moveJellyBackground { 62 | 0% { 63 | background-position: 0% 50%; 64 | } 65 | 50% { 66 | background-position: 100% 50%; 67 | } 68 | 100% { 69 | background-position: 0% 50%; 70 | } 71 | } 72 | 73 | .gradio-container { 74 | position: absolute; 75 | z-index: 10; 76 | } 77 | 78 | 79 | .quan { 80 | position: absolute; 81 | bottom: 0; 82 | width: var(--water-width); 83 | height: var(--water-heigh); 84 | border-radius: 0; 85 | /*border: 3px solid rgb(246, 247, 248);*/ 86 | /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ 87 | z-index: 0; 88 | 89 | } 90 | 91 | .quan:last-child { 92 | margin-right: 0; 93 | } 94 | 95 | .shui { 96 | position: absolute; 97 | top: 0; 98 | left: 0; 99 | width: 100%; 100 | height: 100%; 101 | background-color: rgb(23, 106, 201); 102 | border-radius: 0; 103 | overflow: hidden; 104 | z-index: 0; 105 | } 106 | 107 | .shui::after { 108 | 109 | content: ''; 110 | position: absolute; 111 | top: 20%; 112 | left: 50%; 113 | width: 150%; 114 | height: 150%; 115 | border-radius: 40%; 116 | background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); 117 | animation: shi 5s linear infinite; 118 | } 119 | 120 | @keyframes shi { 121 | 0% { 122 | transform: translate(-50%, -65%) rotate(0deg); 123 | } 124 | 100% { 125 | transform: translate(-50%, -65%) rotate(360deg); 126 | } 127 | } 128 | 129 | .shui::before { 130 | content: ''; 131 | position: absolute; 132 | top: 20%; 133 | left: 50%; 134 | width: 150%; 135 | height: 150%; 136 | border-radius: 42%; 137 | background-color: rgb(240, 228, 228, 0.2); 138 | animation: xu 7s linear infinite; 139 | } 140 | 141 | @keyframes xu { 142 | 0% { 143 | transform: translate(-50%, -60%) rotate(0deg); 144 | } 145 | 100% { 146 | transform: translate(-50%, -60%) rotate(360deg); 147 | } 148 | } 149 | 150 | fieldset.data_src div.wrap label { 151 | background: #f8bffee0 !important; 152 | } 153 | 154 | .scrollable-component { 155 | max-height: 100px; 156 | overflow-y: auto; 157 | } 158 | 159 | #file_accordion { 160 | max-height: 220px !important; 161 | } 162 | -------------------------------------------------------------------------------- /fish_speech/webui/html/footer.html: -------------------------------------------------------------------------------- 1 |
2 | API 3 |  •  4 | Github 5 |  •  6 | Gradio 7 |
8 |
9 |
10 | {versions} 11 |
12 | -------------------------------------------------------------------------------- /fish_speech/webui/js/animate.js: -------------------------------------------------------------------------------- 1 | 2 | function createGradioAnimation() { 3 | const params = new URLSearchParams(window.location.search); 4 | if (!params.has('__theme')) { 5 | params.set('__theme', 'light'); 6 | window.location.search = params.toString(); 7 | } 8 | 9 | var gradioApp = document.querySelector('gradio-app'); 10 | if (gradioApp) { 11 | 12 | document.documentElement.style.setProperty('--my-200', '#80eeee'); 13 | document.documentElement.style.setProperty('--my-50', '#ecfdf5'); 14 | 15 | // gradioApp.style.position = 'relative'; 16 | // gradioApp.style.backgroundSize = '200% 200%'; 17 | // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; 18 | // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; 19 | // gradioApp.style.display = 'flex'; 20 | // gradioApp.style.justifyContent = 'flex-start'; 21 | // gradioApp.style.flexWrap = 'nowrap'; 22 | // gradioApp.style.overflowX = 'auto'; 23 | 24 | // for (let i = 0; i < 6; i++) { 25 | // var quan = document.createElement('div'); 26 | // quan.className = 'quan'; 27 | // gradioApp.insertBefore(quan, gradioApp.firstChild); 28 | // quan.id = 'quan' + i.toString(); 29 | // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; 30 | // var quanContainer = document.querySelector('.quan'); 31 | // if (quanContainer) { 32 | // var shui = document.createElement('div'); 33 | // shui.className = 'shui'; 34 | // quanContainer.insertBefore(shui, quanContainer.firstChild) 35 | // } 36 | // } 37 | } 38 | 39 | var container = document.createElement('div'); 40 | container.id = 'gradio-animation'; 41 | container.style.fontSize = '2em'; 42 | container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; 43 | container.style.fontWeight = 'bold'; 44 | container.style.textAlign = 'center'; 45 | container.style.marginBottom = '20px'; 46 | 47 | var text = 'Welcome to Fish-Speech!'; 48 | for (var i = 0; i < text.length; i++) { 49 | (function(i){ 50 | setTimeout(function(){ 51 | var letter = document.createElement('span'); 52 | letter.style.opacity = '0'; 53 | letter.style.transition = 'opacity 0.5s'; 54 | letter.innerText = text[i]; 55 | 56 | container.appendChild(letter); 57 | 58 | setTimeout(function() { 59 | letter.style.opacity = '1'; 60 | }, 50); 61 | }, i * 200); 62 | })(i); 63 | } 64 | 65 | var gradioContainer = document.querySelector('.gradio-container'); 66 | gradioContainer.insertBefore(container, gradioContainer.firstChild); 67 | 68 | return 'Animation created'; 69 | } 70 | -------------------------------------------------------------------------------- /fish_speech/webui/launch_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import os 3 | import subprocess 4 | import sys 5 | from functools import lru_cache 6 | from pathlib import Path 7 | from typing import Iterable 8 | 9 | import gradio as gr 10 | from gradio.themes.base import Base 11 | from gradio.themes.utils import colors, fonts, sizes 12 | 13 | GIT = ( 14 | (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() 15 | if sys.platform == "win32" 16 | else "git" 17 | ) 18 | GIT = str(GIT) 19 | 20 | 21 | def is_module_installed(module_name: str) -> bool: 22 | spec = importlib.util.find_spec(module_name) 23 | return spec is not None 24 | 25 | 26 | @lru_cache() 27 | def commit_hash(): 28 | try: 29 | return subprocess.check_output( 30 | [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" 31 | ).strip() 32 | except Exception: 33 | return "" 34 | 35 | 36 | def versions_html(): 37 | import torch 38 | 39 | python_version = ".".join([str(x) for x in sys.version_info[0:3]]) 40 | commit = commit_hash() 41 | hash = commit.strip("'").split(" ")[0] 42 | 43 | return f""" 44 | version: {hash} 45 |  •  46 | python: {python_version} 47 |  •  48 | torch: {getattr(torch, '__long_version__',torch.__version__)} 49 |  •  50 | gradio: {gr.__version__} 51 |  •  52 | author: fishaudio 53 | """ 54 | 55 | 56 | def version_check(commit): 57 | try: 58 | import requests 59 | 60 | commits = requests.get( 61 | "https://api.github.com/repos/fishaudio/fish-speech/branches/main" 62 | ).json() 63 | if commit != "" and commits["commit"]["sha"] != commit: 64 | print("--------------------------------------------------------") 65 | print("| You are not up to date with the most recent release. |") 66 | print("| Consider running `git pull` to update. |") 67 | print("--------------------------------------------------------") 68 | elif commits["commit"]["sha"] == commit: 69 | print("You are up to date with the most recent release.") 70 | else: 71 | print("Not a git clone, can't perform version check.") 72 | except Exception as e: 73 | print("version check failed", e) 74 | 75 | 76 | class Seafoam(Base): 77 | def __init__( 78 | self, 79 | *, 80 | primary_hue: colors.Color | str = colors.emerald, 81 | secondary_hue: colors.Color | str = colors.blue, 82 | neutral_hue: colors.Color | str = colors.blue, 83 | spacing_size: sizes.Size | str = sizes.spacing_md, 84 | radius_size: sizes.Size | str = sizes.radius_md, 85 | text_size: sizes.Size | str = sizes.text_lg, 86 | font: fonts.Font | str | Iterable[fonts.Font | str] = ( 87 | fonts.GoogleFont("Quicksand"), 88 | "ui-sans-serif", 89 | "sans-serif", 90 | ), 91 | font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( 92 | fonts.GoogleFont("IBM Plex Mono"), 93 | "ui-monospace", 94 | "monospace", 95 | ), 96 | ): 97 | super().__init__( 98 | primary_hue=primary_hue, 99 | secondary_hue=secondary_hue, 100 | neutral_hue=neutral_hue, 101 | spacing_size=spacing_size, 102 | radius_size=radius_size, 103 | text_size=text_size, 104 | font=font, 105 | font_mono=font_mono, 106 | ) 107 | super().set( 108 | button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", 109 | button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", 110 | button_primary_text_color="white", 111 | button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", 112 | slider_color="*secondary_300", 113 | slider_color_dark="*secondary_600", 114 | block_title_text_weight="600", 115 | block_border_width="3px", 116 | block_shadow="*shadow_drop_lg", 117 | button_shadow="*shadow_drop_lg", 118 | button_small_padding="0px", 119 | button_large_padding="3px", 120 | ) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.35.2 2 | datasets>=2.14.5 3 | lightning>=2.1.0 4 | hydra-core>=1.3.2 5 | tensorboard>=2.14.1 6 | natsort>=8.4.0 7 | einops>=0.7.0 8 | librosa>=0.10.1 9 | rich>=13.5.3 10 | wandb>=0.15.11 11 | grpcio>=1.58.0 12 | kui>=1.6.0 13 | zibai-server>=0.9.0 14 | loguru>=0.6.0 15 | loralib>=0.1.2 16 | natsort>=8.4.0 17 | pyrootutils>=1.0.4 18 | vector_quantize_pytorch>=1.14.7 19 | resampy>=0.4.3 20 | einx[torch]==0.2.2 21 | srt 22 | pydub 23 | audiotsm -------------------------------------------------------------------------------- /tools/extract_model.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | from loguru import logger 4 | 5 | 6 | @click.command() 7 | @click.argument("model_path") 8 | @click.argument("output_path") 9 | def main(model_path, output_path): 10 | if model_path == output_path: 11 | logger.error("Model path and output path are the same") 12 | return 13 | 14 | logger.info(f"Loading model from {model_path}") 15 | state_dict = torch.load(model_path, map_location="cpu")["state_dict"] 16 | torch.save(state_dict, output_path) 17 | logger.info(f"Model saved to {output_path}") 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /tools/llama/build_dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import re 4 | from collections import defaultdict 5 | from functools import partial 6 | from multiprocessing import Pool 7 | from pathlib import Path 8 | 9 | import click 10 | import numpy as np 11 | from loguru import logger 12 | from tqdm import tqdm 13 | 14 | from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData 15 | from fish_speech.datasets.protos.text_data_stream import pack_pb_stream 16 | from fish_speech.utils.file import load_filelist 17 | 18 | # To avoid CPU overload 19 | os.environ["MKL_NUM_THREADS"] = "1" 20 | os.environ["OMP_NUM_THREADS"] = "1" 21 | 22 | 23 | def task_generator_folder(root: Path, text_extension: str): 24 | files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) 25 | files = sorted(files) 26 | 27 | grouped_files = defaultdict(list) 28 | for file in tqdm(files, desc=f"Grouping {root}"): 29 | p = str(file.parent) 30 | speaker = file.parent.name 31 | 32 | try: 33 | if isinstance(text_extension, str): 34 | texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] 35 | else: 36 | texts = [ 37 | file.with_suffix(ext).read_text(encoding="utf-8") 38 | for ext in text_extension 39 | ] 40 | except Exception as e: 41 | logger.error(f"Failed to read text {file}: {e}") 42 | continue 43 | 44 | grouped_files[p].append((speaker, file, texts)) 45 | 46 | logger.info( 47 | f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." 48 | ) 49 | 50 | for i in grouped_files.values(): 51 | subset = [(f, t) for _, f, t in i] 52 | yield i[0][0], subset, "folder" 53 | 54 | 55 | def task_generator_filelist(filelist): 56 | grouped_files = defaultdict(list) 57 | for filename, speaker, _, text in load_filelist(filelist): 58 | grouped_files[speaker].append((Path(filename), [text])) 59 | 60 | logger.info(f"Found {len(grouped_files)} groups in {filelist}") 61 | for speaker, values in grouped_files.items(): 62 | yield speaker, values, "filelist" 63 | 64 | 65 | def run_task(task): 66 | name, subset, source = task 67 | 68 | # Parse the files 69 | sentences = [] 70 | for file, texts in subset: 71 | np_file = file.with_suffix(".npy") 72 | if np_file.exists() is False: 73 | logger.warning(f"Can't find {np_file}") 74 | continue 75 | 76 | new_texts = [] 77 | 78 | for text in texts: 79 | # Simple cleaning: replace { xxx } and < xxx > with space 80 | text = re.sub(r"\{.*?\}", " ", text) 81 | text = re.sub(r"<.*?>", " ", text) 82 | text = re.sub(r"\s+", " ", text) 83 | new_texts.append(text) 84 | 85 | try: 86 | semantics = np.load(np_file) 87 | except Exception as e: 88 | logger.error(f"Failed to parse {file}: {e}") 89 | continue 90 | 91 | if isinstance(semantics, np.ndarray): 92 | semantics = semantics.tolist() 93 | 94 | sentences.append( 95 | Sentence( 96 | texts=new_texts, 97 | semantics=[Semantics(values=s) for s in semantics], 98 | ) 99 | ) 100 | 101 | # Pack the sentences 102 | return pack_pb_stream( 103 | TextData( 104 | source=source, 105 | name=name, 106 | sentences=sentences, 107 | ) 108 | ) 109 | 110 | 111 | @click.command() 112 | @click.option( 113 | "--input", 114 | type=click.Path(path_type=Path), 115 | required=True, 116 | help="A folder containing the dataset or a filelist", 117 | multiple=True, 118 | ) 119 | @click.option( 120 | "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" 121 | ) 122 | @click.option("--num-workers", type=int, default=16) 123 | @click.option("--text-extension", type=str, default=[".txt"], multiple=True) 124 | @click.option( 125 | "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" 126 | ) 127 | def main(input, output, num_workers, text_extension, shard_size): 128 | generator_fns = [] 129 | 130 | for f in input: 131 | assert f.exists(), f"{f} not found" 132 | 133 | if f.is_dir(): 134 | generator_fn = task_generator_folder(f, text_extension) 135 | else: 136 | generator_fn = task_generator_filelist(f) 137 | 138 | generator_fns.append(generator_fn) 139 | 140 | generator_fn = itertools.chain(*generator_fns) 141 | output.mkdir(parents=True, exist_ok=True) 142 | 143 | dataset_fp = None 144 | tar_idx = 0 145 | written_size = 0 146 | 147 | with Pool(num_workers) as p: 148 | for result in tqdm(p.imap_unordered(run_task, generator_fn)): 149 | if dataset_fp is None: 150 | dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") 151 | 152 | dataset_fp.write(result) 153 | written_size += len(result) 154 | 155 | if written_size > shard_size * 1024 * 1024: 156 | logger.info(f"Finished writing {tar_idx} shards to {output}") 157 | dataset_fp.close() 158 | dataset_fp = None 159 | written_size = 0 160 | tar_idx += 1 161 | 162 | if dataset_fp is not None: 163 | dataset_fp.close() 164 | 165 | logger.info(f"Finished writing {tar_idx + 1} shards to {output}") 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /tools/llama/merge_lora.py: -------------------------------------------------------------------------------- 1 | import click 2 | import hydra 3 | import torch 4 | from hydra import compose, initialize 5 | from hydra.utils import instantiate 6 | from loguru import logger 7 | 8 | from fish_speech.models.text2semantic.lora_utils import ( 9 | get_merged_state_dict, 10 | setup_lora, 11 | ) 12 | 13 | 14 | @click.command() 15 | @click.option("--llama-config", type=str, default="dual_ar_2_codebook_large") 16 | @click.option("--lora-config", type=str, default="r_8_alpha_16") 17 | @click.option( 18 | "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth" 19 | ) 20 | @click.option("--lora-weight", type=str, required=True) 21 | @click.option("--output", type=str, required=True) 22 | def merge(llama_config, lora_config, llama_weight, lora_weight, output): 23 | logger.info( 24 | f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}" 25 | ) 26 | 27 | hydra.core.global_hydra.GlobalHydra.instance().clear() 28 | with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"): 29 | # The max_seq_len here doesn't matter. 30 | cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"]) 31 | 32 | llama_model = instantiate(cfg) 33 | logger.info(f"Loaded llama model with config {llama_config}") 34 | 35 | hydra.core.global_hydra.GlobalHydra.instance().clear() 36 | with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): 37 | cfg = compose(config_name=lora_config) 38 | 39 | lora_config = instantiate(cfg) 40 | logger.info(f"Loaded lora model with config {lora_config}") 41 | 42 | setup_lora(llama_model, lora_config) 43 | logger.info(f"Merged model setup complete") 44 | 45 | llama_state_dict = torch.load(llama_weight, map_location="cpu") 46 | lora_state_dict = torch.load(lora_weight, map_location="cpu") 47 | 48 | if "state_dict" in llama_state_dict: 49 | llama_state_dict = llama_state_dict["state_dict"] 50 | 51 | if "state_dict" in lora_state_dict: 52 | lora_state_dict = lora_state_dict["state_dict"] 53 | 54 | # remove prefix model. 55 | llama_state_dict = { 56 | k.replace("model.", ""): v 57 | for k, v in llama_state_dict.items() 58 | if k.startswith("model.") 59 | } 60 | lora_state_dict = { 61 | k.replace("model.", ""): v 62 | for k, v in lora_state_dict.items() 63 | if k.startswith("model.") 64 | } 65 | 66 | logger.info(f"Found {len(llama_state_dict)} keys in llama model") 67 | logger.info(f"Found {len(lora_state_dict)} keys in lora model") 68 | 69 | merged_state_dict = llama_state_dict | lora_state_dict 70 | llama_model.load_state_dict(merged_state_dict, strict=True) 71 | logger.info(f"Merged model loaded") 72 | 73 | state_dict = get_merged_state_dict(llama_model) 74 | torch.save(state_dict, output) 75 | logger.info(f"Merged model saved to {output}") 76 | 77 | 78 | if __name__ == "__main__": 79 | merge() 80 | -------------------------------------------------------------------------------- /tools/llama/rebuild_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers 2 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 3 | 4 | # Initialize a tokenizer 5 | tokenizer = Tokenizer(models.BPE()) 6 | 7 | # Customize pre-tokenization and decoding 8 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 9 | tokenizer.decoder = decoders.ByteLevel() 10 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) 11 | 12 | # Don't train the tokenizer 13 | trainer = trainers.BpeTrainer( 14 | vocab_size=0, 15 | min_frequency=2, 16 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), 17 | special_tokens=[ 18 | "<|begin_of_sequence|>", 19 | "<|end_of_sequence|>", 20 | "<|im_start|>", 21 | "<|im_sep|>", # system, user, assistant, etc. 22 | "<|im_end|>", 23 | "<|semantic|>", # audio features 24 | "<|pad|>", 25 | ], 26 | ) 27 | 28 | # <|im_start|>user<|im_sep|>...<|im_end|> 29 | # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|> 30 | tokenizer.train_from_iterator([], trainer=trainer) 31 | 32 | print(len(tokenizer.get_vocab())) 33 | x = tokenizer.encode( 34 | "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>" 35 | ).ids 36 | print(x, len(x)) 37 | print(tokenizer.decode(x, skip_special_tokens=True)) 38 | 39 | 40 | tokenizer = PreTrainedTokenizerFast( 41 | tokenizer_object=tokenizer, 42 | pad_token="<|pad|>", 43 | bos_token="<|begin_of_sequence|>", 44 | eos_token="<|end_of_sequence|>", 45 | ) 46 | 47 | # Try tokenizing a new sequence 48 | sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>" 49 | encoded = tokenizer(sequence).input_ids 50 | 51 | print("Test encoding....") 52 | print(f"\tSentence: {sequence}") 53 | print(f"\tEncoded: {encoded}") 54 | print(f"\tDecoded: {tokenizer.batch_decode(encoded)}") 55 | print(f"\tDecoded: {tokenizer.decode(encoded)}") 56 | 57 | tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True) 58 | -------------------------------------------------------------------------------- /tools/merge_asr_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from pydub import AudioSegment 5 | from tqdm import tqdm 6 | 7 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files 8 | 9 | 10 | def merge_and_delete_files(save_dir, original_files): 11 | save_path = Path(save_dir) 12 | audio_slice_files = list_files( 13 | path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True 14 | ) 15 | audio_files = {} 16 | label_files = {} 17 | for file_path in tqdm(audio_slice_files, desc="Merging audio files"): 18 | rel_path = Path(file_path).relative_to(save_path) 19 | (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) 20 | if file_path.suffix == ".wav": 21 | prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0] 22 | if prefix == rel_path.parent / file_path.stem: 23 | continue 24 | audio = AudioSegment.from_wav(file_path) 25 | if prefix in audio_files.keys(): 26 | audio_files[prefix] = audio_files[prefix] + audio 27 | else: 28 | audio_files[prefix] = audio 29 | 30 | elif file_path.suffix == ".lab": 31 | prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0] 32 | if prefix == rel_path.parent / file_path.stem: 33 | continue 34 | with open(file_path, "r", encoding="utf-8") as f: 35 | label = f.read() 36 | if prefix in label_files.keys(): 37 | label_files[prefix] = label_files[prefix] + ", " + label 38 | else: 39 | label_files[prefix] = label 40 | 41 | for prefix, audio in audio_files.items(): 42 | output_audio_path = save_path / f"{prefix}.wav" 43 | audio.export(output_audio_path, format="wav") 44 | 45 | for prefix, label in label_files.items(): 46 | output_label_path = save_path / f"{prefix}.lab" 47 | with open(output_label_path, "w", encoding="utf-8") as f: 48 | f.write(label) 49 | 50 | for file_path in original_files: 51 | os.remove(file_path) 52 | 53 | 54 | if __name__ == "__main__": 55 | merge_and_delete_files("/made/by/spicysama/laziman", [__file__]) 56 | -------------------------------------------------------------------------------- /tools/vqgan/create_train_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | from random import Random 4 | 5 | import click 6 | from loguru import logger 7 | from tqdm import tqdm 8 | 9 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist 10 | 11 | 12 | @click.command() 13 | @click.argument("root", type=click.Path(exists=True, path_type=Path)) 14 | @click.option("--val-ratio", type=float, default=None) 15 | @click.option("--val-count", type=int, default=None) 16 | @click.option("--filelist", default=None, type=Path) 17 | def main(root, val_ratio, val_count, filelist): 18 | if filelist: 19 | files = [i[0] for i in load_filelist(filelist)] 20 | else: 21 | files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) 22 | 23 | logger.info(f"Found {len(files)} files") 24 | files = [str(file.relative_to(root)) for file in tqdm(files)] 25 | 26 | Random(42).shuffle(files) 27 | 28 | if val_count is None and val_ratio is None: 29 | logger.info("Validation ratio and count not specified, using min(20%, 100)") 30 | val_size = min(100, math.ceil(len(files) * 0.2)) 31 | elif val_count is not None and val_ratio is not None: 32 | logger.error("Cannot specify both val_count and val_ratio") 33 | return 34 | elif val_count is not None: 35 | if val_count < 1 or val_count > len(files): 36 | logger.error("val_count must be between 1 and number of files") 37 | return 38 | val_size = val_count 39 | else: 40 | val_size = math.ceil(len(files) * val_ratio) 41 | 42 | logger.info(f"Using {val_size} files for validation") 43 | 44 | with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: 45 | f.write("\n".join(files[val_size:])) 46 | 47 | with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: 48 | f.write("\n".join(files[:val_size])) 49 | 50 | logger.info("Done") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /tools/vqgan/extract_vq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess as sp 3 | import sys 4 | import time 5 | from datetime import timedelta 6 | from functools import lru_cache 7 | from pathlib import Path 8 | from random import Random 9 | 10 | import click 11 | import numpy as np 12 | import torch 13 | import torchaudio 14 | from hydra import compose, initialize 15 | from hydra.utils import instantiate 16 | from lightning import LightningModule 17 | from loguru import logger 18 | from omegaconf import OmegaConf 19 | 20 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist 21 | 22 | # register eval resolver 23 | OmegaConf.register_new_resolver("eval", eval) 24 | # This file is used to convert the audio files to text files using the Whisper model. 25 | # It's mainly used to generate the training data for the VQ model. 26 | 27 | 28 | RANK = int(os.environ.get("SLURM_PROCID", 0)) 29 | WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1)) 30 | 31 | logger_format = ( 32 | "{time:YYYY-MM-DD HH:mm:ss.SSS} | " 33 | "{level: <8} | " 34 | "{name}:{function}:{line} | " 35 | "{extra[rank]} - {message}" 36 | ) 37 | logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"}) 38 | logger.remove() 39 | logger.add(sys.stderr, format=logger_format) 40 | 41 | 42 | @lru_cache(maxsize=1) 43 | def get_model( 44 | config_name: str = "vqgan_pretrain", 45 | checkpoint_path: str = "checkpoints/vq-gan-group-fsq-2x1024.pth", 46 | ): 47 | with initialize(version_base="1.3", config_path="../../fish_speech/configs"): 48 | cfg = compose(config_name=config_name) 49 | 50 | model: LightningModule = instantiate(cfg.model) 51 | state_dict = torch.load( 52 | checkpoint_path, 53 | map_location=model.device, 54 | ) 55 | if "state_dict" in state_dict: 56 | state_dict = state_dict["state_dict"] 57 | 58 | model.load_state_dict(state_dict, strict=False) 59 | model.eval() 60 | model.cuda() 61 | 62 | logger.info(f"Loaded model") 63 | return model 64 | 65 | 66 | @torch.inference_mode() 67 | def process_batch(files: list[Path], model) -> float: 68 | wavs = [] 69 | audio_lengths = [] 70 | new_files = [] 71 | max_length = total_time = 0 72 | 73 | for file in files: 74 | try: 75 | wav, sr = torchaudio.load( 76 | str(file), backend="sox" if sys.platform == "linux" else "soundfile" 77 | ) # Need to install libsox-dev 78 | except Exception as e: 79 | logger.error(f"Error reading {file}: {e}") 80 | continue 81 | 82 | if wav.shape[0] > 1: 83 | wav = wav.mean(dim=0, keepdim=True) 84 | 85 | wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0] 86 | total_time += len(wav) / model.sampling_rate 87 | max_length = max(max_length, len(wav)) 88 | 89 | wavs.append(wav) 90 | audio_lengths.append(len(wav)) 91 | new_files.append(file) 92 | 93 | files = new_files 94 | 95 | # Pad to max length 96 | for i, wav in enumerate(wavs): 97 | wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant") 98 | 99 | audios = torch.stack(wavs, dim=0)[:, None] 100 | audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long) 101 | 102 | # Calculate lengths 103 | indices, feature_lengths = model.encode(audios, audio_lengths) 104 | 105 | # Save to disk 106 | outputs = indices.cpu().numpy() 107 | 108 | for file, length, feature, audio_length in zip( 109 | files, feature_lengths, outputs, audio_lengths 110 | ): 111 | feature = feature[:, :length] 112 | 113 | # (T,) 114 | with open(file.with_suffix(".npy"), "wb") as f: 115 | np.save(f, feature) 116 | 117 | return total_time 118 | 119 | 120 | @click.command() 121 | @click.argument("folder") 122 | @click.option("--num-workers", default=1) 123 | @click.option("--config-name", default="vqgan_pretrain") 124 | @click.option( 125 | "--checkpoint-path", 126 | default="checkpoints/vq-gan-group-fsq-2x1024.pth", 127 | ) 128 | @click.option("--batch-size", default=64) 129 | @click.option("--filelist", default=None, type=Path) 130 | def main( 131 | folder: str, 132 | num_workers: int, 133 | config_name: str, 134 | checkpoint_path: str, 135 | batch_size: int, 136 | filelist: Path, 137 | ): 138 | if num_workers > 1 and WORLD_SIZE != num_workers: 139 | assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" 140 | 141 | logger.info(f"Spawning {num_workers} workers") 142 | 143 | if torch.cuda.is_available(): 144 | visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) 145 | if visible_devices is None: 146 | visible_devices = list(range(torch.cuda.device_count())) 147 | else: 148 | visible_devices = visible_devices.split(",") 149 | else: 150 | # Set to empty string to avoid using GPU 151 | visible_devices = [""] 152 | 153 | processes = [] 154 | for i in range(num_workers): 155 | env = os.environ.copy() 156 | env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) 157 | env["SLURM_PROCID"] = str(i) 158 | env["SLURM_NTASKS"] = str(num_workers) 159 | 160 | processes.append( 161 | sp.Popen( 162 | [sys.executable] + sys.argv.copy(), 163 | env=env, 164 | ) 165 | ) 166 | 167 | for p in processes: 168 | p.wait() 169 | 170 | logger.info(f"All workers finished") 171 | return 172 | 173 | # This is a worker 174 | logger.info(f"Starting worker") 175 | if filelist: 176 | files = [i[0] for i in load_filelist(filelist)] 177 | else: 178 | files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False) 179 | 180 | print(f"Found {len(files)} files") 181 | files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()] 182 | 183 | total_files = len(files) 184 | files = files[RANK::WORLD_SIZE] 185 | logger.info(f"Processing {len(files)}/{total_files} files") 186 | 187 | # Batch processing 188 | total_time = 0 189 | begin_time = time.time() 190 | processed_files = 0 191 | model = get_model(config_name, checkpoint_path) 192 | 193 | for n_batch, idx in enumerate(range(0, len(files), batch_size)): 194 | batch = files[idx : idx + batch_size] 195 | batch_time = process_batch(batch, model) 196 | 197 | total_time += batch_time 198 | processed_files += len(batch) 199 | 200 | if (n_batch + 1) % 10 == 0: 201 | eta = ( 202 | (time.time() - begin_time) 203 | / processed_files 204 | * (len(files) - processed_files) 205 | ) 206 | logger.info( 207 | f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, " 208 | + f"ETA: {timedelta(seconds=round(eta))}s" 209 | ) 210 | 211 | logger.info( 212 | f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" 213 | ) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /tools/vqgan/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | import hydra 5 | import numpy as np 6 | import soundfile as sf 7 | import torch 8 | import torchaudio 9 | from hydra import compose, initialize 10 | from hydra.utils import instantiate 11 | from lightning import LightningModule 12 | from loguru import logger 13 | from omegaconf import OmegaConf 14 | 15 | from fish_speech.utils.file import AUDIO_EXTENSIONS 16 | 17 | # register eval resolver 18 | OmegaConf.register_new_resolver("eval", eval) 19 | 20 | 21 | def load_model(config_name, checkpoint_path, device="cuda"): 22 | hydra.core.global_hydra.GlobalHydra.instance().clear() 23 | with initialize(version_base="1.3", config_path="../../fish_speech/configs"): 24 | cfg = compose(config_name=config_name) 25 | 26 | model: LightningModule = instantiate(cfg.model) 27 | state_dict = torch.load( 28 | checkpoint_path, 29 | map_location=model.device, 30 | ) 31 | 32 | if "state_dict" in state_dict: 33 | state_dict = state_dict["state_dict"] 34 | 35 | model.load_state_dict(state_dict, strict=False) 36 | model.eval() 37 | model.to(device) 38 | logger.info("Restored model from checkpoint") 39 | 40 | return model 41 | 42 | 43 | @torch.no_grad() 44 | @click.command() 45 | @click.option( 46 | "--input-path", 47 | "-i", 48 | default="test.wav", 49 | type=click.Path(exists=True, path_type=Path), 50 | ) 51 | @click.option( 52 | "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) 53 | ) 54 | @click.option("--config-name", "-cfg", default="vqgan_pretrain") 55 | @click.option( 56 | "--checkpoint-path", 57 | "-ckpt", 58 | default="checkpoints/vq-gan-group-fsq-2x1024.pth", 59 | ) 60 | @click.option( 61 | "--device", 62 | "-d", 63 | default="cuda", 64 | ) 65 | def main(input_path, output_path, config_name, checkpoint_path, device): 66 | model = load_model(config_name, checkpoint_path, device=device) 67 | 68 | if input_path.suffix.lower() in AUDIO_EXTENSIONS: 69 | logger.info(f"Processing in-place reconstruction of {input_path}") 70 | 71 | # Load audio 72 | audio, sr = torchaudio.load(input_path) 73 | if audio.shape[0] > 1: 74 | audio = audio.mean(0, keepdim=True) 75 | audio = torchaudio.functional.resample(audio, sr, model.sampling_rate) 76 | 77 | audios = audio[None].to(model.device) 78 | logger.info( 79 | f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds" 80 | ) 81 | 82 | # VQ Encoder 83 | audio_lengths = torch.tensor( 84 | [audios.shape[2]], device=model.device, dtype=torch.long 85 | ) 86 | indices = model.encode(audios, audio_lengths)[0][0] 87 | 88 | logger.info(f"Generated indices of shape {indices.shape}") 89 | 90 | # Save indices 91 | np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) 92 | elif input_path.suffix == ".npy": 93 | logger.info(f"Processing precomputed indices from {input_path}") 94 | indices = np.load(input_path) 95 | indices = torch.from_numpy(indices).to(model.device).long() 96 | assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" 97 | else: 98 | raise ValueError(f"Unknown input type: {input_path}") 99 | 100 | # Restore 101 | feature_lengths = torch.tensor([indices.shape[1]], device=model.device) 102 | fake_audios = model.decode( 103 | indices=indices[None], feature_lengths=feature_lengths, return_audios=True 104 | ) 105 | audio_time = fake_audios.shape[-1] / model.sampling_rate 106 | 107 | logger.info( 108 | f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" 109 | ) 110 | 111 | # Save audio 112 | fake_audio = fake_audios[0, 0].float().cpu().numpy() 113 | sf.write(output_path, fake_audio, model.sampling_rate) 114 | logger.info(f"Saved audio to {output_path}") 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /tools/whisper_asr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to transcribe all audio files in one folder into another folder. 3 | e.g. 4 | Directory structure: 5 | --pre_data_root 6 | ----SP_1 7 | ------01.wav 8 | ------02.wav 9 | ------...... 10 | ----SP_2 11 | ------01.wav 12 | ------02.wav 13 | ------...... 14 | Use 15 | python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1 16 | to transcribe the first speaker. 17 | 18 | Use 19 | python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2 20 | to transcribe the second speaker. 21 | 22 | Note: Be aware of your audio sample rate, which defaults to 44.1kHz. 23 | """ 24 | from pathlib import Path 25 | 26 | import click 27 | import librosa 28 | import soundfile as sf 29 | import whisper 30 | from loguru import logger 31 | from merge_asr_files import merge_and_delete_files 32 | from tqdm import tqdm 33 | 34 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files 35 | 36 | 37 | @click.command() 38 | @click.option("--model-size", default="large", help="Size of the Whisper model") 39 | @click.option("--audio-dir", required=True, help="Directory containing audio files") 40 | @click.option( 41 | "--save-dir", required=True, help="Directory to save processed audio files" 42 | ) 43 | @click.option( 44 | "--sample-rate", 45 | default=None, 46 | type=int, 47 | help="Output sample rate, default to input sample rate", 48 | ) 49 | @click.option("--device", default="cuda", help="Device to use") 50 | @click.option("--language", default="ZH", help="Language of the transcription") 51 | def main(model_size, audio_dir, save_dir, sample_rate, device, language): 52 | logger.info("Loading / Downloading OpenAI Whisper model...") 53 | model = whisper.load_model( 54 | name=model_size, 55 | device=device, 56 | download_root=str(Path(".cache/whisper").resolve()), 57 | ) 58 | logger.info("Model loaded.") 59 | 60 | save_path = Path(save_dir) 61 | save_path.mkdir(parents=True, exist_ok=True) 62 | original_files = [] 63 | audio_files = list_files( 64 | path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True 65 | ) 66 | for file_path in tqdm(audio_files, desc="Processing audio file"): 67 | file_stem = file_path.stem 68 | file_suffix = file_path.suffix 69 | 70 | rel_path = Path(file_path).relative_to(audio_dir) 71 | (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) 72 | 73 | if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and ( 74 | save_path / rel_path.parent / f"{rel_path.stem}.lab" 75 | ).exists(): 76 | continue 77 | 78 | audio, sr = librosa.load(file_path, sr=sample_rate, mono=False) 79 | transcription = model.transcribe(str(file_path), language=language) 80 | 81 | for segment in transcription.get("segments", []): 82 | id, text, start, end = ( 83 | segment["id"], 84 | segment["text"], 85 | segment["start"], 86 | segment["end"], 87 | ) 88 | 89 | extract = audio[..., int(start * sr) : int(end * sr)] 90 | audio_save_path = ( 91 | save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}" 92 | ) 93 | sf.write( 94 | audio_save_path, 95 | extract, 96 | samplerate=sr, 97 | ) 98 | original_files.append(audio_save_path) 99 | 100 | transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab" 101 | with open( 102 | transcript_save_path, 103 | "w", 104 | encoding="utf-8", 105 | ) as f: 106 | f.write(text) 107 | original_files.append(transcript_save_path) 108 | 109 | merge_and_delete_files(save_dir, original_files) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /web/js/previewAudio.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | 4 | function fitHeight(node) { 5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 6 | node?.graph?.setDirtyCanvas(true); 7 | } 8 | function chainCallback(object, property, callback) { 9 | if (object == undefined) { 10 | //This should not happen. 11 | console.error("Tried to add callback to non-existant object") 12 | return; 13 | } 14 | if (property in object) { 15 | const callback_orig = object[property] 16 | object[property] = function () { 17 | const r = callback_orig.apply(this, arguments); 18 | callback.apply(this, arguments); 19 | return r 20 | }; 21 | } else { 22 | object[property] = callback; 23 | } 24 | } 25 | 26 | function addPreviewOptions(nodeType) { 27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) { 28 | // The intended way of appending options is returning a list of extra options, 29 | // but this isn't used in widgetInputs.js and would require 30 | // less generalization of chainCallback 31 | let optNew = [] 32 | try { 33 | const previewWidget = this.widgets.find((w) => w.name === "audiopreview"); 34 | 35 | let url = null 36 | if (previewWidget.audioEl?.hidden == false && previewWidget.audioEl.src) { 37 | //Use full quality audio 38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params)); 39 | url = previewWidget.audioEl.src 40 | } 41 | if (url) { 42 | optNew.push( 43 | { 44 | content: "Open preview", 45 | callback: () => { 46 | window.open(url, "_blank") 47 | }, 48 | }, 49 | { 50 | content: "Save preview", 51 | callback: () => { 52 | const a = document.createElement("a"); 53 | a.href = url; 54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename")); 55 | document.body.append(a); 56 | a.click(); 57 | requestAnimationFrame(() => a.remove()); 58 | }, 59 | } 60 | ); 61 | } 62 | if(options.length > 0 && options[0] != null && optNew.length > 0) { 63 | optNew.push(null); 64 | } 65 | options.unshift(...optNew); 66 | 67 | } catch (error) { 68 | console.log(error); 69 | } 70 | 71 | }); 72 | } 73 | function previewAudio(node,file,type){ 74 | var element = document.createElement("div"); 75 | const previewNode = node; 76 | var previewWidget = node.addDOMWidget("audiopreview", "preview", element, { 77 | serialize: false, 78 | hideOnZoom: false, 79 | getValue() { 80 | return element.value; 81 | }, 82 | setValue(v) { 83 | element.value = v; 84 | }, 85 | }); 86 | previewWidget.computeSize = function(width) { 87 | if (this.aspectRatio && !this.parentEl.hidden) { 88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 89 | if (!(height > 0)) { 90 | height = 0; 91 | } 92 | this.computedHeight = height + 10; 93 | return [width, height]; 94 | } 95 | return [width, -4];//no loaded src, widget should not display 96 | } 97 | // element.style['pointer-events'] = "none" 98 | previewWidget.value = {hidden: false, paused: false, params: {}} 99 | previewWidget.parentEl = document.createElement("div"); 100 | previewWidget.parentEl.className = "audio_preview"; 101 | previewWidget.parentEl.style['width'] = "100%" 102 | element.appendChild(previewWidget.parentEl); 103 | previewWidget.audioEl = document.createElement("audio"); 104 | previewWidget.audioEl.controls = true; 105 | previewWidget.audioEl.loop = false; 106 | previewWidget.audioEl.muted = false; 107 | previewWidget.audioEl.style['width'] = "100%" 108 | previewWidget.audioEl.addEventListener("loadedmetadata", () => { 109 | 110 | previewWidget.aspectRatio = previewWidget.audioEl.audioWidth / previewWidget.audioEl.audioHeight; 111 | fitHeight(this); 112 | }); 113 | previewWidget.audioEl.addEventListener("error", () => { 114 | //TODO: consider a way to properly notify the user why a preview isn't shown. 115 | previewWidget.parentEl.hidden = true; 116 | fitHeight(this); 117 | }); 118 | 119 | let params = { 120 | "filename": file, 121 | "type": type, 122 | } 123 | 124 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 125 | previewWidget.audioEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 126 | let target_width = 256 127 | if (element.style?.width) { 128 | //overscale to allow scrolling. Endpoint won't return higher than native 129 | target_width = element.style.width.slice(0,-2)*2; 130 | } 131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 132 | params.force_size = target_width+"x?" 133 | } else { 134 | let size = params.force_size.split("x") 135 | let ar = parseInt(size[0])/parseInt(size[1]) 136 | params.force_size = target_width+"x"+(target_width/ar) 137 | } 138 | 139 | previewWidget.audioEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 140 | 141 | previewWidget.audioEl.hidden = false; 142 | previewWidget.parentEl.appendChild(previewWidget.audioEl) 143 | } 144 | 145 | app.registerExtension({ 146 | name: "FishSpeech.AudioPreviewer", 147 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 148 | if (nodeData?.name == "PreViewAudio") { 149 | nodeType.prototype.onExecuted = function (data) { 150 | previewAudio(this, data.audio[0], data.audio[1]); 151 | } 152 | addPreviewOptions(nodeType) 153 | } 154 | } 155 | }); -------------------------------------------------------------------------------- /web/js/uploadAudio.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function fitHeight(node) { 6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 7 | node?.graph?.setDirtyCanvas(true); 8 | } 9 | 10 | function previewAudio(node,file){ 11 | while (node.widgets.length > 2){ 12 | node.widgets.pop(); 13 | } 14 | try { 15 | var el = document.getElementById("uploadAudio"); 16 | el.remove(); 17 | } catch (error) { 18 | console.log(error); 19 | } 20 | var element = document.createElement("div"); 21 | element.id = "uploadAudio"; 22 | const previewNode = node; 23 | var previewWidget = node.addDOMWidget("audiopreview", "preview", element, { 24 | serialize: false, 25 | hideOnZoom: false, 26 | getValue() { 27 | return element.value; 28 | }, 29 | setValue(v) { 30 | element.value = v; 31 | }, 32 | }); 33 | previewWidget.computeSize = function(width) { 34 | if (this.aspectRatio && !this.parentEl.hidden) { 35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 36 | if (!(height > 0)) { 37 | height = 0; 38 | } 39 | this.computedHeight = height + 10; 40 | return [width, height]; 41 | } 42 | return [width, -4];//no loaded src, widget should not display 43 | } 44 | // element.style['pointer-events'] = "none" 45 | previewWidget.value = {hidden: false, paused: false, params: {}} 46 | previewWidget.parentEl = document.createElement("div"); 47 | previewWidget.parentEl.className = "audio_preview"; 48 | previewWidget.parentEl.style['width'] = "100%" 49 | element.appendChild(previewWidget.parentEl); 50 | previewWidget.audioEl = document.createElement("audio"); 51 | previewWidget.audioEl.controls = true; 52 | previewWidget.audioEl.loop = false; 53 | previewWidget.audioEl.muted = false; 54 | previewWidget.audioEl.style['width'] = "100%" 55 | previewWidget.audioEl.addEventListener("loadedmetadata", () => { 56 | 57 | previewWidget.aspectRatio = previewWidget.audioEl.audioWidth / previewWidget.audioEl.audioHeight; 58 | fitHeight(this); 59 | }); 60 | previewWidget.audioEl.addEventListener("error", () => { 61 | //TODO: consider a way to properly notify the user why a preview isn't shown. 62 | previewWidget.parentEl.hidden = true; 63 | fitHeight(this); 64 | }); 65 | 66 | let params = { 67 | "filename": file, 68 | "type": "input", 69 | } 70 | 71 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 72 | previewWidget.audioEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 73 | let target_width = 256 74 | if (element.style?.width) { 75 | //overscale to allow scrolling. Endpoint won't return higher than native 76 | target_width = element.style.width.slice(0,-2)*2; 77 | } 78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 79 | params.force_size = target_width+"x?" 80 | } else { 81 | let size = params.force_size.split("x") 82 | let ar = parseInt(size[0])/parseInt(size[1]) 83 | params.force_size = target_width+"x"+(target_width/ar) 84 | } 85 | 86 | previewWidget.audioEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 87 | 88 | previewWidget.audioEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.audioEl) 90 | } 91 | 92 | function audioUpload(node, inputName, inputData, app) { 93 | const audioWidget = node.widgets.find((w) => w.name === "audio"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = audioWidget.value; 99 | Object.defineProperty(audioWidget, "value", { 100 | set : function(value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get : function() { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if(real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | } 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!audioWidget.options.values.includes(path)) { 145 | audioWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | audioWidget.value = path; 150 | previewAudio(node,path) 151 | 152 | } 153 | } else { 154 | alert(resp.status + " - " + resp.statusText); 155 | } 156 | } catch (error) { 157 | alert(error); 158 | } 159 | } 160 | 161 | const fileInput = document.createElement("input"); 162 | Object.assign(fileInput, { 163 | type: "file", 164 | accept: "audio/mp3,audio/wav,audio/flac,audio/m4a", 165 | style: "display: none", 166 | onchange: async () => { 167 | if (fileInput.files.length) { 168 | await uploadFile(fileInput.files[0], true); 169 | } 170 | }, 171 | }); 172 | document.body.append(fileInput); 173 | 174 | // Create the button widget for selecting the files 175 | uploadWidget = node.addWidget("button", "choose audio file to upload", "Audio", () => { 176 | fileInput.click(); 177 | }); 178 | 179 | uploadWidget.serialize = false; 180 | 181 | previewAudio(node, audioWidget.value); 182 | const cb = node.callback; 183 | audioWidget.callback = function () { 184 | previewAudio(node,audioWidget.value); 185 | if (cb) { 186 | return cb.apply(this, arguments); 187 | } 188 | }; 189 | 190 | return { widget: uploadWidget }; 191 | } 192 | 193 | ComfyWidgets.AUDIOPLOAD = audioUpload; 194 | 195 | app.registerExtension({ 196 | name: "FishSpeech.UploadAudio", 197 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 198 | if (nodeData?.name == "LoadAudio") { 199 | nodeData.input.required.upload = ["AUDIOPLOAD"]; 200 | } 201 | }, 202 | }); 203 | 204 | -------------------------------------------------------------------------------- /web/js/uploadSRT.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function srtUpload(node, inputName, inputData, app) { 6 | const srtWidget = node.widgets.find((w) => w.name === "srt"); 7 | let uploadWidget; 8 | /* 9 | A method that returns the required style for the html 10 | */ 11 | var default_value = srtWidget.value; 12 | Object.defineProperty(srtWidget, "value", { 13 | set : function(value) { 14 | this._real_value = value; 15 | }, 16 | 17 | get : function() { 18 | let value = ""; 19 | if (this._real_value) { 20 | value = this._real_value; 21 | } else { 22 | return default_value; 23 | } 24 | 25 | if (value.filename) { 26 | let real_value = value; 27 | value = ""; 28 | if (real_value.subfolder) { 29 | value = real_value.subfolder + "/"; 30 | } 31 | 32 | value += real_value.filename; 33 | 34 | if(real_value.type && real_value.type !== "input") 35 | value += ` [${real_value.type}]`; 36 | } 37 | return value; 38 | } 39 | }); 40 | async function uploadFile(file, updateNode, pasted = false) { 41 | try { 42 | // Wrap file in formdata so it includes filename 43 | const body = new FormData(); 44 | body.append("image", file); 45 | if (pasted) body.append("subfolder", "pasted"); 46 | const resp = await api.fetchApi("/upload/image", { 47 | method: "POST", 48 | body, 49 | }); 50 | 51 | if (resp.status === 200) { 52 | const data = await resp.json(); 53 | // Add the file to the dropdown list and update the widget value 54 | let path = data.name; 55 | if (data.subfolder) path = data.subfolder + "/" + path; 56 | 57 | if (!srtWidget.options.values.includes(path)) { 58 | srtWidget.options.values.push(path); 59 | } 60 | 61 | if (updateNode) { 62 | srtWidget.value = path; 63 | } 64 | } else { 65 | alert(resp.status + " - " + resp.statusText); 66 | } 67 | } catch (error) { 68 | alert(error); 69 | } 70 | } 71 | 72 | const fileInput = document.createElement("input"); 73 | Object.assign(fileInput, { 74 | type: "file", 75 | accept: "file/srt,file/txt", 76 | style: "display: none", 77 | onchange: async () => { 78 | if (fileInput.files.length) { 79 | await uploadFile(fileInput.files[0], true); 80 | } 81 | }, 82 | }); 83 | document.body.append(fileInput); 84 | 85 | // Create the button widget for selecting the files 86 | uploadWidget = node.addWidget("button", "choose srt file to upload", "Audio", () => { 87 | fileInput.click(); 88 | }); 89 | 90 | uploadWidget.serialize = false; 91 | 92 | const cb = node.callback; 93 | srtWidget.callback = function () { 94 | if (cb) { 95 | return cb.apply(this, arguments); 96 | } 97 | }; 98 | 99 | return { widget: uploadWidget }; 100 | } 101 | 102 | ComfyWidgets.SRTPLOAD = srtUpload; 103 | 104 | app.registerExtension({ 105 | name: "FishSpeech.UploadSRT", 106 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 107 | if (nodeData?.name == "LoadSRT") { 108 | nodeData.input.required.upload = ["SRTPLOAD"]; 109 | } 110 | }, 111 | }); 112 | 113 | --------------------------------------------------------------------------------