├── vocos ├── __init__.py ├── __pycache__ │ ├── loss.cpython-39.pyc │ ├── heads.cpython-39.pyc │ ├── models.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── dataset.cpython-39.pyc │ ├── helpers.cpython-39.pyc │ ├── modules.cpython-39.pyc │ ├── experiment.cpython-39.pyc │ ├── pretrained.cpython-39.pyc │ ├── discriminators.cpython-39.pyc │ ├── spectral_ops.cpython-39.pyc │ └── feature_extractors.cpython-39.pyc ├── helpers.py ├── dataset.py ├── feature_extractors.py ├── models.py ├── loss.py ├── pretrained.py ├── modules.py ├── spectral_ops.py ├── heads.py ├── discriminators.py └── experiment.py ├── .gitignore ├── wavenext_architecture.png ├── requirements.txt ├── metrics ├── __pycache__ │ ├── UTMOS.cpython-39.pyc │ └── periodicity.cpython-39.pyc ├── periodicity.py └── UTMOS.py ├── requirements-train.txt ├── train.py ├── LICENSE ├── setup.py ├── inference.py ├── configs ├── vocos-resnet.yaml ├── vocos-imdct.yaml ├── vocos-encodec.yaml ├── vocos.yaml ├── wavenext_export.yaml ├── wavenext-encodec.yaml ├── vocos-matcha.yaml └── wavenext.yaml ├── README.md ├── export_onnx.py ├── README_vocos.md └── notebooks └── Bark+Vocos.ipynb /vocos/__init__.py: -------------------------------------------------------------------------------- 1 | from vocos.pretrained import Vocos 2 | 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | SLURM 2 | filelists 3 | configs/*_local.yaml 4 | *.pth 5 | *.pt 6 | *.ckpt 7 | *.pyc 8 | -------------------------------------------------------------------------------- /wavenext_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/wavenext_architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | numpy 4 | scipy 5 | einops 6 | pyyaml 7 | huggingface_hub 8 | encodec==0.1.1 -------------------------------------------------------------------------------- /vocos/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/heads.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/heads.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/UTMOS.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/metrics/__pycache__/UTMOS.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /requirements-train.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.8.6 2 | jsonargparse[signatures] 3 | transformers 4 | matplotlib 5 | torchcrepe 6 | pesq 7 | fairseq 8 | -------------------------------------------------------------------------------- /vocos/__pycache__/experiment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/experiment.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/pretrained.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/pretrained.cpython-39.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/periodicity.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/metrics/__pycache__/periodicity.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/discriminators.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/discriminators.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/spectral_ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/spectral_ops.cpython-39.pyc -------------------------------------------------------------------------------- /vocos/__pycache__/feature_extractors.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/wavenext_pytorch/HEAD/vocos/__pycache__/feature_extractors.cpython-39.pyc -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI 2 | 3 | 4 | if __name__ == "__main__": 5 | cli = LightningCLI(run=False) 6 | cli.trainer.fit(model=cli.model, datamodule=cli.datamodule) 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Charactr Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | for line in open("vocos/__init__.py"): 7 | line = line.strip() 8 | if "__version__" in line: 9 | context = {} 10 | exec(line, context) 11 | VERSION = context["__version__"] 12 | 13 | 14 | def read(*paths, **kwargs): 15 | content = "" 16 | with io.open( 17 | os.path.join(os.path.dirname(__file__), *paths), encoding=kwargs.get("encoding", "utf8"), 18 | ) as open_file: 19 | content = open_file.read().strip() 20 | return content 21 | 22 | 23 | def read_requirements(path): 24 | return [line.strip() for line in read(path).split("\n") if not line.startswith(('"', "#", "-", "git+"))] 25 | 26 | 27 | setup( 28 | name="vocos", 29 | version=VERSION, 30 | author="Hubert Siuzdak", 31 | author_email="huberts@charactr.com", 32 | description="Fourier-based neural vocoder for high-quality audio synthesis", 33 | url="https://github.com/charactr-platform/vocos", 34 | long_description=read("README.md"), 35 | long_description_content_type="text/markdown", 36 | packages=find_packages(), 37 | install_requires=read_requirements("requirements.txt"), 38 | extras_require={"train": read_requirements("requirements-train.txt")}, 39 | ) 40 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from glob import glob 4 | import numpy as np 5 | import os 6 | import torchaudio.functional as F 7 | from vocos import Vocos 8 | 9 | import argparse 10 | 11 | def main() -> None: 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--model", required=True, help="Path to model ckpt") 14 | parser.add_argument("--config_path", required=True, help="Path to model config (.yaml)") 15 | parser.add_argument("--output_path", required=True, help="Path to write WAV file") 16 | parser.add_argument("--mel_input", required=False, type=str, help="mel input") 17 | parser.add_argument("--audio_input", required=False, type=str, help="audio input") 18 | args = parser.parse_args() 19 | 20 | checkpoint_path = args.model 21 | config_path = args.config_path 22 | audio_path = args.audio_input 23 | mel_path = args.mel_input 24 | 25 | ## load model for inference 26 | model = Vocos.from_hparams(config_path) 27 | raw_model = torch.load(checkpoint_path, map_location="cpu") 28 | model.load_state_dict(raw_model['state_dict'], strict=False) 29 | model.eval() 30 | 31 | # read soruce audio 32 | if audio_path: 33 | src_audio, fs = torchaudio.load(audio_path) 34 | if fs != 22050: 35 | src_audio = F.resample(src_audio, orig_freq=fs, new_freq=22050) 36 | 37 | # inference 38 | audio = model(src_audio) 39 | # read mel spectrogram 40 | elif mel_path: 41 | mel = torch.tensor(np.load(mel_path)) 42 | audio = model.decode(mel) 43 | 44 | wav_file = f'{os.path.basename(checkpoint_path)}_{os.path.basename(audio_path)}_mod.wav' 45 | torchaudio.save(wav_file, audio.cpu(), 22050, ) 46 | 47 | if __name__=="__main__": 48 | main() -------------------------------------------------------------------------------- /vocos/helpers.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from pytorch_lightning import Callback 6 | 7 | matplotlib.use("Agg") 8 | 9 | 10 | def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: 11 | """ 12 | Save a matplotlib figure to a numpy array. 13 | 14 | Args: 15 | fig (Figure): Matplotlib figure object. 16 | 17 | Returns: 18 | ndarray: Numpy array representing the figure. 19 | """ 20 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 21 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 22 | return data 23 | 24 | 25 | def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: 26 | """ 27 | Plot a spectrogram and convert it to a numpy array. 28 | 29 | Args: 30 | spectrogram (ndarray): Spectrogram data. 31 | 32 | Returns: 33 | ndarray: Numpy array representing the plotted spectrogram. 34 | """ 35 | spectrogram = spectrogram.astype(np.float32) 36 | fig, ax = plt.subplots(figsize=(12, 3)) 37 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 38 | plt.colorbar(im, ax=ax) 39 | plt.xlabel("Frames") 40 | plt.ylabel("Channels") 41 | plt.tight_layout() 42 | 43 | fig.canvas.draw() 44 | data = save_figure_to_numpy(fig) 45 | plt.close() 46 | return data 47 | 48 | 49 | class GradNormCallback(Callback): 50 | """ 51 | Callback to log the gradient norm. 52 | """ 53 | 54 | def on_after_backward(self, trainer, model): 55 | model.log("grad_norm", gradient_norm(model)) 56 | 57 | 58 | def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: 59 | """ 60 | Compute the gradient norm. 61 | 62 | Args: 63 | model (Module): PyTorch model. 64 | norm_type (float, optional): Type of the norm. Defaults to 2.0. 65 | 66 | Returns: 67 | Tensor: Gradient norm. 68 | """ 69 | grads = [p.grad for p in model.parameters() if p.grad is not None] 70 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) 71 | return total_norm 72 | -------------------------------------------------------------------------------- /configs/vocos-resnet.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosResNetBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | num_blocks: 3 51 | 52 | head: 53 | class_path: vocos.heads.ISTFTHead 54 | init_args: 55 | dim: 512 56 | n_fft: 1024 57 | hop_length: 256 58 | padding: center 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /configs/vocos-imdct.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | intermediate_dim: 1536 51 | num_layers: 8 52 | 53 | head: 54 | class_path: vocos.heads.IMDCTCosHead 55 | init_args: 56 | dim: 512 57 | mdct_frame_len: 512 # mel-spec hop_length * 2 58 | padding: center 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /configs/vocos-encodec.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 24000 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 24000 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosEncodecExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 1.0 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.EncodecFeatures 38 | init_args: 39 | encodec_model: encodec_24khz 40 | bandwidths: [1.5, 3.0, 6.0, 12.0] 41 | train_codebooks: false 42 | 43 | backbone: 44 | class_path: vocos.models.VocosBackbone 45 | init_args: 46 | input_channels: 128 47 | dim: 384 48 | intermediate_dim: 1152 49 | num_layers: 8 50 | adanorm_num_embeddings: 4 # len(bandwidths) 51 | 52 | head: 53 | class_path: vocos.heads.ISTFTHead 54 | init_args: 55 | dim: 384 56 | n_fft: 1280 57 | hop_length: 320 58 | padding: same 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /vocos/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | import torch 5 | import torchaudio 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | torch.set_num_threads(1) 10 | 11 | 12 | @dataclass 13 | class DataConfig: 14 | filelist_path: str 15 | sampling_rate: int 16 | num_samples: int 17 | batch_size: int 18 | num_workers: int 19 | 20 | 21 | class VocosDataModule(LightningDataModule): 22 | def __init__(self, train_params: DataConfig, val_params: DataConfig): 23 | super().__init__() 24 | self.train_config = train_params 25 | self.val_config = val_params 26 | 27 | def _get_dataloder(self, cfg: DataConfig, train: bool): 28 | dataset = VocosDataset(cfg, train=train) 29 | dataloader = DataLoader( 30 | dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, 31 | ) 32 | return dataloader 33 | 34 | def train_dataloader(self) -> DataLoader: 35 | return self._get_dataloder(self.train_config, train=True) 36 | 37 | def val_dataloader(self) -> DataLoader: 38 | return self._get_dataloder(self.val_config, train=False) 39 | 40 | 41 | class VocosDataset(Dataset): 42 | def __init__(self, cfg: DataConfig, train: bool): 43 | with open(cfg.filelist_path) as f: 44 | self.filelist = f.read().splitlines() 45 | self.sampling_rate = cfg.sampling_rate 46 | self.num_samples = cfg.num_samples 47 | self.train = train 48 | 49 | def __len__(self) -> int: 50 | return len(self.filelist) 51 | 52 | def __getitem__(self, index: int) -> torch.Tensor: 53 | audio_path = self.filelist[index] 54 | y, sr = torchaudio.load(audio_path) 55 | if y.size(0) > 1: 56 | # mix to mono 57 | y = y.mean(dim=0, keepdim=True) 58 | gain = np.random.uniform(-1, -6) if self.train else -3 59 | y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) 60 | if sr != self.sampling_rate: 61 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) 62 | if y.size(-1) < self.num_samples: 63 | pad_length = self.num_samples - y.size(-1) 64 | padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) 65 | y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) 66 | elif self.train: 67 | start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) 68 | y = y[:, start : start + self.num_samples] 69 | else: 70 | # During validation, take always the first segment for determinism 71 | y = y[:, : self.num_samples] 72 | 73 | return y[0] 74 | -------------------------------------------------------------------------------- /configs/vocos.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | intermediate_dim: 1536 51 | num_layers: 8 52 | 53 | head: 54 | class_path: vocos.heads.ISTFTHead 55 | init_args: 56 | dim: 512 57 | n_fft: 1024 58 | hop_length: 256 59 | padding: center 60 | 61 | melspec_loss: 62 | class_path: vocos.loss.MelSpecReconstructionLoss 63 | init_args: 64 | sample_rate: 24000 65 | n_fft: 1024 66 | hop_length: 256 67 | n_mels: 100 68 | 69 | trainer: 70 | logger: 71 | class_path: pytorch_lightning.loggers.TensorBoardLogger 72 | init_args: 73 | save_dir: logs/ 74 | callbacks: 75 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 76 | - class_path: pytorch_lightning.callbacks.ModelSummary 77 | init_args: 78 | max_depth: 2 79 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 80 | init_args: 81 | monitor: val_loss 82 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 83 | save_top_k: 3 84 | save_last: true 85 | - class_path: vocos.helpers.GradNormCallback 86 | 87 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 88 | # This equals to 1M steps per generator and 1M per discriminator 89 | max_steps: 2000000 90 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 91 | limit_val_batches: 100 92 | accelerator: gpu 93 | strategy: ddp 94 | devices: [0] 95 | log_every_n_steps: 100 96 | -------------------------------------------------------------------------------- /configs/wavenext_export.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 22050 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 22050 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 22050 25 | initial_learning_rate: 1e-3 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 # original value 0.1 28 | num_warmup_steps: 500 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 22050 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 80 43 | padding: same 44 | f_min: 0 45 | f_max: 8000 46 | norm: "slaney" 47 | mel_scale: "slaney" 48 | clip_val: 1e-5 49 | 50 | 51 | backbone: 52 | class_path: vocos.models.VocosBackbone 53 | init_args: 54 | input_channels: 80 55 | dim: 512 56 | intermediate_dim: 1536 57 | num_layers: 8 58 | 59 | head: 60 | class_path: vocos.heads.WaveNextHead 61 | init_args: 62 | dim: 512 63 | n_fft: 1024 64 | hop_length: 256 65 | padding: same 66 | 67 | melspec_loss: 68 | class_path: vocos.loss.MelSpecReconstructionLoss 69 | init_args: 70 | sample_rate: 22050 71 | n_fft: 1024 72 | hop_length: 256 73 | n_mels: 128 74 | f_min: 0 75 | f_max: 11000 76 | norm: "slaney" 77 | mel_scale: "slaney" 78 | clip_val: 1e-5 79 | 80 | 81 | trainer: 82 | logger: 83 | class_path: pytorch_lightning.loggers.TensorBoardLogger 84 | init_args: 85 | save_dir: ??? 86 | callbacks: 87 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 88 | - class_path: pytorch_lightning.callbacks.ModelSummary 89 | init_args: 90 | max_depth: 2 91 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 92 | init_args: 93 | monitor: val_loss 94 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 95 | save_top_k: 3 96 | save_last: true 97 | - class_path: vocos.helpers.GradNormCallback 98 | 99 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 100 | # This equals to 1M steps per generator and 1M per discriminator 101 | max_steps: 2000000 102 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 103 | limit_val_batches: 50 104 | accelerator: gpu 105 | strategy: ddp 106 | devices: [0] 107 | log_every_n_steps: 250 108 | -------------------------------------------------------------------------------- /configs/wavenext-encodec.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 24000 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 24000 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosEncodecExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 1.0 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.EncodecFeatures 38 | init_args: 39 | encodec_model: encodec_24khz 40 | bandwidths: [1.5, 3.0, 6.0, 12.0] 41 | train_codebooks: false 42 | 43 | backbone: 44 | class_path: vocos.models.VocosBackbone 45 | init_args: 46 | input_channels: 128 47 | dim: 384 48 | intermediate_dim: 1152 49 | num_layers: 8 50 | adanorm_num_embeddings: 4 # len(bandwidths) 51 | 52 | head: 53 | class_path: vocos.heads.WaveNextHead 54 | init_args: 55 | dim: 384 56 | n_fft: 1280 57 | hop_length: 320 58 | padding: same 59 | 60 | melspec_loss: 61 | class_path: vocos.loss.MelSpecReconstructionLoss 62 | init_args: 63 | sample_rate: 22050 64 | n_fft: 1024 65 | hop_length: 256 66 | n_mels: 128 67 | f_min: 0 68 | f_max: 11025 69 | norm: "slaney" 70 | mel_scale: "slaney" 71 | clip_val: 1e-5 72 | 73 | trainer: 74 | logger: 75 | class_path: pytorch_lightning.loggers.TensorBoardLogger 76 | init_args: 77 | save_dir: logs/ 78 | callbacks: 79 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 80 | - class_path: pytorch_lightning.callbacks.ModelSummary 81 | init_args: 82 | max_depth: 2 83 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 84 | init_args: 85 | monitor: val_loss 86 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 87 | save_top_k: 3 88 | save_last: true 89 | - class_path: vocos.helpers.GradNormCallback 90 | 91 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 92 | # This equals to 1M steps per generator and 1M per discriminator 93 | max_steps: 2000000 94 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 95 | limit_val_batches: 100 96 | accelerator: gpu 97 | strategy: ddp 98 | devices: [0] 99 | log_every_n_steps: 100 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wavenext_pytorch 2 | ## WaveNext: ConvNext-Based fast neural vocoder without ISTFT layer 3 | ### Autorhs: Takuma Okamoto, Haruki Yamashita Yamato Ohtani1, Tomoki Toda and Hisashi Kawai1 4 | 5 | Unofficial implementation of wavenext neural vocoder(WIP) 6 | 7 | [WaveNext](https://ieeexplore.ieee.org/document/10389765) proposed to replace the ISTFT final layer of Vocos with a linear layer without bias followed by a reshape op. As this is a slight modification of vocos we're just using the [official vocos implementation](https://github.com/gemelo-ai/vocos) and adding the WaveNext head in wavenext_pytorch/vocos/heads.py 8 | 9 | ![WaveNext](wavenext_architecture.png) 10 | 11 | We also added the modifications in the feature extraction and mel spec loss to make it compatible with the HifiGAN features, However, you can also use the original features from Vocos. 12 | 13 | ## Installation 14 | 15 | To use Vocos only in inference mode, install it using: 16 | 17 | ```bash 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | If you wish to train the model, install it with additional dependencies: 22 | 23 | ```bash 24 | pip install -r requirements-train.txt 25 | ``` 26 | 27 | 28 | 29 | ## Training 30 | 31 | 32 | Prepare a filelist of audio files for the training and validation set: 33 | 34 | ```bash 35 | find $TRAIN_DATASET_DIR -name *.wav > filelist.train 36 | find $VAL_DATASET_DIR -name *.wav > filelist.val 37 | ``` 38 | 39 | Fill a config file, e.g. [wavenext.yaml](configs%2Fwavenext.yaml), with your filelist paths and start training with: 40 | 41 | ```bash 42 | python train.py -c configs/wavenext.yaml 43 | ``` 44 | 45 | Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the 46 | training pipeline. 47 | 48 | ## Trained checkpoints. 49 | 50 | Pre-trained models 51 | 52 | | Model Name | Dataset | Training Iterations | Parameters 53 | |------------------------------------------------------------------------|---------------|-------------------|------------| 54 | | [BSC-LT/wavenext-mel](https://huggingface.co/BSC-LT/wavenext-mel)| LibriTTS + LJSpeech + openslr69 + festcat | 1M | 13.68M | 55 | 56 | 57 | 58 | ## Todo 59 | 60 | - [X] Add tensorboards. 61 | - [X] Add encodec config. 62 | 63 | ## Citation 64 | 65 | If this code contributes to your research, please cite the work: 66 | ``` 67 | @INPROCEEDINGS{10389765, 68 | author={Okamoto, Takuma and Yamashita, Haruki and Ohtani, Yamato and Toda, Tomoki and Kawai, Hisashi}, 69 | booktitle={2023 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU)}, 70 | title={WaveNeXt: ConvNeXt-Based Fast Neural Vocoder Without ISTFT layer}, 71 | year={2023}, 72 | volume={}, 73 | number={}, 74 | pages={1-8}, 75 | keywords={Fourier transforms;Vocoders;Conferences;Automatic speech recognition;ConvNext;end-to-end text-to-speech;linear layer-based upsampling;neural vocoder;Vocos}, 76 | doi={10.1109/ASRU57964.2023.10389765}} 77 | ``` 78 | 79 | 80 | ## License 81 | 82 | The code in this repository is released under the MIT license as found in the 83 | [LICENSE](LICENSE) file. 84 | -------------------------------------------------------------------------------- /configs/vocos-matcha.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 22050 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 22050 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 22050 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 22050 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 80 43 | padding: same 44 | f_min: 0 45 | f_max: 8000 46 | norm: "slaney" 47 | mel_scale: "slaney" 48 | clip_val: 1e-5 49 | 50 | 51 | backbone: 52 | class_path: vocos.models.VocosBackbone 53 | init_args: 54 | input_channels: 80 55 | dim: 512 56 | intermediate_dim: 1536 57 | num_layers: 8 58 | 59 | head: 60 | class_path: vocos.heads.ISTFTHead 61 | init_args: 62 | dim: 512 63 | n_fft: 1024 64 | hop_length: 256 65 | padding: same 66 | 67 | melspec_loss: 68 | class_path: vocos.loss.MelSpecReconstructionLoss 69 | init_args: 70 | sample_rate: 22050 71 | n_fft: 1024 72 | hop_length: 256 73 | n_mels: 80 74 | f_min: 0 75 | f_max: 8000 76 | norm: "slaney" 77 | mel_scale: "slaney" 78 | clip_val: 1e-5 79 | 80 | 81 | trainer: 82 | logger: 83 | class_path: pytorch_lightning.loggers.TensorBoardLogger 84 | init_args: 85 | save_dir: ?? 86 | callbacks: 87 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 88 | - class_path: pytorch_lightning.callbacks.ModelSummary 89 | init_args: 90 | max_depth: 2 91 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 92 | init_args: 93 | monitor: val_loss 94 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 95 | save_top_k: 3 96 | save_last: true 97 | - class_path: vocos.helpers.GradNormCallback 98 | 99 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 100 | # This equals to 1M steps per generator and 1M per discriminator 101 | max_steps: 2000000 102 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 103 | limit_val_batches: 100 104 | accelerator: gpu 105 | strategy: ddp 106 | devices: [0] 107 | log_every_n_steps: 100 108 | -------------------------------------------------------------------------------- /configs/wavenext.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 22050 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 22050 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 22050 25 | initial_learning_rate: 1e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 # original value 0.1 28 | num_warmup_steps: 500 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 22050 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 80 43 | padding: same 44 | f_min: 0 45 | f_max: 8000 46 | norm: "slaney" 47 | mel_scale: "slaney" 48 | clip_val: 1e-5 49 | 50 | 51 | backbone: 52 | class_path: vocos.models.VocosBackbone 53 | init_args: 54 | input_channels: 80 55 | dim: 512 56 | intermediate_dim: 1536 57 | num_layers: 8 58 | 59 | head: 60 | class_path: vocos.heads.WaveNextHead 61 | init_args: 62 | dim: 512 63 | n_fft: 1024 64 | hop_length: 256 65 | padding: same 66 | 67 | melspec_loss: 68 | class_path: vocos.loss.MelSpecReconstructionLoss 69 | init_args: 70 | sample_rate: 22050 71 | n_fft: 1024 72 | hop_length: 256 73 | n_mels: 128 74 | f_min: 0 75 | f_max: 11025 76 | norm: "slaney" 77 | mel_scale: "slaney" 78 | clip_val: 1e-5 79 | 80 | 81 | trainer: 82 | logger: 83 | class_path: pytorch_lightning.loggers.TensorBoardLogger 84 | init_args: 85 | save_dir: ??? 86 | callbacks: 87 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 88 | - class_path: pytorch_lightning.callbacks.ModelSummary 89 | init_args: 90 | max_depth: 2 91 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 92 | init_args: 93 | monitor: val_loss 94 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 95 | save_top_k: 3 96 | save_last: true 97 | - class_path: vocos.helpers.GradNormCallback 98 | 99 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 100 | # This equals to 1M steps per generator and 1M per discriminator 101 | max_steps: 1000000 102 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 103 | limit_val_batches: 50 104 | accelerator: gpu 105 | strategy: ddp 106 | devices: [0] 107 | log_every_n_steps: 250 108 | -------------------------------------------------------------------------------- /metrics/periodicity.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | import torchaudio 5 | import torchcrepe 6 | from torchcrepe.loudness import REF_DB 7 | 8 | SILENCE_THRESHOLD = -60 9 | UNVOICED_THRESHOLD = 0.21 10 | 11 | """ 12 | Periodicity metrics adapted from https://github.com/descriptinc/cargan 13 | """ 14 | 15 | 16 | def predict_pitch( 17 | audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD 18 | ): 19 | """ 20 | Predicts pitch and periodicity for the given audio. 21 | 22 | Args: 23 | audio (Tensor): The audio waveform. 24 | silence_threshold (float): The threshold for silence detection. 25 | unvoiced_treshold (float): The threshold for unvoiced detection. 26 | 27 | Returns: 28 | pitch (ndarray): The predicted pitch. 29 | periodicity (ndarray): The predicted periodicity. 30 | """ 31 | # torchcrepe inference 32 | pitch, periodicity = torchcrepe.predict( 33 | audio, 34 | fmin=50.0, 35 | fmax=550, 36 | sample_rate=torchcrepe.SAMPLE_RATE, 37 | model="full", 38 | return_periodicity=True, 39 | device=audio.device, 40 | pad=False, 41 | ) 42 | pitch = pitch.cpu().numpy() 43 | periodicity = periodicity.cpu().numpy() 44 | 45 | # Calculate dB-scaled spectrogram and set low energy frames to unvoiced 46 | hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE 47 | stft = torchaudio.functional.spectrogram( 48 | audio, 49 | window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device), 50 | n_fft=torchcrepe.WINDOW_SIZE, 51 | hop_length=hop_length, 52 | win_length=torchcrepe.WINDOW_SIZE, 53 | power=2, 54 | normalized=False, 55 | pad=0, 56 | center=False, 57 | ) 58 | 59 | # Perceptual weighting 60 | freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE) 61 | perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB 62 | silence = perceptual_stft.mean(axis=1) < silence_threshold 63 | 64 | periodicity[silence] = 0 65 | pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED 66 | 67 | return pitch, periodicity 68 | 69 | 70 | def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor): 71 | """ 72 | Calculates periodicity metrics for the predicted and true audio data. 73 | 74 | Args: 75 | y (Tensor): The true audio data. 76 | y_hat (Tensor): The predicted audio data. 77 | 78 | Returns: 79 | periodicity_loss (float): The periodicity loss. 80 | pitch_loss (float): The pitch loss. 81 | f1 (float): The F1 score for voiced/unvoiced classification 82 | """ 83 | true_pitch, true_periodicity = predict_pitch(y) 84 | pred_pitch, pred_periodicity = predict_pitch(y_hat) 85 | 86 | true_voiced = ~np.isnan(true_pitch) 87 | pred_voiced = ~np.isnan(pred_pitch) 88 | 89 | periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean() 90 | 91 | # Update pitch rmse 92 | voiced = true_voiced & pred_voiced 93 | difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced])) 94 | pitch_loss = np.sqrt((difference_cents ** 2).mean()) 95 | 96 | # voiced/unvoiced precision and recall 97 | true_positives = (true_voiced & pred_voiced).sum() 98 | false_positives = (~true_voiced & pred_voiced).sum() 99 | false_negatives = (true_voiced & ~pred_voiced).sum() 100 | 101 | precision = true_positives / (true_positives + false_positives) 102 | recall = true_positives / (true_positives + false_negatives) 103 | f1 = 2 * precision * recall / (precision + recall) 104 | 105 | return periodicity_loss, pitch_loss, f1 106 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # src https://github.com/gemelo-ai/vocos/issues/38 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | import yaml 12 | from torch import nn 13 | 14 | from vocos.pretrained import Vocos 15 | from vocos.loss import MelSpecReconstructionLoss 16 | 17 | DEFAULT_OPSET_VERSION = 15 18 | _LOGGER = logging.getLogger("export_onnx") 19 | 20 | 21 | class VocosGen(nn.Module): 22 | def __init__(self, vocos): 23 | super().__init__() 24 | self.vocos = vocos 25 | 26 | def forward(self, mels): 27 | x = self.vocos.backbone(mels) 28 | waveform = self.vocos.head(x) 29 | return waveform 30 | 31 | 32 | def export_generator(config_path, checkpoint_path, output_dir, opset_version): 33 | 34 | with open(config_path, "r") as f: 35 | config = yaml.safe_load(f) 36 | 37 | class_module, class_name = config["model"]["class_path"].rsplit(".", 1) 38 | module = __import__(class_module, fromlist=[class_name]) 39 | vocos_cls = getattr(module, class_name) 40 | 41 | components = Vocos.from_hparams(config_path) 42 | print(module, class_module) 43 | params = config["model"]["init_args"] 44 | 45 | vocos = vocos_cls( 46 | feature_extractor=components.feature_extractor, 47 | backbone=components.backbone, 48 | head=components.head, 49 | sample_rate=params["sample_rate"], 50 | initial_learning_rate=params["initial_learning_rate"], 51 | num_warmup_steps=params["num_warmup_steps"], 52 | mel_loss_coeff=params["mel_loss_coeff"], 53 | mrd_loss_coeff=params["mrd_loss_coeff"], 54 | melspec_loss=MelSpecReconstructionLoss 55 | 56 | ) 57 | 58 | if checkpoint_path.endswith(".bin"): 59 | state_dict = torch.load(checkpoint_path, map_location="cpu") 60 | vocos.load_state_dict(state_dict, strict=False) 61 | 62 | elif checkpoint_path.endswith(".ckpt"): 63 | raw_model = torch.load(checkpoint_path, map_location="cpu") 64 | vocos.load_state_dict(raw_model['state_dict'], strict=False) 65 | 66 | model = VocosGen(vocos) 67 | model.eval() 68 | 69 | Path(output_dir).mkdir(parents=True, exist_ok=True) 70 | onnx_filename = f"mel_spec_22khz_wavenext.onnx" 71 | onnx_path = os.path.join(output_dir, onnx_filename) 72 | 73 | dummy_input = torch.rand(1, vocos.backbone.input_channels, 64) 74 | dynamic_axes = { 75 | "mels": {0: "batch_size", 2: "time"}, 76 | } 77 | 78 | #Conventional ONNX export 79 | torch.onnx.export( 80 | model=model, 81 | args=dummy_input, 82 | f=onnx_path, 83 | input_names=["mels"], 84 | output_names=["waveform"], 85 | dynamic_axes=dynamic_axes, 86 | opset_version=opset_version, 87 | export_params=True, 88 | do_constant_folding=True, 89 | ) 90 | 91 | return onnx_path 92 | 93 | 94 | def main(): 95 | logging.basicConfig(level=logging.DEBUG) 96 | 97 | parser = argparse.ArgumentParser( 98 | prog="export_onnx", 99 | description="Export a wavenext checkpoint to onnx", 100 | ) 101 | 102 | parser.add_argument("--config", type=str, required=True) 103 | parser.add_argument("--checkpoint", type=str, required=True) 104 | parser.add_argument("--output-dir", type=str, required=True) 105 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 106 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET_VERSION) 107 | 108 | args = parser.parse_args() 109 | 110 | random.seed(args.seed) 111 | np.random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | torch.cuda.manual_seed(args.seed) 114 | torch.backends.cudnn.deterministic = True 115 | torch.backends.cudnn.benchmark = False 116 | 117 | _LOGGER.info("Exporting model to ONNX") 118 | _LOGGER.info(f"Config path: `{args.config}`") 119 | _LOGGER.info(f"Using checkpoint: `{args.checkpoint}`") 120 | onnx_path = export_generator( 121 | config_path=args.config, 122 | checkpoint_path=args.checkpoint, 123 | output_dir=args.output_dir, 124 | opset_version=args.opset 125 | ) 126 | _LOGGER.info(f"Exported ONNX model to: `{onnx_path}`") 127 | 128 | 129 | if __name__ == '__main__': 130 | main() -------------------------------------------------------------------------------- /README_vocos.md: -------------------------------------------------------------------------------- 1 | # Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis 2 | 3 | [Audio samples](https://gemelo-ai.github.io/vocos/) | 4 | Paper [[abs]](https://arxiv.org/abs/2306.00814) [[pdf]](https://arxiv.org/pdf/2306.00814.pdf) 5 | 6 | Vocos is a fast neural vocoder designed to synthesize audio waveforms from acoustic features. Trained using a Generative 7 | Adversarial Network (GAN) objective, Vocos can generate waveforms in a single forward pass. Unlike other typical 8 | GAN-based vocoders, Vocos does not model audio samples in the time domain. Instead, it generates spectral 9 | coefficients, facilitating rapid audio reconstruction through inverse Fourier transform. 10 | 11 | ## Installation 12 | 13 | To use Vocos only in inference mode, install it using: 14 | 15 | ```bash 16 | pip install vocos 17 | ``` 18 | 19 | If you wish to train the model, install it with additional dependencies: 20 | 21 | ```bash 22 | pip install vocos[train] 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Reconstruct audio from mel-spectrogram 28 | 29 | ```python 30 | import torch 31 | 32 | from vocos import Vocos 33 | 34 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 35 | 36 | mel = torch.randn(1, 100, 256) # B, C, T 37 | audio = vocos.decode(mel) 38 | ``` 39 | 40 | Copy-synthesis from a file: 41 | 42 | ```python 43 | import torchaudio 44 | 45 | y, sr = torchaudio.load(YOUR_AUDIO_FILE) 46 | if y.size(0) > 1: # mix to mono 47 | y = y.mean(dim=0, keepdim=True) 48 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) 49 | y_hat = vocos(y) 50 | ``` 51 | 52 | ### Reconstruct audio from EnCodec tokens 53 | 54 | Additionally, you need to provide a `bandwidth_id` which corresponds to the embedding for bandwidth from the 55 | list: `[1.5, 3.0, 6.0, 12.0]`. 56 | 57 | ```python 58 | vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz") 59 | 60 | audio_tokens = torch.randint(low=0, high=1024, size=(8, 200)) # 8 codeboooks, 200 frames 61 | features = vocos.codes_to_features(audio_tokens) 62 | bandwidth_id = torch.tensor([2]) # 6 kbps 63 | 64 | audio = vocos.decode(features, bandwidth_id=bandwidth_id) 65 | ``` 66 | 67 | Copy-synthesis from a file: It extracts and quantizes features with EnCodec, then reconstructs them with Vocos in a 68 | single forward pass. 69 | 70 | ```python 71 | y, sr = torchaudio.load(YOUR_AUDIO_FILE) 72 | if y.size(0) > 1: # mix to mono 73 | y = y.mean(dim=0, keepdim=True) 74 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) 75 | 76 | y_hat = vocos(y, bandwidth_id=bandwidth_id) 77 | ``` 78 | 79 | ### Integrate with 🐶 [Bark](https://github.com/suno-ai/bark) text-to-audio model 80 | 81 | See [example notebook](notebooks%2FBark%2BVocos.ipynb). 82 | 83 | ## Pre-trained models 84 | 85 | | Model Name | Dataset | Training Iterations | Parameters 86 | |-------------------------------------------------------------------------------------|---------------|-------------------|------------| 87 | | [charactr/vocos-mel-24khz](https://huggingface.co/charactr/vocos-mel-24khz) | LibriTTS | 1M | 13.5M 88 | | [charactr/vocos-encodec-24khz](https://huggingface.co/charactr/vocos-encodec-24khz) | DNS Challenge | 2M | 7.9M 89 | 90 | ## Training 91 | 92 | Prepare a filelist of audio files for the training and validation set: 93 | 94 | ```bash 95 | find $TRAIN_DATASET_DIR -name *.wav > filelist.train 96 | find $VAL_DATASET_DIR -name *.wav > filelist.val 97 | ``` 98 | 99 | Fill a config file, e.g. [vocos.yaml](configs%2Fvocos.yaml), with your filelist paths and start training with: 100 | 101 | ```bash 102 | python train.py -c configs/vocos.yaml 103 | ``` 104 | 105 | Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the 106 | training pipeline. 107 | 108 | ## Citation 109 | 110 | If this code contributes to your research, please cite our work: 111 | 112 | ``` 113 | @article{siuzdak2023vocos, 114 | title={Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis}, 115 | author={Siuzdak, Hubert}, 116 | journal={arXiv preprint arXiv:2306.00814}, 117 | year={2023} 118 | } 119 | ``` 120 | 121 | ## License 122 | 123 | The code in this repository is released under the MIT license as found in the 124 | [LICENSE](LICENSE) file. 125 | -------------------------------------------------------------------------------- /vocos/feature_extractors.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torchaudio 5 | from encodec import EncodecModel 6 | from torch import nn 7 | 8 | from vocos.modules import safe_log 9 | 10 | 11 | class FeatureExtractor(nn.Module): 12 | """Base class for feature extractors.""" 13 | 14 | def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: 15 | """ 16 | Extract features from the given audio. 17 | 18 | Args: 19 | audio (Tensor): Input audio waveform. 20 | 21 | Returns: 22 | Tensor: Extracted features of shape (B, C, L), where B is the batch size, 23 | C denotes output features, and L is the sequence length. 24 | """ 25 | raise NotImplementedError("Subclasses must implement the forward method.") 26 | 27 | 28 | class MelSpectrogramFeatures(FeatureExtractor): 29 | def __init__(self, 30 | sample_rate=24000, 31 | n_fft=1024, 32 | hop_length=256, 33 | n_mels=100, 34 | padding="center", 35 | f_min=0, # to match matcha :X 36 | f_max=None, 37 | norm=None, 38 | mel_scale="htk", 39 | clip_val=1e-7): 40 | super().__init__() 41 | if padding not in ["center", "same"]: 42 | raise ValueError("Padding must be 'center' or 'same'.") 43 | self.padding = padding 44 | self.clip_val = clip_val 45 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 46 | sample_rate=sample_rate, 47 | n_fft=n_fft, 48 | hop_length=hop_length, 49 | n_mels=n_mels, 50 | center=padding == "center", 51 | power=1, 52 | f_min=f_min, # to match matcha :X 53 | f_max=f_max, 54 | norm=norm, 55 | mel_scale=mel_scale 56 | ) 57 | 58 | def forward(self, audio, **kwargs): 59 | if self.padding == "same": 60 | pad = self.mel_spec.win_length - self.mel_spec.hop_length 61 | audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") 62 | mel = self.mel_spec(audio) 63 | features = safe_log(mel, clip_val=self.clip_val) 64 | return features 65 | 66 | 67 | class EncodecFeatures(FeatureExtractor): 68 | def __init__( 69 | self, 70 | encodec_model: str = "encodec_24khz", 71 | bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], 72 | train_codebooks: bool = False, 73 | ): 74 | super().__init__() 75 | if encodec_model == "encodec_24khz": 76 | encodec = EncodecModel.encodec_model_24khz 77 | elif encodec_model == "encodec_48khz": 78 | encodec = EncodecModel.encodec_model_48khz 79 | else: 80 | raise ValueError( 81 | f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." 82 | ) 83 | self.encodec = encodec(pretrained=True) 84 | for param in self.encodec.parameters(): 85 | param.requires_grad = False 86 | self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( 87 | self.encodec.frame_rate, bandwidth=max(bandwidths) 88 | ) 89 | codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) 90 | self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) 91 | self.bandwidths = bandwidths 92 | 93 | @torch.no_grad() 94 | def get_encodec_codes(self, audio): 95 | audio = audio.unsqueeze(1) 96 | emb = self.encodec.encoder(audio) 97 | codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) 98 | return codes 99 | 100 | def forward(self, audio: torch.Tensor, **kwargs): 101 | bandwidth_id = kwargs.get("bandwidth_id") 102 | if bandwidth_id is None: 103 | raise ValueError("The 'bandwidth_id' argument is required") 104 | self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode 105 | self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id]) 106 | codes = self.get_encodec_codes(audio) 107 | # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` 108 | # with offsets given by the number of bins, and finally summed in a vectorized operation. 109 | offsets = torch.arange( 110 | 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device 111 | ) 112 | embeddings_idxs = codes + offsets.view(-1, 1, 1) 113 | features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) 114 | return features.transpose(1, 2) 115 | -------------------------------------------------------------------------------- /vocos/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import weight_norm 6 | 7 | from vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm 8 | 9 | 10 | class Backbone(nn.Module): 11 | """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" 12 | 13 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 14 | """ 15 | Args: 16 | x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, 17 | C denotes output features, and L is the sequence length. 18 | 19 | Returns: 20 | Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, 21 | and H denotes the model dimension. 22 | """ 23 | raise NotImplementedError("Subclasses must implement the forward method.") 24 | 25 | 26 | class VocosBackbone(Backbone): 27 | """ 28 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization 29 | 30 | Args: 31 | input_channels (int): Number of input features channels. 32 | dim (int): Hidden dimension of the model. 33 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. 34 | num_layers (int): Number of ConvNeXtBlock layers. 35 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. 36 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 37 | None means non-conditional model. Defaults to None. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | input_channels: int, 43 | dim: int, 44 | intermediate_dim: int, 45 | num_layers: int, 46 | layer_scale_init_value: Optional[float] = None, 47 | adanorm_num_embeddings: Optional[int] = None, 48 | ): 49 | super().__init__() 50 | self.input_channels = input_channels 51 | self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) 52 | self.adanorm = adanorm_num_embeddings is not None 53 | if adanorm_num_embeddings: 54 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 55 | else: 56 | self.norm = nn.LayerNorm(dim, eps=1e-6) 57 | layer_scale_init_value = layer_scale_init_value or 1 / num_layers 58 | self.convnext = nn.ModuleList( 59 | [ 60 | ConvNeXtBlock( 61 | dim=dim, 62 | intermediate_dim=intermediate_dim, 63 | layer_scale_init_value=layer_scale_init_value, 64 | adanorm_num_embeddings=adanorm_num_embeddings, 65 | ) 66 | for _ in range(num_layers) 67 | ] 68 | ) 69 | self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) 70 | self.apply(self._init_weights) 71 | 72 | def _init_weights(self, m): 73 | if isinstance(m, (nn.Conv1d, nn.Linear)): 74 | nn.init.trunc_normal_(m.weight, std=0.02) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 78 | bandwidth_id = kwargs.get('bandwidth_id', None) 79 | x = self.embed(x) 80 | if self.adanorm: 81 | assert bandwidth_id is not None 82 | x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) 83 | else: 84 | x = self.norm(x.transpose(1, 2)) 85 | x = x.transpose(1, 2) 86 | for conv_block in self.convnext: 87 | x = conv_block(x, cond_embedding_id=bandwidth_id) 88 | x = self.final_layer_norm(x.transpose(1, 2)) 89 | return x 90 | 91 | 92 | class VocosResNetBackbone(Backbone): 93 | """ 94 | Vocos backbone module built with ResBlocks. 95 | 96 | Args: 97 | input_channels (int): Number of input features channels. 98 | dim (int): Hidden dimension of the model. 99 | num_blocks (int): Number of ResBlock1 blocks. 100 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. 101 | """ 102 | 103 | def __init__( 104 | self, input_channels, dim, num_blocks, layer_scale_init_value=None, 105 | ): 106 | super().__init__() 107 | self.input_channels = input_channels 108 | self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) 109 | layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 110 | self.resnet = nn.Sequential( 111 | *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] 112 | ) 113 | 114 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 115 | x = self.embed(x) 116 | x = self.resnet(x) 117 | x = x.transpose(1, 2) 118 | return x 119 | -------------------------------------------------------------------------------- /vocos/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torchaudio 5 | from torch import nn 6 | 7 | from vocos.modules import safe_log 8 | 9 | 10 | class MelSpecReconstructionLoss(nn.Module): 11 | """ 12 | L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample 13 | """ 14 | 15 | def __init__(self, 16 | sample_rate: int = 24000, 17 | n_fft: int = 1024, 18 | hop_length: int = 256, 19 | n_mels: int = 100, 20 | f_min: float = 0, 21 | f_max: float = None, 22 | norm: str = None, 23 | mel_scale: str = "htk", 24 | clip_val: float = 1e-7 25 | ): 26 | super().__init__() 27 | self.clip_val = clip_val 28 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 29 | sample_rate=sample_rate, 30 | n_fft=n_fft, 31 | hop_length=hop_length, 32 | n_mels=n_mels, 33 | center=True, 34 | power=1, 35 | f_min=f_min, 36 | f_max=f_max, 37 | norm=norm, 38 | mel_scale=mel_scale 39 | ) 40 | 41 | def forward(self, y_hat, y) -> torch.Tensor: 42 | """ 43 | Args: 44 | y_hat (Tensor): Predicted audio waveform. 45 | y (Tensor): Ground truth audio waveform. 46 | 47 | Returns: 48 | Tensor: L1 loss between the mel-scaled magnitude spectrograms. 49 | """ 50 | mel_hat = safe_log(self.mel_spec(y_hat), clip_val=self.clip_val) 51 | mel = safe_log(self.mel_spec(y), clip_val=self.clip_val) 52 | 53 | loss = torch.nn.functional.l1_loss(mel, mel_hat) 54 | 55 | return loss 56 | 57 | 58 | class GeneratorLoss(nn.Module): 59 | """ 60 | Generator Loss module. Calculates the loss for the generator based on discriminator outputs. 61 | """ 62 | 63 | def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: 64 | """ 65 | Args: 66 | disc_outputs (List[Tensor]): List of discriminator outputs. 67 | 68 | Returns: 69 | Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from 70 | the sub-discriminators 71 | """ 72 | loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype) 73 | gen_losses = [] 74 | for dg in disc_outputs: 75 | l = torch.mean(torch.clamp(1 - dg, min=0)) 76 | gen_losses.append(l) 77 | loss += l 78 | 79 | return loss, gen_losses 80 | 81 | 82 | class DiscriminatorLoss(nn.Module): 83 | """ 84 | Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. 85 | """ 86 | 87 | def forward( 88 | self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] 89 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: 90 | """ 91 | Args: 92 | disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. 93 | disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. 94 | 95 | Returns: 96 | Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from 97 | the sub-discriminators for real outputs, and a list of 98 | loss values for generated outputs. 99 | """ 100 | loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype) 101 | r_losses = [] 102 | g_losses = [] 103 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 104 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 105 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 106 | loss += r_loss + g_loss 107 | r_losses.append(r_loss) 108 | g_losses.append(g_loss) 109 | 110 | return loss, r_losses, g_losses 111 | 112 | 113 | class FeatureMatchingLoss(nn.Module): 114 | """ 115 | Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. 116 | """ 117 | 118 | def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: 119 | """ 120 | Args: 121 | fmap_r (List[List[Tensor]]): List of feature maps from real samples. 122 | fmap_g (List[List[Tensor]]): List of feature maps from generated samples. 123 | 124 | Returns: 125 | Tensor: The calculated feature matching loss. 126 | """ 127 | loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) 128 | for dr, dg in zip(fmap_r, fmap_g): 129 | for rl, gl in zip(dr, dg): 130 | loss += torch.mean(torch.abs(rl - gl)) 131 | 132 | return loss 133 | -------------------------------------------------------------------------------- /vocos/pretrained.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, Tuple, Union, Optional 4 | 5 | import torch 6 | import yaml 7 | from huggingface_hub import hf_hub_download 8 | from torch import nn 9 | from vocos.feature_extractors import FeatureExtractor, EncodecFeatures 10 | from vocos.heads import FourierHead 11 | from vocos.models import Backbone 12 | 13 | 14 | def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: 15 | """Instantiates a class with the given args and init. 16 | 17 | Args: 18 | args: Positional arguments required for instantiation. 19 | init: Dict of the form {"class_path":...,"init_args":...}. 20 | 21 | Returns: 22 | The instantiated class object. 23 | """ 24 | kwargs = init.get("init_args", {}) 25 | if not isinstance(args, tuple): 26 | args = (args,) 27 | class_module, class_name = init["class_path"].rsplit(".", 1) 28 | module = __import__(class_module, fromlist=[class_name]) 29 | args_class = getattr(module, class_name) 30 | return args_class(*args, **kwargs) 31 | 32 | 33 | class Vocos(nn.Module): 34 | """ 35 | The Vocos class represents a Fourier-based neural vocoder for audio synthesis. 36 | This class is primarily designed for inference, with support for loading from pretrained 37 | model checkpoints. It consists of three main components: a feature extractor, 38 | a backbone, and a head. 39 | """ 40 | 41 | def __init__( 42 | self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, 43 | ): 44 | super().__init__() 45 | self.feature_extractor = feature_extractor 46 | self.backbone = backbone 47 | self.head = head 48 | 49 | @classmethod 50 | def from_hparams(cls, config_path: str) -> Vocos: 51 | """ 52 | Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. 53 | """ 54 | with open(config_path, "r") as f: 55 | config = yaml.safe_load(f) 56 | feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) 57 | backbone = instantiate_class(args=(), init=config["backbone"]) 58 | head = instantiate_class(args=(), init=config["head"]) 59 | model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) 60 | return model 61 | 62 | @classmethod 63 | def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: 64 | """ 65 | Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. 66 | """ 67 | config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision) 68 | model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision) 69 | model = cls.from_hparams(config_path) 70 | state_dict = torch.load(model_path, map_location="cpu") 71 | if isinstance(model.feature_extractor, EncodecFeatures): 72 | encodec_parameters = { 73 | "feature_extractor.encodec." + key: value 74 | for key, value in model.feature_extractor.encodec.state_dict().items() 75 | } 76 | state_dict.update(encodec_parameters) 77 | model.load_state_dict(state_dict) 78 | model.eval() 79 | return model 80 | 81 | @torch.inference_mode() 82 | def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: 83 | """ 84 | Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, 85 | which is then passed through the backbone and the head to reconstruct the audio output. 86 | 87 | Args: 88 | audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), 89 | where B is the batch size and L is the waveform length. 90 | 91 | 92 | Returns: 93 | Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). 94 | """ 95 | features = self.feature_extractor(audio_input, **kwargs) 96 | audio_output = self.decode(features, **kwargs) 97 | return audio_output 98 | 99 | @torch.inference_mode() 100 | def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: 101 | """ 102 | Method to decode audio waveform from already calculated features. The features input is passed through 103 | the backbone and the head to reconstruct the audio output. 104 | 105 | Args: 106 | features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, 107 | C denotes the feature dimension, and L is the sequence length. 108 | 109 | Returns: 110 | Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). 111 | """ 112 | x = self.backbone(features_input, **kwargs) 113 | audio_output = self.head(x) 114 | return audio_output 115 | 116 | @torch.inference_mode() 117 | def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: 118 | """ 119 | Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's 120 | codebook weights. 121 | 122 | Args: 123 | codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), 124 | where K is the number of codebooks, B is the batch size and L is the sequence length. 125 | 126 | Returns: 127 | Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, 128 | and L is the sequence length. 129 | """ 130 | assert isinstance( 131 | self.feature_extractor, EncodecFeatures 132 | ), "Feature extractor should be an instance of EncodecFeatures" 133 | 134 | if codes.dim() == 2: 135 | codes = codes.unsqueeze(1) 136 | 137 | n_bins = self.feature_extractor.encodec.quantizer.bins 138 | offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) 139 | embeddings_idxs = codes + offsets.view(-1, 1, 1) 140 | features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) 141 | features = features.transpose(1, 2) 142 | 143 | return features 144 | -------------------------------------------------------------------------------- /metrics/UTMOS.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fairseq 4 | import pytorch_lightning as pl 5 | import requests 6 | import torch 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | 10 | UTMOS_CKPT_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt" 11 | WAV2VEC_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt" 12 | 13 | """ 14 | UTMOS score, automatic Mean Opinion Score (MOS) prediction system, 15 | adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo 16 | """ 17 | 18 | 19 | class UTMOSScore: 20 | """Predicting score for each audio clip.""" 21 | 22 | def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"): 23 | self.device = device 24 | filepath = os.path.join(os.path.dirname(__file__), ckpt_path) 25 | if not os.path.exists(filepath): 26 | download_file(UTMOS_CKPT_URL, filepath) 27 | self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device) 28 | 29 | def score(self, wavs: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Args: 32 | wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, 33 | the model processes the input as a single audio clip. The model 34 | performs batch processing when len(wavs) == 3. 35 | """ 36 | if len(wavs.shape) == 1: 37 | out_wavs = wavs.unsqueeze(0).unsqueeze(0) 38 | elif len(wavs.shape) == 2: 39 | out_wavs = wavs.unsqueeze(0) 40 | elif len(wavs.shape) == 3: 41 | out_wavs = wavs 42 | else: 43 | raise ValueError("Dimension of input tensor needs to be <= 3.") 44 | bs = out_wavs.shape[0] 45 | batch = { 46 | "wav": out_wavs, 47 | "domains": torch.zeros(bs, dtype=torch.int).to(self.device), 48 | "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, 49 | } 50 | with torch.no_grad(): 51 | output = self.model(batch) 52 | 53 | return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 54 | 55 | 56 | def download_file(url, filename): 57 | """ 58 | Downloads a file from the given URL 59 | 60 | Args: 61 | url (str): The URL of the file to download. 62 | filename (str): The name to save the file as. 63 | """ 64 | print(f"Downloading file {filename}...") 65 | response = requests.get(url, stream=True) 66 | response.raise_for_status() 67 | 68 | total_size_in_bytes = int(response.headers.get("content-length", 0)) 69 | progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) 70 | 71 | with open(filename, "wb") as f: 72 | for chunk in response.iter_content(chunk_size=8192): 73 | progress_bar.update(len(chunk)) 74 | f.write(chunk) 75 | 76 | progress_bar.close() 77 | 78 | 79 | def load_ssl_model(ckpt_path="wav2vec_small.pt"): 80 | filepath = os.path.join(os.path.dirname(__file__), ckpt_path) 81 | if not os.path.exists(filepath): 82 | download_file(WAV2VEC_URL, filepath) 83 | SSL_OUT_DIM = 768 84 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([filepath]) 85 | ssl_model = model[0] 86 | ssl_model.remove_pretraining_modules() 87 | return SSL_model(ssl_model, SSL_OUT_DIM) 88 | 89 | 90 | class BaselineLightningModule(pl.LightningModule): 91 | def __init__(self, cfg): 92 | super().__init__() 93 | self.cfg = cfg 94 | self.construct_model() 95 | self.save_hyperparameters() 96 | 97 | def construct_model(self): 98 | self.feature_extractors = nn.ModuleList( 99 | [load_ssl_model(ckpt_path="wav2vec_small.pt"), DomainEmbedding(3, 128),] 100 | ) 101 | output_dim = sum([feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) 102 | output_layers = [LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)] 103 | output_dim = output_layers[-1].get_output_dim() 104 | output_layers.append( 105 | Projection(hidden_dim=2048, activation=torch.nn.ReLU(), range_clipping=False, input_dim=output_dim) 106 | ) 107 | 108 | self.output_layers = nn.ModuleList(output_layers) 109 | 110 | def forward(self, inputs): 111 | outputs = {} 112 | for feature_extractor in self.feature_extractors: 113 | outputs.update(feature_extractor(inputs)) 114 | x = outputs 115 | for output_layer in self.output_layers: 116 | x = output_layer(x, inputs) 117 | return x 118 | 119 | 120 | class SSL_model(nn.Module): 121 | def __init__(self, ssl_model, ssl_out_dim) -> None: 122 | super(SSL_model, self).__init__() 123 | self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim 124 | 125 | def forward(self, batch): 126 | wav = batch["wav"] 127 | wav = wav.squeeze(1) # [batches, audio_len] 128 | res = self.ssl_model(wav, mask=False, features_only=True) 129 | x = res["x"] 130 | return {"ssl-feature": x} 131 | 132 | def get_output_dim(self): 133 | return self.ssl_out_dim 134 | 135 | 136 | class DomainEmbedding(nn.Module): 137 | def __init__(self, n_domains, domain_dim) -> None: 138 | super().__init__() 139 | self.embedding = nn.Embedding(n_domains, domain_dim) 140 | self.output_dim = domain_dim 141 | 142 | def forward(self, batch): 143 | return {"domain-feature": self.embedding(batch["domains"])} 144 | 145 | def get_output_dim(self): 146 | return self.output_dim 147 | 148 | 149 | class LDConditioner(nn.Module): 150 | """ 151 | Conditions ssl output by listener embedding 152 | """ 153 | 154 | def __init__(self, input_dim, judge_dim, num_judges=None): 155 | super().__init__() 156 | self.input_dim = input_dim 157 | self.judge_dim = judge_dim 158 | self.num_judges = num_judges 159 | assert num_judges != None 160 | self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) 161 | # concat [self.output_layer, phoneme features] 162 | 163 | self.decoder_rnn = nn.LSTM( 164 | input_size=self.input_dim + self.judge_dim, 165 | hidden_size=512, 166 | num_layers=1, 167 | batch_first=True, 168 | bidirectional=True, 169 | ) # linear? 170 | self.out_dim = self.decoder_rnn.hidden_size * 2 171 | 172 | def get_output_dim(self): 173 | return self.out_dim 174 | 175 | def forward(self, x, batch): 176 | judge_ids = batch["judge_id"] 177 | if "phoneme-feature" in x.keys(): 178 | concatenated_feature = torch.cat( 179 | (x["ssl-feature"], x["phoneme-feature"].unsqueeze(1).expand(-1, x["ssl-feature"].size(1), -1)), dim=2 180 | ) 181 | else: 182 | concatenated_feature = x["ssl-feature"] 183 | if "domain-feature" in x.keys(): 184 | concatenated_feature = torch.cat( 185 | (concatenated_feature, x["domain-feature"].unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),), 186 | dim=2, 187 | ) 188 | if judge_ids != None: 189 | concatenated_feature = torch.cat( 190 | ( 191 | concatenated_feature, 192 | self.judge_embedding(judge_ids).unsqueeze(1).expand(-1, concatenated_feature.size(1), -1), 193 | ), 194 | dim=2, 195 | ) 196 | decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) 197 | return decoder_output 198 | 199 | 200 | class Projection(nn.Module): 201 | def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): 202 | super(Projection, self).__init__() 203 | self.range_clipping = range_clipping 204 | output_dim = 1 205 | if range_clipping: 206 | self.proj = nn.Tanh() 207 | 208 | self.net = nn.Sequential( 209 | nn.Linear(input_dim, hidden_dim), activation, nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim), 210 | ) 211 | self.output_dim = output_dim 212 | 213 | def forward(self, x, batch): 214 | output = self.net(x) 215 | 216 | # range clipping 217 | if self.range_clipping: 218 | return self.proj(output) * 2.0 + 3 219 | else: 220 | return output 221 | 222 | def get_output_dim(self): 223 | return self.output_dim 224 | -------------------------------------------------------------------------------- /vocos/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | 8 | class ConvNeXtBlock(nn.Module): 9 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 10 | 11 | Args: 12 | dim (int): Number of input channels. 13 | intermediate_dim (int): Dimensionality of the intermediate layer. 14 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 15 | Defaults to None. 16 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 17 | None means non-conditional LayerNorm. Defaults to None. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dim: int, 23 | intermediate_dim: int, 24 | layer_scale_init_value: float, 25 | adanorm_num_embeddings: Optional[int] = None, 26 | ): 27 | super().__init__() 28 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.adanorm = adanorm_num_embeddings is not None 30 | if adanorm_num_embeddings: 31 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 32 | else: 33 | self.norm = nn.LayerNorm(dim, eps=1e-6) 34 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 35 | self.act = nn.GELU() 36 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 37 | self.gamma = ( 38 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 39 | if layer_scale_init_value > 0 40 | else None 41 | ) 42 | 43 | def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: 44 | residual = x 45 | x = self.dwconv(x) 46 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 47 | if self.adanorm: 48 | assert cond_embedding_id is not None 49 | x = self.norm(x, cond_embedding_id) 50 | else: 51 | x = self.norm(x) 52 | x = self.pwconv1(x) 53 | x = self.act(x) 54 | x = self.pwconv2(x) 55 | if self.gamma is not None: 56 | x = self.gamma * x 57 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 58 | 59 | x = residual + x 60 | return x 61 | 62 | 63 | class AdaLayerNorm(nn.Module): 64 | """ 65 | Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes 66 | 67 | Args: 68 | num_embeddings (int): Number of embeddings. 69 | embedding_dim (int): Dimension of the embeddings. 70 | """ 71 | 72 | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): 73 | super().__init__() 74 | self.eps = eps 75 | self.dim = embedding_dim 76 | self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 77 | self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 78 | torch.nn.init.ones_(self.scale.weight) 79 | torch.nn.init.zeros_(self.shift.weight) 80 | 81 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: 82 | scale = self.scale(cond_embedding_id) 83 | shift = self.shift(cond_embedding_id) 84 | x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) 85 | x = x * scale + shift 86 | return x 87 | 88 | 89 | class ResBlock1(nn.Module): 90 | """ 91 | ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, 92 | but without upsampling layers. 93 | 94 | Args: 95 | dim (int): Number of input channels. 96 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 97 | dilation (tuple[int], optional): Dilation factors for the dilated convolutions. 98 | Defaults to (1, 3, 5). 99 | lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. 100 | Defaults to 0.1. 101 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 102 | Defaults to None. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | dim: int, 108 | kernel_size: int = 3, 109 | dilation: Tuple[int, int, int] = (1, 3, 5), 110 | lrelu_slope: float = 0.1, 111 | layer_scale_init_value: Optional[float] = None, 112 | ): 113 | super().__init__() 114 | self.lrelu_slope = lrelu_slope 115 | self.convs1 = nn.ModuleList( 116 | [ 117 | weight_norm( 118 | nn.Conv1d( 119 | dim, 120 | dim, 121 | kernel_size, 122 | 1, 123 | dilation=dilation[0], 124 | padding=self.get_padding(kernel_size, dilation[0]), 125 | ) 126 | ), 127 | weight_norm( 128 | nn.Conv1d( 129 | dim, 130 | dim, 131 | kernel_size, 132 | 1, 133 | dilation=dilation[1], 134 | padding=self.get_padding(kernel_size, dilation[1]), 135 | ) 136 | ), 137 | weight_norm( 138 | nn.Conv1d( 139 | dim, 140 | dim, 141 | kernel_size, 142 | 1, 143 | dilation=dilation[2], 144 | padding=self.get_padding(kernel_size, dilation[2]), 145 | ) 146 | ), 147 | ] 148 | ) 149 | 150 | self.convs2 = nn.ModuleList( 151 | [ 152 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 153 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 154 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 155 | ] 156 | ) 157 | 158 | self.gamma = nn.ParameterList( 159 | [ 160 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 161 | if layer_scale_init_value is not None 162 | else None, 163 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 164 | if layer_scale_init_value is not None 165 | else None, 166 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 167 | if layer_scale_init_value is not None 168 | else None, 169 | ] 170 | ) 171 | 172 | def forward(self, x: torch.Tensor) -> torch.Tensor: 173 | for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): 174 | xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) 175 | xt = c1(xt) 176 | xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) 177 | xt = c2(xt) 178 | if gamma is not None: 179 | xt = gamma * xt 180 | x = xt + x 181 | return x 182 | 183 | def remove_weight_norm(self): 184 | for l in self.convs1: 185 | remove_weight_norm(l) 186 | for l in self.convs2: 187 | remove_weight_norm(l) 188 | 189 | @staticmethod 190 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 191 | return int((kernel_size * dilation - dilation) / 2) 192 | 193 | 194 | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: 195 | """ 196 | Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. 197 | 198 | Args: 199 | x (Tensor): Input tensor. 200 | clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. 201 | 202 | Returns: 203 | Tensor: Element-wise logarithm of the input tensor with clipping applied. 204 | """ 205 | return torch.log(torch.clip(x, min=float(clip_val))) 206 | 207 | 208 | def symlog(x: torch.Tensor) -> torch.Tensor: 209 | return torch.sign(x) * torch.log1p(x.abs()) 210 | 211 | 212 | def symexp(x: torch.Tensor) -> torch.Tensor: 213 | return torch.sign(x) * (torch.exp(x.abs()) - 1) 214 | -------------------------------------------------------------------------------- /notebooks/Bark+Vocos.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "private_outputs": true, 7 | "provenance": [], 8 | "gpuType": "T4", 9 | "authorship_tag": "ABX9TyMC53IsYoVJIVijVzw3ADvX", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "source": [ 35 | "# Text-to-Audio Synthesis using Bark and Vocos" 36 | ], 37 | "metadata": { 38 | "id": "NuRzVtHDZ_Gl" 39 | } 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "source": [ 44 | "In this notebook, we use [Bark](https://github.com/suno-ai/bark) generative model to turn a text prompt into EnCodec audio tokens. These tokens then go through two decoders, EnCodec and Vocos, to reconstruct the audio waveform. Compare the results to discover the differences in audio quality and characteristics." 45 | ], 46 | "metadata": { 47 | "id": "zJFDte0daDAz" 48 | } 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "source": [ 53 | "Make sure you have Bark and Vocos installed:" 54 | ], 55 | "metadata": { 56 | "id": "c9omqGDYnajY" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "!pip install git+https://github.com/suno-ai/bark.git\n", 63 | "!pip install vocos" 64 | ], 65 | "metadata": { 66 | "id": "voH44g90NvtV" 67 | }, 68 | "execution_count": null, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "source": [ 74 | "Download and load Bark models" 75 | ], 76 | "metadata": { 77 | "id": "s3cEjOIuj6tq" 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "source": [ 83 | "from bark import preload_models\n", 84 | "\n", 85 | "preload_models()" 86 | ], 87 | "metadata": { 88 | "id": "1H7XtXRMjxUM" 89 | }, 90 | "execution_count": null, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "source": [ 96 | "Download and load Vocos." 97 | ], 98 | "metadata": { 99 | "id": "YO1m0dJ1j-F5" 100 | } 101 | }, 102 | { 103 | "cell_type": "code", 104 | "source": [ 105 | "from vocos import Vocos\n", 106 | "import torch\n", 107 | "\n", 108 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 109 | "vocos = Vocos.from_pretrained(\"charactr/vocos-encodec-24khz\").to(device)" 110 | ], 111 | "metadata": { 112 | "id": "COQYTDDFkBCq" 113 | }, 114 | "execution_count": null, 115 | "outputs": [] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "source": [ 120 | "We are going to reuse `text_to_semantic` from Bark API, but to reconstruct audio waveform with a custom vododer, we need to slightly redefine the API to return `fine_tokens`." 121 | ], 122 | "metadata": { 123 | "id": "--RjqW0rk5JQ" 124 | } 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "id": "OiUsuN2DNl5S" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "from typing import Optional, Union, Dict\n", 135 | "\n", 136 | "import numpy as np\n", 137 | "from bark.generation import generate_coarse, generate_fine\n", 138 | "\n", 139 | "\n", 140 | "def semantic_to_audio_tokens(\n", 141 | " semantic_tokens: np.ndarray,\n", 142 | " history_prompt: Optional[Union[Dict, str]] = None,\n", 143 | " temp: float = 0.7,\n", 144 | " silent: bool = False,\n", 145 | " output_full: bool = False,\n", 146 | "):\n", 147 | " coarse_tokens = generate_coarse(\n", 148 | " semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True\n", 149 | " )\n", 150 | " fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)\n", 151 | "\n", 152 | " if output_full:\n", 153 | " full_generation = {\n", 154 | " \"semantic_prompt\": semantic_tokens,\n", 155 | " \"coarse_prompt\": coarse_tokens,\n", 156 | " \"fine_prompt\": fine_tokens,\n", 157 | " }\n", 158 | " return full_generation\n", 159 | " return fine_tokens" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "source": [ 165 | "Let's create a text prompt and generate audio tokens:" 166 | ], 167 | "metadata": { 168 | "id": "Cv8KCzXlmoF9" 169 | } 170 | }, 171 | { 172 | "cell_type": "code", 173 | "source": [ 174 | "from bark import text_to_semantic\n", 175 | "\n", 176 | "history_prompt = None\n", 177 | "text_prompt = \"So, you've heard about neural vocoding? [laughs] We've been messing around with this new model called Vocos.\"\n", 178 | "semantic_tokens = text_to_semantic(text_prompt, history_prompt=history_prompt, temp=0.7, silent=False,)\n", 179 | "audio_tokens = semantic_to_audio_tokens(\n", 180 | " semantic_tokens, history_prompt=history_prompt, temp=0.7, silent=False, output_full=False,\n", 181 | ")" 182 | ], 183 | "metadata": { 184 | "id": "pDmSTutoOH_G" 185 | }, 186 | "execution_count": null, 187 | "outputs": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "source": [ 192 | "Reconstruct audio waveform with EnCodec:" 193 | ], 194 | "metadata": { 195 | "id": "UYMzI8svTNqI" 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "source": [ 201 | "from bark.generation import codec_decode\n", 202 | "from IPython.display import Audio\n", 203 | "\n", 204 | "encodec_output = codec_decode(audio_tokens)\n", 205 | "\n", 206 | "import torchaudio\n", 207 | "# Upsample to 44100 Hz for better reproduction on audio hardware\n", 208 | "encodec_output = torchaudio.functional.resample(torch.from_numpy(encodec_output), orig_freq=24000, new_freq=44100)\n", 209 | "Audio(encodec_output, rate=44100)" 210 | ], 211 | "metadata": { 212 | "id": "PzdytlXFTNQ2" 213 | }, 214 | "execution_count": null, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "source": [ 220 | "Reconstruct with Vocos:" 221 | ], 222 | "metadata": { 223 | "id": "BhUxBuP9TTTw" 224 | } 225 | }, 226 | { 227 | "cell_type": "code", 228 | "source": [ 229 | "audio_tokens_torch = torch.from_numpy(audio_tokens).to(device)\n", 230 | "features = vocos.codes_to_features(audio_tokens_torch)\n", 231 | "vocos_output = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) # 6 kbps\n", 232 | "# Upsample to 44100 Hz for better reproduction on audio hardware\n", 233 | "vocos_output = torchaudio.functional.resample(vocos_output, orig_freq=24000, new_freq=44100).cpu()\n", 234 | "Audio(vocos_output.numpy(), rate=44100)" 235 | ], 236 | "metadata": { 237 | "id": "8hzSWQ5-nBlV" 238 | }, 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "source": [ 245 | "Optionally save to mp3 files:" 246 | ], 247 | "metadata": { 248 | "id": "RjVXQIZRb1Re" 249 | } 250 | }, 251 | { 252 | "cell_type": "code", 253 | "source": [ 254 | "torchaudio.save(\"encodec.mp3\", encodec_output[None, :], 44100, compression=128)\n", 255 | "torchaudio.save(\"vocos.mp3\", vocos_output, 44100, compression=128)" 256 | ], 257 | "metadata": { 258 | "id": "PLFXpjUKb3WX" 259 | }, 260 | "execution_count": null, 261 | "outputs": [] 262 | } 263 | ] 264 | } -------------------------------------------------------------------------------- /vocos/spectral_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import torch 4 | from torch import nn, view_as_real, view_as_complex 5 | 6 | 7 | class ISTFT(nn.Module): 8 | """ 9 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with 10 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. 11 | See issue: https://github.com/pytorch/pytorch/issues/62323 12 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. 13 | The NOLA constraint is met as we trim padded samples anyway. 14 | 15 | Args: 16 | n_fft (int): Size of Fourier transform. 17 | hop_length (int): The distance between neighboring sliding window frames. 18 | win_length (int): The size of window frame and STFT filter. 19 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 20 | """ 21 | 22 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): 23 | super().__init__() 24 | if padding not in ["center", "same"]: 25 | raise ValueError("Padding must be 'center' or 'same'.") 26 | self.padding = padding 27 | self.n_fft = n_fft 28 | self.hop_length = hop_length 29 | self.win_length = win_length 30 | window = torch.hann_window(win_length) 31 | self.register_buffer("window", window) 32 | 33 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. 36 | 37 | Args: 38 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, 39 | N is the number of frequency bins, and T is the number of time frames. 40 | 41 | Returns: 42 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. 43 | """ 44 | if self.padding == "center": 45 | # Fallback to pytorch native implementation 46 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) 47 | elif self.padding == "same": 48 | pad = (self.win_length - self.hop_length) // 2 49 | else: 50 | raise ValueError("Padding must be 'center' or 'same'.") 51 | 52 | assert spec.dim() == 3, "Expected a 3D tensor as input" 53 | B, N, T = spec.shape 54 | 55 | # Inverse FFT 56 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") 57 | ifft = ifft * self.window[None, :, None] 58 | 59 | # Overlap and Add 60 | output_size = (T - 1) * self.hop_length + self.win_length 61 | y = torch.nn.functional.fold( 62 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 63 | )[:, 0, 0, pad:-pad] 64 | 65 | # Window envelope 66 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 67 | window_envelope = torch.nn.functional.fold( 68 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 69 | ).squeeze()[pad:-pad] 70 | 71 | # Normalize 72 | assert (window_envelope > 1e-11).all() 73 | y = y / window_envelope 74 | 75 | return y 76 | 77 | 78 | class MDCT(nn.Module): 79 | """ 80 | Modified Discrete Cosine Transform (MDCT) module. 81 | 82 | Args: 83 | frame_len (int): Length of the MDCT frame. 84 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 85 | """ 86 | 87 | def __init__(self, frame_len: int, padding: str = "same"): 88 | super().__init__() 89 | if padding not in ["center", "same"]: 90 | raise ValueError("Padding must be 'center' or 'same'.") 91 | self.padding = padding 92 | self.frame_len = frame_len 93 | N = frame_len // 2 94 | n0 = (N + 1) / 2 95 | window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() 96 | self.register_buffer("window", window) 97 | 98 | pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) 99 | post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) 100 | # view_as_real: NCCL Backend does not support ComplexFloat data type 101 | # https://github.com/pytorch/pytorch/issues/71613 102 | self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) 103 | self.register_buffer("post_twiddle", view_as_real(post_twiddle)) 104 | 105 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 106 | """ 107 | Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. 108 | 109 | Args: 110 | audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size 111 | and T is the length of the audio. 112 | 113 | Returns: 114 | Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames 115 | and N is the number of frequency bins. 116 | """ 117 | if self.padding == "center": 118 | audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) 119 | elif self.padding == "same": 120 | # hop_length is 1/2 frame_len 121 | audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) 122 | else: 123 | raise ValueError("Padding must be 'center' or 'same'.") 124 | 125 | x = audio.unfold(-1, self.frame_len, self.frame_len // 2) 126 | N = self.frame_len // 2 127 | x = x * self.window.expand(x.shape) 128 | X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] 129 | res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) 130 | return torch.real(res) * np.sqrt(2) 131 | 132 | 133 | class IMDCT(nn.Module): 134 | """ 135 | Inverse Modified Discrete Cosine Transform (IMDCT) module. 136 | 137 | Args: 138 | frame_len (int): Length of the MDCT frame. 139 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 140 | """ 141 | 142 | def __init__(self, frame_len: int, padding: str = "same"): 143 | super().__init__() 144 | if padding not in ["center", "same"]: 145 | raise ValueError("Padding must be 'center' or 'same'.") 146 | self.padding = padding 147 | self.frame_len = frame_len 148 | N = frame_len // 2 149 | n0 = (N + 1) / 2 150 | window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() 151 | self.register_buffer("window", window) 152 | 153 | pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) 154 | post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) 155 | self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) 156 | self.register_buffer("post_twiddle", view_as_real(post_twiddle)) 157 | 158 | def forward(self, X: torch.Tensor) -> torch.Tensor: 159 | """ 160 | Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. 161 | 162 | Args: 163 | X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, 164 | L is the number of frames, and N is the number of frequency bins. 165 | 166 | Returns: 167 | Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. 168 | """ 169 | B, L, N = X.shape 170 | Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) 171 | Y[..., :N] = X 172 | Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) 173 | y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) 174 | y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) 175 | result = y * self.window.expand(y.shape) 176 | output_size = (1, (L + 1) * N) 177 | audio = torch.nn.functional.fold( 178 | result.transpose(1, 2), 179 | output_size=output_size, 180 | kernel_size=(1, self.frame_len), 181 | stride=(1, self.frame_len // 2), 182 | )[:, 0, 0, :] 183 | 184 | if self.padding == "center": 185 | pad = self.frame_len // 2 186 | elif self.padding == "same": 187 | pad = self.frame_len // 4 188 | else: 189 | raise ValueError("Padding must be 'center' or 'same'.") 190 | 191 | audio = audio[:, pad:-pad] 192 | return audio 193 | -------------------------------------------------------------------------------- /vocos/heads.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz 6 | 7 | from vocos.spectral_ops import IMDCT, ISTFT 8 | from vocos.modules import symexp 9 | 10 | 11 | class FourierHead(nn.Module): 12 | """Base class for inverse fourier modules.""" 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Args: 17 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 18 | L is the sequence length, and H denotes the model dimension. 19 | 20 | Returns: 21 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 22 | """ 23 | raise NotImplementedError("Subclasses must implement the forward method.") 24 | 25 | 26 | class ISTFTHead(FourierHead): 27 | """ 28 | ISTFT Head module for predicting STFT complex coefficients. 29 | 30 | Args: 31 | dim (int): Hidden dimension of the model. 32 | n_fft (int): Size of Fourier transform. 33 | hop_length (int): The distance between neighboring sliding window frames, which should align with 34 | the resolution of the input features. 35 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 36 | """ 37 | 38 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 39 | super().__init__() 40 | out_dim = n_fft + 2 41 | self.out = torch.nn.Linear(dim, out_dim) 42 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Forward pass of the ISTFTHead module. 47 | 48 | Args: 49 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 50 | L is the sequence length, and H denotes the model dimension. 51 | 52 | Returns: 53 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 54 | """ 55 | x = self.out(x).transpose(1, 2) 56 | mag, p = x.chunk(2, dim=1) 57 | mag = torch.exp(mag) 58 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes 59 | # wrapping happens here. These two lines produce real and imaginary value 60 | x = torch.cos(p) 61 | y = torch.sin(p) 62 | # recalculating phase here does not produce anything new 63 | # only costs time 64 | # phase = torch.atan2(y, x) 65 | # S = mag * torch.exp(phase * 1j) 66 | # better directly produce the complex value 67 | S = mag * (x + 1j * y) 68 | audio = self.istft(S) 69 | return audio 70 | 71 | class WaveNextHead(FourierHead): 72 | """ 73 | WaveNext Head module for predicting waveform samples. 74 | 75 | Args: 76 | dim (int): Hidden dimension of the model. 77 | n_fft (int): Size of Fourier transform. 78 | hop_length (int): The distance between neighboring sliding window frames, which should align with 79 | the resolution of the input features. 80 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 81 | """ 82 | 83 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 84 | super().__init__() 85 | l_fft = n_fft + 2 86 | l_shift = hop_length 87 | self.linear_1 = torch.nn.Linear(dim, l_fft) 88 | self.linear_2 = torch.nn.Linear(l_fft, l_shift, bias=False) 89 | 90 | # W init 91 | nn.init.trunc_normal_(self.linear_1.weight, std=0.02) 92 | nn.init.trunc_normal_(self.linear_2.weight, std=0.02) 93 | 94 | #self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 95 | 96 | def forward(self, x: torch.Tensor) -> torch.Tensor: 97 | """ 98 | Forward pass of the WaveNextHead module . 99 | 100 | Args: 101 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 102 | L is the sequence length, and H denotes the model dimension. 103 | 104 | Returns: 105 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 106 | """ 107 | B, C , T = x.shape 108 | x = self.linear_1(x) 109 | x = self.linear_2(x) 110 | audio = x.view(B,-1) # / 100 111 | #print("max amplitude: ", audio.max().item()) 112 | audio = torch.clip(audio, min=-1.0, max=1.0) 113 | return audio 114 | 115 | 116 | 117 | class IMDCTSymExpHead(FourierHead): 118 | """ 119 | IMDCT Head module for predicting MDCT coefficients with symmetric exponential function 120 | 121 | Args: 122 | dim (int): Hidden dimension of the model. 123 | mdct_frame_len (int): Length of the MDCT frame. 124 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 125 | sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized 126 | based on perceptual scaling. Defaults to None. 127 | clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. 128 | """ 129 | 130 | def __init__( 131 | self, 132 | dim: int, 133 | mdct_frame_len: int, 134 | padding: str = "same", 135 | sample_rate: Optional[int] = None, 136 | clip_audio: bool = False, 137 | ): 138 | super().__init__() 139 | out_dim = mdct_frame_len // 2 140 | self.out = nn.Linear(dim, out_dim) 141 | self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) 142 | self.clip_audio = clip_audio 143 | 144 | if sample_rate is not None: 145 | # optionally init the last layer following mel-scale 146 | m_max = _hz_to_mel(sample_rate // 2) 147 | m_pts = torch.linspace(0, m_max, out_dim) 148 | f_pts = _mel_to_hz(m_pts) 149 | scale = 1 - (f_pts / f_pts.max()) 150 | 151 | with torch.no_grad(): 152 | self.out.weight.mul_(scale.view(-1, 1)) 153 | 154 | def forward(self, x: torch.Tensor) -> torch.Tensor: 155 | """ 156 | Forward pass of the IMDCTSymExpHead module. 157 | 158 | Args: 159 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 160 | L is the sequence length, and H denotes the model dimension. 161 | 162 | Returns: 163 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 164 | """ 165 | x = self.out(x) 166 | x = symexp(x) 167 | x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes 168 | audio = self.imdct(x) 169 | if self.clip_audio: 170 | audio = torch.clip(x, min=-1.0, max=1.0) 171 | 172 | return audio 173 | 174 | 175 | class IMDCTCosHead(FourierHead): 176 | """ 177 | IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) 178 | 179 | Args: 180 | dim (int): Hidden dimension of the model. 181 | mdct_frame_len (int): Length of the MDCT frame. 182 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 183 | clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. 184 | """ 185 | 186 | def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): 187 | super().__init__() 188 | self.clip_audio = clip_audio 189 | self.out = nn.Linear(dim, mdct_frame_len) 190 | self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) 191 | 192 | def forward(self, x: torch.Tensor) -> torch.Tensor: 193 | """ 194 | Forward pass of the IMDCTCosHead module. 195 | 196 | Args: 197 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 198 | L is the sequence length, and H denotes the model dimension. 199 | 200 | Returns: 201 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 202 | """ 203 | x = self.out(x) 204 | m, p = x.chunk(2, dim=2) 205 | m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes 206 | audio = self.imdct(m * torch.cos(p)) 207 | if self.clip_audio: 208 | audio = torch.clip(x, min=-1.0, max=1.0) 209 | return audio 210 | -------------------------------------------------------------------------------- /vocos/discriminators.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import nn 6 | from torch.nn import Conv2d 7 | from torch.nn.utils import weight_norm 8 | from torchaudio.transforms import Spectrogram 9 | 10 | 11 | class MultiPeriodDiscriminator(nn.Module): 12 | """ 13 | Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. 14 | Additionally, it allows incorporating conditional information with a learned embeddings table. 15 | 16 | Args: 17 | periods (tuple[int]): Tuple of periods for each discriminator. 18 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 19 | Defaults to None. 20 | """ 21 | 22 | def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None): 23 | super().__init__() 24 | self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) 25 | 26 | def forward( 27 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None 28 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 29 | y_d_rs = [] 30 | y_d_gs = [] 31 | fmap_rs = [] 32 | fmap_gs = [] 33 | for d in self.discriminators: 34 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 35 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 36 | y_d_rs.append(y_d_r) 37 | fmap_rs.append(fmap_r) 38 | y_d_gs.append(y_d_g) 39 | fmap_gs.append(fmap_g) 40 | 41 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 42 | 43 | 44 | class DiscriminatorP(nn.Module): 45 | def __init__( 46 | self, 47 | period: int, 48 | in_channels: int = 1, 49 | kernel_size: int = 5, 50 | stride: int = 3, 51 | lrelu_slope: float = 0.1, 52 | num_embeddings: Optional[int] = None, 53 | ): 54 | super().__init__() 55 | self.period = period 56 | self.convs = nn.ModuleList( 57 | [ 58 | weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 59 | weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 60 | weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 61 | weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 62 | weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), 63 | ] 64 | ) 65 | if num_embeddings is not None: 66 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) 67 | torch.nn.init.zeros_(self.emb.weight) 68 | 69 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 70 | self.lrelu_slope = lrelu_slope 71 | 72 | def forward( 73 | self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None 74 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 75 | x = x.unsqueeze(1) 76 | fmap = [] 77 | # 1d to 2d 78 | b, c, t = x.shape 79 | if t % self.period != 0: # pad first 80 | n_pad = self.period - (t % self.period) 81 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect") 82 | t = t + n_pad 83 | x = x.view(b, c, t // self.period, self.period) 84 | 85 | for i, l in enumerate(self.convs): 86 | x = l(x) 87 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 88 | if i > 0: 89 | fmap.append(x) 90 | if cond_embedding_id is not None: 91 | emb = self.emb(cond_embedding_id) 92 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 93 | else: 94 | h = 0 95 | x = self.conv_post(x) 96 | fmap.append(x) 97 | x += h 98 | x = torch.flatten(x, 1, -1) 99 | 100 | return x, fmap 101 | 102 | 103 | class MultiResolutionDiscriminator(nn.Module): 104 | def __init__( 105 | self, 106 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512), 107 | num_embeddings: Optional[int] = None, 108 | ): 109 | """ 110 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. 111 | Additionally, it allows incorporating conditional information with a learned embeddings table. 112 | 113 | Args: 114 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). 115 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 116 | Defaults to None. 117 | """ 118 | 119 | super().__init__() 120 | self.discriminators = nn.ModuleList( 121 | [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] 122 | ) 123 | 124 | def forward( 125 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 126 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 127 | y_d_rs = [] 128 | y_d_gs = [] 129 | fmap_rs = [] 130 | fmap_gs = [] 131 | 132 | for d in self.discriminators: 133 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 134 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 135 | y_d_rs.append(y_d_r) 136 | fmap_rs.append(fmap_r) 137 | y_d_gs.append(y_d_g) 138 | fmap_gs.append(fmap_g) 139 | 140 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 141 | 142 | 143 | class DiscriminatorR(nn.Module): 144 | def __init__( 145 | self, 146 | window_length: int, 147 | num_embeddings: Optional[int] = None, 148 | channels: int = 32, 149 | hop_factor: float = 0.25, 150 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), 151 | ): 152 | super().__init__() 153 | self.window_length = window_length 154 | self.hop_factor = hop_factor 155 | self.spec_fn = Spectrogram( 156 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None 157 | ) 158 | n_fft = window_length // 2 + 1 159 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 160 | self.bands = bands 161 | convs = lambda: nn.ModuleList( 162 | [ 163 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), 164 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 165 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 166 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 167 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), 168 | ] 169 | ) 170 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 171 | 172 | if num_embeddings is not None: 173 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 174 | torch.nn.init.zeros_(self.emb.weight) 175 | 176 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) 177 | 178 | def spectrogram(self, x): 179 | # Remove DC offset 180 | x = x - x.mean(dim=-1, keepdims=True) 181 | # Peak normalize the volume of input audio 182 | x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 183 | x = self.spec_fn(x) 184 | x = torch.view_as_real(x) 185 | x = rearrange(x, "b f t c -> b c t f") 186 | # Split into bands 187 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 188 | return x_bands 189 | 190 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 191 | x_bands = self.spectrogram(x) 192 | fmap = [] 193 | x = [] 194 | for band, stack in zip(x_bands, self.band_convs): 195 | for i, layer in enumerate(stack): 196 | band = layer(band) 197 | band = torch.nn.functional.leaky_relu(band, 0.1) 198 | if i > 0: 199 | fmap.append(band) 200 | x.append(band) 201 | x = torch.cat(x, dim=-1) 202 | if cond_embedding_id is not None: 203 | emb = self.emb(cond_embedding_id) 204 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 205 | else: 206 | h = 0 207 | x = self.conv_post(x) 208 | fmap.append(x) 209 | x += h 210 | 211 | return x, fmap 212 | -------------------------------------------------------------------------------- /vocos/experiment.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchaudio 7 | import transformers 8 | 9 | from vocos.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator 10 | from vocos.feature_extractors import FeatureExtractor 11 | from vocos.heads import FourierHead 12 | from vocos.helpers import plot_spectrogram_to_numpy 13 | from vocos.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss 14 | from vocos.models import Backbone 15 | from vocos.modules import safe_log 16 | 17 | 18 | class VocosExp(pl.LightningModule): 19 | # noinspection PyUnusedLocal 20 | def __init__( 21 | self, 22 | feature_extractor: FeatureExtractor, 23 | backbone: Backbone, 24 | head: FourierHead, 25 | melspec_loss: MelSpecReconstructionLoss, 26 | sample_rate: int, 27 | initial_learning_rate: float, 28 | num_warmup_steps: int = 0, 29 | mel_loss_coeff: float = 45, 30 | mrd_loss_coeff: float = 1.0, 31 | pretrain_mel_steps: int = 0, 32 | decay_mel_coeff: bool = False, 33 | evaluate_utmos: bool = False, 34 | evaluate_pesq: bool = False, 35 | evaluate_periodicty: bool = False, 36 | ): 37 | """ 38 | Args: 39 | feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals. 40 | backbone (Backbone): An instance of Backbone model. 41 | head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform. 42 | sample_rate (int): Sampling rate of the audio signals. 43 | initial_learning_rate (float): Initial learning rate for the optimizer. 44 | num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0. 45 | mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45. 46 | mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0. 47 | pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0. 48 | decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False. 49 | evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run. 50 | evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run. 51 | evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run. 52 | """ 53 | super().__init__() 54 | self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"]) 55 | 56 | self.feature_extractor = feature_extractor 57 | self.backbone = backbone 58 | self.head = head 59 | 60 | self.multiperioddisc = MultiPeriodDiscriminator() 61 | self.multiresddisc = MultiResolutionDiscriminator() 62 | 63 | self.disc_loss = DiscriminatorLoss() 64 | self.gen_loss = GeneratorLoss() 65 | self.feat_matching_loss = FeatureMatchingLoss() 66 | self.melspec_loss = melspec_loss 67 | 68 | self.train_discriminator = False 69 | self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff 70 | 71 | def configure_optimizers(self): 72 | disc_params = [ 73 | {"params": self.multiperioddisc.parameters()}, 74 | {"params": self.multiresddisc.parameters()}, 75 | ] 76 | gen_params = [ 77 | {"params": self.feature_extractor.parameters()}, 78 | {"params": self.backbone.parameters()}, 79 | {"params": self.head.parameters()}, 80 | ] 81 | 82 | opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate , betas=(0.8, 0.9)) 83 | opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate , betas=(0.8, 0.9)) 84 | 85 | max_steps = self.trainer.max_steps // 2 # Max steps per optimizer 86 | scheduler_disc = transformers.get_cosine_schedule_with_warmup( 87 | opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, 88 | ) 89 | scheduler_gen = transformers.get_cosine_schedule_with_warmup( 90 | opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, 91 | ) 92 | 93 | return ( 94 | [opt_disc, opt_gen], 95 | [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], 96 | ) 97 | 98 | def forward(self, audio_input, **kwargs): 99 | features = self.feature_extractor(audio_input, **kwargs) 100 | x = self.backbone(features, **kwargs) 101 | audio_output = self.head(x) 102 | return audio_output 103 | 104 | def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): 105 | audio_input = batch 106 | 107 | # train discriminator 108 | if optimizer_idx == 0 and self.train_discriminator: #and self.global_step % 5 == 0: TTUR 109 | with torch.no_grad(): 110 | audio_hat = self(audio_input, **kwargs) 111 | 112 | real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,) 113 | real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,) 114 | loss_mp, loss_mp_real, _ = self.disc_loss( 115 | disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp 116 | ) 117 | loss_mrd, loss_mrd_real, _ = self.disc_loss( 118 | disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd 119 | ) 120 | loss_mp /= len(loss_mp_real) 121 | loss_mrd /= len(loss_mrd_real) 122 | loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd 123 | 124 | self.log("discriminator/total", loss, prog_bar=True) 125 | self.log("discriminator/multi_period_loss", loss_mp) 126 | self.log("discriminator/multi_res_loss", loss_mrd) 127 | return loss 128 | 129 | # train generator 130 | if optimizer_idx == 1: 131 | audio_hat = self(audio_input, **kwargs) 132 | if self.train_discriminator: 133 | _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( 134 | y=audio_input, y_hat=audio_hat, **kwargs, 135 | ) 136 | _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( 137 | y=audio_input, y_hat=audio_hat, **kwargs, 138 | ) 139 | loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) 140 | loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) 141 | loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) 142 | loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) 143 | loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) 144 | loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) 145 | 146 | self.log("generator/multi_period_loss", loss_gen_mp) 147 | self.log("generator/multi_res_loss", loss_gen_mrd) 148 | self.log("generator/feature_matching_mp", loss_fm_mp) 149 | self.log("generator/feature_matching_mrd", loss_fm_mrd) 150 | else: 151 | loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 152 | 153 | mel_loss = self.melspec_loss(audio_hat, audio_input) 154 | loss = ( 155 | loss_gen_mp 156 | + self.hparams.mrd_loss_coeff * loss_gen_mrd 157 | + loss_fm_mp 158 | + self.hparams.mrd_loss_coeff * loss_fm_mrd 159 | + self.mel_loss_coeff * mel_loss 160 | ) 161 | 162 | self.log("generator/total_loss", loss, prog_bar=True) 163 | self.log("mel_loss_coeff", self.mel_loss_coeff) 164 | self.log("generator/mel_loss", mel_loss) 165 | 166 | if self.global_step % 1000 == 0 and self.global_rank == 0: 167 | self.logger.experiment.add_audio( 168 | "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate 169 | ) 170 | self.logger.experiment.add_audio( 171 | "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate 172 | ) 173 | with torch.no_grad(): 174 | mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) 175 | mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0])) 176 | self.logger.experiment.add_image( 177 | "train/mel_target", 178 | plot_spectrogram_to_numpy(mel.data.cpu().numpy()), 179 | self.global_step, 180 | dataformats="HWC", 181 | ) 182 | self.logger.experiment.add_image( 183 | "train/mel_pred", 184 | plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), 185 | self.global_step, 186 | dataformats="HWC", 187 | ) 188 | 189 | return loss 190 | 191 | def on_validation_epoch_start(self): 192 | if self.hparams.evaluate_utmos: 193 | from metrics.UTMOS import UTMOSScore 194 | 195 | if not hasattr(self, "utmos_model"): 196 | self.utmos_model = UTMOSScore(device=self.device) 197 | 198 | def validation_step(self, batch, batch_idx, **kwargs): 199 | audio_input = batch 200 | audio_hat = self(audio_input, **kwargs) 201 | 202 | audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) 203 | audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) 204 | 205 | if self.hparams.evaluate_periodicty: 206 | from metrics.periodicity import calculate_periodicity_metrics 207 | 208 | periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) 209 | else: 210 | periodicity_loss = pitch_loss = f1_score = 0 211 | 212 | if self.hparams.evaluate_utmos: 213 | utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() 214 | else: 215 | utmos_score = torch.zeros(1, device=self.device) 216 | 217 | if self.hparams.evaluate_pesq: 218 | from pesq import pesq 219 | 220 | pesq_score = 0 221 | for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()): 222 | pesq_score += pesq(16000, ref, deg, "wb", on_error=1) 223 | pesq_score /= len(audio_16_khz) 224 | pesq_score = torch.tensor(pesq_score) 225 | else: 226 | pesq_score = torch.zeros(1, device=self.device) 227 | 228 | mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) 229 | total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) 230 | 231 | return { 232 | "val_loss": total_loss, 233 | "mel_loss": mel_loss, 234 | "utmos_score": utmos_score, 235 | "pesq_score": pesq_score, 236 | "periodicity_loss": periodicity_loss, 237 | "pitch_loss": pitch_loss, 238 | "f1_score": f1_score, 239 | "audio_input": audio_input[0], 240 | "audio_pred": audio_hat[0], 241 | } 242 | 243 | def validation_epoch_end(self, outputs): 244 | if self.global_rank == 0: 245 | *_, audio_in, audio_pred = outputs[0].values() 246 | self.logger.experiment.add_audio( 247 | "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate 248 | ) 249 | self.logger.experiment.add_audio( 250 | "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate 251 | ) 252 | mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) 253 | mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) 254 | self.logger.experiment.add_image( 255 | "val_mel_target", 256 | plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()), 257 | self.global_step, 258 | dataformats="HWC", 259 | ) 260 | self.logger.experiment.add_image( 261 | "val_mel_hat", 262 | plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), 263 | self.global_step, 264 | dataformats="HWC", 265 | ) 266 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 267 | mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean() 268 | utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean() 269 | pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean() 270 | periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean() 271 | pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean() 272 | f1_score = np.array([x["f1_score"] for x in outputs]).mean() 273 | 274 | self.log("val_loss", avg_loss, sync_dist=True) 275 | self.log("val/mel_loss", mel_loss, sync_dist=True) 276 | self.log("val/utmos_score", utmos_score, sync_dist=True) 277 | self.log("val/pesq_score", pesq_score, sync_dist=True) 278 | self.log("val/periodicity_loss", periodicity_loss, sync_dist=True) 279 | self.log("val/pitch_loss", pitch_loss, sync_dist=True) 280 | self.log("val/f1_score", f1_score, sync_dist=True) 281 | 282 | @property 283 | def global_step(self): 284 | """ 285 | Override global_step so that it returns the total number of batches processed 286 | """ 287 | return self.trainer.fit_loop.epoch_loop.total_batch_idx 288 | 289 | def on_train_batch_start(self, *args): 290 | if self.global_step >= self.hparams.pretrain_mel_steps: 291 | self.train_discriminator = True 292 | else: 293 | self.train_discriminator = False 294 | 295 | def on_train_batch_end(self, *args): 296 | def mel_loss_coeff_decay(current_step, num_cycles=0.5): 297 | max_steps = self.trainer.max_steps // 2 298 | if current_step < self.hparams.num_warmup_steps: 299 | return 1.0 300 | progress = float(current_step - self.hparams.num_warmup_steps) / float( 301 | max(1, max_steps - self.hparams.num_warmup_steps) 302 | ) 303 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 304 | 305 | if self.hparams.decay_mel_coeff: 306 | self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) 307 | 308 | 309 | class VocosEncodecExp(VocosExp): 310 | """ 311 | VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN. 312 | It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to 313 | a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step, 314 | while during validation, a fixed bandwidth_id is used. 315 | """ 316 | 317 | def __init__( 318 | self, 319 | feature_extractor: FeatureExtractor, 320 | backbone: Backbone, 321 | head: FourierHead, 322 | melspec_loss: MelSpecReconstructionLoss, 323 | sample_rate: int, 324 | initial_learning_rate: float, 325 | num_warmup_steps: int, 326 | mel_loss_coeff: float = 45, 327 | mrd_loss_coeff: float = 1.0, 328 | pretrain_mel_steps: int = 0, 329 | decay_mel_coeff: bool = False, 330 | evaluate_utmos: bool = False, 331 | evaluate_pesq: bool = False, 332 | evaluate_periodicty: bool = False, 333 | ): 334 | super().__init__( 335 | feature_extractor, 336 | backbone, 337 | head, 338 | melspec_loss, 339 | sample_rate, 340 | initial_learning_rate, 341 | num_warmup_steps, 342 | mel_loss_coeff, 343 | mrd_loss_coeff, 344 | pretrain_mel_steps, 345 | decay_mel_coeff, 346 | evaluate_utmos, 347 | evaluate_pesq, 348 | evaluate_periodicty, 349 | ) 350 | # Override with conditional discriminators 351 | self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) 352 | self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) 353 | 354 | def training_step(self, *args): 355 | bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,) 356 | output = super().training_step(*args, bandwidth_id=bandwidth_id) 357 | return output 358 | 359 | def validation_step(self, *args): 360 | bandwidth_id = torch.tensor([0], device=self.device) 361 | output = super().validation_step(*args, bandwidth_id=bandwidth_id) 362 | return output 363 | 364 | def validation_epoch_end(self, outputs): 365 | if self.global_rank == 0: 366 | *_, audio_in, _ = outputs[0].values() 367 | # Resynthesis with encodec for reference 368 | self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) 369 | encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) 370 | self.logger.experiment.add_audio( 371 | "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, 372 | ) 373 | 374 | super().validation_epoch_end(outputs) 375 | --------------------------------------------------------------------------------