├── src ├── utils │ ├── __init__.py │ ├── logging.py │ ├── scheduler.py │ ├── torch_common.py │ └── audio_utils.py ├── data │ ├── __init__.py │ ├── pretransform │ │ ├── __init__.py │ │ └── gemma_audio_feature.py │ ├── degradation │ │ ├── __init__.py │ │ ├── base.py │ │ ├── lowpass.py │ │ ├── clipping.py │ │ ├── noise.py │ │ └── reverb.py │ ├── modification.py │ ├── audio_io.py │ └── dataset.py ├── model │ ├── wavefit │ │ ├── loss │ │ │ ├── __init__.py │ │ │ └── mrstft.py │ │ ├── __init__.py │ │ ├── wavefit.py │ │ ├── discriminator.py │ │ └── generator.py │ ├── __init__.py │ ├── feature_cleaner │ │ ├── __init__.py │ │ ├── base.py │ │ ├── parallel_adapter.py │ │ └── google_usm.py │ └── miipher_2.py ├── train.py ├── inference.py └── trainer.py ├── assets └── fig │ ├── miipher-2.png │ └── compare_layers.png ├── container ├── build_singularity.bash └── Open-Miipher-2.def ├── configs ├── default.yaml ├── optimizer │ ├── feature_cleaner.yaml │ └── wavefit.yaml ├── data │ ├── deg_gemma_24khz_06sec_clean-only.yaml │ ├── deg_gemma_24khz_10sec.yaml │ └── deg_gemma_24khz_30sec.yaml ├── model │ ├── feature_cleaner │ │ └── google-usm.yaml │ ├── miipher-2_google-usm_wavefit-5_clean-input.yaml │ └── miipher-2_google-usm_wavefit-5_noisy-input.yaml └── trainer │ └── default.yaml ├── LICENSE ├── _LICENSES ├── descript-audio-codec.LICENSE ├── transformers.LICENSE └── unify-parameter-efficient-tuning.LICENSE ├── dataset └── script │ └── make_metadata_csv.py ├── .gitignore └── README.md /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import DegradedAudioDataset 2 | -------------------------------------------------------------------------------- /src/model/wavefit/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .mrstft import MRSTFTLoss, MELMAELoss 2 | -------------------------------------------------------------------------------- /src/data/pretransform/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemma_audio_feature import GemmaAudioFeature 2 | -------------------------------------------------------------------------------- /assets/fig/miipher-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukara-ikemiya/Open-Miipher-2/HEAD/assets/fig/miipher-2.png -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_cleaner import * 2 | from .wavefit import * 3 | from .miipher_2 import * 4 | -------------------------------------------------------------------------------- /assets/fig/compare_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukara-ikemiya/Open-Miipher-2/HEAD/assets/fig/compare_layers.png -------------------------------------------------------------------------------- /src/model/feature_cleaner/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import AudioEncoderAdapter 2 | from .google_usm import GoogleUSMAdapter 3 | -------------------------------------------------------------------------------- /container/build_singularity.bash: -------------------------------------------------------------------------------- 1 | # Create `Open-Miipher-2.sif` file 2 | NAME="Open-Miipher-2" 3 | singularity build --fakeroot ~/$NAME.sif $NAME.def -------------------------------------------------------------------------------- /src/model/wavefit/__init__.py: -------------------------------------------------------------------------------- 1 | from .wavefit import WaveFit 2 | from .discriminator import Discriminator 3 | from .loss import MRSTFTLoss, MELMAELoss 4 | -------------------------------------------------------------------------------- /src/data/degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from .clipping import AudioClipping 2 | from .noise import NoiseAddition 3 | from .reverb import RIRReverb 4 | from .lowpass import AudioLowpass 5 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: feature_cleaner/google-usm 3 | - data: deg_gemma_24khz_30sec 4 | - optimizer: feature_cleaner 5 | - trainer: default 6 | - _self_ 7 | 8 | # no hydra logging 9 | hydra: 10 | output_subdir: null 11 | run: 12 | dir: . 13 | # multirun 14 | sweep: 15 | dir: . -------------------------------------------------------------------------------- /configs/optimizer/feature_cleaner.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimizer: 3 | _partial_: true 4 | _target_: torch.optim.AdamW 5 | betas: [0.8, 0.99] 6 | lr: 0.001 7 | weight_decay: 0.001 8 | 9 | scheduler: 10 | _partial_: true 11 | _target_: utils.scheduler.WarmupCosineLR 12 | warmup_steps: 2000 13 | total_steps: 2000000 14 | min_lr: 0.00001 -------------------------------------------------------------------------------- /src/data/degradation/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class Degradation(ABC): 8 | def __init__(self, sample_rate: int = 16000): 9 | super().__init__() 10 | self.sample_rate = sample_rate 11 | 12 | @abstractmethod 13 | def __call__(self, x): 14 | pass 15 | -------------------------------------------------------------------------------- /configs/data/deg_gemma_24khz_06sec_clean-only.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: data.DegradedAudioDataset 3 | dirs_audio: ??? 4 | dirs_noise: null 5 | sample_size: 14400 # 0.6 sec at 24 kHz 6 | sample_rate: 24000 7 | pretransform: 'gemma' 8 | exts: ['wav', 'flac'] 9 | augment_shift: True 10 | augment_flip: True 11 | augment_volume: True 12 | volume_range: [0.25, 1.0] 13 | clean_only: True 14 | -------------------------------------------------------------------------------- /configs/data/deg_gemma_24khz_10sec.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: data.DegradedAudioDataset 3 | dirs_audio: ??? 4 | dirs_noise: ??? 5 | sample_size: 240000 6 | sample_rate: 24000 7 | pretransform: 'gemma' 8 | exts: ['wav', 'flac'] 9 | augment_shift: True 10 | augment_flip: True 11 | augment_volume: True 12 | volume_range: [0.25, 1.0] 13 | deg_types: ['clipping', 'noise', 'reverb', 'lowpass'] 14 | n_deg_comb: 3 15 | prob_no_deg: 0.05 16 | clean_only: False 17 | -------------------------------------------------------------------------------- /configs/data/deg_gemma_24khz_30sec.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: data.DegradedAudioDataset 3 | dirs_audio: ??? 4 | dirs_noise: ??? 5 | sample_size: 720000 6 | sample_rate: 24000 7 | pretransform: 'gemma' 8 | exts: ['wav', 'flac'] 9 | augment_shift: True 10 | augment_flip: True 11 | augment_volume: True 12 | volume_range: [0.25, 1.0] 13 | deg_types: ['clipping', 'noise', 'reverb', 'lowpass'] 14 | n_deg_comb: 3 15 | prob_no_deg: 0.05 16 | clean_only: False 17 | -------------------------------------------------------------------------------- /configs/model/feature_cleaner/google-usm.yaml: -------------------------------------------------------------------------------- 1 | # Feature cleaner model using Google USM (Gemma3n) as the encoder with adapter layers. 2 | # 3 | # Input sample rate: 16000 hz 4 | # Dim size: 1536 5 | # Output frame rate: 25 frames/sec 6 | 7 | _target_: model.GoogleUSMAdapter 8 | 9 | n_adaptive_layers: 6 # /12 10 | encoder_id: "Atotti/google-usm" 11 | adapter_config: 12 | "dim_bottleneck": 1024 13 | "init_option": "bert" 14 | "adapter_scalar": 1.0 15 | "pre_ln_class": "Gemma3nRMSNorm" 16 | -------------------------------------------------------------------------------- /configs/optimizer/wavefit.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimizer: 3 | _partial_: true 4 | _target_: torch.optim.AdamW 5 | betas: [0.8, 0.99] 6 | lr: 0.0001 7 | weight_decay: 0.001 8 | 9 | scheduler: 10 | _partial_: true 11 | _target_: utils.scheduler.WarmupCosineLR 12 | warmup_steps: 2000 13 | total_steps: 1000000 14 | min_lr: 0.00001 15 | 16 | optimizer_d: 17 | _partial_: true 18 | _target_: torch.optim.AdamW 19 | betas: [0.8, 0.99] 20 | lr: 0.0002 21 | weight_decay: 0.001 22 | 23 | scheduler_d: 24 | _partial_: true 25 | _target_: utils.scheduler.WarmupCosineLR 26 | warmup_steps: 2000 27 | total_steps: 1000000 28 | min_lr: 0.00001 -------------------------------------------------------------------------------- /container/Open-Miipher-2.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel 3 | 4 | %post 5 | apt-get update \ 6 | && DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | python -m pip install --upgrade --no-cache-dir pip setuptools wheel 10 | 11 | python -m pip install transformers==4.53 12 | python -m pip install librosa einops 13 | python -m pip install pyroomacoustics 14 | python -m pip install torchaudio 15 | python -m pip install hydra-core wandb accelerate ema-pytorch 16 | 17 | 18 | %environment 19 | export PATH=/usr/local/bin:$PATH 20 | export PYTHONUNBUFFERED=1 -------------------------------------------------------------------------------- /src/data/degradation/lowpass.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | import random 5 | 6 | import torch 7 | import torchaudio 8 | 9 | from .base import Degradation 10 | 11 | 12 | class AudioLowpass(Degradation): 13 | def __init__( 14 | self, 15 | cutoff_range: tuple = (2000.0, 7000.0), 16 | sample_rate: int = 16000 17 | ): 18 | super().__init__(sample_rate=sample_rate) 19 | self.cutoff_range = cutoff_range 20 | 21 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 22 | cutoff = random.uniform(self.cutoff_range[0], self.cutoff_range[1]) 23 | x = torchaudio.functional.lowpass_biquad(x, self.sample_rate, cutoff) 24 | 25 | # avoid clipping 26 | amp = x.abs().max().item() 27 | if amp > 1.0: 28 | x = x / (amp + 1e-8) 29 | 30 | return x 31 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | 4 | Convenient modules for logging metrics. 5 | """ 6 | 7 | import typing as tp 8 | 9 | import torch 10 | 11 | 12 | class MetricsLogger: 13 | def __init__(self): 14 | self.counts = {} 15 | self.metrics = {} 16 | 17 | def add(self, metrics: tp.Dict[str, torch.Tensor]) -> None: 18 | for k, v in metrics.items(): 19 | if k in self.counts.keys(): 20 | self.counts[k] += 1 21 | self.metrics[k] += v.detach().clone() 22 | else: 23 | self.counts[k] = 1 24 | self.metrics[k] = v.detach().clone() 25 | 26 | def pop(self, mean: bool = True) -> tp.Dict[str, torch.Tensor]: 27 | metrics = {} 28 | for k, v in self.metrics.items(): 29 | metrics[k] = v / self.counts[k] if mean else v 30 | 31 | # reset 32 | self.counts = {} 33 | self.metrics = {} 34 | 35 | return metrics 36 | -------------------------------------------------------------------------------- /src/data/degradation/clipping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | import typing as tp 5 | import random 6 | 7 | import torch 8 | 9 | from .base import Degradation 10 | 11 | 12 | class AudioClipping(Degradation): 13 | def __init__( 14 | self, 15 | amp_range: tp.Tuple[float, float] = (1.2, 2.0), 16 | prob_soft_clip: float = 0.5, 17 | **kwargs 18 | ): 19 | super().__init__(**kwargs) 20 | self.amp_range = amp_range 21 | self.prob_soft_clip = prob_soft_clip 22 | 23 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 24 | amp_org = x.abs().max().item() 25 | assert amp_org <= 1.0 26 | amp = random.uniform(self.amp_range[0], self.amp_range[1]) 27 | scale = amp / (amp_org + 1e-9) 28 | 29 | # clipping 30 | x = x * scale 31 | if random.random() < self.prob_soft_clip: 32 | x = torch.tanh(x) 33 | else: 34 | x = x.clamp(-1.0, 1.0) 35 | 36 | # rescale 37 | x = x / scale 38 | 39 | return x 40 | -------------------------------------------------------------------------------- /src/data/degradation/noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | import random 5 | 6 | import torch 7 | 8 | from .base import Degradation 9 | 10 | 11 | class NoiseAddition(Degradation): 12 | def __init__( 13 | self, 14 | snr_range: tuple = (5.0, 30.0), 15 | **kwargs 16 | ): 17 | super().__init__(**kwargs) 18 | self.snr_range = snr_range 19 | 20 | def __call__(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: 21 | assert x.shape == noise.shape, f"Shapes of input and noise must be the same: {x.shape} != {noise.shape}" 22 | snr = random.uniform(self.snr_range[0], self.snr_range[1]) 23 | p_x = x.pow(2.0).mean() 24 | p_n = noise.pow(2.0).mean() 25 | P_new = p_x / (10 ** (snr / 10)) 26 | scale_n = (P_new / (p_n + 1e-9)).sqrt() 27 | 28 | # noise addition 29 | x = x + noise * scale_n 30 | 31 | # avoid clipping 32 | amp = x.abs().max().item() 33 | if amp > 1.0: 34 | x = x / (amp + 1e-8) 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | # Root directory for outputs 2 | output_dir: ??? 3 | 4 | # Checkpoint directory for resuming training 5 | ckpt_dir: null 6 | 7 | # Batch size 8 | # NOTE: Batch size must be a multiple of GPU number 9 | batch_size: 1 10 | num_workers: 2 11 | 12 | # Seed value used for rng initialization 13 | seed: 0 14 | 15 | # debug 16 | debug: false 17 | 18 | # Automatic mixed precision 19 | # Choose from ‘no’,‘fp16’,‘bf16’ or ‘fp8’. 20 | amp: 'bf16' # 'no' 21 | 22 | # Max norm of gradient clipping 23 | max_grad_norm: 1.0 24 | 25 | # EMA (Exponential Moving Average) 26 | ema: 27 | beta: 0.999 28 | update_after_step: 100 29 | update_every: 10 30 | 31 | # Logging 32 | 33 | logger: 34 | project_name: 'project_name' 35 | run_name: 'run_name' 36 | 37 | logging: 38 | # Step interval for logging metrics / saving checkpoints 39 | # / generating samples / test (validation) / printing metrics 40 | n_step_log: 10 41 | n_step_ckpt: 2000 42 | n_step_sample: 2000 43 | n_step_print: 100 44 | # Number of generated samples 45 | n_samples_per_step: 3 46 | 47 | metrics_for_best_ckpt: ['loss'] -------------------------------------------------------------------------------- /src/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | ------ 5 | Learning rate schedulers. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | 13 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 14 | def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0, last_epoch=-1): 15 | self.warmup_steps = warmup_steps 16 | self.total_steps = total_steps 17 | self.min_lr = min_lr 18 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | step = self.last_epoch + 1 22 | 23 | if step < self.warmup_steps: 24 | # warmup 25 | return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs] 26 | else: 27 | # cosine decay (base_lr -> min_lr) 28 | progress = min(1, (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)) 29 | 30 | return [ 31 | self.min_lr + (base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)) 32 | for base_lr in self.base_lrs 33 | ] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yukara Ikemiya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /_LICENSES/descript-audio-codec.LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present, Descript 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils/torch_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def exists(x: torch.Tensor): 13 | return x is not None 14 | 15 | 16 | def get_world_size(): 17 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 18 | return 1 19 | else: 20 | return torch.distributed.get_world_size() 21 | 22 | 23 | def get_rank(): 24 | """Get rank of current process.""" 25 | 26 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 27 | return 0 28 | else: 29 | return torch.distributed.get_rank() 30 | 31 | 32 | def print_once(*args): 33 | if get_rank() == 0: 34 | print(*args) 35 | 36 | 37 | def set_seed(seed: int = 0): 38 | torch.manual_seed(seed) 39 | if torch.cuda.is_available(): 40 | torch.cuda.manual_seed_all(seed) 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | os.environ["PYTHONHASHSEED"] = str(seed) 44 | 45 | 46 | def count_parameters(model: torch.nn.Module, include_buffers: bool = False): 47 | n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 48 | n_buffers = sum(p.numel() for p in model.buffers()) if include_buffers else 0 49 | return n_trainable_params + n_buffers 50 | 51 | 52 | def sort_dict(D: dict): 53 | s_keys = sorted(D.keys()) 54 | return {k: D[k] for k in s_keys} 55 | 56 | 57 | def checkpoint(function, *args, **kwargs): 58 | """ Gradient checkpointing """ 59 | kwargs.setdefault("use_reentrant", False) 60 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 61 | -------------------------------------------------------------------------------- /src/data/pretransform/gemma_audio_feature.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | -------------- 5 | A wrapper of Gemma audio feature extractor 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torchaudio import transforms as T 11 | from transformers import Gemma3nAudioFeatureExtractor 12 | 13 | 14 | class GemmaAudioFeature: 15 | """ 16 | A wrapper around the Gemma3nAudioFeatureExtractor. 17 | 18 | NOTE: Feature extraction is executed on CPU. 19 | """ 20 | 21 | def __init__(self, model_id="google/gemma-3n-e2b-it"): 22 | self.extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model_id) 23 | self.sr = self.extractor.sampling_rate 24 | 25 | def __call__(self, audio: torch.Tensor, sr_in: int) -> torch.Tensor: 26 | """ 27 | Args: 28 | audio (torch.Tensor): (num_samples) 29 | Returns: 30 | (torch.Tensor): (num_frames, feature_dim) 31 | """ 32 | if sr_in != self.sr: 33 | resample_tf = T.Resample(sr_in, self.sr) 34 | audio = resample_tf(audio) 35 | 36 | audio = audio.numpy() 37 | audio = audio.reshape(1, -1) # (1, L) 38 | output = self.extractor(audio, return_tensors="pt") 39 | audio_mel = output["input_features"] 40 | # The encoder expects a padding mask (True for padding), while the feature extractor 41 | # returns an attention mask (True for valid tokens). We must invert it. 42 | # NOTE: 'False' for valid frames, 'True' for padded frames 43 | # audio_mel_mask = ~output["input_features_mask"].to(torch.bool) # not used for now 44 | 45 | return audio_mel.squeeze(0) # (num_frames, feature_dim) 46 | -------------------------------------------------------------------------------- /src/data/modification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import random 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | # Channels 13 | 14 | class Mono(nn.Module): 15 | def __call__(self, x: torch.Tensor): 16 | assert len(x.shape) <= 2 17 | return torch.mean(x, dim=0, keepdims=True) if len(x.shape) > 1 else x 18 | 19 | 20 | class Stereo(nn.Module): 21 | def __call__(self, x: torch.Tensor): 22 | x_shape = x.shape 23 | assert len(x_shape) <= 2 24 | # Check if it's mono 25 | if len(x_shape) == 1: # s -> 2, s 26 | x = x.unsqueeze(0).repeat(2, 1) 27 | elif len(x_shape) == 2: 28 | if x_shape[0] == 1: # 1, s -> 2, s 29 | x = x.repeat(2, 1) 30 | elif x_shape[0] > 2: # ?, s -> 2,s 31 | x = x[:2, :] 32 | 33 | return x 34 | 35 | 36 | # Augmentation 37 | 38 | class PhaseFlipper(nn.Module): 39 | """Randomly invert the phase of a signal""" 40 | 41 | def __init__(self, p=0.5): 42 | super().__init__() 43 | self.p = p 44 | 45 | def __call__(self, x: torch.Tensor): 46 | assert len(x.shape) <= 2 47 | return -x if (random.random() < self.p) else x 48 | 49 | 50 | class VolumeChanger(nn.Module): 51 | """Randomly change volume (amplitude) of a signal""" 52 | 53 | def __init__(self, min_amp: float = 0.25, max_amp: float = 1.0): 54 | super().__init__() 55 | self.min_amp = min_amp 56 | self.max_amp = max_amp 57 | 58 | def __call__(self, x: torch.Tensor): 59 | assert x.ndim <= 2 60 | amp_x = x.abs().max().item() 61 | if amp_x < 1e-5: 62 | return x 63 | 64 | min_db = 20 * math.log10(self.min_amp / amp_x) 65 | max_db = 20 * math.log10(self.max_amp / amp_x) 66 | scale_db = random.uniform(min_db, max_db) 67 | scale = 10 ** (scale_db / 20) 68 | x = x * scale 69 | 70 | return x 71 | -------------------------------------------------------------------------------- /configs/model/miipher-2_google-usm_wavefit-5_clean-input.yaml: -------------------------------------------------------------------------------- 1 | # Miipher-2 model consists of a Google-USM feature cleaner and a WaveFit-5 vocoder. 2 | # 3 | # Input sample rate: 16000 hz (-> mel-spectrogram) 4 | # Output sample rate: 24000 hz 5 | 6 | defaults: 7 | - feature_cleaner: google-usm 8 | 9 | _target_: model.Miipher2 10 | 11 | mode: clean_input 12 | upsample_factor: 4 13 | upsample_mode: "linear" 14 | gradient_checkpointing: false 15 | 16 | vocoder: 17 | _target_: model.WaveFit 18 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5) 19 | num_iteration: 5 20 | # Gain of target audio 21 | target_gain: 0.9 22 | 23 | # Pre-network (Conformer blocks) 24 | num_conformer_blocks: 4 25 | args_conformer: 26 | # NOTE: These settings are based on Gemma3n audio encoder settings. 27 | # https://huggingface.co/google/gemma-3n-E2B-it/blob/main/config.json 28 | conf_attention_chunk_size: 12 29 | conf_attention_context_left: 13 30 | conf_attention_context_right: 0 31 | conf_attention_logit_cap: 50.0 32 | conf_conv_kernel_size: 5 33 | conf_num_attention_heads: 8 34 | conf_reduction_factor: 4 35 | conf_residual_weight: 0.5 36 | gradient_clipping: 10000000000.0 37 | hidden_size: 1536 38 | rms_norm_eps: 1e-06 39 | 40 | # Generator 41 | args_generator: 42 | dim_feat: 1536 43 | upsample_factors: [5, 4, 3, 2, 2] 44 | upsample_channels: [512, 512, 256, 128, 128] 45 | downsample_channels: [128, 128, 256, 512] 46 | 47 | discriminator: 48 | _target_: model.Discriminator 49 | msd_kwargs: 50 | num_D: 3 51 | ndf: 16 52 | n_layers: 4 53 | downsampling_factor: 4 54 | mpd_kwargs: 55 | periods: [2, 3, 5, 7, 11, 13, 17, 19] 56 | 57 | mrstft: 58 | # Parallel WaveGAN setting (Sec.2.3 in the Miipher paper) 59 | _target_: model.MRSTFTLoss 60 | n_ffts: [512, 1024, 2048] 61 | win_sizes: [240, 600, 1200] 62 | hop_sizes: [50, 120, 240] 63 | 64 | loss_lambdas: 65 | mrstft_sc_loss: 2.5 66 | mrstft_mag_loss: 2.5 67 | disc_gan_loss: 1.0 68 | disc_feat_loss: 10.0 -------------------------------------------------------------------------------- /src/data/audio_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import math 6 | import json 7 | 8 | from torch.nn import functional as F 9 | import torchaudio 10 | from torchaudio import transforms as T 11 | 12 | 13 | def get_audio_metadata(filepath, cache=True): 14 | try: 15 | with open(filepath + '.json', 'r') as f: 16 | info = json.load(f) 17 | return info 18 | except Exception: 19 | try: 20 | info_ = torchaudio.info(filepath) 21 | sample_rate = info_.sample_rate 22 | num_channels = info_.num_channels 23 | num_frames = info_.num_frames 24 | 25 | info = { 26 | 'sample_rate': sample_rate, 27 | 'num_frames': num_frames, 28 | 'num_channels': num_channels 29 | } 30 | except Exception as e: 31 | # error : cannot open an audio file 32 | print(f"Failed to load metadata for {filepath}: {e}") 33 | info = {'sample_rate': 0, 'num_frames': 0, 'num_channels': 0} 34 | 35 | if cache: 36 | with open(filepath + '.json', 'w') as f: 37 | json.dump(info, f, indent=2) 38 | 39 | return info 40 | 41 | 42 | def load_audio_with_pad(filepath, info: dict, sr: int, n_samples: int, offset: int): 43 | sr_in, num_frames = info['sample_rate'], info['num_frames'] 44 | n_samples_in = int(math.ceil(n_samples * (sr_in / sr))) 45 | 46 | # load audio 47 | ext = filepath.split(".")[-1] 48 | out_frames = min(n_samples_in, num_frames - offset) 49 | 50 | audio, _ = torchaudio.load( 51 | filepath, frame_offset=offset, num_frames=out_frames, 52 | format=ext, backend='soundfile') 53 | 54 | # resample 55 | if sr_in != sr: 56 | resample_tf = T.Resample(sr_in, sr) 57 | audio = resample_tf(audio)[..., :n_samples] 58 | 59 | # zero pad 60 | L = audio.shape[-1] 61 | if L < n_samples: 62 | audio = F.pad(audio, (0, n_samples - L), value=0.) 63 | 64 | return audio 65 | -------------------------------------------------------------------------------- /configs/model/miipher-2_google-usm_wavefit-5_noisy-input.yaml: -------------------------------------------------------------------------------- 1 | # Miipher-2 model consists of a Google-USM feature cleaner and a WaveFit-5 vocoder. 2 | # 3 | # Input sample rate: 16000 hz (-> mel-spectrogram) 4 | # Output sample rate: 24000 hz 5 | 6 | defaults: 7 | - feature_cleaner: google-usm 8 | 9 | _target_: model.Miipher2 10 | 11 | mode: noisy_input 12 | upsample_factor: 4 13 | upsample_mode: "linear" 14 | gradient_checkpointing: true 15 | 16 | feature_cleaner_ckpt_dir: ??? 17 | vocoder_ckpt_dir: ??? 18 | 19 | vocoder: 20 | _target_: model.WaveFit 21 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5) 22 | num_iteration: 5 23 | # Gain of target audio 24 | target_gain: 0.9 25 | 26 | # Pre-network (Conformer blocks) 27 | num_conformer_blocks: 4 28 | args_conformer: 29 | # NOTE: These settings are based on Gemma3n audio encoder settings. 30 | # https://huggingface.co/google/gemma-3n-E2B-it/blob/main/config.json 31 | conf_attention_chunk_size: 12 32 | conf_attention_context_left: 13 33 | conf_attention_context_right: 0 34 | conf_attention_logit_cap: 50.0 35 | conf_conv_kernel_size: 5 36 | conf_num_attention_heads: 8 37 | conf_reduction_factor: 4 38 | conf_residual_weight: 0.5 39 | gradient_clipping: 10000000000.0 40 | hidden_size: 1536 41 | rms_norm_eps: 1e-06 42 | 43 | # Generator 44 | args_generator: 45 | dim_feat: 1536 46 | upsample_factors: [5, 4, 3, 2, 2] 47 | upsample_channels: [512, 512, 256, 128, 128] 48 | downsample_channels: [128, 128, 256, 512] 49 | 50 | discriminator: 51 | _target_: model.Discriminator 52 | msd_kwargs: 53 | num_D: 3 54 | ndf: 16 55 | n_layers: 4 56 | downsampling_factor: 4 57 | mpd_kwargs: 58 | periods: [2, 3, 5, 7, 11, 13, 17, 19] 59 | 60 | mrstft: 61 | # Parallel WaveGAN setting (Sec.2.3 in the Miipher paper) 62 | _target_: model.MRSTFTLoss 63 | n_ffts: [512, 1024, 2048] 64 | win_sizes: [240, 600, 1200] 65 | hop_sizes: [50, 120, 240] 66 | 67 | loss_lambdas: 68 | mrstft_sc_loss: 2.5 69 | mrstft_mag_loss: 2.5 70 | disc_gan_loss: 1.0 71 | disc_feat_loss: 10.0 -------------------------------------------------------------------------------- /src/model/feature_cleaner/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | ----------------------------------------------------- 5 | A base class of audio encoder adapters. 6 | """ 7 | from abc import ABC, abstractmethod 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class AudioEncoderAdapter(ABC, nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | @abstractmethod 18 | def forward(self, x: torch.Tensor, encoder_only: bool = False) -> torch.Tensor: 19 | pass 20 | 21 | @abstractmethod 22 | def save_state_dict(self, path: str): 23 | pass 24 | 25 | @abstractmethod 26 | def load_state_dict(self, path: str): 27 | pass 28 | 29 | @abstractmethod 30 | def get_state_dict(self) -> dict: 31 | pass 32 | 33 | def train_step( 34 | self, 35 | x_tgt: torch.Tensor, 36 | x_deg: torch.Tensor, 37 | loss_lambda: dict = { 38 | 'l1': 1.0, 39 | 'l2': 1.0, 40 | 'spectral_convergence': 1.0 41 | }, 42 | train: bool = True 43 | ) -> dict: 44 | """ 45 | Loss computation for feature cleaners defined in the Miipher paper. 46 | https://arxiv.org/abs/2303.01664 47 | """ 48 | self.train() if train else self.eval() 49 | assert x_tgt.shape == x_deg.shape, f"Shapes of target and degraded features must be the same: {x_tgt.shape} != {x_deg.shape}" 50 | 51 | with torch.no_grad(): 52 | feats_tgt = self.forward(x_tgt, encoder_only=True) 53 | 54 | with torch.set_grad_enabled(train): 55 | feats_deg = self.forward(x_deg, encoder_only=False) 56 | 57 | # loss 58 | l1_loss = (feats_deg - feats_tgt).abs().mean() * loss_lambda['l1'] 59 | l2_loss = (feats_deg - feats_tgt).pow(2.0).mean() * loss_lambda['l2'] 60 | sc_loss = l2_loss / (feats_tgt.pow(2.0).mean() + 1e-9) * loss_lambda['spectral_convergence'] 61 | 62 | loss = l1_loss + l2_loss + sc_loss 63 | 64 | return { 65 | 'loss': loss, 66 | 'l1_loss': l1_loss.detach(), 67 | 'l2_loss': l2_loss.detach(), 68 | 'sc_loss': sc_loss.detach() 69 | } 70 | -------------------------------------------------------------------------------- /src/model/feature_cleaner/parallel_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | Adapted from the following repo's code under Apache-2.0 License. 5 | https://github.com/jxhe/unify-parameter-efficient-tuning/ 6 | 7 | ----------------------------------------------------- 8 | Parallel adapter for LLMs. 9 | """ 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def init_bert_weights(module): 18 | """Initialize the weights.""" 19 | if isinstance(module, (nn.Linear, nn.Embedding)): 20 | # std defaults to 0.02, this might need to be changed 21 | module.weight.data.normal_(mean=0.0, std=0.02) 22 | elif isinstance(module, nn.LayerNorm): 23 | module.bias.data.zero_() 24 | module.weight.data.fill_(1.0) 25 | if isinstance(module, nn.Linear) and module.bias is not None: 26 | module.bias.data.zero_() 27 | 28 | 29 | class AdapterLayer(nn.Module): 30 | def __init__( 31 | self, 32 | dim_in: int, 33 | dim_bottleneck: int, 34 | dropout: float = 0.0, 35 | init_option: str = "bert", 36 | adapter_scalar: tp.Union[float, str] = 1.0, 37 | pre_ln_class=None 38 | ): 39 | super().__init__() 40 | 41 | self.n_embd = dim_in 42 | self.down_size = dim_bottleneck 43 | 44 | # Layer normalization options 45 | self.use_pre_ln = pre_ln_class is not None 46 | if self.use_pre_ln: 47 | self.pre_ln = pre_ln_class(dim_in) 48 | 49 | # PA modules 50 | self.down_proj = nn.Linear(self.n_embd, self.down_size) 51 | self.non_linear_func = nn.ReLU() 52 | self.up_proj = nn.Linear(self.down_size, self.n_embd) 53 | self.scale = nn.Parameter(torch.ones(1)) if adapter_scalar == "learnable_scalar" else float(adapter_scalar) 54 | self.dropout = dropout 55 | 56 | # Initialization options 57 | if init_option == "bert": 58 | self.apply(init_bert_weights) 59 | elif init_option == "lora": 60 | with torch.no_grad(): 61 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 62 | nn.init.zeros_(self.up_proj.weight) 63 | nn.init.zeros_(self.down_proj.bias) 64 | nn.init.zeros_(self.up_proj.bias) 65 | else: 66 | raise ValueError(f"Unknown initialization option: {init_option}") 67 | 68 | def forward(self, x): 69 | if self.use_pre_ln: 70 | x = self.pre_ln(x) 71 | 72 | down = self.down_proj(x) 73 | down = self.non_linear_func(down) 74 | down = nn.functional.dropout(down, p=self.dropout, training=self.training) 75 | up = self.up_proj(down) 76 | up = up * self.scale 77 | 78 | return up 79 | -------------------------------------------------------------------------------- /src/utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from librosa.filters import mel as librosa_mel 11 | 12 | EPS = 1e-10 13 | 14 | 15 | def get_amplitude_spec(x, n_fft, win_size, hop_size, window, return_power: bool = False): 16 | stft_spec = torch.stft( 17 | x, n_fft, hop_length=hop_size, win_length=win_size, window=window, 18 | center=True, normalized=False, onesided=True, return_complex=True) 19 | 20 | power_spec = torch.view_as_real(stft_spec).pow(2).sum(-1) 21 | 22 | return power_spec if return_power else torch.sqrt(power_spec + EPS) 23 | 24 | 25 | class MelSpectrogram(nn.Module): 26 | def __init__( 27 | self, 28 | sr: int, 29 | # STFT setting 30 | n_fft: int, win_size: int, hop_size: int, 31 | # MelSpec setting 32 | n_mels: int, fmin: float, fmax: float, 33 | ): 34 | super().__init__() 35 | 36 | self.sr = sr 37 | self.n_fft = n_fft 38 | self.win_size = win_size 39 | self.hop_size = hop_size 40 | self.n_mels = n_mels 41 | self.fmin = fmin 42 | self.fmax = fmax 43 | 44 | mel_basis = librosa_mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 45 | mel_basis = torch.from_numpy(mel_basis).float() 46 | mel_inv_basis = torch.linalg.pinv(mel_basis) 47 | 48 | self.register_buffer('fft_win', torch.hann_window(win_size)) 49 | self.register_buffer('mel_basis', mel_basis) 50 | self.register_buffer('mel_inv_basis', mel_inv_basis) 51 | 52 | def compute_mel(self, x: torch.Tensor): 53 | """ 54 | Compute Mel-spectrogram. 55 | 56 | Args: 57 | x: time_signal, (bs, length) 58 | Returns: 59 | mel_spec: Mel spectrogram, (bs, n_mels, num_frame) 60 | """ 61 | assert x.dim() == 2 62 | L = x.shape[-1] 63 | # NOTE : To prevent different signal length in the final frame of the STFT between training and inference time, 64 | # input signal length must be a multiple of hop_size. 65 | assert L % self.hop_size == 0, f"Input signal length must be a multiple of hop_size {self.hop_size}." + \ 66 | f"Input shape -> {x.shape}" 67 | 68 | num_frame = L // self.hop_size 69 | 70 | # STFT 71 | stft_spec = get_amplitude_spec(x, self.n_fft, self.win_size, self.hop_size, self.fft_win) 72 | 73 | # Mel Spec 74 | mel_spec = torch.matmul(self.mel_basis, stft_spec) 75 | 76 | # NOTE : The last frame is removed here. 77 | # When using center=True setting, output from torch.stft has frame length of (L//hopsize+1). 78 | # For training WaveGrad-based architecture, the frame length must be (L//hopsize). 79 | # There might be a better way, but I believe this has little to no impact on training 80 | # since the whole signal information is contained in the previous frames even when removing the last one. 81 | mel_spec = mel_spec[..., :num_frame] 82 | 83 | return mel_spec 84 | -------------------------------------------------------------------------------- /src/model/wavefit/loss/mrstft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import typing as tp 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.audio_utils import get_amplitude_spec, MelSpectrogram 12 | 13 | 14 | class MRSTFTLoss(nn.Module): 15 | """ 16 | Multi-resolution STFT loss corresponding to the eq.(9) 17 | """ 18 | 19 | def __init__( 20 | self, 21 | n_ffts: tp.List[int] = [512, 1024, 2048], 22 | win_sizes: tp.List[int] = [360, 900, 1800], 23 | hop_sizes: tp.List[int] = [80, 150, 300], 24 | EPS: float = 1e-5 25 | ): 26 | super().__init__() 27 | assert len(n_ffts) == len(win_sizes) == len(hop_sizes) 28 | self.n_ffts = n_ffts 29 | self.win_sizes = win_sizes 30 | self.hop_sizes = hop_sizes 31 | 32 | # NOTE: Since spectral convergence is quite sensitive to small values in the spectrum, 33 | # I believe setting a higher lower bound will result in more stable training. 34 | self.EPS = EPS 35 | 36 | def forward( 37 | self, 38 | pred: torch.Tensor, 39 | target: torch.Tensor 40 | ): 41 | losses = { 42 | 'mrstft_sc_loss': 0., 43 | 'mrstft_mag_loss': 0. 44 | } 45 | 46 | for n_fft, win_size, hop_size in zip(self.n_ffts, self.win_sizes, self.hop_sizes): 47 | window = torch.hann_window(win_size, device=pred.device) 48 | spec_t = get_amplitude_spec(target.squeeze(1), n_fft, win_size, hop_size, window) 49 | spec_p = get_amplitude_spec(pred.squeeze(1), n_fft, win_size, hop_size, window) 50 | 51 | # spectral convergence 52 | sc_loss = (spec_t - spec_p).norm(p=2) / (spec_t.norm(p=2) + self.EPS) 53 | 54 | # magnitude loss 55 | mag_loss = F.l1_loss(torch.log(spec_t.clamp(min=self.EPS)), torch.log(spec_p.clamp(min=self.EPS))) 56 | 57 | losses['mrstft_sc_loss'] += sc_loss 58 | losses['mrstft_mag_loss'] += mag_loss 59 | 60 | losses['mrstft_sc_loss'] /= len(self.n_ffts) 61 | losses['mrstft_mag_loss'] /= len(self.n_ffts) 62 | 63 | return losses 64 | 65 | 66 | class MELMAELoss(nn.Module): 67 | """ 68 | MAE(L1) loss of Mel spectrogram corresponding to the second term of the eq.(19) 69 | """ 70 | 71 | def __init__( 72 | self, 73 | sr: int = 24000, 74 | n_fft: int = 1024, 75 | win_size: int = 900, 76 | hop_size: int = 150, 77 | n_mels: int = 128, 78 | fmin: float = 20., 79 | fmax: float = 12000. 80 | ): 81 | super().__init__() 82 | 83 | self.mel = MelSpectrogram(sr, n_fft, win_size, hop_size, n_mels, fmin, fmax) 84 | 85 | def forward( 86 | self, 87 | pred: torch.Tensor, 88 | target: torch.Tensor 89 | ): 90 | losses = {'mel_mae_loss': 0.} 91 | 92 | # Mel MAE (L1) loss 93 | mel_p = self.mel.compute_mel(pred.squeeze(1)) 94 | mel_t = self.mel.compute_mel(target.squeeze(1)) 95 | 96 | losses['mel_mae_loss'] = F.l1_loss(mel_t, mel_p) 97 | 98 | return losses 99 | -------------------------------------------------------------------------------- /dataset/script/make_metadata_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | 5 | import pandas as pd 6 | import torchaudio 7 | 8 | 9 | def fast_scandir(dir: str, exts: list = ['wav', 'flac']) -> tuple: 10 | """ Very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243 11 | 12 | fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py 13 | 14 | Args: 15 | dir (str): top-level directory at which to begin scanning. 16 | exts (tp.List[str]): list of allowed file extensions. 17 | """ 18 | subfolders, files = [], [] 19 | # add starting period to extensions if needed 20 | exts = ['.' + x if x[0] != '.' else x for x in exts] 21 | 22 | try: # hope to avoid 'permission denied' by this try 23 | for f in os.scandir(dir): 24 | try: # 'hope to avoid too many levels of symbolic links' error 25 | if f.is_dir(): 26 | subfolders.append(f.path) 27 | elif f.is_file(): 28 | is_hidden = os.path.basename(f.path).startswith(".") 29 | has_ext = os.path.splitext(f.name)[1].lower() in exts 30 | 31 | if has_ext and (not is_hidden): 32 | files.append(f.path) 33 | except Exception: 34 | pass 35 | except Exception: 36 | pass 37 | 38 | for dir in list(subfolders): 39 | sf, f = fast_scandir(dir, exts) 40 | subfolders.extend(sf) 41 | files.extend(f) 42 | 43 | return subfolders, files 44 | 45 | 46 | def get_audio_metadata(filepath): 47 | info_ = torchaudio.info(filepath) 48 | sample_rate = info_.sample_rate 49 | num_channels = info_.num_channels 50 | num_frames = info_.num_frames 51 | 52 | info = { 53 | 'sample_rate': sample_rate, 54 | 'num_frames': num_frames, 55 | 'num_channels': num_channels 56 | } 57 | 58 | return info 59 | 60 | 61 | def make_metadata_csv(dir: str, exts=['wav', 'flac']): 62 | 63 | csv_path = os.path.join(dir, "metadata.csv") 64 | rows = [] 65 | 66 | _, files = fast_scandir(dir, exts=exts) 67 | print(f"Found {len(files)} audio files in {dir}") 68 | 69 | for p in tqdm(files): 70 | info = get_audio_metadata(p) 71 | rel_path = os.path.relpath(p, dir) 72 | row = { 73 | "file_path": rel_path, 74 | "sample_rate": info["sample_rate"], 75 | "num_frames": info["num_frames"], 76 | "num_channels": info["num_channels"] 77 | } 78 | rows.append(row) 79 | 80 | # print(row) 81 | 82 | rows.sort(key=lambda x: x["file_path"]) 83 | df = pd.DataFrame(rows) 84 | df.to_csv(csv_path, index=False) 85 | 86 | print(f"Saved metadata to {csv_path}") 87 | 88 | 89 | def main(): 90 | args = argparse.ArgumentParser() 91 | args.add_argument('--root-dir', type=str, required=True, help="A root directory of audio dataset.") 92 | args = args.parse_args() 93 | 94 | root_dir = args.root_dir 95 | 96 | print(root_dir) 97 | make_metadata_csv(root_dir) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /src/data/degradation/reverb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | import typing as tp 5 | import random 6 | 7 | import torch 8 | import numpy as np 9 | from scipy.signal import butter, filtfilt, fftconvolve 10 | import pyroomacoustics as pra 11 | 12 | from .base import Degradation 13 | 14 | 15 | class RIRReverb(Degradation): 16 | def __init__( 17 | self, 18 | sample_rate: int = 16000, 19 | rt60_range: tuple = (0.2, 0.5), 20 | xyz_range: tuple = ((2., 10.), (2., 10.), (2., 5.)), 21 | mic_margin: float = 0.5, 22 | src_margin: float = 0.1, 23 | # cutoff freq for RIR filter (highpass) 24 | hp_cutoff_hz: float = 20.0 25 | ): 26 | super().__init__(sample_rate=sample_rate) 27 | self.rt60_range = rt60_range 28 | self.xyz_range = xyz_range 29 | self.mic_margin = mic_margin 30 | self.src_margin = src_margin 31 | self.hp_cutoff_hz = hp_cutoff_hz 32 | 33 | # pre-compute highpass filter 34 | nyq = sample_rate / 2. 35 | norm_cutoff = self.hp_cutoff_hz / nyq 36 | self.hp_b, self.hp_a = butter(4, norm_cutoff, btype="high") 37 | 38 | def _sample_room_params(self): 39 | rt60 = random.uniform(self.rt60_range[0], self.rt60_range[1]) 40 | Lx = random.uniform(self.xyz_range[0][0], self.xyz_range[0][1]) 41 | Ly = random.uniform(self.xyz_range[1][0], self.xyz_range[1][1]) 42 | Lz = random.uniform(self.xyz_range[2][0], self.xyz_range[2][1]) 43 | mic_pos = [ 44 | random.uniform(self.mic_margin, Lx - self.mic_margin), 45 | random.uniform(self.mic_margin, Ly - self.mic_margin), 46 | random.uniform(1.2, min(2.0, Lz - self.mic_margin)) 47 | ] 48 | src_pos = [ 49 | random.uniform(self.src_margin, Lx - self.src_margin), 50 | random.uniform(self.src_margin, Ly - self.src_margin), 51 | random.uniform(self.src_margin, Lz - self.src_margin) 52 | ] 53 | 54 | return rt60, [Lx, Ly, Lz], mic_pos, src_pos 55 | 56 | def _generate_rir( 57 | self, rt60: float, room_size: tp.List[float], 58 | mic_pos: tp.List[float], src_pos: tp.List[float] 59 | ): 60 | e_absorption, max_order = pra.inverse_sabine(rt60, room_size) 61 | room = pra.ShoeBox( 62 | room_size, 63 | fs=self.sample_rate, 64 | materials=pra.Material(e_absorption), 65 | max_order=max_order, 66 | ) 67 | room.add_microphone_array( 68 | pra.MicrophoneArray(np.array(mic_pos).reshape(3, 1), self.sample_rate) 69 | ) 70 | room.add_source(src_pos) 71 | room.compute_rir() 72 | 73 | rir = room.rir[0][0] 74 | L_rir = len(rir) 75 | 76 | # highpass to RIR 77 | rir_hp = filtfilt(self.hp_b, self.hp_a, rir) 78 | rir_hp = rir_hp[:L_rir] # trim to original length 79 | 80 | # scale direct sound to 0 dB 81 | peak_idx = np.argmax(np.abs(rir_hp)) 82 | if np.abs(rir_hp[peak_idx]) > 1e-9: 83 | rir_hp /= rir_hp[peak_idx] 84 | 85 | return rir_hp, peak_idx 86 | 87 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 88 | assert x.ndim == 1, f"Input audio must be 1D tensor: {x.shape}" 89 | dtype = x.dtype 90 | L = len(x) 91 | 92 | # sample room parameters 93 | rt60, room_size, mic_pos, src_pos = self._sample_room_params() 94 | 95 | # generate RIR 96 | rir, peak_idx = self._generate_rir(rt60, room_size, mic_pos, src_pos) 97 | 98 | # convolve 99 | x = fftconvolve(x.numpy(), rir, mode='full') 100 | x = torch.from_numpy(x).to(dtype) 101 | 102 | # fix signal shift and length 103 | offset = min(peak_idx, len(x) - L) 104 | x = x[offset:offset + L] 105 | 106 | # avoid clipping 107 | amp = x.abs().max().item() 108 | if amp > 1.0: 109 | x = x / (amp + 1e-8) 110 | 111 | return x 112 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | 5 | import sys 6 | sys.dont_write_bytecode = True 7 | 8 | # DDP 9 | from accelerate import Accelerator, DistributedDataParallelKwargs, DataLoaderConfiguration 10 | from accelerate.utils import ProjectConfiguration 11 | 12 | import torch 13 | import hydra 14 | from hydra.core.hydra_config import HydraConfig 15 | from omegaconf import DictConfig, OmegaConf 16 | from ema_pytorch import EMA 17 | 18 | from utils.torch_common import get_world_size, count_parameters, set_seed 19 | from trainer import Trainer 20 | 21 | 22 | @hydra.main(version_base=None, config_path='../configs/', config_name="default.yaml") 23 | def main(cfg: DictConfig): 24 | 25 | # Update config if ckpt_dir is specified (training resumption) 26 | 27 | if cfg.trainer.ckpt_dir is not None: 28 | overrides = HydraConfig.get().overrides.task 29 | overrides = [e for e in overrides if isinstance(e, str)] 30 | override_conf = OmegaConf.from_dotlist(overrides) 31 | cfg = OmegaConf.merge(cfg, override_conf) 32 | 33 | # Load checkpoint configuration 34 | cfg_ckpt = OmegaConf.load(f'{cfg.trainer.ckpt_dir}/config.yaml') 35 | cfg = OmegaConf.merge(cfg_ckpt, override_conf) 36 | 37 | # HuggingFace Accelerate for distributed training 38 | 39 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 40 | dl_config = DataLoaderConfiguration(split_batches=True) 41 | p_config = ProjectConfiguration(project_dir=cfg.trainer.output_dir) 42 | accel = Accelerator( 43 | mixed_precision=cfg.trainer.amp, 44 | dataloader_config=dl_config, 45 | project_config=p_config, 46 | kwargs_handlers=[ddp_kwargs], 47 | log_with='wandb' 48 | ) 49 | 50 | accel.init_trackers(cfg.trainer.logger.project_name, config=OmegaConf.to_container(cfg), 51 | init_kwargs={"wandb": {"name": cfg.trainer.logger.run_name, "dir": cfg.trainer.output_dir}}) 52 | 53 | if accel.is_main_process: 54 | print("->->-> DDP Initialized.") 55 | print(f"->->-> World size (Number of GPUs): {get_world_size()}") 56 | 57 | set_seed(cfg.trainer.seed) 58 | 59 | # Dataset 60 | 61 | batch_size = cfg.trainer.batch_size 62 | num_workers = cfg.trainer.num_workers 63 | train_dataset = hydra.utils.instantiate(cfg.data.train) 64 | train_dataloader = torch.utils.data.DataLoader( 65 | train_dataset, batch_size=batch_size, shuffle=True, 66 | num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers > 0)) 67 | 68 | # Model 69 | 70 | model = hydra.utils.instantiate(cfg.model) 71 | 72 | # EMA 73 | 74 | ema = None 75 | if accel.is_main_process: 76 | ema = EMA(model, **cfg.trainer.ema) 77 | ema.to(accel.device) 78 | 79 | # Optimizer 80 | 81 | # check if cfg.optimizer has optimizer_d 82 | if 'optimizer_d' in cfg.optimizer: 83 | # discriminator exists 84 | opt = hydra.utils.instantiate(cfg.optimizer.optimizer)(params=model.vocoder.parameters()) 85 | sche = hydra.utils.instantiate(cfg.optimizer.scheduler)(optimizer=opt) 86 | opt_d = hydra.utils.instantiate(cfg.optimizer.optimizer_d)(params=model.discriminator.parameters()) 87 | sche_d = hydra.utils.instantiate(cfg.optimizer.scheduler_d)(optimizer=opt_d) 88 | else: 89 | opt = hydra.utils.instantiate(cfg.optimizer.optimizer)(params=model.parameters()) 90 | sche = hydra.utils.instantiate(cfg.optimizer.scheduler)(optimizer=opt) 91 | opt_d = None 92 | sche_d = None 93 | 94 | # Log 95 | 96 | model.train() 97 | num_params = count_parameters(model) / 1e6 98 | if accel.is_main_process: 99 | print("=== Parameters ===") 100 | print(f"\tModel:\t{num_params:.2f} [million]") 101 | print("=== Dataset ===") 102 | print(f"\tBatch size: {cfg.trainer.batch_size}") 103 | print("\tTrain data:") 104 | print(f"\t\tChunks: {len(train_dataset)}") 105 | print(f"\t\tBatches: {len(train_dataset)//cfg.trainer.batch_size}") 106 | 107 | # Prepare for DDP 108 | 109 | train_dataloader, model, opt, sche, opt_d, sche_d = \ 110 | accel.prepare(train_dataloader, model, opt, sche, opt_d, sche_d) 111 | 112 | # Start training 113 | 114 | trainer = Trainer( 115 | model=model, 116 | ema=ema, 117 | optimizer=opt, 118 | scheduler=sche, 119 | optimizer_d=opt_d, 120 | scheduler_d=sche_d, 121 | train_dataloader=train_dataloader, 122 | accel=accel, 123 | cfg=cfg, 124 | ckpt_dir=cfg.trainer.ckpt_dir 125 | ) 126 | 127 | trainer.start_training() 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | print("[Training finished.]") 133 | -------------------------------------------------------------------------------- /src/model/wavefit/wavefit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | ----------------- 5 | WaveFit module in Miipher-2. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from transformers.models.gemma3n.modeling_gemma3n import Gemma3nAudioConformerBlock 12 | 13 | from .generator import WaveFitGenerator 14 | from utils.torch_common import checkpoint 15 | 16 | 17 | class ConformerConfig: 18 | def __init__( 19 | self, 20 | conf_attention_chunk_size: int = 12, 21 | conf_attention_context_left: int = 13, 22 | conf_attention_context_right: int = 0, 23 | conf_attention_logit_cap: float = 50.0, 24 | conf_conv_kernel_size: int = 5, 25 | conf_num_attention_heads: int = 8, 26 | conf_reduction_factor: int = 4, 27 | conf_residual_weight: float = 0.5, 28 | gradient_clipping: float = 10000000000.0, 29 | hidden_size: int = 1536, 30 | rms_norm_eps: float = 1e-06 31 | ): 32 | self.conf_attention_chunk_size = conf_attention_chunk_size 33 | self.conf_attention_context_left = conf_attention_context_left 34 | self.conf_attention_context_right = conf_attention_context_right 35 | self.conf_attention_logit_cap = conf_attention_logit_cap 36 | self.conf_conv_kernel_size = conf_conv_kernel_size 37 | self.conf_num_attention_heads = conf_num_attention_heads 38 | self.conf_reduction_factor = conf_reduction_factor 39 | self.conf_residual_weight = conf_residual_weight 40 | self.gradient_clipping = gradient_clipping 41 | self.hidden_size = hidden_size 42 | self.rms_norm_eps = rms_norm_eps 43 | 44 | 45 | class WaveFit(nn.Module): 46 | def __init__( 47 | self, 48 | num_iteration: int, 49 | target_gain: float = 0.9, 50 | # Pre-network Conformer blocks (Sec.3.2) 51 | num_conformer_blocks: int = 4, 52 | args_conformer: dict = { 53 | "conf_attention_chunk_size": 12, 54 | "conf_attention_context_left": 13, 55 | "conf_attention_context_right": 0, 56 | "conf_attention_logit_cap": 50.0, 57 | "conf_conv_kernel_size": 5, 58 | "conf_num_attention_heads": 8, 59 | "conf_reduction_factor": 4, 60 | "conf_residual_weight": 0.5, 61 | "gradient_clipping": 10000000000.0, 62 | "hidden_size": 1536, 63 | "rms_norm_eps": 1e-06 64 | }, 65 | # WaveFit generator 66 | args_generator: dict = { 67 | "dim_feat": 1536, 68 | "upsample_factors": [5, 4, 3, 2, 2], 69 | "upsample_channels": [512, 512, 256, 128, 128], 70 | "downsample_channels": [128, 128, 256, 512], 71 | } 72 | ): 73 | super().__init__() 74 | 75 | self.T = num_iteration 76 | self.target_gain = target_gain 77 | 78 | # Conformer blocks 79 | self.conformer_config = ConformerConfig(**args_conformer) 80 | self.conformer_blocks = nn.ModuleList( 81 | [Gemma3nAudioConformerBlock(self.conformer_config) for _ in range(num_conformer_blocks)] 82 | ) 83 | 84 | # Generator 85 | self.generator = WaveFitGenerator(num_iteration, **args_generator) 86 | self.EPS = 1e-8 87 | 88 | @property 89 | def upsample_rate(self) -> int: 90 | return self.generator.upsample_rate 91 | 92 | def forward( 93 | self, 94 | initial_noise: torch.Tensor, 95 | audio_feats: torch.Tensor, 96 | # You can use this option at inference time 97 | return_only_last: bool = False, 98 | # training config 99 | gradient_checkpointing: bool = False 100 | ): 101 | """ 102 | Args: 103 | initial_noise: Initial noise, (bs, L). 104 | audio_feats: Audio features, (bs, n_frame, dim). 105 | return_only_last: If true, only the last output (y_0) is returned. 106 | Returns: 107 | preds: List of predictions (y_t) 108 | """ 109 | initial_noise = initial_noise.unsqueeze(1) # (bs, 1, L) 110 | assert initial_noise.dim() == audio_feats.dim() == 3 111 | assert initial_noise.size(0) == audio_feats.size(0) 112 | 113 | # Pre-network 114 | mask = torch.zeros(audio_feats.size(0), audio_feats.size(1), dtype=torch.bool, device=audio_feats.device) 115 | for block in self.conformer_blocks: 116 | if gradient_checkpointing and self.training: 117 | audio_feats = checkpoint(block, audio_feats, mask) 118 | else: 119 | audio_feats = block(audio_feats, mask) 120 | 121 | # (bs, n_frame, dim) -> (bs, dim, n_frame) 122 | audio_feats = audio_feats.transpose(1, 2).contiguous() 123 | 124 | preds = [] 125 | y_t = initial_noise 126 | for t in range(self.T): 127 | # estimate noise 128 | if gradient_checkpointing and self.training: 129 | est = checkpoint(self.generator, y_t, audio_feats, t) 130 | else: 131 | est = self.generator(y_t, audio_feats, t) 132 | 133 | y_t = y_t - est 134 | 135 | # gain normalization (Sec.2.3 in the Miipher paper) 136 | y_t = self.normalize_gain(y_t) 137 | 138 | if (not return_only_last) or (t == self.T - 1): 139 | preds.append(y_t.squeeze(1)) 140 | 141 | # To avoid gradient loop 142 | y_t = y_t.detach() 143 | 144 | return preds 145 | 146 | def normalize_gain(self, z_t: torch.Tensor): 147 | # z_t: (bs, 1, L) 148 | scale = self.target_gain / (z_t.squeeze(1).abs().max(dim=1, keepdim=True)[0][:, :, None] + self.EPS) 149 | return z_t * scale 150 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | 209 | # other directories 210 | job/ 211 | runs/ 212 | audio/ 213 | 214 | # other files 215 | *.sif 216 | *.log 217 | *.code-workspace -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.dont_write_bytecode = True 8 | import argparse 9 | import math 10 | 11 | import hydra 12 | import torch 13 | import torchaudio 14 | from accelerate import Accelerator 15 | from omegaconf import OmegaConf 16 | 17 | from model import Miipher2 18 | from utils.torch_common import get_rank, get_world_size, print_once 19 | from data.dataset import get_audio_info 20 | 21 | 22 | def make_audio_batch(audio, sample_size: int, overlap: int): 23 | """ 24 | audio : (ch, L) 25 | """ 26 | assert 0 <= overlap < sample_size 27 | L = audio.shape[-1] 28 | shift = sample_size - overlap 29 | 30 | n_split = math.ceil(max(L - sample_size, 0) / shift) + 1 31 | # to mono 32 | audio = audio.mean(0) # (L) 33 | batch = [] 34 | for n in range(n_split): 35 | b = audio[n * shift: n * shift + sample_size] 36 | if n == n_split - 1: 37 | b = torch.nn.functional.pad(b, (0, sample_size - len(b))) 38 | batch.append(b) 39 | 40 | batch = torch.stack(batch, dim=0) # (n_split, sample_size) 41 | return batch, L 42 | 43 | 44 | def cross_fade(preds, overlap: int, L: int): 45 | """ 46 | preds: (bs, sample_size) 47 | """ 48 | bs, sample_size = preds.shape 49 | shift = sample_size - overlap 50 | full_L = sample_size + (bs - 1) * shift 51 | win = torch.bartlett_window(overlap * 2, device=preds.device) 52 | 53 | buf = torch.zeros(full_L, device=preds.device) 54 | pre_overlap = None 55 | for idx in range(bs): 56 | pred = preds[idx] # (sample_size) 57 | ofs = idx * shift 58 | if idx != 0: 59 | # Fix volume 60 | # NOTE: Since volume is changed by gain normalization in WaveFit module, 61 | # it have to be adjusted not to be discontinuous. 62 | cur_overlap = pred[:overlap] 63 | volume_rescale = (pre_overlap.pow(2).sum() / (cur_overlap.pow(2).sum() + 1e-10)).sqrt() 64 | pred *= volume_rescale 65 | 66 | pred[:overlap] *= win[:overlap] 67 | 68 | if idx != bs - 1: 69 | pre_overlap = pred[-overlap:].clone() 70 | pred[-overlap:] *= win[overlap:] 71 | 72 | buf[ofs:ofs + sample_size] += pred 73 | 74 | buf = buf[:L] 75 | 76 | return buf 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--ckpt-dir', type=str, help="Checkpoint directory.") 82 | parser.add_argument('--input-audio-dir', type=str, help="Root directory which contains input audio files.") 83 | parser.add_argument('--output-dir', type=str, help="Output directory.") 84 | parser.add_argument('--sample-size', type=int, default=160000, help="Input sample size.") 85 | parser.add_argument('--sr-in', type=int, default=16000, help="Input sample rate.") 86 | parser.add_argument('--sr-out', type=int, default=24000, help="Output sample rate.") 87 | parser.add_argument('--max-batch-size', type=int, default=10, help="Max batch size for inference.") 88 | parser.add_argument('--overlap-rate', type=float, default=0.05, help="Overlap rate for inference.") 89 | parser.add_argument('--use-original-name', default=True, type=bool, help="Whether to use an original file name as an output name.") 90 | args = parser.parse_args() 91 | 92 | ckpt_dir = args.ckpt_dir 93 | input_audio_dir = args.input_audio_dir 94 | output_dir = args.output_dir 95 | sample_size = args.sample_size 96 | sr_in = args.sr_in 97 | sr_out = args.sr_out 98 | max_batch_size = args.max_batch_size 99 | overlap_rate = args.overlap_rate 100 | use_original_name = args.use_original_name 101 | 102 | # Distributed inference 103 | accel = Accelerator() 104 | device = accel.device 105 | rank = get_rank() 106 | world_size = get_world_size() 107 | 108 | print_once(f"Checkpoint dir : {ckpt_dir}") 109 | print_once(f"Input audio dir : {input_audio_dir}") 110 | print_once(f"Output dir : {output_dir}") 111 | 112 | # Load Miipher-2 model 113 | cfg_ckpt = OmegaConf.load(f'{ckpt_dir}/config.yaml') 114 | # remove discriminator and MRSTFT modules 115 | cfg_ckpt.model.discriminator = None 116 | cfg_ckpt.model.mrstft = None 117 | 118 | model: Miipher2 = hydra.utils.instantiate(cfg_ckpt.model) 119 | model.load_state_dict(ckpt_dir) 120 | model.to(device) 121 | model.eval() 122 | print_once("->-> Successfully loaded Miipher-2 model from checkpoint.") 123 | 124 | overlap_in = int(sample_size * overlap_rate) 125 | overlap_out = int(overlap_in * sr_out / sr_in) 126 | print_once(f"->-> [Sample size] : {sample_size} samples") 127 | print_once(f"->-> [Overlap size]: {overlap_in} samples ({overlap_in / sample_size * 100:.1f} %)") 128 | 129 | # Get audio files 130 | files, _ = get_audio_info([input_audio_dir]) 131 | print_once(f"->-> Found {len(files)} audio files from {input_audio_dir}.") 132 | os.makedirs(output_dir, exist_ok=True) 133 | 134 | # Split files for each process 135 | files = files[rank::world_size] 136 | 137 | print_once(f"--- Rank-{rank} : Start inference... ---") 138 | 139 | for idx, f_path in enumerate(files): 140 | # load and split audio 141 | audio, sr = torchaudio.load(f_path) 142 | if sr != sr_in: 143 | audio = torchaudio.functional.resample(audio, sr, sr_in) 144 | sr = sr_in 145 | 146 | audio_batch, L = make_audio_batch(audio, sample_size, overlap_in) 147 | n_iter = math.ceil(audio_batch.shape[0] / max_batch_size) 148 | 149 | audio_batch = audio_batch.to(device) 150 | 151 | # execute 152 | preds = [] 153 | for n in range(n_iter): 154 | batch_ = audio_batch[n * max_batch_size:(n + 1) * max_batch_size] 155 | with torch.no_grad(): 156 | pred = model.inference(batch_) # (bs, L) 157 | 158 | preds.append(pred) 159 | 160 | preds = torch.cat(preds, dim=0) 161 | 162 | # cross-fade 163 | L_out = int(L * sr_out / sr_in) 164 | pred_audio = cross_fade(preds, overlap_out, L_out).cpu() 165 | 166 | # rescale volume to avoid clipping 167 | pred_audio = pred_audio / (pred_audio.abs().max() + 1e-8) * 0.9 168 | 169 | # save audio 170 | out_name = os.path.splitext(os.path.basename(f_path))[0] if use_original_name else f"sample_{idx}" 171 | out_path = f"{output_dir}/{out_name}.wav" 172 | torchaudio.save(out_path, pred_audio.unsqueeze(0), sample_rate=sr_out, encoding="PCM_F") 173 | 174 | print(f"--- Rank-{rank} : Finished. ---") 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /src/model/wavefit/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | Adapted from the following repos code under MIT License. 5 | https://github.com/descriptinc/melgan-neurips/ 6 | https://github.com/descriptinc/descript-audio-codec 7 | """ 8 | from abc import ABC, abstractmethod 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.utils.parametrizations import weight_norm 14 | from einops import rearrange 15 | 16 | 17 | def weights_init(m): 18 | classname = m.__class__.__name__ 19 | if classname.find("Conv") != -1: 20 | m.weight.data.normal_(0.0, 0.02) 21 | elif classname.find("BatchNorm2d") != -1: 22 | m.weight.data.normal_(1.0, 0.02) 23 | m.bias.data.fill_(0) 24 | 25 | 26 | def WNConv1d(*args, **kwargs): 27 | return weight_norm(nn.Conv1d(*args, **kwargs)) 28 | 29 | 30 | def WNConv2dWithLeakyReLU(*args, **kwargs): 31 | act = kwargs.pop("act", True) 32 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 33 | if not act: 34 | return conv 35 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 36 | 37 | 38 | class MultiDiscriminator(ABC, nn.Module): 39 | @abstractmethod 40 | def forward(self, x: torch.Tensor, return_feature: bool = True): 41 | pass 42 | 43 | def compute_G_loss(self, x_fake, x_real): 44 | """ 45 | The eq.(18) loss 46 | """ 47 | assert x_fake.shape == x_real.shape 48 | 49 | out_f = self(x_fake, return_feature=True) 50 | with torch.no_grad(): 51 | out_r = self(x_real, return_feature=True) 52 | 53 | num_D = len(self.model) 54 | losses = { 55 | 'disc_gan_loss': 0., 56 | 'disc_feat_loss': 0. 57 | } 58 | 59 | for i_d in range(num_D): 60 | n_layer = len(out_f[i_d]) 61 | 62 | # GAN loss 63 | losses['disc_gan_loss'] += (1 - out_f[i_d][-1]).relu().mean() 64 | 65 | # Feature-matching loss 66 | # eq.(8) 67 | feat_loss = 0. 68 | for i_l in range(n_layer - 1): 69 | feat_loss += F.l1_loss(out_f[i_d][i_l], out_r[i_d][i_l]) 70 | 71 | losses['disc_feat_loss'] += feat_loss / (n_layer - 1) 72 | 73 | losses['disc_gan_loss'] /= num_D 74 | losses['disc_feat_loss'] /= num_D 75 | 76 | return losses 77 | 78 | def compute_D_loss(self, x, mode: str): 79 | """ 80 | The eq.(7) loss 81 | """ 82 | assert mode in ['fake', 'real'] 83 | sign = 1 if mode == 'fake' else -1 84 | 85 | out = self(x, return_feature=False) 86 | 87 | num_D = len(self.model) 88 | losses = {'loss': 0.} 89 | 90 | for i_d in range(num_D): 91 | # Hinge loss 92 | losses['loss'] += (1 + sign * out[i_d][-1]).relu().mean() 93 | 94 | losses['loss'] /= num_D 95 | 96 | return losses 97 | 98 | 99 | class MSDBlock(nn.Module): 100 | def __init__(self, ndf, n_layers, downsampling_factor): 101 | super().__init__() 102 | model = nn.ModuleDict() 103 | 104 | model["layer_0"] = nn.Sequential( 105 | nn.ReflectionPad1d(7), 106 | WNConv1d(1, ndf, kernel_size=15), 107 | nn.LeakyReLU(0.2, True), 108 | ) 109 | 110 | nf = ndf 111 | stride = downsampling_factor 112 | for n in range(1, n_layers + 1): 113 | nf_prev = nf 114 | nf = min(nf * stride, 1024) 115 | 116 | model["layer_%d" % n] = nn.Sequential( 117 | WNConv1d( 118 | nf_prev, 119 | nf, 120 | kernel_size=stride * 10 + 1, 121 | stride=stride, 122 | padding=stride * 5, 123 | groups=nf_prev // 4, 124 | ), 125 | nn.LeakyReLU(0.2, True), 126 | ) 127 | 128 | nf = min(nf * 2, 1024) 129 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 130 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 131 | nn.LeakyReLU(0.2, True), 132 | ) 133 | 134 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 135 | nf, 1, kernel_size=3, stride=1, padding=1 136 | ) 137 | 138 | self.model = model 139 | 140 | def forward(self, x: torch.Tensor, return_feature: bool = True): 141 | """ 142 | Args: 143 | x: input audio, (bs, 1, L) 144 | """ 145 | n_layer = len(self.model) 146 | results = [] 147 | for idx, (key, layer) in enumerate(self.model.items()): 148 | x = layer(x) 149 | if return_feature or (idx == n_layer - 1): 150 | results.append(x) 151 | 152 | return results 153 | 154 | 155 | class MSD(MultiDiscriminator): 156 | """ Multi-scale discriminator """ 157 | 158 | def __init__( 159 | self, 160 | num_D: int = 3, 161 | ndf: int = 16, 162 | n_layers: int = 4, 163 | downsampling_factor: int = 4 164 | ): 165 | super().__init__() 166 | self.model = nn.ModuleDict() 167 | for i in range(num_D): 168 | self.model[f"disc_{i}"] = MSDBlock( 169 | ndf, n_layers, downsampling_factor 170 | ) 171 | 172 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 173 | self.apply(weights_init) 174 | 175 | def forward(self, x: torch.Tensor, return_feature: bool = True): 176 | """ 177 | Args: 178 | x: input audio, (bs, 1, L) 179 | """ 180 | results = [] 181 | for key, disc in self.model.items(): 182 | results.append(disc(x, return_feature)) 183 | x = self.downsample(x) 184 | 185 | return results 186 | 187 | 188 | class MPDBlock(nn.Module): 189 | def __init__(self, period: int): 190 | super().__init__() 191 | self.period = period 192 | self.convs = nn.ModuleList( 193 | [ 194 | WNConv2dWithLeakyReLU(1, 64, (5, 1), (3, 1), padding=(2, 0)), 195 | WNConv2dWithLeakyReLU(64, 128, (5, 1), (3, 1), padding=(2, 0)), 196 | WNConv2dWithLeakyReLU(128, 256, (5, 1), (3, 1), padding=(2, 0)), 197 | WNConv2dWithLeakyReLU(256, 512, (5, 1), (3, 1), padding=(2, 0)), 198 | WNConv2dWithLeakyReLU(512, 1024, (5, 1), 1, padding=(2, 0)), 199 | ] 200 | ) 201 | self.conv_post = WNConv2dWithLeakyReLU( 202 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 203 | ) 204 | 205 | def pad_to_period(self, x): 206 | t = x.shape[-1] 207 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 208 | return x 209 | 210 | def forward(self, x, return_feature: bool = True): 211 | results = [] 212 | 213 | x = self.pad_to_period(x) 214 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 215 | 216 | for layer in self.convs: 217 | x = layer(x) 218 | if return_feature: 219 | results.append(x) 220 | 221 | x = self.conv_post(x) 222 | results.append(x) 223 | 224 | return results 225 | 226 | 227 | class MPD(MultiDiscriminator): 228 | """ Multi-period discriminator """ 229 | 230 | def __init__(self, periods=[2, 3, 5, 7, 11, 13, 17, 19]): 231 | super().__init__() 232 | self.model = nn.ModuleDict() 233 | for p in periods: 234 | self.model[f"disc_{p}"] = MPDBlock(p) 235 | 236 | self.apply(weights_init) 237 | 238 | def forward(self, x: torch.Tensor, return_feature: bool = True): 239 | """ 240 | Args: 241 | x: input audio, (bs, 1, L) 242 | """ 243 | results = [] 244 | for key, disc in self.model.items(): 245 | results.append(disc(x, return_feature)) 246 | 247 | return results 248 | 249 | 250 | class Discriminator(nn.Module): 251 | def __init__( 252 | self, 253 | msd_kwargs: dict = { 254 | "num_D": 3, 255 | "ndf": 16, 256 | "n_layers": 4, 257 | "downsampling_factor": 4 258 | }, 259 | mpd_kwargs: dict = { 260 | "periods": [2, 3, 5, 7, 11, 13, 17, 19] 261 | } 262 | ): 263 | super().__init__() 264 | self.msd = MSD(**msd_kwargs) 265 | self.mpd = MPD(**mpd_kwargs) 266 | 267 | def compute_G_loss(self, x_fake, x_real): 268 | losses = {} 269 | losses_msd = self.msd.compute_G_loss(x_fake, x_real) 270 | losses_mpd = self.mpd.compute_G_loss(x_fake, x_real) 271 | for k in losses_msd.keys(): 272 | losses[k] = (losses_msd[k] + losses_mpd[k]) / 2 273 | losses[f"mpd-{k}"] = losses_mpd[k].detach() 274 | losses[f"msd-{k}"] = losses_msd[k].detach() 275 | 276 | return losses 277 | 278 | def compute_D_loss(self, x, mode: str): 279 | losses = {} 280 | losses_msd = self.msd.compute_D_loss(x, mode) 281 | losses_mpd = self.mpd.compute_D_loss(x, mode) 282 | for k in losses_msd.keys(): 283 | losses[k] = (losses_msd[k] + losses_mpd[k]) / 2 284 | losses[f"mpd-{k}"] = losses_mpd[k].detach() 285 | losses[f"msd-{k}"] = losses_msd[k].detach() 286 | 287 | return losses 288 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import random 7 | import typing as tp 8 | import csv 9 | 10 | import torch 11 | import numpy as np 12 | 13 | from utils.torch_common import print_once, exists 14 | from .audio_io import get_audio_metadata, load_audio_with_pad 15 | from .modification import Mono, PhaseFlipper, VolumeChanger 16 | from .pretransform import GemmaAudioFeature 17 | from .degradation import AudioClipping, NoiseAddition, RIRReverb, AudioLowpass 18 | 19 | 20 | def fast_scandir(dir: str, ext: tp.List[str]): 21 | """ Very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243 22 | 23 | fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py 24 | 25 | Args: 26 | dir (str): top-level directory at which to begin scanning. 27 | ext (tp.List[str]): list of allowed file extensions. 28 | """ 29 | subfolders, files = [], [] 30 | # add starting period to extensions if needed 31 | ext = ['.' + x if x[0] != '.' else x for x in ext] 32 | 33 | try: # hope to avoid 'permission denied' by this try 34 | for f in os.scandir(dir): 35 | try: # 'hope to avoid too many levels of symbolic links' error 36 | if f.is_dir(): 37 | subfolders.append(f.path) 38 | elif f.is_file(): 39 | is_hidden = os.path.basename(f.path).startswith(".") 40 | has_ext = os.path.splitext(f.name)[1].lower() in ext 41 | 42 | if has_ext and (not is_hidden): 43 | files.append(f.path) 44 | except Exception: 45 | pass 46 | except Exception: 47 | pass 48 | 49 | for dir in list(subfolders): 50 | sf, f = fast_scandir(dir, ext) 51 | subfolders.extend(sf) 52 | files.extend(f) 53 | 54 | return subfolders, files 55 | 56 | 57 | def get_info_from_csv(csv_path: str, filepath_tag: str = 'file_path', 58 | other_info_tags: tp.List[str] = ['sample_rate', 'num_frames', 'num_channels']): 59 | file_paths = [] 60 | meta_dicts = [] 61 | with open(csv_path, 'r', newline='') as f: 62 | reader = csv.DictReader(f) # 各行を dict として読み込む 63 | for row in reader: 64 | file_paths.append(row[filepath_tag]) 65 | meta = {k: int(row[k]) for k in other_info_tags} 66 | meta_dicts.append(meta) 67 | 68 | # sort by file path 69 | sorted_indices = np.argsort(file_paths) 70 | file_paths = [file_paths[i] for i in sorted_indices] 71 | meta_dicts = [meta_dicts[i] for i in sorted_indices] 72 | 73 | return file_paths, meta_dicts 74 | 75 | 76 | def get_audio_info( 77 | paths: tp.List[str], # directories in which to search 78 | exts: tp.List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] 79 | ): 80 | """recursively get a list of audio filenames""" 81 | if isinstance(paths, str): 82 | paths = [paths] 83 | 84 | # get a list of relevant filenames 85 | filepaths = [] 86 | metas = [] 87 | for p in paths: 88 | metadata_csv_path = f"{p}/metadata.csv" 89 | if os.path.exists(metadata_csv_path): 90 | # If metadata.csv exists, it's faster to get info 91 | filepaths_, metas_ = get_info_from_csv(metadata_csv_path) 92 | else: 93 | _, filepaths_ = fast_scandir(p, exts) 94 | filepaths_.sort() 95 | metas_ = [] 96 | for filepath in filepaths_: 97 | info = get_audio_metadata(filepath, cache=True) 98 | metas_.append(info) 99 | 100 | filepaths_ = [os.path.join(p, f) for f in filepaths_] 101 | filepaths.extend(filepaths_) 102 | metas.extend(metas_) 103 | 104 | return filepaths, metas 105 | 106 | 107 | class DegradedAudioDataset(torch.utils.data.Dataset): 108 | """ 109 | A dataset class for loading a pair of a target audio and a degraded audio. 110 | NOTE: This dataset supports only monoral output. 111 | """ 112 | 113 | def __init__( 114 | self, 115 | dirs_audio: tp.List[str], 116 | dirs_noise: tp.Optional[tp.List[str]] = None, 117 | sample_size: int = 120000, 118 | sample_rate: int = 24000, 119 | pretransform: tp.Optional[str] = None, 120 | exts: tp.List[str] = ['wav', 'flac'], 121 | # augmentation 122 | augment_shift: bool = True, 123 | augment_flip: bool = True, 124 | augment_volume: bool = True, 125 | volume_range: tp.Tuple[float, float] = (0.25, 1.0), 126 | # degradation 127 | deg_types: tp.List[str] = ['clipping', 'noise', 'reverb', 'lowpass'], 128 | n_deg_comb: int = 3, # maximum number of combined degradations 129 | prob_no_deg: float = 0.05, # probability of no degradation samples 130 | clean_only: bool = False, # If true, only clean audio is returned. 131 | ): 132 | 133 | super().__init__() 134 | self.sample_size = sample_size 135 | self.sr = sample_rate 136 | self.augment_shift = augment_shift 137 | 138 | self.deg_types = deg_types 139 | self.n_deg_comb = min(n_deg_comb, len(deg_types)) 140 | self.prob_no_deg = prob_no_deg 141 | self.clean_only = clean_only 142 | 143 | print_once('[Dataset instantiation]') 144 | 145 | # Degradation modules 146 | self.degradations = { 147 | 'clipping': AudioClipping(sample_rate=sample_rate), 148 | 'noise': NoiseAddition(sample_rate=sample_rate), 149 | 'reverb': RIRReverb(sample_rate=sample_rate), 150 | 'lowpass': AudioLowpass(sample_rate=sample_rate) 151 | } 152 | 153 | # Audio augmentations 154 | self.ch_encoding = torch.nn.Sequential(Mono()) 155 | self.augs = torch.nn.Sequential( 156 | PhaseFlipper() if augment_flip else torch.nn.Identity(), 157 | VolumeChanger(*volume_range) if augment_volume else torch.nn.Identity() 158 | ) 159 | 160 | # Pre-transform 161 | if pretransform == "gemma": 162 | self.pretransform = GemmaAudioFeature() 163 | else: 164 | self.pretransform = None 165 | 166 | # find all audio files 167 | print_once('\t->-> Searching AUDIO files...') 168 | self.filepaths, self.metas = get_audio_info(dirs_audio, exts=exts) 169 | print_once(f'\t->-> Found {len(self.filepaths)} AUDIO files.') 170 | 171 | if exists(dirs_noise) and ('noise' in self.deg_types) and (not self.clean_only): 172 | print_once('\t->-> Searching NOISE files...') 173 | self.filepaths_noise, self.metas_noise = get_audio_info(dirs_noise, exts=exts) 174 | self.n_noise_files = len(self.filepaths_noise) 175 | print_once(f'\t->-> Found {self.n_noise_files} NOISE files.') 176 | 177 | def get_track_info(self, idx, noise: bool = False): 178 | if not noise: 179 | filepath = self.filepaths[idx] 180 | info = self.metas[idx] 181 | else: 182 | filepath = self.filepaths_noise[idx] 183 | info = self.metas_noise[idx] 184 | 185 | max_ofs = max(0, info['num_frames'] - self.sample_size) 186 | offset = random.randint(0, max_ofs) if (self.augment_shift and max_ofs) else 0 187 | return filepath, offset, info 188 | 189 | def __len__(self): 190 | return len(self.filepaths) 191 | 192 | def __getitem__(self, idx): 193 | 194 | filename, offset, info = self.get_track_info(idx) 195 | # Load audio 196 | audio = load_audio_with_pad(filename, info, self.sr, self.sample_size, offset) 197 | # To mono 198 | audio = self.ch_encoding(audio).squeeze(0) # (L,) 199 | # Audio augmentations 200 | audio = self.augs(audio) 201 | target = self.pretransform(audio, self.sr) if exists(self.pretransform) else audio 202 | 203 | # Degradation 204 | if self.clean_only: 205 | deg = deg_audio = 0. 206 | else: 207 | num_deg = random.randint(1, self.n_deg_comb) if random.random() > self.prob_no_deg else 0 208 | # randomly sample degradations 209 | deg_types = random.sample(self.deg_types, num_deg) 210 | deg_audio = audio.clone() 211 | # apply degradations 212 | for deg_type in deg_types: 213 | if deg_type == 'noise': 214 | # randomly sample a noise file 215 | if self.n_noise_files > 0: 216 | noise_idx = random.randint(0, self.n_noise_files - 1) 217 | filename, offset, info = self.get_track_info(noise_idx, noise=True) 218 | noise = load_audio_with_pad(filename, info, self.sr, self.sample_size, offset) 219 | noise = self.ch_encoding(noise).squeeze(0) # (L,) 220 | else: 221 | # if no noise files are found, use Gaussian noise 222 | print_once("No noise files found. Using Gaussian noise.") 223 | noise = torch.randn_like(deg_audio) 224 | deg_audio = self.degradations[deg_type](deg_audio, noise) 225 | else: 226 | deg_audio = self.degradations[deg_type](deg_audio) 227 | 228 | # print(f"Deg ({deg_type}): {deg_audio.abs().max().item()}") 229 | 230 | deg = self.pretransform(deg_audio, self.sr) if exists(self.pretransform) else deg_audio 231 | info['deg_types'] = deg_types + ['none'] * (self.n_deg_comb - len(deg_types)) 232 | 233 | return target, deg, audio, deg_audio, info 234 | -------------------------------------------------------------------------------- /src/model/wavefit/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | Adapted from the following repo's code under Apache License 2.0. 5 | https://github.com/lmnt-com/wavegrad/ 6 | """ 7 | 8 | import typing as tp 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class Conv1d(nn.Conv1d): 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | nn.init.orthogonal_(self.weight) 23 | nn.init.zeros_(self.bias) 24 | 25 | 26 | class SinusoidalPositionalEncoding(nn.Module): 27 | def __init__(self, dim: int, max_iter: int, use_conv: bool = True): 28 | super().__init__() 29 | self.dim = dim 30 | self.max_iter = max_iter 31 | self.use_conv = use_conv 32 | assert dim % 2 == 0 33 | 34 | if use_conv: 35 | # 1x1 conv 36 | self.conv = nn.Conv1d(dim, dim, 1) 37 | nn.init.xavier_uniform_(self.conv.weight) 38 | nn.init.zeros_(self.conv.bias) 39 | 40 | # pre-compute positional embedding 41 | pos_embs = self.prepare_embedding() # (max_iter, dim) 42 | self.register_buffer('pos_embs', pos_embs) 43 | 44 | def forward(self, x, t: int): 45 | """ 46 | Args: 47 | x: (bs, dim, T) 48 | t: Step index 49 | 50 | Returns: 51 | x_with_pos: (bs, dim, T) 52 | """ 53 | assert 0 <= t < self.max_iter, f"Invalid step index {t}. It must be 0 <= t < {self.max_iter} = max_iter." 54 | pos_emb = self.pos_embs[t][None, :, None] 55 | if self.use_conv: 56 | pos_emb = self.conv(pos_emb) 57 | 58 | return x + pos_emb 59 | 60 | def prepare_embedding(self, scale: float = 5000.): 61 | dim_h = self.dim // 2 62 | pos = torch.linspace(0., scale, self.max_iter) 63 | div_term = torch.exp(- math.log(10000.0) * torch.arange(dim_h) / dim_h) 64 | pos = pos[:, None] @ div_term[None, :] # (max_iter, dim_h) 65 | pos_embs = torch.cat([torch.sin(pos), torch.cos(pos)], dim=-1) # (max_iter, dim) 66 | return pos_embs 67 | 68 | 69 | class MemEfficientFiLM(nn.Module): 70 | def __init__(self, input_size: int, output_size: int, max_iter: int): 71 | super().__init__() 72 | self.step_condition = SinusoidalPositionalEncoding(input_size, max_iter, use_conv=True) 73 | self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) 74 | self.output_conv_1 = nn.Conv1d(input_size, output_size, 3, padding=1) 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | nn.init.xavier_uniform_(self.input_conv.weight) 79 | nn.init.zeros_(self.input_conv.bias) 80 | nn.init.xavier_uniform_(self.output_conv_1.weight) 81 | nn.init.zeros_(self.output_conv_1.bias) 82 | 83 | def forward(self, x, t: int): 84 | x = self.input_conv(x) 85 | x = F.leaky_relu(x, 0.2) 86 | x = self.step_condition(x, t) 87 | shift = self.output_conv_1(x) 88 | 89 | return shift, None 90 | 91 | 92 | class EmptyFiLM(nn.Module): 93 | def __init__(self): 94 | super().__init__() 95 | 96 | def forward(self, x, t): 97 | return 0, 1 98 | 99 | 100 | class UBlock(nn.Module): 101 | def __init__(self, input_size, hidden_size, factor, dilation): 102 | super().__init__() 103 | assert isinstance(dilation, (list, tuple)) 104 | assert len(dilation) == 4 105 | 106 | self.factor = factor 107 | self.block1 = Conv1d(input_size, hidden_size, 1) 108 | self.block2 = nn.ModuleList([ 109 | Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), 110 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]) 111 | ]) 112 | self.block3 = nn.ModuleList([ 113 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), 114 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]) 115 | ]) 116 | 117 | def forward(self, x, film_shift, film_scale: tp.Optional[torch.Tensor]): 118 | if film_scale is None: 119 | film_scale = 1.0 120 | 121 | block1 = F.interpolate(x, size=x.shape[-1] * self.factor) 122 | block1 = self.block1(block1) 123 | 124 | block2 = F.leaky_relu(x, 0.2) 125 | block2 = F.interpolate(block2, size=x.shape[-1] * self.factor) 126 | block2 = self.block2[0](block2) 127 | block2 = film_shift + film_scale * block2 128 | block2 = F.leaky_relu(block2, 0.2) 129 | block2 = self.block2[1](block2) 130 | 131 | x = block1 + block2 132 | 133 | block3 = film_shift + film_scale * x 134 | block3 = F.leaky_relu(block3, 0.2) 135 | block3 = self.block3[0](block3) 136 | block3 = film_shift + film_scale * block3 137 | block3 = F.leaky_relu(block3, 0.2) 138 | block3 = self.block3[1](block3) 139 | 140 | x = x + block3 141 | 142 | return x 143 | 144 | 145 | class DBlock(nn.Module): 146 | def __init__(self, input_size, hidden_size, factor): 147 | super().__init__() 148 | self.factor = factor 149 | 150 | # self.residual_dense = Conv1d(input_size, hidden_size, 1) 151 | # self.conv = nn.ModuleList([ 152 | # Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), 153 | # Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), 154 | # Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), 155 | # ]) 156 | 157 | # NOTE : This might be the correct architecture rather than the above one 158 | # since parameter size is quite closer to the reported size in the WaveGrad paper (15M). 159 | self.residual_dense = Conv1d(input_size, input_size, 1) 160 | self.conv = nn.ModuleList([ 161 | Conv1d(input_size, input_size, 3, dilation=1, padding=1), 162 | Conv1d(input_size, input_size, 3, dilation=2, padding=2), 163 | Conv1d(input_size, hidden_size, 3, dilation=4, padding=4), 164 | ]) 165 | 166 | # downsampling module using Conv1d 167 | # NOTE: When using kernel_size=3 for all downsampling factors, 168 | # the parameter size of generator is 15.12 millions. 169 | kernel_size = factor // 2 * 2 + 1 170 | padding = kernel_size // 2 171 | self.down1 = Conv1d(input_size, hidden_size, kernel_size, padding=padding, stride=factor) 172 | self.down2 = Conv1d(input_size, input_size, kernel_size, padding=padding, stride=factor) 173 | 174 | def forward(self, x): 175 | residual = self.residual_dense(x) 176 | residual = self.down1(residual) 177 | 178 | x = self.down2(x) 179 | for layer in self.conv: 180 | x = F.leaky_relu(x, 0.2) 181 | x = layer(x) 182 | 183 | return x + residual 184 | 185 | 186 | class WaveFitGenerator(nn.Module): 187 | """ 188 | WaveFit generator module based on WaveGrad. 189 | See https://arxiv.org/abs/2009.00713 for details. 190 | """ 191 | 192 | def __init__( 193 | self, 194 | num_iteration: int, 195 | dim_feat: int = 1536, 196 | upsample_factors: tp.List[int] = [5, 4, 3, 2, 2], 197 | upsample_channels: tp.List[int] = [512, 512, 256, 128, 128], 198 | downsample_channels: tp.List[int] = [128, 128, 256, 512] 199 | ): 200 | super().__init__() 201 | 202 | self.dim_feat = dim_feat 203 | self.upsample_factors = upsample_factors 204 | self.upsample_channels = upsample_channels 205 | self.downsample_factors = upsample_factors[1:][::-1] # e.g. [2, 2, 3, 4] 206 | self.downsample_channels = downsample_channels 207 | assert len(upsample_factors) == len(upsample_channels) == len(downsample_channels) + 1 208 | self.upsample_rate = math.prod(upsample_factors) 209 | 210 | # Downsampling blocks 211 | ch_first_down = 32 212 | self.downsample = nn.ModuleList([Conv1d(1, ch_first_down, 5, padding=2)]) 213 | for i, (factor, ch_out) in enumerate(zip(self.downsample_factors, self.downsample_channels)): 214 | ch_in = ch_first_down if i == 0 else self.downsample_channels[i - 1] 215 | self.downsample.append(DBlock(ch_in, ch_out, factor)) 216 | 217 | # FiLM layers 218 | self.film = nn.ModuleList([EmptyFiLM()]) 219 | for i in range(len(self.downsample_channels)): 220 | ch_in, ch_out = self.downsample_channels[i], self.upsample_channels[-(i + 2)] 221 | self.film.append(MemEfficientFiLM(ch_in, ch_out, num_iteration)) 222 | 223 | # Upsampling blocks 224 | # NOTE: Dilation factors in a 5-block case follows an implementation in the WaveGrad paper. 225 | # Cases other than 5-block are not verified. 226 | self.upsample = nn.ModuleList() 227 | for i, (factor, ch_out) in enumerate(zip(self.upsample_factors, self.upsample_channels)): 228 | ch_in = self.dim_feat if i == 0 else self.upsample_channels[i - 1] 229 | dilations = [1, 2, 4, 8] if i < 3 else [1, 2, 1, 2] 230 | self.upsample.append(UBlock(ch_in, ch_out, factor, dilations)) 231 | 232 | self.last_conv = Conv1d(128, 1, 3, padding=1) 233 | 234 | def forward(self, y_t: torch.Tensor, audio_feats: torch.Tensor, t: int): 235 | """ 236 | Args: 237 | y_t: Noisy input, (bs, 1, L) 238 | audio_feats: Audio features, (bs, dim_feat, num_frame) 239 | t: Step index 240 | Returns: 241 | n_hat: Estimated noise, (bs, 1, L) 242 | """ 243 | bs, ch, L = y_t.size() 244 | bs_, dim, num_frame = audio_feats.size() 245 | assert bs == bs_ and dim == self.dim_feat and ch == 1 246 | assert L == num_frame * self.upsample_rate, f"Length mismatch: {L} != {num_frame} * {self.upsample_rate}" 247 | 248 | x = y_t 249 | 250 | # downsampling 251 | downsampled = [] 252 | for film, layer in zip(self.film, self.downsample): 253 | x = layer(x) 254 | downsampled.append(film(x, t)) 255 | 256 | # upsampling and FiLM 257 | x = audio_feats 258 | for layer, (film_shift, film_scale) in zip(self.upsample, reversed(downsampled)): 259 | x = layer(x, film_shift, film_scale) 260 | 261 | # to monoral 262 | x = self.last_conv(x) 263 | 264 | return x 265 | -------------------------------------------------------------------------------- /src/model/feature_cleaner/google_usm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | Adapted from the following repo's code under Apache-2.0 License. 5 | https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py 6 | 7 | ----------------------------------------------------- 8 | Universal Speech Model (USM) from Google. 9 | """ 10 | import typing as tp 11 | import gc 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.nn as nn 16 | from transformers import Gemma3nAudioEncoder, Gemma3nAudioFeatureExtractor 17 | from transformers.models.gemma3n.modeling_gemma3n import Gemma3nRMSNorm 18 | 19 | from .parallel_adapter import AdapterLayer 20 | from .base import AudioEncoderAdapter 21 | from utils.torch_common import exists 22 | 23 | 24 | class GoogleUSMAdapter(AudioEncoderAdapter): 25 | """ 26 | Parallel adapter for Google USM described in Fig.1 of the Miipher-2 paper. 27 | 28 | NOTE: The shared layer norm before the adapter layers seems to be missing in the Gemma3n implementation. 29 | Instead, I tentatively introduce a pre-layer norm here for each adapter layer. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | n_adaptive_layers: tp.Optional[int] = None, 35 | model_id: str = "google/gemma-3n-e2b-it", 36 | encoder_id: str = "Atotti/google-usm", 37 | adapter_config: dict = { 38 | "dim_bottleneck": 1024, 39 | "init_option": "bert", 40 | "adapter_scalar": 1.0, 41 | "pre_ln_class": Gemma3nRMSNorm 42 | } 43 | ): 44 | super().__init__() 45 | self.model_id = model_id 46 | self.encoder_id = encoder_id 47 | 48 | # Feature extractor (mel-spectrogram) 49 | # NOTE: This feature extractor is not used when an input is a mel-spectrogram (e.g. forward function). 50 | # This could be used at an inference time when the input is a waveform. 51 | self.feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model_id) 52 | 53 | # Main audio encoder 54 | self.audio_encoder = Gemma3nAudioEncoder.from_pretrained(encoder_id) 55 | self.n_adaptive_layers = min(n_adaptive_layers, self.n_layers) if exists(n_adaptive_layers) else self.n_layers 56 | self.dim = self.audio_encoder.config.hidden_size 57 | self.adapter_config = adapter_config 58 | self.adapter_config["dim_in"] = self.dim 59 | self.adapter_config["pre_ln_class"] = globals()[adapter_config["pre_ln_class"]] \ 60 | if isinstance(adapter_config["pre_ln_class"], str) else adapter_config["pre_ln_class"] 61 | 62 | # Remove unused layers 63 | layers = self.audio_encoder.conformer 64 | self.audio_encoder.conformer = nn.ModuleList(layers[:self.n_adaptive_layers]) 65 | del layers # Free memory 66 | gc.collect() 67 | torch.cuda.empty_cache() 68 | 69 | # Exclude encoder from learnable modules 70 | for param in self.audio_encoder.parameters(): 71 | param.requires_grad = False 72 | self.audio_encoder.eval() 73 | 74 | # Adapter layers 75 | self.adapter_layers = nn.ModuleList() 76 | self.adapter_norms = nn.ModuleList() 77 | for _ in range(self.n_adaptive_layers): 78 | self.adapter_layers.append(AdapterLayer(**self.adapter_config)) 79 | self.adapter_norms.append(Gemma3nRMSNorm(self.dim)) 80 | 81 | @property 82 | def n_layers(self) -> int: 83 | return len(self.audio_encoder.conformer) 84 | 85 | def eval(self): 86 | super().eval() 87 | return self 88 | 89 | def train(self, mode=True): 90 | super().train(mode) 91 | self.audio_encoder.eval() # Keep the encoder in eval mode 92 | return self 93 | 94 | def forward( 95 | self, 96 | audio_mel: torch.Tensor, 97 | encoder_only: bool = False 98 | ) -> torch.Tensor: 99 | """ 100 | Args: 101 | audio_mel (torch.Tensor): (bs, n_frame, mel_bins) 102 | encoder_only (bool): If True, only the encoder is used without adaptation. 103 | Returns: 104 | (torch.Tensor): (bs, n_frame', dim) 105 | """ 106 | bs, n_frame, _ = audio_mel.shape 107 | # NOTE: 'False' for valid frames, 'True' for padded frames 108 | audio_mel_mask = torch.zeros((bs, n_frame), dtype=torch.bool, device=audio_mel.device) 109 | 110 | feats = self._encoder_forward(audio_mel, audio_mel_mask, encoder_only) 111 | 112 | return feats 113 | 114 | def _encoder_forward( 115 | self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, encoder_only: bool = False 116 | ) -> tuple[torch.Tensor, torch.BoolTensor]: 117 | """Encodes a batch of MELs. 118 | 119 | Args: 120 | audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels, mel_bins]. 121 | audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. 122 | 123 | Returns: 124 | feats: a torch.Tensor of shape `[batch_size, frame_length, self.dim]` 125 | """ 126 | audio_encodings = self.audio_encoder.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D] 127 | 128 | # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub) 129 | t_sub = audio_encodings.shape[1] 130 | 131 | time_stride_product = 1 132 | for stride_pair_idx in range(len(self.audio_encoder.config.sscp_conv_stride_size)): 133 | time_stride_product *= self.audio_encoder.config.sscp_conv_stride_size[stride_pair_idx][0] 134 | 135 | # Create indices for gathering from the original mask. 136 | # These indices map to original time steps corresponding to the start of each 137 | # receptive field in the subsampled output. 138 | indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product 139 | indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid 140 | 141 | # Expand indices for batch compatibility if B > 1 and indices is 1D. 142 | if audio_mel_mask.ndim > 1 and indices.ndim == 1: 143 | indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub] 144 | elif ( 145 | audio_mel_mask.ndim == indices.ndim 146 | and audio_mel_mask.shape[0] == 1 147 | and indices.shape[0] != 1 148 | and t_sub == indices.shape[0] 149 | ): 150 | # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub] 151 | indices = indices.unsqueeze(0) 152 | 153 | current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub] 154 | 155 | # Adaptation 156 | feats = audio_encodings 157 | for i, (conformer_block, adapter_layer, adapter_norm) \ 158 | in enumerate(zip(self.audio_encoder.conformer, self.adapter_layers, self.adapter_norms)): 159 | feats = self._block_forward( 160 | conformer_block, 161 | adapter_layer, 162 | adapter_norm, 163 | feats, 164 | current_mask, 165 | encoder_only 166 | ) 167 | 168 | return feats 169 | 170 | def _block_forward( 171 | self, 172 | c_block, 173 | adapter_layer, 174 | adapter_norm, 175 | feats: torch.Tensor, 176 | mask: torch.BoolTensor, 177 | encoder_only: bool = False 178 | ) -> torch.Tensor: 179 | """ Gemma3nAudioConformerBlock forward pass with adaptation (Fig.1) """ 180 | 181 | feats = c_block.ffw_layer_start(feats) 182 | feats = c_block.attention(feats, mask) 183 | validity_mask_for_lconv = ~mask # True for valid 184 | feats_for_lconv_input = feats * validity_mask_for_lconv.unsqueeze(-1).to(feats.dtype) 185 | feats_conv = c_block.lconv1d(feats_for_lconv_input) 186 | feats = c_block.ffw_layer_end(feats_conv) # feats_conv + feats_mlp 187 | feats = torch.clamp(feats, -c_block.gradient_clipping, c_block.gradient_clipping) 188 | # NOTE: This layer norm doesn't exist in the Fig.1 of the paper. 189 | feats = c_block.norm(feats) 190 | 191 | if not encoder_only: 192 | # Adapter layer 193 | feats_adapt = adapter_layer(feats_conv) 194 | feats = feats + feats_adapt 195 | # Post layer norm 196 | feats = adapter_norm(feats) 197 | 198 | return feats 199 | 200 | @torch.no_grad() 201 | def forward_waveform( 202 | self, 203 | audio: torch.Tensor, 204 | encoder_only: bool = False, 205 | device: tp.Optional[torch.device] = None 206 | ) -> torch.Tensor: 207 | """ 208 | Args: 209 | audio (torch.Tensor): Monoral audio, (bs, n_sample) 210 | encoder_only (bool): If True, only the encoder is used without adaptation. 211 | Returns: 212 | (torch.Tensor): (bs, n_frame', dim) 213 | """ 214 | # Feature extraction 215 | audio_np = audio.cpu().numpy() 216 | output = self.feature_extractor(audio_np, return_tensors="pt") 217 | audio_mel = output["input_features"] 218 | 219 | device = audio.device if device is None else device 220 | audio_mel = audio_mel.to(device) 221 | 222 | # Forward 223 | feats = self(audio_mel, encoder_only) 224 | 225 | return feats 226 | 227 | def save_state_dict(self, dir_save: str): 228 | """ 229 | Save only the adapter layer and adaptive norm parameters. 230 | """ 231 | state = { 232 | 'adapter_layers': self.adapter_layers.state_dict(), 233 | 'adapter_norms': self.adapter_norms.state_dict() 234 | } 235 | torch.save(state, Path(dir_save) / "model.pth") 236 | 237 | def load_state_dict(self, dir_load: tp.Optional[str] = None, state: tp.Optional[dict] = None): 238 | """ 239 | Load only the adapter layer and adaptive norm parameters. 240 | """ 241 | assert exists(dir_load) or exists(state), "Either dir_load or state must be provided." 242 | state = state if exists(state) else torch.load(Path(dir_load) / "model.pth") 243 | self.adapter_layers.load_state_dict(state['adapter_layers']) 244 | self.adapter_norms.load_state_dict(state['adapter_norms']) 245 | 246 | def get_state_dict(self): 247 | state = { 248 | 'adapter_layers': self.adapter_layers.state_dict(), 249 | 'adapter_norms': self.adapter_norms.state_dict() 250 | } 251 | 252 | return state 253 | -------------------------------------------------------------------------------- /src/model/miipher_2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | 4 | -------------- 5 | Miipher-2 model 6 | """ 7 | import typing as tp 8 | from enum import Enum 9 | from pathlib import Path 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | 15 | from utils.torch_common import exists, print_once 16 | 17 | 18 | class MiipherMode(Enum): 19 | """ 20 | Miipher-2 processing modes. 21 | 22 | CLEAN_INPUT: Clean audio is processed by a non-adaptive feature cleaner (WaveFit pretraining). 23 | NOISY_INPUT: Noisy audio is processed by an adaptive feature cleaner (WaveFit finetuning / Inference). 24 | """ 25 | CLEAN_INPUT = 'clean_input' # Clean input waveform 26 | NOISY_INPUT = 'noisy_input' # Noisy input waveform 27 | 28 | 29 | class Miipher2(nn.Module): 30 | """ 31 | Miipher-2 model consists of a feature cleaner and a WaveFit vocoder. 32 | 33 | Args: 34 | feature_cleaner (nn.Module): Feature cleaner model (e.g., Google-USM with adapter layers). 35 | vocoder (nn.Module): Vocoder model (e.g., WaveFit-5). 36 | mode (MiipherMode): Processing mode. See MiipherMode for details. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | feature_cleaner: nn.Module, 42 | vocoder: nn.Module, 43 | mode: str = 'noisy_input', 44 | # modules for vocoder training 45 | discriminator: tp.Optional[nn.Module] = None, 46 | mrstft: tp.Optional[nn.Module] = None, 47 | loss_lambdas: dict = {}, 48 | # upsampling before vocoder 49 | upsample_factor: int = 4, 50 | upsample_mode: str = 'linear', 51 | # pretrained checkpoints are required for finetuning 52 | feature_cleaner_ckpt_dir: tp.Optional[str] = None, 53 | vocoder_ckpt_dir: tp.Optional[str] = None, 54 | # training config 55 | gradient_checkpointing: bool = False 56 | ): 57 | super().__init__() 58 | self.feature_cleaner = feature_cleaner 59 | self.vocoder = vocoder 60 | self.discriminator = discriminator 61 | self.mrstft = mrstft 62 | self.loss_lambdas = loss_lambdas 63 | self._mode = MiipherMode(mode) 64 | self.upsample_factor = upsample_factor 65 | self.upsample_mode = upsample_mode 66 | self.gradient_checkpointing = gradient_checkpointing 67 | 68 | if exists(feature_cleaner_ckpt_dir): 69 | self.feature_cleaner.load_state_dict(feature_cleaner_ckpt_dir) 70 | print_once(f"[Miipher-2 class] Loaded feature cleaner state from {feature_cleaner_ckpt_dir}") 71 | 72 | if exists(vocoder_ckpt_dir): 73 | state_vocoder = torch.load(Path(vocoder_ckpt_dir) / "vocoder.pth") 74 | self.vocoder.load_state_dict(state_vocoder) 75 | print_once(f"[Miipher-2 class] Loaded vocoder state from {vocoder_ckpt_dir}") 76 | if exists(self.discriminator): 77 | state_discriminator = torch.load(Path(vocoder_ckpt_dir) / "discriminator.pth") 78 | self.discriminator.load_state_dict(state_discriminator) 79 | print_once(f"[Miipher-2 class] Loaded discriminator state from {vocoder_ckpt_dir}") 80 | 81 | # no gradient for feature cleaner 82 | for param in self.feature_cleaner.parameters(): 83 | param.requires_grad = False 84 | 85 | def eval(self): 86 | super().eval() 87 | return self 88 | 89 | def train(self, mode=True): 90 | super().train(mode) 91 | self.feature_cleaner.eval() # Keep the feature cleaner in eval mode 92 | return self 93 | 94 | def set_mode(self, mode: tp.Union[MiipherMode, str]): 95 | mode = MiipherMode(mode) if isinstance(mode, str) else mode 96 | self._mode = mode 97 | 98 | @property 99 | def mode(self) -> MiipherMode: 100 | return self._mode 101 | 102 | def forward(self, mel_spec: torch.Tensor, initial_noise: torch.Tensor) -> torch.Tensor: 103 | """ 104 | Args: 105 | mel_spec (torch.Tensor): Input mel-spectrogram, (B, num_frames, feature_dim) 106 | initial_noise (torch.Tensor): Initial noise for the vocoder, (B, L) 107 | Returns: 108 | vocoder_output (List[torch.Tensor]): Output waveform, (B, L') 109 | """ 110 | # Mel-spectrogram to audio encoder feature 111 | encoder_only = (self._mode == MiipherMode.CLEAN_INPUT) 112 | with torch.no_grad(): 113 | feats = self.feature_cleaner(mel_spec, encoder_only=encoder_only) # (bs, num_frames, dim) 114 | 115 | # Upsample features to match the vocoder's input frame rate (Sec.2.3) 116 | feats = feats.transpose(1, 2) # (bs, dim, num_frames) 117 | feats = F.interpolate(feats, scale_factor=self.upsample_factor, mode=self.upsample_mode) 118 | feats = feats.transpose(1, 2) # (bs, num_frames, dim) 119 | 120 | # Audio encoder feature to waveform 121 | vocoder_output = self.vocoder(initial_noise, feats, 122 | gradient_checkpointing=self.gradient_checkpointing) # List of (B, L) 123 | 124 | return vocoder_output 125 | 126 | @torch.no_grad() 127 | def inference( 128 | self, 129 | input_waveform: torch.Tensor, 130 | initial_noise: tp.Optional[torch.Tensor] = None 131 | ) -> torch.Tensor: 132 | """ 133 | Inference with waveform input. 134 | 135 | Args: 136 | input_waveform (torch.Tensor): Input waveform, 16kh, (B, L) 137 | Returns: 138 | decoded_waveform (torch.Tensor): Output waveform, 24khz, (B, L') 139 | """ 140 | 141 | # 16khz -> 24khz 142 | if exists(initial_noise): 143 | assert initial_noise.shape[-1] == int(input_waveform.shape[-1] * 1.5) 144 | else: 145 | initial_noise = torch.randn(input_waveform.shape[0], int(input_waveform.shape[-1] * 1.5), device=input_waveform.device) 146 | 147 | # Waveform to audio encoder feature 148 | feats = self.feature_cleaner.forward_waveform(input_waveform, encoder_only=False) 149 | 150 | # Upsample features to match the vocoder's input frame rate (Sec.2.3) 151 | feats = feats.transpose(1, 2) # (bs, dim, num_frames) 152 | feats = F.interpolate(feats, scale_factor=self.upsample_factor, mode=self.upsample_mode) 153 | feats = feats.transpose(1, 2) # (bs, num_frames, dim) 154 | 155 | # Audio encoder feature to waveform 156 | decoded_waveform = self.vocoder(initial_noise, feats, return_only_last=True)[-1] # (B, L) 157 | 158 | return decoded_waveform 159 | 160 | def train_step( 161 | self, 162 | target_audio: torch.Tensor, 163 | input_mel_spec: torch.Tensor, 164 | train: bool = True 165 | ) -> dict: 166 | """ 167 | Training step for the Miipher-2 model. 168 | 169 | Args: 170 | target_audio (torch.Tensor): Target audio waveform, (B, L) 171 | input_mel_spec (torch.Tensor): Input mel-spectrogram, (B, num_frames, feature_dim) 172 | train (bool): Whether in training mode. 173 | 174 | Returns: 175 | dict: Losses and metrics for the training step. 176 | """ 177 | assert exists(self.discriminator) and exists(self.mrstft), "Discriminator and MRSTFT must be provided for training." 178 | 179 | self.train() if train else self.eval() 180 | 181 | # Fix the gain of target audio 182 | # NOTE: Note that the gain of target audio is specified by the WaveFit model 183 | # due to the gain normalization. 184 | scale = self.vocoder.target_gain / (target_audio.abs().max(dim=1, keepdim=True)[0] + 1e-8) 185 | target_audio = target_audio * scale 186 | 187 | # initial noise 188 | initial_noise = torch.randn_like(target_audio) 189 | 190 | target_audio = target_audio.unsqueeze(1) # (bs, 1, L) 191 | assert target_audio.dim() == 3 and input_mel_spec.dim() == 3 192 | assert target_audio.size(0) == input_mel_spec.size(0) 193 | 194 | # Forward pass 195 | preds = self(input_mel_spec, initial_noise) 196 | 197 | # Vocoder losses 198 | losses = {} 199 | for pred in preds: 200 | pred = pred.unsqueeze(1) # (bs, 1, L) 201 | losses_i = {} 202 | losses_i.update(self.mrstft(pred, target_audio)) 203 | losses_i.update(self.discriminator.compute_G_loss(pred, target_audio)) 204 | for k, v in losses_i.items(): 205 | losses[k] = losses.get(k, 0.) + v / len(preds) 206 | 207 | loss = 0. 208 | for k in self.loss_lambdas.keys(): 209 | losses[k] = losses[k] * self.loss_lambdas[k] 210 | loss += losses[k] 211 | 212 | # Discriminator loss 213 | out_real = self.discriminator.compute_D_loss(target_audio, mode='real') 214 | loss_d_real = out_real.pop('loss') 215 | losses.update({f"D/{k}_real": v for k, v in out_real.items()}) 216 | # NOTE: Discriminator loss is also computed for all intermediate predictions (Sec.4.2) 217 | loss_d_fake = 0. 218 | out_fake = {} 219 | for pred in preds: 220 | pred = pred.unsqueeze(1) 221 | out_fake_ = self.discriminator.compute_D_loss(pred.detach(), mode='fake') 222 | loss_d_fake += out_fake_.pop('loss') / len(preds) 223 | for k, v in out_fake_.items(): 224 | out_fake[f"{k}_fake"] = out_fake.get(f"{k}_fake", 0.) + v / len(preds) 225 | 226 | loss_d = loss_d_real + loss_d_fake 227 | 228 | losses.update({f"D/{k}": v for k, v in out_fake.items()}) 229 | output = {'loss': loss} 230 | output.update({k: v.detach() for k, v in losses.items()}) 231 | output['D/loss_d'] = loss_d 232 | output['D/loss_d_real'] = loss_d_real.detach() 233 | output['D/loss_d_fake'] = loss_d_fake.detach() 234 | 235 | return output 236 | 237 | def save_state_dict(self, dir_save: str): 238 | state_feature_cleaner = self.feature_cleaner.get_state_dict() 239 | state_vocoder = self.vocoder.state_dict() 240 | torch.save(state_feature_cleaner, Path(dir_save) / "feature_cleaner.pth") 241 | torch.save(state_vocoder, Path(dir_save) / "vocoder.pth") 242 | 243 | if exists(self.discriminator): 244 | state_discriminator = self.discriminator.state_dict() 245 | torch.save(state_discriminator, Path(dir_save) / "discriminator.pth") 246 | 247 | def load_state_dict(self, dir_load: str): 248 | state_feature_cleaner = torch.load(Path(dir_load) / "feature_cleaner.pth") 249 | state_vocoder = torch.load(Path(dir_load) / "vocoder.pth") 250 | self.feature_cleaner.load_state_dict(state=state_feature_cleaner) 251 | self.vocoder.load_state_dict(state_vocoder) 252 | 253 | if exists(self.discriminator): 254 | state_discriminator = torch.load(Path(dir_load) / "discriminator.pth") 255 | self.discriminator.load_state_dict(state_discriminator) 256 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | 7 | import torch 8 | import wandb 9 | import hydra 10 | from einops import rearrange 11 | 12 | from utils.logging import MetricsLogger 13 | from utils.torch_common import exists, sort_dict, print_once 14 | 15 | from model import AudioEncoderAdapter, Miipher2, MiipherMode 16 | 17 | 18 | class Trainer: 19 | def __init__( 20 | self, 21 | model, # model 22 | ema, # exponential moving average 23 | optimizer, # optimizer 24 | scheduler, # scheduler 25 | train_dataloader, 26 | accel, # Accelerator object 27 | cfg, # Configurations 28 | # Discriminator options (for WaveFit training) 29 | optimizer_d=None, 30 | scheduler_d=None, 31 | # Resume training from a checkpoint directory 32 | ckpt_dir=None 33 | ): 34 | self.model = accel.unwrap_model(model) 35 | self.ema = ema 36 | self.opt = optimizer 37 | self.sche = scheduler 38 | self.train_dataloader = train_dataloader 39 | self.accel = accel 40 | self.cfg = cfg 41 | self.cfg_t = cfg.trainer 42 | self.EPS = 1e-8 43 | 44 | # discriminator 45 | self.have_disc = exists(optimizer_d) and exists(scheduler_d) 46 | self.opt_d = optimizer_d 47 | self.sche_d = scheduler_d 48 | 49 | self.logger = MetricsLogger() # Logger for WandB 50 | self.logger_print = MetricsLogger() # Logger for printing 51 | self.logger_test = MetricsLogger() # Logger for test 52 | 53 | self.states = {'global_step': 0, 'best_metrics': float('inf'), 'latest_metrics': float('inf')} 54 | 55 | # time measurement 56 | self.s_event = torch.cuda.Event(enable_timing=True) 57 | self.e_event = torch.cuda.Event(enable_timing=True) 58 | 59 | # resume training 60 | if ckpt_dir is not None: 61 | self.__load_ckpt(ckpt_dir) 62 | 63 | def start_training(self): 64 | """ 65 | Start training with infinite loops 66 | """ 67 | self.model.train() 68 | self.s_event.record() 69 | 70 | print_once("\n[ Started training ]\n") 71 | 72 | while True: 73 | for batch in self.train_dataloader: 74 | # Update 75 | metrics = self.run_step(batch) 76 | 77 | if self.accel.is_main_process: 78 | self.logger.add(metrics) 79 | self.logger_print.add(metrics) 80 | 81 | # Log 82 | if self.__its_time(self.cfg_t.logging.n_step_log): 83 | self.__log_metrics() 84 | 85 | # Print 86 | if self.__its_time(self.cfg_t.logging.n_step_print): 87 | self.__print_metrics() 88 | 89 | # Save checkpoint 90 | if self.__its_time(self.cfg_t.logging.n_step_ckpt): 91 | self.__save_ckpt() 92 | 93 | # Sample 94 | if not isinstance(self.model, AudioEncoderAdapter): 95 | if self.__its_time(self.cfg_t.logging.n_step_sample): 96 | self.__sampling() 97 | 98 | self.states['global_step'] += 1 99 | 100 | def run_step(self, batch, train: bool = True): 101 | """ One training step """ 102 | 103 | # target and degraded audios: (bs, ch, sample_length) 104 | x_tgt, x_deg, clean_audio, noisy_audio, _ = batch 105 | 106 | # Update 107 | 108 | if train: 109 | self.opt.zero_grad() 110 | if self.have_disc: 111 | self.opt_d.zero_grad() 112 | 113 | if isinstance(self.model, AudioEncoderAdapter): 114 | output = self.model.train_step(x_tgt, x_deg, train=train) 115 | elif isinstance(self.model, Miipher2): 116 | input = x_tgt if self.model.mode == MiipherMode.CLEAN_INPUT else x_deg 117 | output = self.model.train_step(clean_audio, input, train=train) 118 | else: 119 | raise NotImplementedError(f"Model class '{self.model.__class__.__name__}' is not supported.") 120 | 121 | if train: 122 | self.accel.backward(output['loss']) 123 | if self.accel.sync_gradients: 124 | self.accel.clip_grad_norm_(self.model.parameters(), self.cfg_t.max_grad_norm) 125 | self.opt.step() 126 | self.sche.step() 127 | 128 | if self.have_disc: 129 | self.accel.backward(output['D/loss_d']) 130 | if self.accel.sync_gradients: 131 | self.accel.clip_grad_norm_(self.model.discriminator.parameters(), self.cfg_t.max_grad_norm) 132 | self.opt_d.step() 133 | self.sche_d.step() 134 | 135 | # EMA 136 | if exists(self.ema): 137 | self.ema.update() 138 | 139 | return {k: v.detach() for k, v in output.items()} 140 | 141 | @torch.no_grad() 142 | def __sampling(self): 143 | # Restoration / Reconstruction samples from Miipher-2 144 | self.model.eval() 145 | 146 | n_sample: int = self.cfg_t.logging.n_samples_per_step 147 | 148 | # randomly select samples 149 | dataset = self.train_dataloader.dataset 150 | idxs = torch.randint(len(dataset), size=(n_sample,)) 151 | x_tgt, x_deg, clean_audio, noisy_audio, deg_types = [], [], [], [], [] 152 | for idx in idxs: 153 | x_tgt_, x_deg_, clean_audio_, noisy_audio_, info_ = dataset[idx] 154 | x_tgt.append(x_tgt_) 155 | x_deg.append(x_deg_) 156 | clean_audio.append(clean_audio_) 157 | noisy_audio.append(noisy_audio_) 158 | deg_types.append(info_.get('deg_types', ['none'])) 159 | 160 | x_tgt = torch.stack(x_tgt, dim=0).to(self.accel.device) 161 | x_deg = torch.stack(x_deg, dim=0).to(self.accel.device) if self.model.mode != MiipherMode.CLEAN_INPUT else None 162 | clean_audio = torch.stack(clean_audio, dim=0).to(self.accel.device) 163 | noisy_audio = torch.stack(noisy_audio, dim=0).to(self.accel.device) if self.model.mode != MiipherMode.CLEAN_INPUT else None 164 | 165 | columns = ['clean (audio)', 'decoded (audio)'] if self.model.mode == MiipherMode.CLEAN_INPUT \ 166 | else ['clean (audio)', 'degraded (audio)', 'Degrations', 'restored (audio)'] 167 | table_audio = wandb.Table(columns=columns) 168 | 169 | # sampling 170 | x_input = x_tgt if self.model.mode == MiipherMode.CLEAN_INPUT else x_deg 171 | initial_noise = torch.randn_like(clean_audio) 172 | with torch.no_grad(): 173 | x_preds = self.model(x_input, initial_noise) 174 | x_pred = x_preds[-1] # (n_sample, L) 175 | 176 | for idx in range(n_sample): 177 | # clean audio 178 | data = [wandb.Audio(clean_audio[idx].cpu().numpy(), sample_rate=dataset.sr)] 179 | 180 | # degraded audio 181 | if self.model.mode == MiipherMode.NOISY_INPUT: 182 | data += [wandb.Audio(noisy_audio[idx].cpu().numpy(), sample_rate=dataset.sr)] 183 | data += ['/'.join([d for d in deg_types[idx] if d != 'none'])] 184 | 185 | # decoded audio 186 | data += [wandb.Audio(x_pred[idx].cpu().numpy().T, sample_rate=dataset.sr)] 187 | 188 | table_audio.add_data(*data) 189 | 190 | self.accel.log({'Samples': table_audio}, step=self.states['global_step']) 191 | 192 | self.model.train() 193 | 194 | print("\t->->-> Sampled.") 195 | 196 | def __save_ckpt(self): 197 | import shutil 198 | import json 199 | from omegaconf import OmegaConf 200 | 201 | out_dir = self.cfg_t.output_dir + '/ckpt' 202 | 203 | # save latest ckpt 204 | latest_dir = out_dir + '/latest' 205 | os.makedirs(latest_dir, exist_ok=True) 206 | 207 | # save optimizer/scheduler states 208 | torch.save(self.opt.state_dict(), f"{latest_dir}/optimizer.pth") 209 | torch.save(self.sche.state_dict(), f"{latest_dir}/scheduler.pth") 210 | 211 | if self.have_disc: 212 | torch.save(self.opt_d.state_dict(), f"{latest_dir}/optimizer_d.pth") 213 | torch.save(self.sche_d.state_dict(), f"{latest_dir}/scheduler_d.pth") 214 | 215 | # save model states 216 | self.model.save_state_dict(latest_dir) 217 | 218 | # save states and configuration 219 | OmegaConf.save(self.cfg, f"{latest_dir}/config.yaml") 220 | with open(f"{latest_dir}/states.json", mode="wt", encoding="utf-8") as f: 221 | json.dump(self.states, f, indent=2) 222 | 223 | # save best ckpt 224 | if self.states['latest_metrics'] == self.states['best_metrics']: 225 | shutil.copytree(latest_dir, out_dir + '/best', dirs_exist_ok=True) 226 | 227 | print("\t->->-> Saved checkpoints.") 228 | 229 | def __load_ckpt(self, dir: str): 230 | import json 231 | 232 | print_once(f"\n[Resuming training from the checkpoint directory] -> {dir}") 233 | 234 | self.opt.load_state_dict(torch.load(f"{dir}/optimizer.pth", weights_only=False)) 235 | self.sche.load_state_dict(torch.load(f"{dir}/scheduler.pth", weights_only=False)) 236 | 237 | if self.have_disc: 238 | self.opt_d.load_state_dict(torch.load(f"{dir}/optimizer_d.pth", weights_only=False)) 239 | self.sche_d.load_state_dict(torch.load(f"{dir}/scheduler_d.pth", weights_only=False)) 240 | 241 | self.model.load_state_dict(dir) 242 | 243 | with open(f"{dir}/states.json", mode="rt", encoding="utf-8") as f: 244 | self.states.update(json.load(f)) 245 | 246 | def __log_metrics(self, sort_by_key: bool = False): 247 | metrics = self.logger.pop() 248 | # learning rate 249 | metrics['lr'] = self.sche.get_last_lr()[0] 250 | if sort_by_key: 251 | metrics = sort_dict(metrics) 252 | 253 | self.accel.log(metrics, step=self.states['global_step']) 254 | 255 | # update states 256 | m_for_ckpt = self.cfg_t.logging.metrics_for_best_ckpt 257 | m_latest = float(sum([metrics[k].detach() for k in m_for_ckpt])) 258 | self.states['latest_metrics'] = m_latest 259 | if m_latest < self.states['best_metrics']: 260 | self.states['best_metrics'] = m_latest 261 | 262 | def __print_metrics(self, sort_by_key: bool = False): 263 | self.e_event.record() 264 | torch.cuda.synchronize() 265 | p_time = self.s_event.elapsed_time(self.e_event) / 1000. # [sec] 266 | 267 | metrics = self.logger_print.pop() 268 | # tensor to scalar 269 | metrics = {k: v.item() for k, v in metrics.items()} 270 | if sort_by_key: 271 | metrics = sort_dict(metrics) 272 | 273 | step = self.states['global_step'] 274 | s = f"Step {step} ({p_time:.1e} [sec]): " + ' / '.join([f"[{k}] - {v:.3e}" for k, v in metrics.items()]) 275 | print(s) 276 | 277 | self.s_event.record() 278 | 279 | def __its_time(self, itv: int): 280 | return (self.states['global_step'] - 1) % itv == 0 281 | -------------------------------------------------------------------------------- /_LICENSES/transformers.LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018- The Hugging Face team. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /_LICENSES/unify-parameter-efficient-tuning.LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018- The Hugging Face team. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🎵 Open-Miipher-2 | A Universal Speech Restoration Model 2 | 3 | This is an unofficial implementation of **`Miipher-2`**[1] 4 | which is **a state-of-the-art universal speech restoration model** from Google Research. 5 | 6 | ![Miipher-2](./assets/fig/miipher-2.png) 7 | 8 | This repository supports: 9 | - 🔥 Full implementation and training code for the `Miipher-2` model 10 | - 🔥 Distributed training with multiple GPUs / multiple Nodes 11 | 12 | ## What Google Can Do vs. What You Can Do 13 | 14 | Google's Miipher-2 leverages proprietary datasets and large-scale infrastructure that are not publicly accessible. While you cannot use Google's internal data or resources, this repository enables you to train and experiment with open datasets and models. The implementation is designed to be flexible and reproducible for academic and general-purpose use. 15 | 16 | ### Google USM 17 | 18 | Google uses the Universal Speech Model (USM) [4] for both feature extraction and as the base (pretrained) model for the feature cleaner in Miipher-2. Specifically, Miipher-2 is trained using the first 13 layers of a 32-layer Conformer with 2 billion parameters. However, this pretrained model is not publicly available. 19 | 20 | On the other hand, Google has open-sourced [Gemma 3](https://huggingface.co/docs/transformers/main/model_doc/gemma3), a multimodal LLM that includes a 0.6 billion parameter (12-layer) USM module. In this repository, the default configuration uses up to the 6th layer of this model as the base for the feature cleaner. Naturally, differences in base model size may lead to variations in restoration performance between the Google version and this repository. 21 | And please note that the USM (0.6B) in Gemma 3 differs in architectural configuration from the 0.6B model described in the USM paper [4]. 22 | 23 | The key differences between the Google version and this repository are summarized below: 24 | 25 | | | Base model | Model size | Conformer layers | Dimension | Frame rate | Parallel Adapter size | 26 | |:-----------|:----------:|:----------:|:----------------:|:---------:|:----------:|:---:| 27 | | **Google** | USM | 2B | 13th / 32 layers | 1536 | 25 | 41M | 28 | | **Open-Miipher-2** | USM | 0.6B | 6th / 12 layers | 1536 | 25 | 19M | 29 | 30 | Currently, Gemma 3's USM is loaded from the following model card on HuggingFace. 31 | - [Atotti/Google-USM](https://huggingface.co/Atotti/Google-USM) 32 | 33 | For more details on the selection of Conformer layers and related considerations, please refer to the [Tips](#-tips) section. 34 | 35 | ### Audio dataset 36 | According to the paper, Google used `3,195 hours of speech from 1,642 speakers across 44 languages` as speech data and `internally collected audio snippets from environments such as cafes, kitchens, and automobiles` as noise data for training Miipher-2. These datasets are internal to Google and are not publicly available. 37 | 38 | For general-purpose use, it is preferable to utilize larger and more diverse speech/noise datasets. However, for experiments or academic purposes, you can use open datasets such as those listed below. 39 | 40 | | Type | Dataset name | Link | Hours | 41 | |-------|-------------------------|------|---| 42 | | Speech | LibriTTS-R [5] | [https://www.openslr.org/141/](https://www.openslr.org/141/) | 585 | 43 | | Noise | TAU Urban Acoustic Scenes 2020 Mobile, Development dataset | [https://zenodo.org/records/3670167](https://zenodo.org/records/3670167) | 64 | 44 | | Noise | TAU Urban Audio-Visual Scenes 2021, Development dataset | [https://zenodo.org/records/4477542](https://zenodo.org/records/4477542) | 34 | 45 | 46 | 47 | # Requirements 48 | 49 | - Python 3.8.10 or later 50 | - PyTorch 2.1 or later 51 | - transformers>=4.53 (NOTE: `transformers` must contain Gemma implementations from Google.) 52 | 53 | ## Building a training environment 54 | 55 | To simplify setting up the training environment, I recommend to use container systems like `Docker` or `Singularity` instead of installing dependencies on each GPU machine. Below are the steps for creating `Singularity` containers. 56 | 57 | All example scripts are stored at the [container](container/) folder. 58 | 59 | ### 1. Install Singularity 60 | 61 | Install the latest Singularity by following the official instruction. 62 | - https://docs.sylabs.io/guides/main/user-guide/quick_start.html#quick-installation-steps 63 | 64 | ### 2. Create a Singularity image file 65 | 66 | Create (build) a Singularity image file with a definition file. 67 | ```bash 68 | singularity build --fakeroot Open-Miipher-2.sif Open-Miipher-2.def 69 | ``` 70 | 71 | ** NOTE: You might need to change NVIDIA base image in the definition file to match your GPU machine. 72 | 73 | Now, you obtained a container file for training and inference of Open-Miipher-2. 74 | 75 | ## Setting a WandB account for logging 76 | 77 | The training code also requires a Weights & Biases account to log the training outputs and demos. 78 | Please create an account and follow the instruction. 79 | 80 | Once you create your WandB account, 81 | you can obtain the API key from https://wandb.ai/authorize after logging in to your account. 82 | And then, the API key can be passed as an environment variable `WANDB_API_KEY` to a training job 83 | for logging training information. 84 | 85 | ```bash 86 | $ WANDB_API_KEY="12345x6789y..." 87 | ``` 88 | 89 | ## Authentication of HuggingFace 90 | 91 | This repository uses the `google/gemma-3n-e2b-it` model from HuggingFace. 92 | To load this model, you must authenticate with HuggingFace in advance. 93 | 94 | - https://huggingface.co/google/gemma-3n-E2B-it 95 | 96 | ### Login setting 97 | 98 | 1. Create a HuggingFace token for accessing the weights from https://huggingface.co/settings/tokens. 99 | 2. Set the token to `HF_TOKEN` variable when running your script. 100 | 101 | 102 | # Data preparation 103 | 104 | By default, Miipher-2 performs decoding on 24kHz audio signals. Therefore, if audio with a sample rate other than 24kHz is loaded, it will be automatically resampled to 24kHz within the Dataset class. To avoid this extra computation, it is recommended to pre-process all audio files in each dataset to 24kHz before training. 105 | 106 | For fast initialization of the dataset class, this repository uses a metadata file (`metadata.csv`) placed in the root directory of each audio dataset. You can generate this file by running the following script for each dataset directory: 107 | 108 | ```bash 109 | AUDIO_DIR=/path/to/audio-root-directory/ 110 | python dataset/script/make_metadata_csv.py --root-dir ${AUDIO_DIR} 111 | ``` 112 | 113 | ## Pre-computation of degraded speech signals [Not supported] 114 | 115 | For clarity and simplicity, this repository applies random degradations to audio samples within a Dataset class during loading (online processing). However, some degradation methods (e.g. convolution filtering) can be computationally intensive on the CPU, potentially becoming a bottleneck and preventing full utilization of GPU resources during training. In such cases, consider pre-computing and saving multiple degraded versions of each clean speech sample before training, so that the Dataset can load these pre-processed files directly. 116 | 117 | # Training 118 | 119 | Miipher-2 training consists of the following three stages: 120 | 121 | 1. Training of the feature cleaner module 122 | 2. Pretraining of the WaveFit module 123 | 3. Finetuning of the WaveFit module 124 | 125 | Stage 1 trains a module that converts noisy SSL features into clean ones, serving as the main part of audio restoration. Stage 2 pre-trains the WaveFit speech vocoder using only clean speech. Stage 3 fine-tunes the WaveFit module using features restored by the feature cleaner from Stage 1. 126 | 127 | For Stage 2, short audio signals of 0.6 seconds are used to accelerate training (see Sec.2.4 in [2]). In contrast, Stages 1 and 3 use longer audio signals (typically 10–30 seconds) to capture more global information. The original paper suggests that 30 seconds, which is the maximum input length of the USM Conformer block is used for the training, but this requires significant memory and may be impractical for most users. Therefore, the sample code in this repository defaults to a 10-second input length. 128 | 129 | In summary, the input signal lengths for each training stage are as follows: 130 | 131 | | Stage | Module | Input length | Input audio | Purpose | 132 | |-------|-----------------------|--------------|------|----------------------------------------------| 133 | | 1 | Feature Cleaner | 10 sec | Noisy speech | Restore noisy SSL features | 134 | | 2 | WaveFit Pretraining | 0.6 sec | Clean speech | Pretrain vocoder with clean features | 135 | | 3 | WaveFit Finetuning | 10 sec | Noisy speech | Finetune vocoder with restored features | 136 | 137 | 138 | ## Stage 1: Feature Cleaner Training 139 | 140 |
Sample script for Feature Cleaner Training 141 | 142 | ```bash 143 | ROOT_DIR="/path/to/this/repository/" 144 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif" 145 | DATASET_ROOT="/path/to/dataset/root/" 146 | JOB_ID="your_job_id" 147 | 148 | # Dataset 149 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/ 150 | DIRS_NOISE=${DATASET_ROOT}/TAU-urban-audio-visual-scenes-2021-development_24k-mono/,${DATASET_ROOT}/TAU-urban-acoustic-scenes-2020-mobile-development_24k-mono/ 151 | 152 | # Configuration 153 | PROJECT_NAME="cleaner_google-usm" 154 | MODEL="feature_cleaner/google-usm" 155 | DATA="deg_gemma_24khz_10sec" 156 | OPTIMIZER="feature_cleaner" 157 | BS_PER_GPU=20 158 | NUM_WORKERS=4 159 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}" 160 | 161 | WANDB_API_KEY="your_wandb_api_key" 162 | HF_TOKEN="your_huggingface_token" 163 | PORT=50000 164 | 165 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID} 166 | mkdir -p ${OUTPUT_DIR} 167 | 168 | # Calculate total batch size based on number of GPUs 169 | NUM_GPUS=2 # Adjust based on your setup 170 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU})) 171 | 172 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \ 173 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \ 174 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \ 175 | ${CONTAINER_PATH} \ 176 | torchrun --nproc_per_node gpu --master_port ${PORT} \ 177 | ${ROOT_DIR}/src/train.py \ 178 | data.train.dirs_audio=[${DIRS_AUDIO}] \ 179 | data.train.dirs_noise=[${DIRS_NOISE}] \ 180 | trainer.output_dir=${OUTPUT_DIR} \ 181 | trainer.batch_size=${BATCH_SIZE} \ 182 | trainer.num_workers=${NUM_WORKERS} \ 183 | trainer.logger.project_name=${PROJECT_NAME} \ 184 | trainer.logger.run_name=job-${JOB_ID} \ 185 | ${EXTRA_ARGS} 186 | ``` 187 | 188 |
189 | 190 | ## Stage 2: WaveFit Pretraining 191 | 192 |
Sample script for WaveFit Pretraining 193 | 194 | ```bash 195 | ROOT_DIR="/path/to/this/repository/" 196 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif" 197 | DATASET_ROOT="/path/to/dataset/root/" 198 | JOB_ID="your_job_id" 199 | 200 | # Dataset 201 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/ 202 | 203 | # Configuration 204 | PROJECT_NAME="wavefit_pretrain" 205 | MODEL="miipher-2_google-usm_wavefit-5_clean-input" 206 | DATA="deg_gemma_24khz_06sec_clean-only" 207 | OPTIMIZER="wavefit" 208 | BS_PER_GPU=30 209 | NUM_WORKERS=4 210 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}" 211 | 212 | WANDB_API_KEY="your_wandb_api_key" 213 | HF_TOKEN="your_huggingface_token" 214 | PORT=50000 215 | 216 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID} 217 | mkdir -p ${OUTPUT_DIR} 218 | 219 | # Calculate total batch size based on number of GPUs 220 | NUM_GPUS=2 # Adjust based on your setup 221 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU})) 222 | 223 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \ 224 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \ 225 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \ 226 | ${CONTAINER_PATH} \ 227 | torchrun --nproc_per_node gpu --master_port ${PORT} \ 228 | ${ROOT_DIR}/src/train.py \ 229 | data.train.dirs_audio=[${DIRS_AUDIO}] \ 230 | trainer.output_dir=${OUTPUT_DIR} \ 231 | trainer.batch_size=${BATCH_SIZE} \ 232 | trainer.num_workers=${NUM_WORKERS} \ 233 | trainer.logger.project_name=${PROJECT_NAME} \ 234 | trainer.logger.run_name=job-${JOB_ID} \ 235 | ${EXTRA_ARGS} 236 | ``` 237 | 238 |
239 | 240 | ## Stage 3: WaveFit Finetuning 241 | 242 | Please note that for WaveFit finetuning, you must specify the checkpoint directories for the modules trained in Stage 1 and Stage 2. 243 | 244 |
Sample script for WaveFit Finetuning 245 | 246 | ```bash 247 | ROOT_DIR="/path/to/this/repository/" 248 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif" 249 | DATASET_ROOT="/path/to/dataset/root/" 250 | JOB_ID="your_job_id" 251 | 252 | # Dataset 253 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/ 254 | DIRS_NOISE=${DATASET_ROOT}/TAU-urban-audio-visual-scenes-2021-development_24k-mono/,${DATASET_ROOT}/TAU-urban-acoustic-scenes-2020-mobile-development_24k-mono/ 255 | 256 | # Configuration 257 | PROJECT_NAME="wavefit_finetune" 258 | MODEL="miipher-2_google-usm_wavefit-5_noisy-input" 259 | DATA="deg_gemma_24khz_10sec" 260 | OPTIMIZER="wavefit" 261 | BS_PER_GPU=5 262 | NUM_WORKERS=4 263 | 264 | # Pre-trained model checkpoints 265 | FEATURE_CLEANER_CKPT_DIR=/path/to/feature_cleaner_ckpt_dir/ 266 | VOCODER_CKPT_DIR=/path/to/vocoder_ckpt_dir/ 267 | 268 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}" 269 | EXTRA_ARGS="${EXTRA_ARGS} model.feature_cleaner_ckpt_dir=${FEATURE_CLEANER_CKPT_DIR} model.vocoder_ckpt_dir=${VOCODER_CKPT_DIR}" 270 | 271 | WANDB_API_KEY="your_wandb_api_key" 272 | HF_TOKEN="your_huggingface_token" 273 | PORT=50000 274 | 275 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID} 276 | mkdir -p ${OUTPUT_DIR} 277 | 278 | # Calculate total batch size based on number of GPUs 279 | NUM_GPUS=2 # Adjust based on your setup 280 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU})) 281 | 282 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \ 283 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \ 284 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \ 285 | ${CONTAINER_PATH} \ 286 | torchrun --nproc_per_node gpu --master_port ${PORT} \ 287 | ${ROOT_DIR}/src/train.py \ 288 | data.train.dirs_audio=[${DIRS_AUDIO}] \ 289 | data.train.dirs_noise=[${DIRS_NOISE}] \ 290 | trainer.output_dir=${OUTPUT_DIR} \ 291 | trainer.batch_size=${BATCH_SIZE} \ 292 | trainer.num_workers=${NUM_WORKERS} \ 293 | trainer.logger.project_name=${PROJECT_NAME} \ 294 | trainer.logger.run_name=job-${JOB_ID} \ 295 | trainer.debug=0 \ 296 | ${EXTRA_ARGS} 297 | ``` 298 | 299 |
300 | 301 | ## Resume training from a checkpoint 302 | 303 | While training, checkpoints (state_dict) of models, optimizers and schedulers are saved under the output directory specified in the configuration as follows. 304 | ``` 305 | output_dir/ 306 | ├─ ckpt/ 307 | │ ├─ latest/ 308 | │ │ ├─ model.pth 309 | │ │ ├─ optimizer.pth 310 | │ │ ├─ scheduler.pth 311 | │ │ ├─ ... 312 | ``` 313 | 314 | By specifying the checkpoint directory, you can easily resume your training from the checkpoint. 315 | 316 |
Sample script for training resumption from saved checkpoints 317 | 318 | ```bash 319 | CKPT_DIR="output_dir/ckpt/latest/" 320 | OUTPUT_DIR="another/directory/" 321 | 322 | # Execution 323 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \ 324 | --env MASTER_PORT=${PORT} --env WANDB_API_KEY=$WANDB_API_KEY \ 325 | ${CONTAINER_PATH} \ 326 | torchrun --nproc_per_node gpu ${ROOT_DIR}/src/train.py \ 327 | trainer.ckpt_dir=${CKPT_DIR} \ 328 | trainer.output_dir=${OUTPUT_DIR} 329 | ``` 330 | 331 |
332 | 333 | # Inference 334 | 335 | Using pre-trained Open-Miipher-2 models, you can perform inference with audio signals as input 336 | (e.g. for speech restoration evaluation). 337 | 338 | The [`inference.py`](src/inference.py) performs inference for all of audio files in a target directory. 339 | To check other options for the script, please use `-h` option. 340 | 341 |
Sample script for Inference (speech restoration) 342 | 343 | ```bash 344 | ROOT_DIR="/path/to/this/repository/" 345 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif" 346 | JOB_ID="your_job_id" 347 | 348 | CKPT_DIR="/path/to/feature_cleaner_ckpt_dir/" 349 | AUDIO_DIR="/path/to/target/speech/directory/" 350 | OUTPUT_DIR="${ROOT_DIR}/runs/inference/${JOB_ID}" 351 | 352 | HF_TOKEN="your_huggingface_token" 353 | PORT=50000 354 | 355 | mkdir -p ${OUTPUT_DIR} 356 | 357 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $AUDIO_DIR \ 358 | --env MASTER_PORT=${PORT} --env HF_TOKEN=$HF_TOKEN \ 359 | ${CONTAINER_PATH} \ 360 | torchrun --nproc_per_node=1 --master_port=${PORT} \ 361 | ${ROOT_DIR}/src/inference.py \ 362 | --ckpt-dir ${CKPT_DIR} \ 363 | --input-audio-dir ${AUDIO_DIR} \ 364 | --output-dir ${OUTPUT_DIR} \ 365 | --sample-size 160000 \ 366 | --sr-in 16000 \ 367 | --sr-out 24000 368 | ``` 369 | 370 |
371 | 372 | # 💡 Tips 373 | 374 | ## Number of Conformer layers 375 | 376 | The number of Conformer layers used from USM is a crucial parameter that affects the overall performance of Miipher-2. Since Conformer layers are repeatedly applied to Mel-spectrogram inputs, using fewer layers may make it easier for WaveFit to restore the speech signal. However, it is also necessary to ensure a sufficient number of layers for effective feature cleaning. In the original Miipher-2, 13 out of 32 layers are utilized. 377 | 378 | To determine the optimal number of layers for this repository, small-scale WaveFit training experiments were conducted using SSL features obtained from different numbers of layers. The figure below shows the progression of `STFT spectral convergence loss`, `STFT magnitude loss`, and `GAN loss` when using up to 1, 2, 5, 7, 9, and 11 layers. 379 | 380 | ![Layer comparison](./assets/fig/compare_layers.png) 381 | 382 | As expected, using fewer Conformer layers generally results in lower STFT loss, indicating that decoding to speech is easier. Notably, when using 9 or more layers, performance drops sharply, and using all 12 layers fails entirely. Based on these results, this repository defaults to using `6 layers`. 383 | 384 | This setting can be easily changed via the configuration. 385 | 386 | ## Degradation types 387 | 388 | The methods for degrading speech used for model training differ between those described in the paper and those implemented in this repository, as summarized below. 389 | 390 | | | Background noise | Room reverb | Codec (MP3, Vorbis, A-law, AMR-WB, OPUS) | Soft/Hard clipping | Lowpass | 391 | |:--------------:|:---------------:|:-----------:|:----------------------------------------:|:------------------:|:-------:| 392 | | **Google** | ✓ | ✓ | ✓ | | | 393 | | **Open-Miipher-2** | ✓ | ✓ | | ✓ | ✓ | 394 | 395 | Codec processing is considered too computationally intensive for online processing, so it is excluded from this repository. 396 | 397 | 400 | 401 | # 🤔 Unclear points in the implementation 402 | 403 | ## 1. Parameter size 404 | 405 | According to the paper, the Parallel Adapter (PA) is described as having `20 million` learnable parameters (Sec.2.2). However, the PA used in Miipher-2 takes a 1536-dimensional input and has a 1024-dimensional bottleneck structure, which is applied to 13 layers of Conformers. The approximate parameter size of the PA can be calculated from two linear layers, resulting in a total parameter count of about `1536 x 1024 x 2 x 13 = 40.9M`. This is likely a typo in the paper. 406 | 407 | ## 2. Upsampling method of USM feature 408 | 409 | When inputting SSL features into WaveFit in Miipher-2, upsampling is performed along the time axis to fit the frame rate to the appropriate input length (Sec.2.3). The specific upsampling method (e.g., 'nearest', 'linear', or 'bilinear') is not described in the paper, but it is likely that the choice does not significantly affect performance. Therefore, this repository uses `linear` interpolation as the default. 410 | 411 | ## 3. Loss function of feature cleaner 412 | 413 | The loss function for training feature cleaner is defined in the Miipher paper [2] as below. 414 | 415 | ```math 416 | \mathcal{L} = \| S - \hat{S} \|_{1} 417 | + \| S - \hat{S} \|_{2}^{2} 418 | + \frac{\| S - \hat{S} \|_{2}^{2}}{\| S \|_{2}^{2}}, 419 | \quad \text{where } 420 | \| S \|_{p} = \left( \sum_{k} \sum_{d} | S_{k,d} |^{p} \right)^{1/p}. 421 | ``` 422 | 423 | In the paper, the first term is referred to as "mean-absolute-error" and the second term as "mean-squared-error," but the formulas do not actually compute the mean. In practice, when calculating this loss, the first and second terms become disproportionately large compared to the third term (spectral convergence loss). Therefore, it is reasonable to compute the loss as follows: 424 | 425 | ```math 426 | \mathcal{L} = \frac{1}{KD}\| S - \hat{S} \|_{1} 427 | + \frac{1}{KD}\| S - \hat{S} \|_{2}^{2} 428 | + \frac{\| S - \hat{S} \|_{2}^{2}}{\| S \|_{2}^{2}} 429 | ``` 430 | 431 | # References 432 | 433 | 1. "Miipher-2: A Universal Speech Restoration Model for Million-Hour Scale Data Restoration", S. Karita et al., WASPAA 2025 434 | 1. "Miipher: A Robust Speech Restoration Model Integrating Self-Supervised Speech and Text Representations", Y. Koizumi, WASPAA 2021 435 | 1. "WaveFit: An Iterative and Non-autoregressive Neural Vocoder based on Fixed-Point Iteration", Y. Koizumi et al., IEEE SLT, 2022 436 | 1. "Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages", Y. Zhang et al., Arxiv, 2023 437 | 1. "LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus", Yuma Koizumi et al., INTERSPEECH 2023. --------------------------------------------------------------------------------