├── src
├── utils
│ ├── __init__.py
│ ├── logging.py
│ ├── scheduler.py
│ ├── torch_common.py
│ └── audio_utils.py
├── data
│ ├── __init__.py
│ ├── pretransform
│ │ ├── __init__.py
│ │ └── gemma_audio_feature.py
│ ├── degradation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── lowpass.py
│ │ ├── clipping.py
│ │ ├── noise.py
│ │ └── reverb.py
│ ├── modification.py
│ ├── audio_io.py
│ └── dataset.py
├── model
│ ├── wavefit
│ │ ├── loss
│ │ │ ├── __init__.py
│ │ │ └── mrstft.py
│ │ ├── __init__.py
│ │ ├── wavefit.py
│ │ ├── discriminator.py
│ │ └── generator.py
│ ├── __init__.py
│ ├── feature_cleaner
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── parallel_adapter.py
│ │ └── google_usm.py
│ └── miipher_2.py
├── train.py
├── inference.py
└── trainer.py
├── assets
└── fig
│ ├── miipher-2.png
│ └── compare_layers.png
├── container
├── build_singularity.bash
└── Open-Miipher-2.def
├── configs
├── default.yaml
├── optimizer
│ ├── feature_cleaner.yaml
│ └── wavefit.yaml
├── data
│ ├── deg_gemma_24khz_06sec_clean-only.yaml
│ ├── deg_gemma_24khz_10sec.yaml
│ └── deg_gemma_24khz_30sec.yaml
├── model
│ ├── feature_cleaner
│ │ └── google-usm.yaml
│ ├── miipher-2_google-usm_wavefit-5_clean-input.yaml
│ └── miipher-2_google-usm_wavefit-5_noisy-input.yaml
└── trainer
│ └── default.yaml
├── LICENSE
├── _LICENSES
├── descript-audio-codec.LICENSE
├── transformers.LICENSE
└── unify-parameter-efficient-tuning.LICENSE
├── dataset
└── script
│ └── make_metadata_csv.py
├── .gitignore
└── README.md
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import DegradedAudioDataset
2 |
--------------------------------------------------------------------------------
/src/model/wavefit/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .mrstft import MRSTFTLoss, MELMAELoss
2 |
--------------------------------------------------------------------------------
/src/data/pretransform/__init__.py:
--------------------------------------------------------------------------------
1 | from .gemma_audio_feature import GemmaAudioFeature
2 |
--------------------------------------------------------------------------------
/assets/fig/miipher-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukara-ikemiya/Open-Miipher-2/HEAD/assets/fig/miipher-2.png
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .feature_cleaner import *
2 | from .wavefit import *
3 | from .miipher_2 import *
4 |
--------------------------------------------------------------------------------
/assets/fig/compare_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukara-ikemiya/Open-Miipher-2/HEAD/assets/fig/compare_layers.png
--------------------------------------------------------------------------------
/src/model/feature_cleaner/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import AudioEncoderAdapter
2 | from .google_usm import GoogleUSMAdapter
3 |
--------------------------------------------------------------------------------
/container/build_singularity.bash:
--------------------------------------------------------------------------------
1 | # Create `Open-Miipher-2.sif` file
2 | NAME="Open-Miipher-2"
3 | singularity build --fakeroot ~/$NAME.sif $NAME.def
--------------------------------------------------------------------------------
/src/model/wavefit/__init__.py:
--------------------------------------------------------------------------------
1 | from .wavefit import WaveFit
2 | from .discriminator import Discriminator
3 | from .loss import MRSTFTLoss, MELMAELoss
4 |
--------------------------------------------------------------------------------
/src/data/degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from .clipping import AudioClipping
2 | from .noise import NoiseAddition
3 | from .reverb import RIRReverb
4 | from .lowpass import AudioLowpass
5 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: feature_cleaner/google-usm
3 | - data: deg_gemma_24khz_30sec
4 | - optimizer: feature_cleaner
5 | - trainer: default
6 | - _self_
7 |
8 | # no hydra logging
9 | hydra:
10 | output_subdir: null
11 | run:
12 | dir: .
13 | # multirun
14 | sweep:
15 | dir: .
--------------------------------------------------------------------------------
/configs/optimizer/feature_cleaner.yaml:
--------------------------------------------------------------------------------
1 |
2 | optimizer:
3 | _partial_: true
4 | _target_: torch.optim.AdamW
5 | betas: [0.8, 0.99]
6 | lr: 0.001
7 | weight_decay: 0.001
8 |
9 | scheduler:
10 | _partial_: true
11 | _target_: utils.scheduler.WarmupCosineLR
12 | warmup_steps: 2000
13 | total_steps: 2000000
14 | min_lr: 0.00001
--------------------------------------------------------------------------------
/src/data/degradation/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 | from abc import ABC, abstractmethod
5 |
6 |
7 | class Degradation(ABC):
8 | def __init__(self, sample_rate: int = 16000):
9 | super().__init__()
10 | self.sample_rate = sample_rate
11 |
12 | @abstractmethod
13 | def __call__(self, x):
14 | pass
15 |
--------------------------------------------------------------------------------
/configs/data/deg_gemma_24khz_06sec_clean-only.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | _target_: data.DegradedAudioDataset
3 | dirs_audio: ???
4 | dirs_noise: null
5 | sample_size: 14400 # 0.6 sec at 24 kHz
6 | sample_rate: 24000
7 | pretransform: 'gemma'
8 | exts: ['wav', 'flac']
9 | augment_shift: True
10 | augment_flip: True
11 | augment_volume: True
12 | volume_range: [0.25, 1.0]
13 | clean_only: True
14 |
--------------------------------------------------------------------------------
/configs/data/deg_gemma_24khz_10sec.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | _target_: data.DegradedAudioDataset
3 | dirs_audio: ???
4 | dirs_noise: ???
5 | sample_size: 240000
6 | sample_rate: 24000
7 | pretransform: 'gemma'
8 | exts: ['wav', 'flac']
9 | augment_shift: True
10 | augment_flip: True
11 | augment_volume: True
12 | volume_range: [0.25, 1.0]
13 | deg_types: ['clipping', 'noise', 'reverb', 'lowpass']
14 | n_deg_comb: 3
15 | prob_no_deg: 0.05
16 | clean_only: False
17 |
--------------------------------------------------------------------------------
/configs/data/deg_gemma_24khz_30sec.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | _target_: data.DegradedAudioDataset
3 | dirs_audio: ???
4 | dirs_noise: ???
5 | sample_size: 720000
6 | sample_rate: 24000
7 | pretransform: 'gemma'
8 | exts: ['wav', 'flac']
9 | augment_shift: True
10 | augment_flip: True
11 | augment_volume: True
12 | volume_range: [0.25, 1.0]
13 | deg_types: ['clipping', 'noise', 'reverb', 'lowpass']
14 | n_deg_comb: 3
15 | prob_no_deg: 0.05
16 | clean_only: False
17 |
--------------------------------------------------------------------------------
/configs/model/feature_cleaner/google-usm.yaml:
--------------------------------------------------------------------------------
1 | # Feature cleaner model using Google USM (Gemma3n) as the encoder with adapter layers.
2 | #
3 | # Input sample rate: 16000 hz
4 | # Dim size: 1536
5 | # Output frame rate: 25 frames/sec
6 |
7 | _target_: model.GoogleUSMAdapter
8 |
9 | n_adaptive_layers: 6 # /12
10 | encoder_id: "Atotti/google-usm"
11 | adapter_config:
12 | "dim_bottleneck": 1024
13 | "init_option": "bert"
14 | "adapter_scalar": 1.0
15 | "pre_ln_class": "Gemma3nRMSNorm"
16 |
--------------------------------------------------------------------------------
/configs/optimizer/wavefit.yaml:
--------------------------------------------------------------------------------
1 |
2 | optimizer:
3 | _partial_: true
4 | _target_: torch.optim.AdamW
5 | betas: [0.8, 0.99]
6 | lr: 0.0001
7 | weight_decay: 0.001
8 |
9 | scheduler:
10 | _partial_: true
11 | _target_: utils.scheduler.WarmupCosineLR
12 | warmup_steps: 2000
13 | total_steps: 1000000
14 | min_lr: 0.00001
15 |
16 | optimizer_d:
17 | _partial_: true
18 | _target_: torch.optim.AdamW
19 | betas: [0.8, 0.99]
20 | lr: 0.0002
21 | weight_decay: 0.001
22 |
23 | scheduler_d:
24 | _partial_: true
25 | _target_: utils.scheduler.WarmupCosineLR
26 | warmup_steps: 2000
27 | total_steps: 1000000
28 | min_lr: 0.00001
--------------------------------------------------------------------------------
/container/Open-Miipher-2.def:
--------------------------------------------------------------------------------
1 | Bootstrap: docker
2 | From: pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
3 |
4 | %post
5 | apt-get update \
6 | && DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg \
7 | && rm -rf /var/lib/apt/lists/*
8 |
9 | python -m pip install --upgrade --no-cache-dir pip setuptools wheel
10 |
11 | python -m pip install transformers==4.53
12 | python -m pip install librosa einops
13 | python -m pip install pyroomacoustics
14 | python -m pip install torchaudio
15 | python -m pip install hydra-core wandb accelerate ema-pytorch
16 |
17 |
18 | %environment
19 | export PATH=/usr/local/bin:$PATH
20 | export PYTHONUNBUFFERED=1
--------------------------------------------------------------------------------
/src/data/degradation/lowpass.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 | import random
5 |
6 | import torch
7 | import torchaudio
8 |
9 | from .base import Degradation
10 |
11 |
12 | class AudioLowpass(Degradation):
13 | def __init__(
14 | self,
15 | cutoff_range: tuple = (2000.0, 7000.0),
16 | sample_rate: int = 16000
17 | ):
18 | super().__init__(sample_rate=sample_rate)
19 | self.cutoff_range = cutoff_range
20 |
21 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
22 | cutoff = random.uniform(self.cutoff_range[0], self.cutoff_range[1])
23 | x = torchaudio.functional.lowpass_biquad(x, self.sample_rate, cutoff)
24 |
25 | # avoid clipping
26 | amp = x.abs().max().item()
27 | if amp > 1.0:
28 | x = x / (amp + 1e-8)
29 |
30 | return x
31 |
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 |
4 | Convenient modules for logging metrics.
5 | """
6 |
7 | import typing as tp
8 |
9 | import torch
10 |
11 |
12 | class MetricsLogger:
13 | def __init__(self):
14 | self.counts = {}
15 | self.metrics = {}
16 |
17 | def add(self, metrics: tp.Dict[str, torch.Tensor]) -> None:
18 | for k, v in metrics.items():
19 | if k in self.counts.keys():
20 | self.counts[k] += 1
21 | self.metrics[k] += v.detach().clone()
22 | else:
23 | self.counts[k] = 1
24 | self.metrics[k] = v.detach().clone()
25 |
26 | def pop(self, mean: bool = True) -> tp.Dict[str, torch.Tensor]:
27 | metrics = {}
28 | for k, v in self.metrics.items():
29 | metrics[k] = v / self.counts[k] if mean else v
30 |
31 | # reset
32 | self.counts = {}
33 | self.metrics = {}
34 |
35 | return metrics
36 |
--------------------------------------------------------------------------------
/src/data/degradation/clipping.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 | import typing as tp
5 | import random
6 |
7 | import torch
8 |
9 | from .base import Degradation
10 |
11 |
12 | class AudioClipping(Degradation):
13 | def __init__(
14 | self,
15 | amp_range: tp.Tuple[float, float] = (1.2, 2.0),
16 | prob_soft_clip: float = 0.5,
17 | **kwargs
18 | ):
19 | super().__init__(**kwargs)
20 | self.amp_range = amp_range
21 | self.prob_soft_clip = prob_soft_clip
22 |
23 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
24 | amp_org = x.abs().max().item()
25 | assert amp_org <= 1.0
26 | amp = random.uniform(self.amp_range[0], self.amp_range[1])
27 | scale = amp / (amp_org + 1e-9)
28 |
29 | # clipping
30 | x = x * scale
31 | if random.random() < self.prob_soft_clip:
32 | x = torch.tanh(x)
33 | else:
34 | x = x.clamp(-1.0, 1.0)
35 |
36 | # rescale
37 | x = x / scale
38 |
39 | return x
40 |
--------------------------------------------------------------------------------
/src/data/degradation/noise.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 | import random
5 |
6 | import torch
7 |
8 | from .base import Degradation
9 |
10 |
11 | class NoiseAddition(Degradation):
12 | def __init__(
13 | self,
14 | snr_range: tuple = (5.0, 30.0),
15 | **kwargs
16 | ):
17 | super().__init__(**kwargs)
18 | self.snr_range = snr_range
19 |
20 | def __call__(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
21 | assert x.shape == noise.shape, f"Shapes of input and noise must be the same: {x.shape} != {noise.shape}"
22 | snr = random.uniform(self.snr_range[0], self.snr_range[1])
23 | p_x = x.pow(2.0).mean()
24 | p_n = noise.pow(2.0).mean()
25 | P_new = p_x / (10 ** (snr / 10))
26 | scale_n = (P_new / (p_n + 1e-9)).sqrt()
27 |
28 | # noise addition
29 | x = x + noise * scale_n
30 |
31 | # avoid clipping
32 | amp = x.abs().max().item()
33 | if amp > 1.0:
34 | x = x / (amp + 1e-8)
35 |
36 | return x
37 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | # Root directory for outputs
2 | output_dir: ???
3 |
4 | # Checkpoint directory for resuming training
5 | ckpt_dir: null
6 |
7 | # Batch size
8 | # NOTE: Batch size must be a multiple of GPU number
9 | batch_size: 1
10 | num_workers: 2
11 |
12 | # Seed value used for rng initialization
13 | seed: 0
14 |
15 | # debug
16 | debug: false
17 |
18 | # Automatic mixed precision
19 | # Choose from ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
20 | amp: 'bf16' # 'no'
21 |
22 | # Max norm of gradient clipping
23 | max_grad_norm: 1.0
24 |
25 | # EMA (Exponential Moving Average)
26 | ema:
27 | beta: 0.999
28 | update_after_step: 100
29 | update_every: 10
30 |
31 | # Logging
32 |
33 | logger:
34 | project_name: 'project_name'
35 | run_name: 'run_name'
36 |
37 | logging:
38 | # Step interval for logging metrics / saving checkpoints
39 | # / generating samples / test (validation) / printing metrics
40 | n_step_log: 10
41 | n_step_ckpt: 2000
42 | n_step_sample: 2000
43 | n_step_print: 100
44 | # Number of generated samples
45 | n_samples_per_step: 3
46 |
47 | metrics_for_best_ckpt: ['loss']
--------------------------------------------------------------------------------
/src/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | ------
5 | Learning rate schedulers.
6 | """
7 |
8 | import math
9 |
10 | import torch
11 |
12 |
13 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler):
14 | def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0, last_epoch=-1):
15 | self.warmup_steps = warmup_steps
16 | self.total_steps = total_steps
17 | self.min_lr = min_lr
18 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | step = self.last_epoch + 1
22 |
23 | if step < self.warmup_steps:
24 | # warmup
25 | return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs]
26 | else:
27 | # cosine decay (base_lr -> min_lr)
28 | progress = min(1, (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps))
29 |
30 | return [
31 | self.min_lr + (base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
32 | for base_lr in self.base_lrs
33 | ]
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Yukara Ikemiya
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/_LICENSES/descript-audio-codec.LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023-present, Descript
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/utils/torch_common.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 | """
4 |
5 | import os
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 |
11 |
12 | def exists(x: torch.Tensor):
13 | return x is not None
14 |
15 |
16 | def get_world_size():
17 | if not torch.distributed.is_available() or not torch.distributed.is_initialized():
18 | return 1
19 | else:
20 | return torch.distributed.get_world_size()
21 |
22 |
23 | def get_rank():
24 | """Get rank of current process."""
25 |
26 | if not torch.distributed.is_available() or not torch.distributed.is_initialized():
27 | return 0
28 | else:
29 | return torch.distributed.get_rank()
30 |
31 |
32 | def print_once(*args):
33 | if get_rank() == 0:
34 | print(*args)
35 |
36 |
37 | def set_seed(seed: int = 0):
38 | torch.manual_seed(seed)
39 | if torch.cuda.is_available():
40 | torch.cuda.manual_seed_all(seed)
41 | np.random.seed(seed)
42 | random.seed(seed)
43 | os.environ["PYTHONHASHSEED"] = str(seed)
44 |
45 |
46 | def count_parameters(model: torch.nn.Module, include_buffers: bool = False):
47 | n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
48 | n_buffers = sum(p.numel() for p in model.buffers()) if include_buffers else 0
49 | return n_trainable_params + n_buffers
50 |
51 |
52 | def sort_dict(D: dict):
53 | s_keys = sorted(D.keys())
54 | return {k: D[k] for k in s_keys}
55 |
56 |
57 | def checkpoint(function, *args, **kwargs):
58 | """ Gradient checkpointing """
59 | kwargs.setdefault("use_reentrant", False)
60 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
61 |
--------------------------------------------------------------------------------
/src/data/pretransform/gemma_audio_feature.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | --------------
5 | A wrapper of Gemma audio feature extractor
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | from torchaudio import transforms as T
11 | from transformers import Gemma3nAudioFeatureExtractor
12 |
13 |
14 | class GemmaAudioFeature:
15 | """
16 | A wrapper around the Gemma3nAudioFeatureExtractor.
17 |
18 | NOTE: Feature extraction is executed on CPU.
19 | """
20 |
21 | def __init__(self, model_id="google/gemma-3n-e2b-it"):
22 | self.extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model_id)
23 | self.sr = self.extractor.sampling_rate
24 |
25 | def __call__(self, audio: torch.Tensor, sr_in: int) -> torch.Tensor:
26 | """
27 | Args:
28 | audio (torch.Tensor): (num_samples)
29 | Returns:
30 | (torch.Tensor): (num_frames, feature_dim)
31 | """
32 | if sr_in != self.sr:
33 | resample_tf = T.Resample(sr_in, self.sr)
34 | audio = resample_tf(audio)
35 |
36 | audio = audio.numpy()
37 | audio = audio.reshape(1, -1) # (1, L)
38 | output = self.extractor(audio, return_tensors="pt")
39 | audio_mel = output["input_features"]
40 | # The encoder expects a padding mask (True for padding), while the feature extractor
41 | # returns an attention mask (True for valid tokens). We must invert it.
42 | # NOTE: 'False' for valid frames, 'True' for padded frames
43 | # audio_mel_mask = ~output["input_features_mask"].to(torch.bool) # not used for now
44 |
45 | return audio_mel.squeeze(0) # (num_frames, feature_dim)
46 |
--------------------------------------------------------------------------------
/src/data/modification.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 | """
4 |
5 | import random
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | # Channels
13 |
14 | class Mono(nn.Module):
15 | def __call__(self, x: torch.Tensor):
16 | assert len(x.shape) <= 2
17 | return torch.mean(x, dim=0, keepdims=True) if len(x.shape) > 1 else x
18 |
19 |
20 | class Stereo(nn.Module):
21 | def __call__(self, x: torch.Tensor):
22 | x_shape = x.shape
23 | assert len(x_shape) <= 2
24 | # Check if it's mono
25 | if len(x_shape) == 1: # s -> 2, s
26 | x = x.unsqueeze(0).repeat(2, 1)
27 | elif len(x_shape) == 2:
28 | if x_shape[0] == 1: # 1, s -> 2, s
29 | x = x.repeat(2, 1)
30 | elif x_shape[0] > 2: # ?, s -> 2,s
31 | x = x[:2, :]
32 |
33 | return x
34 |
35 |
36 | # Augmentation
37 |
38 | class PhaseFlipper(nn.Module):
39 | """Randomly invert the phase of a signal"""
40 |
41 | def __init__(self, p=0.5):
42 | super().__init__()
43 | self.p = p
44 |
45 | def __call__(self, x: torch.Tensor):
46 | assert len(x.shape) <= 2
47 | return -x if (random.random() < self.p) else x
48 |
49 |
50 | class VolumeChanger(nn.Module):
51 | """Randomly change volume (amplitude) of a signal"""
52 |
53 | def __init__(self, min_amp: float = 0.25, max_amp: float = 1.0):
54 | super().__init__()
55 | self.min_amp = min_amp
56 | self.max_amp = max_amp
57 |
58 | def __call__(self, x: torch.Tensor):
59 | assert x.ndim <= 2
60 | amp_x = x.abs().max().item()
61 | if amp_x < 1e-5:
62 | return x
63 |
64 | min_db = 20 * math.log10(self.min_amp / amp_x)
65 | max_db = 20 * math.log10(self.max_amp / amp_x)
66 | scale_db = random.uniform(min_db, max_db)
67 | scale = 10 ** (scale_db / 20)
68 | x = x * scale
69 |
70 | return x
71 |
--------------------------------------------------------------------------------
/configs/model/miipher-2_google-usm_wavefit-5_clean-input.yaml:
--------------------------------------------------------------------------------
1 | # Miipher-2 model consists of a Google-USM feature cleaner and a WaveFit-5 vocoder.
2 | #
3 | # Input sample rate: 16000 hz (-> mel-spectrogram)
4 | # Output sample rate: 24000 hz
5 |
6 | defaults:
7 | - feature_cleaner: google-usm
8 |
9 | _target_: model.Miipher2
10 |
11 | mode: clean_input
12 | upsample_factor: 4
13 | upsample_mode: "linear"
14 | gradient_checkpointing: false
15 |
16 | vocoder:
17 | _target_: model.WaveFit
18 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5)
19 | num_iteration: 5
20 | # Gain of target audio
21 | target_gain: 0.9
22 |
23 | # Pre-network (Conformer blocks)
24 | num_conformer_blocks: 4
25 | args_conformer:
26 | # NOTE: These settings are based on Gemma3n audio encoder settings.
27 | # https://huggingface.co/google/gemma-3n-E2B-it/blob/main/config.json
28 | conf_attention_chunk_size: 12
29 | conf_attention_context_left: 13
30 | conf_attention_context_right: 0
31 | conf_attention_logit_cap: 50.0
32 | conf_conv_kernel_size: 5
33 | conf_num_attention_heads: 8
34 | conf_reduction_factor: 4
35 | conf_residual_weight: 0.5
36 | gradient_clipping: 10000000000.0
37 | hidden_size: 1536
38 | rms_norm_eps: 1e-06
39 |
40 | # Generator
41 | args_generator:
42 | dim_feat: 1536
43 | upsample_factors: [5, 4, 3, 2, 2]
44 | upsample_channels: [512, 512, 256, 128, 128]
45 | downsample_channels: [128, 128, 256, 512]
46 |
47 | discriminator:
48 | _target_: model.Discriminator
49 | msd_kwargs:
50 | num_D: 3
51 | ndf: 16
52 | n_layers: 4
53 | downsampling_factor: 4
54 | mpd_kwargs:
55 | periods: [2, 3, 5, 7, 11, 13, 17, 19]
56 |
57 | mrstft:
58 | # Parallel WaveGAN setting (Sec.2.3 in the Miipher paper)
59 | _target_: model.MRSTFTLoss
60 | n_ffts: [512, 1024, 2048]
61 | win_sizes: [240, 600, 1200]
62 | hop_sizes: [50, 120, 240]
63 |
64 | loss_lambdas:
65 | mrstft_sc_loss: 2.5
66 | mrstft_mag_loss: 2.5
67 | disc_gan_loss: 1.0
68 | disc_feat_loss: 10.0
--------------------------------------------------------------------------------
/src/data/audio_io.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 | """
4 |
5 | import math
6 | import json
7 |
8 | from torch.nn import functional as F
9 | import torchaudio
10 | from torchaudio import transforms as T
11 |
12 |
13 | def get_audio_metadata(filepath, cache=True):
14 | try:
15 | with open(filepath + '.json', 'r') as f:
16 | info = json.load(f)
17 | return info
18 | except Exception:
19 | try:
20 | info_ = torchaudio.info(filepath)
21 | sample_rate = info_.sample_rate
22 | num_channels = info_.num_channels
23 | num_frames = info_.num_frames
24 |
25 | info = {
26 | 'sample_rate': sample_rate,
27 | 'num_frames': num_frames,
28 | 'num_channels': num_channels
29 | }
30 | except Exception as e:
31 | # error : cannot open an audio file
32 | print(f"Failed to load metadata for {filepath}: {e}")
33 | info = {'sample_rate': 0, 'num_frames': 0, 'num_channels': 0}
34 |
35 | if cache:
36 | with open(filepath + '.json', 'w') as f:
37 | json.dump(info, f, indent=2)
38 |
39 | return info
40 |
41 |
42 | def load_audio_with_pad(filepath, info: dict, sr: int, n_samples: int, offset: int):
43 | sr_in, num_frames = info['sample_rate'], info['num_frames']
44 | n_samples_in = int(math.ceil(n_samples * (sr_in / sr)))
45 |
46 | # load audio
47 | ext = filepath.split(".")[-1]
48 | out_frames = min(n_samples_in, num_frames - offset)
49 |
50 | audio, _ = torchaudio.load(
51 | filepath, frame_offset=offset, num_frames=out_frames,
52 | format=ext, backend='soundfile')
53 |
54 | # resample
55 | if sr_in != sr:
56 | resample_tf = T.Resample(sr_in, sr)
57 | audio = resample_tf(audio)[..., :n_samples]
58 |
59 | # zero pad
60 | L = audio.shape[-1]
61 | if L < n_samples:
62 | audio = F.pad(audio, (0, n_samples - L), value=0.)
63 |
64 | return audio
65 |
--------------------------------------------------------------------------------
/configs/model/miipher-2_google-usm_wavefit-5_noisy-input.yaml:
--------------------------------------------------------------------------------
1 | # Miipher-2 model consists of a Google-USM feature cleaner and a WaveFit-5 vocoder.
2 | #
3 | # Input sample rate: 16000 hz (-> mel-spectrogram)
4 | # Output sample rate: 24000 hz
5 |
6 | defaults:
7 | - feature_cleaner: google-usm
8 |
9 | _target_: model.Miipher2
10 |
11 | mode: noisy_input
12 | upsample_factor: 4
13 | upsample_mode: "linear"
14 | gradient_checkpointing: true
15 |
16 | feature_cleaner_ckpt_dir: ???
17 | vocoder_ckpt_dir: ???
18 |
19 | vocoder:
20 | _target_: model.WaveFit
21 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5)
22 | num_iteration: 5
23 | # Gain of target audio
24 | target_gain: 0.9
25 |
26 | # Pre-network (Conformer blocks)
27 | num_conformer_blocks: 4
28 | args_conformer:
29 | # NOTE: These settings are based on Gemma3n audio encoder settings.
30 | # https://huggingface.co/google/gemma-3n-E2B-it/blob/main/config.json
31 | conf_attention_chunk_size: 12
32 | conf_attention_context_left: 13
33 | conf_attention_context_right: 0
34 | conf_attention_logit_cap: 50.0
35 | conf_conv_kernel_size: 5
36 | conf_num_attention_heads: 8
37 | conf_reduction_factor: 4
38 | conf_residual_weight: 0.5
39 | gradient_clipping: 10000000000.0
40 | hidden_size: 1536
41 | rms_norm_eps: 1e-06
42 |
43 | # Generator
44 | args_generator:
45 | dim_feat: 1536
46 | upsample_factors: [5, 4, 3, 2, 2]
47 | upsample_channels: [512, 512, 256, 128, 128]
48 | downsample_channels: [128, 128, 256, 512]
49 |
50 | discriminator:
51 | _target_: model.Discriminator
52 | msd_kwargs:
53 | num_D: 3
54 | ndf: 16
55 | n_layers: 4
56 | downsampling_factor: 4
57 | mpd_kwargs:
58 | periods: [2, 3, 5, 7, 11, 13, 17, 19]
59 |
60 | mrstft:
61 | # Parallel WaveGAN setting (Sec.2.3 in the Miipher paper)
62 | _target_: model.MRSTFTLoss
63 | n_ffts: [512, 1024, 2048]
64 | win_sizes: [240, 600, 1200]
65 | hop_sizes: [50, 120, 240]
66 |
67 | loss_lambdas:
68 | mrstft_sc_loss: 2.5
69 | mrstft_mag_loss: 2.5
70 | disc_gan_loss: 1.0
71 | disc_feat_loss: 10.0
--------------------------------------------------------------------------------
/src/model/feature_cleaner/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | -----------------------------------------------------
5 | A base class of audio encoder adapters.
6 | """
7 | from abc import ABC, abstractmethod
8 |
9 | import torch
10 | from torch import nn
11 |
12 |
13 | class AudioEncoderAdapter(ABC, nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | @abstractmethod
18 | def forward(self, x: torch.Tensor, encoder_only: bool = False) -> torch.Tensor:
19 | pass
20 |
21 | @abstractmethod
22 | def save_state_dict(self, path: str):
23 | pass
24 |
25 | @abstractmethod
26 | def load_state_dict(self, path: str):
27 | pass
28 |
29 | @abstractmethod
30 | def get_state_dict(self) -> dict:
31 | pass
32 |
33 | def train_step(
34 | self,
35 | x_tgt: torch.Tensor,
36 | x_deg: torch.Tensor,
37 | loss_lambda: dict = {
38 | 'l1': 1.0,
39 | 'l2': 1.0,
40 | 'spectral_convergence': 1.0
41 | },
42 | train: bool = True
43 | ) -> dict:
44 | """
45 | Loss computation for feature cleaners defined in the Miipher paper.
46 | https://arxiv.org/abs/2303.01664
47 | """
48 | self.train() if train else self.eval()
49 | assert x_tgt.shape == x_deg.shape, f"Shapes of target and degraded features must be the same: {x_tgt.shape} != {x_deg.shape}"
50 |
51 | with torch.no_grad():
52 | feats_tgt = self.forward(x_tgt, encoder_only=True)
53 |
54 | with torch.set_grad_enabled(train):
55 | feats_deg = self.forward(x_deg, encoder_only=False)
56 |
57 | # loss
58 | l1_loss = (feats_deg - feats_tgt).abs().mean() * loss_lambda['l1']
59 | l2_loss = (feats_deg - feats_tgt).pow(2.0).mean() * loss_lambda['l2']
60 | sc_loss = l2_loss / (feats_tgt.pow(2.0).mean() + 1e-9) * loss_lambda['spectral_convergence']
61 |
62 | loss = l1_loss + l2_loss + sc_loss
63 |
64 | return {
65 | 'loss': loss,
66 | 'l1_loss': l1_loss.detach(),
67 | 'l2_loss': l2_loss.detach(),
68 | 'sc_loss': sc_loss.detach()
69 | }
70 |
--------------------------------------------------------------------------------
/src/model/feature_cleaner/parallel_adapter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | Adapted from the following repo's code under Apache-2.0 License.
5 | https://github.com/jxhe/unify-parameter-efficient-tuning/
6 |
7 | -----------------------------------------------------
8 | Parallel adapter for LLMs.
9 | """
10 | import math
11 | import typing as tp
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 |
17 | def init_bert_weights(module):
18 | """Initialize the weights."""
19 | if isinstance(module, (nn.Linear, nn.Embedding)):
20 | # std defaults to 0.02, this might need to be changed
21 | module.weight.data.normal_(mean=0.0, std=0.02)
22 | elif isinstance(module, nn.LayerNorm):
23 | module.bias.data.zero_()
24 | module.weight.data.fill_(1.0)
25 | if isinstance(module, nn.Linear) and module.bias is not None:
26 | module.bias.data.zero_()
27 |
28 |
29 | class AdapterLayer(nn.Module):
30 | def __init__(
31 | self,
32 | dim_in: int,
33 | dim_bottleneck: int,
34 | dropout: float = 0.0,
35 | init_option: str = "bert",
36 | adapter_scalar: tp.Union[float, str] = 1.0,
37 | pre_ln_class=None
38 | ):
39 | super().__init__()
40 |
41 | self.n_embd = dim_in
42 | self.down_size = dim_bottleneck
43 |
44 | # Layer normalization options
45 | self.use_pre_ln = pre_ln_class is not None
46 | if self.use_pre_ln:
47 | self.pre_ln = pre_ln_class(dim_in)
48 |
49 | # PA modules
50 | self.down_proj = nn.Linear(self.n_embd, self.down_size)
51 | self.non_linear_func = nn.ReLU()
52 | self.up_proj = nn.Linear(self.down_size, self.n_embd)
53 | self.scale = nn.Parameter(torch.ones(1)) if adapter_scalar == "learnable_scalar" else float(adapter_scalar)
54 | self.dropout = dropout
55 |
56 | # Initialization options
57 | if init_option == "bert":
58 | self.apply(init_bert_weights)
59 | elif init_option == "lora":
60 | with torch.no_grad():
61 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
62 | nn.init.zeros_(self.up_proj.weight)
63 | nn.init.zeros_(self.down_proj.bias)
64 | nn.init.zeros_(self.up_proj.bias)
65 | else:
66 | raise ValueError(f"Unknown initialization option: {init_option}")
67 |
68 | def forward(self, x):
69 | if self.use_pre_ln:
70 | x = self.pre_ln(x)
71 |
72 | down = self.down_proj(x)
73 | down = self.non_linear_func(down)
74 | down = nn.functional.dropout(down, p=self.dropout, training=self.training)
75 | up = self.up_proj(down)
76 | up = up * self.scale
77 |
78 | return up
79 |
--------------------------------------------------------------------------------
/src/utils/audio_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 | """
4 |
5 | import math
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from librosa.filters import mel as librosa_mel
11 |
12 | EPS = 1e-10
13 |
14 |
15 | def get_amplitude_spec(x, n_fft, win_size, hop_size, window, return_power: bool = False):
16 | stft_spec = torch.stft(
17 | x, n_fft, hop_length=hop_size, win_length=win_size, window=window,
18 | center=True, normalized=False, onesided=True, return_complex=True)
19 |
20 | power_spec = torch.view_as_real(stft_spec).pow(2).sum(-1)
21 |
22 | return power_spec if return_power else torch.sqrt(power_spec + EPS)
23 |
24 |
25 | class MelSpectrogram(nn.Module):
26 | def __init__(
27 | self,
28 | sr: int,
29 | # STFT setting
30 | n_fft: int, win_size: int, hop_size: int,
31 | # MelSpec setting
32 | n_mels: int, fmin: float, fmax: float,
33 | ):
34 | super().__init__()
35 |
36 | self.sr = sr
37 | self.n_fft = n_fft
38 | self.win_size = win_size
39 | self.hop_size = hop_size
40 | self.n_mels = n_mels
41 | self.fmin = fmin
42 | self.fmax = fmax
43 |
44 | mel_basis = librosa_mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
45 | mel_basis = torch.from_numpy(mel_basis).float()
46 | mel_inv_basis = torch.linalg.pinv(mel_basis)
47 |
48 | self.register_buffer('fft_win', torch.hann_window(win_size))
49 | self.register_buffer('mel_basis', mel_basis)
50 | self.register_buffer('mel_inv_basis', mel_inv_basis)
51 |
52 | def compute_mel(self, x: torch.Tensor):
53 | """
54 | Compute Mel-spectrogram.
55 |
56 | Args:
57 | x: time_signal, (bs, length)
58 | Returns:
59 | mel_spec: Mel spectrogram, (bs, n_mels, num_frame)
60 | """
61 | assert x.dim() == 2
62 | L = x.shape[-1]
63 | # NOTE : To prevent different signal length in the final frame of the STFT between training and inference time,
64 | # input signal length must be a multiple of hop_size.
65 | assert L % self.hop_size == 0, f"Input signal length must be a multiple of hop_size {self.hop_size}." + \
66 | f"Input shape -> {x.shape}"
67 |
68 | num_frame = L // self.hop_size
69 |
70 | # STFT
71 | stft_spec = get_amplitude_spec(x, self.n_fft, self.win_size, self.hop_size, self.fft_win)
72 |
73 | # Mel Spec
74 | mel_spec = torch.matmul(self.mel_basis, stft_spec)
75 |
76 | # NOTE : The last frame is removed here.
77 | # When using center=True setting, output from torch.stft has frame length of (L//hopsize+1).
78 | # For training WaveGrad-based architecture, the frame length must be (L//hopsize).
79 | # There might be a better way, but I believe this has little to no impact on training
80 | # since the whole signal information is contained in the previous frames even when removing the last one.
81 | mel_spec = mel_spec[..., :num_frame]
82 |
83 | return mel_spec
84 |
--------------------------------------------------------------------------------
/src/model/wavefit/loss/mrstft.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2024 Yukara Ikemiya
3 | """
4 |
5 | import typing as tp
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from utils.audio_utils import get_amplitude_spec, MelSpectrogram
12 |
13 |
14 | class MRSTFTLoss(nn.Module):
15 | """
16 | Multi-resolution STFT loss corresponding to the eq.(9)
17 | """
18 |
19 | def __init__(
20 | self,
21 | n_ffts: tp.List[int] = [512, 1024, 2048],
22 | win_sizes: tp.List[int] = [360, 900, 1800],
23 | hop_sizes: tp.List[int] = [80, 150, 300],
24 | EPS: float = 1e-5
25 | ):
26 | super().__init__()
27 | assert len(n_ffts) == len(win_sizes) == len(hop_sizes)
28 | self.n_ffts = n_ffts
29 | self.win_sizes = win_sizes
30 | self.hop_sizes = hop_sizes
31 |
32 | # NOTE: Since spectral convergence is quite sensitive to small values in the spectrum,
33 | # I believe setting a higher lower bound will result in more stable training.
34 | self.EPS = EPS
35 |
36 | def forward(
37 | self,
38 | pred: torch.Tensor,
39 | target: torch.Tensor
40 | ):
41 | losses = {
42 | 'mrstft_sc_loss': 0.,
43 | 'mrstft_mag_loss': 0.
44 | }
45 |
46 | for n_fft, win_size, hop_size in zip(self.n_ffts, self.win_sizes, self.hop_sizes):
47 | window = torch.hann_window(win_size, device=pred.device)
48 | spec_t = get_amplitude_spec(target.squeeze(1), n_fft, win_size, hop_size, window)
49 | spec_p = get_amplitude_spec(pred.squeeze(1), n_fft, win_size, hop_size, window)
50 |
51 | # spectral convergence
52 | sc_loss = (spec_t - spec_p).norm(p=2) / (spec_t.norm(p=2) + self.EPS)
53 |
54 | # magnitude loss
55 | mag_loss = F.l1_loss(torch.log(spec_t.clamp(min=self.EPS)), torch.log(spec_p.clamp(min=self.EPS)))
56 |
57 | losses['mrstft_sc_loss'] += sc_loss
58 | losses['mrstft_mag_loss'] += mag_loss
59 |
60 | losses['mrstft_sc_loss'] /= len(self.n_ffts)
61 | losses['mrstft_mag_loss'] /= len(self.n_ffts)
62 |
63 | return losses
64 |
65 |
66 | class MELMAELoss(nn.Module):
67 | """
68 | MAE(L1) loss of Mel spectrogram corresponding to the second term of the eq.(19)
69 | """
70 |
71 | def __init__(
72 | self,
73 | sr: int = 24000,
74 | n_fft: int = 1024,
75 | win_size: int = 900,
76 | hop_size: int = 150,
77 | n_mels: int = 128,
78 | fmin: float = 20.,
79 | fmax: float = 12000.
80 | ):
81 | super().__init__()
82 |
83 | self.mel = MelSpectrogram(sr, n_fft, win_size, hop_size, n_mels, fmin, fmax)
84 |
85 | def forward(
86 | self,
87 | pred: torch.Tensor,
88 | target: torch.Tensor
89 | ):
90 | losses = {'mel_mae_loss': 0.}
91 |
92 | # Mel MAE (L1) loss
93 | mel_p = self.mel.compute_mel(pred.squeeze(1))
94 | mel_t = self.mel.compute_mel(target.squeeze(1))
95 |
96 | losses['mel_mae_loss'] = F.l1_loss(mel_t, mel_p)
97 |
98 | return losses
99 |
--------------------------------------------------------------------------------
/dataset/script/make_metadata_csv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 |
5 | import pandas as pd
6 | import torchaudio
7 |
8 |
9 | def fast_scandir(dir: str, exts: list = ['wav', 'flac']) -> tuple:
10 | """ Very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243
11 |
12 | fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
13 |
14 | Args:
15 | dir (str): top-level directory at which to begin scanning.
16 | exts (tp.List[str]): list of allowed file extensions.
17 | """
18 | subfolders, files = [], []
19 | # add starting period to extensions if needed
20 | exts = ['.' + x if x[0] != '.' else x for x in exts]
21 |
22 | try: # hope to avoid 'permission denied' by this try
23 | for f in os.scandir(dir):
24 | try: # 'hope to avoid too many levels of symbolic links' error
25 | if f.is_dir():
26 | subfolders.append(f.path)
27 | elif f.is_file():
28 | is_hidden = os.path.basename(f.path).startswith(".")
29 | has_ext = os.path.splitext(f.name)[1].lower() in exts
30 |
31 | if has_ext and (not is_hidden):
32 | files.append(f.path)
33 | except Exception:
34 | pass
35 | except Exception:
36 | pass
37 |
38 | for dir in list(subfolders):
39 | sf, f = fast_scandir(dir, exts)
40 | subfolders.extend(sf)
41 | files.extend(f)
42 |
43 | return subfolders, files
44 |
45 |
46 | def get_audio_metadata(filepath):
47 | info_ = torchaudio.info(filepath)
48 | sample_rate = info_.sample_rate
49 | num_channels = info_.num_channels
50 | num_frames = info_.num_frames
51 |
52 | info = {
53 | 'sample_rate': sample_rate,
54 | 'num_frames': num_frames,
55 | 'num_channels': num_channels
56 | }
57 |
58 | return info
59 |
60 |
61 | def make_metadata_csv(dir: str, exts=['wav', 'flac']):
62 |
63 | csv_path = os.path.join(dir, "metadata.csv")
64 | rows = []
65 |
66 | _, files = fast_scandir(dir, exts=exts)
67 | print(f"Found {len(files)} audio files in {dir}")
68 |
69 | for p in tqdm(files):
70 | info = get_audio_metadata(p)
71 | rel_path = os.path.relpath(p, dir)
72 | row = {
73 | "file_path": rel_path,
74 | "sample_rate": info["sample_rate"],
75 | "num_frames": info["num_frames"],
76 | "num_channels": info["num_channels"]
77 | }
78 | rows.append(row)
79 |
80 | # print(row)
81 |
82 | rows.sort(key=lambda x: x["file_path"])
83 | df = pd.DataFrame(rows)
84 | df.to_csv(csv_path, index=False)
85 |
86 | print(f"Saved metadata to {csv_path}")
87 |
88 |
89 | def main():
90 | args = argparse.ArgumentParser()
91 | args.add_argument('--root-dir', type=str, required=True, help="A root directory of audio dataset.")
92 | args = args.parse_args()
93 |
94 | root_dir = args.root_dir
95 |
96 | print(root_dir)
97 | make_metadata_csv(root_dir)
98 |
99 |
100 | if __name__ == "__main__":
101 | main()
102 |
--------------------------------------------------------------------------------
/src/data/degradation/reverb.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 | import typing as tp
5 | import random
6 |
7 | import torch
8 | import numpy as np
9 | from scipy.signal import butter, filtfilt, fftconvolve
10 | import pyroomacoustics as pra
11 |
12 | from .base import Degradation
13 |
14 |
15 | class RIRReverb(Degradation):
16 | def __init__(
17 | self,
18 | sample_rate: int = 16000,
19 | rt60_range: tuple = (0.2, 0.5),
20 | xyz_range: tuple = ((2., 10.), (2., 10.), (2., 5.)),
21 | mic_margin: float = 0.5,
22 | src_margin: float = 0.1,
23 | # cutoff freq for RIR filter (highpass)
24 | hp_cutoff_hz: float = 20.0
25 | ):
26 | super().__init__(sample_rate=sample_rate)
27 | self.rt60_range = rt60_range
28 | self.xyz_range = xyz_range
29 | self.mic_margin = mic_margin
30 | self.src_margin = src_margin
31 | self.hp_cutoff_hz = hp_cutoff_hz
32 |
33 | # pre-compute highpass filter
34 | nyq = sample_rate / 2.
35 | norm_cutoff = self.hp_cutoff_hz / nyq
36 | self.hp_b, self.hp_a = butter(4, norm_cutoff, btype="high")
37 |
38 | def _sample_room_params(self):
39 | rt60 = random.uniform(self.rt60_range[0], self.rt60_range[1])
40 | Lx = random.uniform(self.xyz_range[0][0], self.xyz_range[0][1])
41 | Ly = random.uniform(self.xyz_range[1][0], self.xyz_range[1][1])
42 | Lz = random.uniform(self.xyz_range[2][0], self.xyz_range[2][1])
43 | mic_pos = [
44 | random.uniform(self.mic_margin, Lx - self.mic_margin),
45 | random.uniform(self.mic_margin, Ly - self.mic_margin),
46 | random.uniform(1.2, min(2.0, Lz - self.mic_margin))
47 | ]
48 | src_pos = [
49 | random.uniform(self.src_margin, Lx - self.src_margin),
50 | random.uniform(self.src_margin, Ly - self.src_margin),
51 | random.uniform(self.src_margin, Lz - self.src_margin)
52 | ]
53 |
54 | return rt60, [Lx, Ly, Lz], mic_pos, src_pos
55 |
56 | def _generate_rir(
57 | self, rt60: float, room_size: tp.List[float],
58 | mic_pos: tp.List[float], src_pos: tp.List[float]
59 | ):
60 | e_absorption, max_order = pra.inverse_sabine(rt60, room_size)
61 | room = pra.ShoeBox(
62 | room_size,
63 | fs=self.sample_rate,
64 | materials=pra.Material(e_absorption),
65 | max_order=max_order,
66 | )
67 | room.add_microphone_array(
68 | pra.MicrophoneArray(np.array(mic_pos).reshape(3, 1), self.sample_rate)
69 | )
70 | room.add_source(src_pos)
71 | room.compute_rir()
72 |
73 | rir = room.rir[0][0]
74 | L_rir = len(rir)
75 |
76 | # highpass to RIR
77 | rir_hp = filtfilt(self.hp_b, self.hp_a, rir)
78 | rir_hp = rir_hp[:L_rir] # trim to original length
79 |
80 | # scale direct sound to 0 dB
81 | peak_idx = np.argmax(np.abs(rir_hp))
82 | if np.abs(rir_hp[peak_idx]) > 1e-9:
83 | rir_hp /= rir_hp[peak_idx]
84 |
85 | return rir_hp, peak_idx
86 |
87 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
88 | assert x.ndim == 1, f"Input audio must be 1D tensor: {x.shape}"
89 | dtype = x.dtype
90 | L = len(x)
91 |
92 | # sample room parameters
93 | rt60, room_size, mic_pos, src_pos = self._sample_room_params()
94 |
95 | # generate RIR
96 | rir, peak_idx = self._generate_rir(rt60, room_size, mic_pos, src_pos)
97 |
98 | # convolve
99 | x = fftconvolve(x.numpy(), rir, mode='full')
100 | x = torch.from_numpy(x).to(dtype)
101 |
102 | # fix signal shift and length
103 | offset = min(peak_idx, len(x) - L)
104 | x = x[offset:offset + L]
105 |
106 | # avoid clipping
107 | amp = x.abs().max().item()
108 | if amp > 1.0:
109 | x = x / (amp + 1e-8)
110 |
111 | return x
112 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 |
5 | import sys
6 | sys.dont_write_bytecode = True
7 |
8 | # DDP
9 | from accelerate import Accelerator, DistributedDataParallelKwargs, DataLoaderConfiguration
10 | from accelerate.utils import ProjectConfiguration
11 |
12 | import torch
13 | import hydra
14 | from hydra.core.hydra_config import HydraConfig
15 | from omegaconf import DictConfig, OmegaConf
16 | from ema_pytorch import EMA
17 |
18 | from utils.torch_common import get_world_size, count_parameters, set_seed
19 | from trainer import Trainer
20 |
21 |
22 | @hydra.main(version_base=None, config_path='../configs/', config_name="default.yaml")
23 | def main(cfg: DictConfig):
24 |
25 | # Update config if ckpt_dir is specified (training resumption)
26 |
27 | if cfg.trainer.ckpt_dir is not None:
28 | overrides = HydraConfig.get().overrides.task
29 | overrides = [e for e in overrides if isinstance(e, str)]
30 | override_conf = OmegaConf.from_dotlist(overrides)
31 | cfg = OmegaConf.merge(cfg, override_conf)
32 |
33 | # Load checkpoint configuration
34 | cfg_ckpt = OmegaConf.load(f'{cfg.trainer.ckpt_dir}/config.yaml')
35 | cfg = OmegaConf.merge(cfg_ckpt, override_conf)
36 |
37 | # HuggingFace Accelerate for distributed training
38 |
39 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
40 | dl_config = DataLoaderConfiguration(split_batches=True)
41 | p_config = ProjectConfiguration(project_dir=cfg.trainer.output_dir)
42 | accel = Accelerator(
43 | mixed_precision=cfg.trainer.amp,
44 | dataloader_config=dl_config,
45 | project_config=p_config,
46 | kwargs_handlers=[ddp_kwargs],
47 | log_with='wandb'
48 | )
49 |
50 | accel.init_trackers(cfg.trainer.logger.project_name, config=OmegaConf.to_container(cfg),
51 | init_kwargs={"wandb": {"name": cfg.trainer.logger.run_name, "dir": cfg.trainer.output_dir}})
52 |
53 | if accel.is_main_process:
54 | print("->->-> DDP Initialized.")
55 | print(f"->->-> World size (Number of GPUs): {get_world_size()}")
56 |
57 | set_seed(cfg.trainer.seed)
58 |
59 | # Dataset
60 |
61 | batch_size = cfg.trainer.batch_size
62 | num_workers = cfg.trainer.num_workers
63 | train_dataset = hydra.utils.instantiate(cfg.data.train)
64 | train_dataloader = torch.utils.data.DataLoader(
65 | train_dataset, batch_size=batch_size, shuffle=True,
66 | num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers > 0))
67 |
68 | # Model
69 |
70 | model = hydra.utils.instantiate(cfg.model)
71 |
72 | # EMA
73 |
74 | ema = None
75 | if accel.is_main_process:
76 | ema = EMA(model, **cfg.trainer.ema)
77 | ema.to(accel.device)
78 |
79 | # Optimizer
80 |
81 | # check if cfg.optimizer has optimizer_d
82 | if 'optimizer_d' in cfg.optimizer:
83 | # discriminator exists
84 | opt = hydra.utils.instantiate(cfg.optimizer.optimizer)(params=model.vocoder.parameters())
85 | sche = hydra.utils.instantiate(cfg.optimizer.scheduler)(optimizer=opt)
86 | opt_d = hydra.utils.instantiate(cfg.optimizer.optimizer_d)(params=model.discriminator.parameters())
87 | sche_d = hydra.utils.instantiate(cfg.optimizer.scheduler_d)(optimizer=opt_d)
88 | else:
89 | opt = hydra.utils.instantiate(cfg.optimizer.optimizer)(params=model.parameters())
90 | sche = hydra.utils.instantiate(cfg.optimizer.scheduler)(optimizer=opt)
91 | opt_d = None
92 | sche_d = None
93 |
94 | # Log
95 |
96 | model.train()
97 | num_params = count_parameters(model) / 1e6
98 | if accel.is_main_process:
99 | print("=== Parameters ===")
100 | print(f"\tModel:\t{num_params:.2f} [million]")
101 | print("=== Dataset ===")
102 | print(f"\tBatch size: {cfg.trainer.batch_size}")
103 | print("\tTrain data:")
104 | print(f"\t\tChunks: {len(train_dataset)}")
105 | print(f"\t\tBatches: {len(train_dataset)//cfg.trainer.batch_size}")
106 |
107 | # Prepare for DDP
108 |
109 | train_dataloader, model, opt, sche, opt_d, sche_d = \
110 | accel.prepare(train_dataloader, model, opt, sche, opt_d, sche_d)
111 |
112 | # Start training
113 |
114 | trainer = Trainer(
115 | model=model,
116 | ema=ema,
117 | optimizer=opt,
118 | scheduler=sche,
119 | optimizer_d=opt_d,
120 | scheduler_d=sche_d,
121 | train_dataloader=train_dataloader,
122 | accel=accel,
123 | cfg=cfg,
124 | ckpt_dir=cfg.trainer.ckpt_dir
125 | )
126 |
127 | trainer.start_training()
128 |
129 |
130 | if __name__ == '__main__':
131 | main()
132 | print("[Training finished.]")
133 |
--------------------------------------------------------------------------------
/src/model/wavefit/wavefit.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | -----------------
5 | WaveFit module in Miipher-2.
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from transformers.models.gemma3n.modeling_gemma3n import Gemma3nAudioConformerBlock
12 |
13 | from .generator import WaveFitGenerator
14 | from utils.torch_common import checkpoint
15 |
16 |
17 | class ConformerConfig:
18 | def __init__(
19 | self,
20 | conf_attention_chunk_size: int = 12,
21 | conf_attention_context_left: int = 13,
22 | conf_attention_context_right: int = 0,
23 | conf_attention_logit_cap: float = 50.0,
24 | conf_conv_kernel_size: int = 5,
25 | conf_num_attention_heads: int = 8,
26 | conf_reduction_factor: int = 4,
27 | conf_residual_weight: float = 0.5,
28 | gradient_clipping: float = 10000000000.0,
29 | hidden_size: int = 1536,
30 | rms_norm_eps: float = 1e-06
31 | ):
32 | self.conf_attention_chunk_size = conf_attention_chunk_size
33 | self.conf_attention_context_left = conf_attention_context_left
34 | self.conf_attention_context_right = conf_attention_context_right
35 | self.conf_attention_logit_cap = conf_attention_logit_cap
36 | self.conf_conv_kernel_size = conf_conv_kernel_size
37 | self.conf_num_attention_heads = conf_num_attention_heads
38 | self.conf_reduction_factor = conf_reduction_factor
39 | self.conf_residual_weight = conf_residual_weight
40 | self.gradient_clipping = gradient_clipping
41 | self.hidden_size = hidden_size
42 | self.rms_norm_eps = rms_norm_eps
43 |
44 |
45 | class WaveFit(nn.Module):
46 | def __init__(
47 | self,
48 | num_iteration: int,
49 | target_gain: float = 0.9,
50 | # Pre-network Conformer blocks (Sec.3.2)
51 | num_conformer_blocks: int = 4,
52 | args_conformer: dict = {
53 | "conf_attention_chunk_size": 12,
54 | "conf_attention_context_left": 13,
55 | "conf_attention_context_right": 0,
56 | "conf_attention_logit_cap": 50.0,
57 | "conf_conv_kernel_size": 5,
58 | "conf_num_attention_heads": 8,
59 | "conf_reduction_factor": 4,
60 | "conf_residual_weight": 0.5,
61 | "gradient_clipping": 10000000000.0,
62 | "hidden_size": 1536,
63 | "rms_norm_eps": 1e-06
64 | },
65 | # WaveFit generator
66 | args_generator: dict = {
67 | "dim_feat": 1536,
68 | "upsample_factors": [5, 4, 3, 2, 2],
69 | "upsample_channels": [512, 512, 256, 128, 128],
70 | "downsample_channels": [128, 128, 256, 512],
71 | }
72 | ):
73 | super().__init__()
74 |
75 | self.T = num_iteration
76 | self.target_gain = target_gain
77 |
78 | # Conformer blocks
79 | self.conformer_config = ConformerConfig(**args_conformer)
80 | self.conformer_blocks = nn.ModuleList(
81 | [Gemma3nAudioConformerBlock(self.conformer_config) for _ in range(num_conformer_blocks)]
82 | )
83 |
84 | # Generator
85 | self.generator = WaveFitGenerator(num_iteration, **args_generator)
86 | self.EPS = 1e-8
87 |
88 | @property
89 | def upsample_rate(self) -> int:
90 | return self.generator.upsample_rate
91 |
92 | def forward(
93 | self,
94 | initial_noise: torch.Tensor,
95 | audio_feats: torch.Tensor,
96 | # You can use this option at inference time
97 | return_only_last: bool = False,
98 | # training config
99 | gradient_checkpointing: bool = False
100 | ):
101 | """
102 | Args:
103 | initial_noise: Initial noise, (bs, L).
104 | audio_feats: Audio features, (bs, n_frame, dim).
105 | return_only_last: If true, only the last output (y_0) is returned.
106 | Returns:
107 | preds: List of predictions (y_t)
108 | """
109 | initial_noise = initial_noise.unsqueeze(1) # (bs, 1, L)
110 | assert initial_noise.dim() == audio_feats.dim() == 3
111 | assert initial_noise.size(0) == audio_feats.size(0)
112 |
113 | # Pre-network
114 | mask = torch.zeros(audio_feats.size(0), audio_feats.size(1), dtype=torch.bool, device=audio_feats.device)
115 | for block in self.conformer_blocks:
116 | if gradient_checkpointing and self.training:
117 | audio_feats = checkpoint(block, audio_feats, mask)
118 | else:
119 | audio_feats = block(audio_feats, mask)
120 |
121 | # (bs, n_frame, dim) -> (bs, dim, n_frame)
122 | audio_feats = audio_feats.transpose(1, 2).contiguous()
123 |
124 | preds = []
125 | y_t = initial_noise
126 | for t in range(self.T):
127 | # estimate noise
128 | if gradient_checkpointing and self.training:
129 | est = checkpoint(self.generator, y_t, audio_feats, t)
130 | else:
131 | est = self.generator(y_t, audio_feats, t)
132 |
133 | y_t = y_t - est
134 |
135 | # gain normalization (Sec.2.3 in the Miipher paper)
136 | y_t = self.normalize_gain(y_t)
137 |
138 | if (not return_only_last) or (t == self.T - 1):
139 | preds.append(y_t.squeeze(1))
140 |
141 | # To avoid gradient loop
142 | y_t = y_t.detach()
143 |
144 | return preds
145 |
146 | def normalize_gain(self, z_t: torch.Tensor):
147 | # z_t: (bs, 1, L)
148 | scale = self.target_gain / (z_t.squeeze(1).abs().max(dim=1, keepdim=True)[0][:, :, None] + self.EPS)
149 | return z_t * scale
150 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[codz]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py.cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 | #poetry.toml
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115 | #pdm.lock
116 | #pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # pixi
121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122 | #pixi.lock
123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124 | # in the .venv directory. It is recommended not to include this directory in version control.
125 | .pixi
126 |
127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128 | __pypackages__/
129 |
130 | # Celery stuff
131 | celerybeat-schedule
132 | celerybeat.pid
133 |
134 | # SageMath parsed files
135 | *.sage.py
136 |
137 | # Environments
138 | .env
139 | .envrc
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 | .dmypy.json
160 | dmypy.json
161 |
162 | # Pyre type checker
163 | .pyre/
164 |
165 | # pytype static type analyzer
166 | .pytype/
167 |
168 | # Cython debug symbols
169 | cython_debug/
170 |
171 | # PyCharm
172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174 | # and can be added to the global gitignore or merged into this file. For a more nuclear
175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176 | #.idea/
177 |
178 | # Abstra
179 | # Abstra is an AI-powered process automation framework.
180 | # Ignore directories containing user credentials, local state, and settings.
181 | # Learn more at https://abstra.io/docs
182 | .abstra/
183 |
184 | # Visual Studio Code
185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
188 | # you could uncomment the following to ignore the entire vscode folder
189 | # .vscode/
190 |
191 | # Ruff stuff:
192 | .ruff_cache/
193 |
194 | # PyPI configuration file
195 | .pypirc
196 |
197 | # Cursor
198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200 | # refer to https://docs.cursor.com/context/ignore-files
201 | .cursorignore
202 | .cursorindexingignore
203 |
204 | # Marimo
205 | marimo/_static/
206 | marimo/_lsp/
207 | __marimo__/
208 |
209 | # other directories
210 | job/
211 | runs/
212 | audio/
213 |
214 | # other files
215 | *.sif
216 | *.log
217 | *.code-workspace
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 |
5 | import os
6 | import sys
7 | sys.dont_write_bytecode = True
8 | import argparse
9 | import math
10 |
11 | import hydra
12 | import torch
13 | import torchaudio
14 | from accelerate import Accelerator
15 | from omegaconf import OmegaConf
16 |
17 | from model import Miipher2
18 | from utils.torch_common import get_rank, get_world_size, print_once
19 | from data.dataset import get_audio_info
20 |
21 |
22 | def make_audio_batch(audio, sample_size: int, overlap: int):
23 | """
24 | audio : (ch, L)
25 | """
26 | assert 0 <= overlap < sample_size
27 | L = audio.shape[-1]
28 | shift = sample_size - overlap
29 |
30 | n_split = math.ceil(max(L - sample_size, 0) / shift) + 1
31 | # to mono
32 | audio = audio.mean(0) # (L)
33 | batch = []
34 | for n in range(n_split):
35 | b = audio[n * shift: n * shift + sample_size]
36 | if n == n_split - 1:
37 | b = torch.nn.functional.pad(b, (0, sample_size - len(b)))
38 | batch.append(b)
39 |
40 | batch = torch.stack(batch, dim=0) # (n_split, sample_size)
41 | return batch, L
42 |
43 |
44 | def cross_fade(preds, overlap: int, L: int):
45 | """
46 | preds: (bs, sample_size)
47 | """
48 | bs, sample_size = preds.shape
49 | shift = sample_size - overlap
50 | full_L = sample_size + (bs - 1) * shift
51 | win = torch.bartlett_window(overlap * 2, device=preds.device)
52 |
53 | buf = torch.zeros(full_L, device=preds.device)
54 | pre_overlap = None
55 | for idx in range(bs):
56 | pred = preds[idx] # (sample_size)
57 | ofs = idx * shift
58 | if idx != 0:
59 | # Fix volume
60 | # NOTE: Since volume is changed by gain normalization in WaveFit module,
61 | # it have to be adjusted not to be discontinuous.
62 | cur_overlap = pred[:overlap]
63 | volume_rescale = (pre_overlap.pow(2).sum() / (cur_overlap.pow(2).sum() + 1e-10)).sqrt()
64 | pred *= volume_rescale
65 |
66 | pred[:overlap] *= win[:overlap]
67 |
68 | if idx != bs - 1:
69 | pre_overlap = pred[-overlap:].clone()
70 | pred[-overlap:] *= win[overlap:]
71 |
72 | buf[ofs:ofs + sample_size] += pred
73 |
74 | buf = buf[:L]
75 |
76 | return buf
77 |
78 |
79 | def main():
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument('--ckpt-dir', type=str, help="Checkpoint directory.")
82 | parser.add_argument('--input-audio-dir', type=str, help="Root directory which contains input audio files.")
83 | parser.add_argument('--output-dir', type=str, help="Output directory.")
84 | parser.add_argument('--sample-size', type=int, default=160000, help="Input sample size.")
85 | parser.add_argument('--sr-in', type=int, default=16000, help="Input sample rate.")
86 | parser.add_argument('--sr-out', type=int, default=24000, help="Output sample rate.")
87 | parser.add_argument('--max-batch-size', type=int, default=10, help="Max batch size for inference.")
88 | parser.add_argument('--overlap-rate', type=float, default=0.05, help="Overlap rate for inference.")
89 | parser.add_argument('--use-original-name', default=True, type=bool, help="Whether to use an original file name as an output name.")
90 | args = parser.parse_args()
91 |
92 | ckpt_dir = args.ckpt_dir
93 | input_audio_dir = args.input_audio_dir
94 | output_dir = args.output_dir
95 | sample_size = args.sample_size
96 | sr_in = args.sr_in
97 | sr_out = args.sr_out
98 | max_batch_size = args.max_batch_size
99 | overlap_rate = args.overlap_rate
100 | use_original_name = args.use_original_name
101 |
102 | # Distributed inference
103 | accel = Accelerator()
104 | device = accel.device
105 | rank = get_rank()
106 | world_size = get_world_size()
107 |
108 | print_once(f"Checkpoint dir : {ckpt_dir}")
109 | print_once(f"Input audio dir : {input_audio_dir}")
110 | print_once(f"Output dir : {output_dir}")
111 |
112 | # Load Miipher-2 model
113 | cfg_ckpt = OmegaConf.load(f'{ckpt_dir}/config.yaml')
114 | # remove discriminator and MRSTFT modules
115 | cfg_ckpt.model.discriminator = None
116 | cfg_ckpt.model.mrstft = None
117 |
118 | model: Miipher2 = hydra.utils.instantiate(cfg_ckpt.model)
119 | model.load_state_dict(ckpt_dir)
120 | model.to(device)
121 | model.eval()
122 | print_once("->-> Successfully loaded Miipher-2 model from checkpoint.")
123 |
124 | overlap_in = int(sample_size * overlap_rate)
125 | overlap_out = int(overlap_in * sr_out / sr_in)
126 | print_once(f"->-> [Sample size] : {sample_size} samples")
127 | print_once(f"->-> [Overlap size]: {overlap_in} samples ({overlap_in / sample_size * 100:.1f} %)")
128 |
129 | # Get audio files
130 | files, _ = get_audio_info([input_audio_dir])
131 | print_once(f"->-> Found {len(files)} audio files from {input_audio_dir}.")
132 | os.makedirs(output_dir, exist_ok=True)
133 |
134 | # Split files for each process
135 | files = files[rank::world_size]
136 |
137 | print_once(f"--- Rank-{rank} : Start inference... ---")
138 |
139 | for idx, f_path in enumerate(files):
140 | # load and split audio
141 | audio, sr = torchaudio.load(f_path)
142 | if sr != sr_in:
143 | audio = torchaudio.functional.resample(audio, sr, sr_in)
144 | sr = sr_in
145 |
146 | audio_batch, L = make_audio_batch(audio, sample_size, overlap_in)
147 | n_iter = math.ceil(audio_batch.shape[0] / max_batch_size)
148 |
149 | audio_batch = audio_batch.to(device)
150 |
151 | # execute
152 | preds = []
153 | for n in range(n_iter):
154 | batch_ = audio_batch[n * max_batch_size:(n + 1) * max_batch_size]
155 | with torch.no_grad():
156 | pred = model.inference(batch_) # (bs, L)
157 |
158 | preds.append(pred)
159 |
160 | preds = torch.cat(preds, dim=0)
161 |
162 | # cross-fade
163 | L_out = int(L * sr_out / sr_in)
164 | pred_audio = cross_fade(preds, overlap_out, L_out).cpu()
165 |
166 | # rescale volume to avoid clipping
167 | pred_audio = pred_audio / (pred_audio.abs().max() + 1e-8) * 0.9
168 |
169 | # save audio
170 | out_name = os.path.splitext(os.path.basename(f_path))[0] if use_original_name else f"sample_{idx}"
171 | out_path = f"{output_dir}/{out_name}.wav"
172 | torchaudio.save(out_path, pred_audio.unsqueeze(0), sample_rate=sr_out, encoding="PCM_F")
173 |
174 | print(f"--- Rank-{rank} : Finished. ---")
175 |
176 |
177 | if __name__ == '__main__':
178 | main()
179 |
--------------------------------------------------------------------------------
/src/model/wavefit/discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | Adapted from the following repos code under MIT License.
5 | https://github.com/descriptinc/melgan-neurips/
6 | https://github.com/descriptinc/descript-audio-codec
7 | """
8 | from abc import ABC, abstractmethod
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn.utils.parametrizations import weight_norm
14 | from einops import rearrange
15 |
16 |
17 | def weights_init(m):
18 | classname = m.__class__.__name__
19 | if classname.find("Conv") != -1:
20 | m.weight.data.normal_(0.0, 0.02)
21 | elif classname.find("BatchNorm2d") != -1:
22 | m.weight.data.normal_(1.0, 0.02)
23 | m.bias.data.fill_(0)
24 |
25 |
26 | def WNConv1d(*args, **kwargs):
27 | return weight_norm(nn.Conv1d(*args, **kwargs))
28 |
29 |
30 | def WNConv2dWithLeakyReLU(*args, **kwargs):
31 | act = kwargs.pop("act", True)
32 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
33 | if not act:
34 | return conv
35 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
36 |
37 |
38 | class MultiDiscriminator(ABC, nn.Module):
39 | @abstractmethod
40 | def forward(self, x: torch.Tensor, return_feature: bool = True):
41 | pass
42 |
43 | def compute_G_loss(self, x_fake, x_real):
44 | """
45 | The eq.(18) loss
46 | """
47 | assert x_fake.shape == x_real.shape
48 |
49 | out_f = self(x_fake, return_feature=True)
50 | with torch.no_grad():
51 | out_r = self(x_real, return_feature=True)
52 |
53 | num_D = len(self.model)
54 | losses = {
55 | 'disc_gan_loss': 0.,
56 | 'disc_feat_loss': 0.
57 | }
58 |
59 | for i_d in range(num_D):
60 | n_layer = len(out_f[i_d])
61 |
62 | # GAN loss
63 | losses['disc_gan_loss'] += (1 - out_f[i_d][-1]).relu().mean()
64 |
65 | # Feature-matching loss
66 | # eq.(8)
67 | feat_loss = 0.
68 | for i_l in range(n_layer - 1):
69 | feat_loss += F.l1_loss(out_f[i_d][i_l], out_r[i_d][i_l])
70 |
71 | losses['disc_feat_loss'] += feat_loss / (n_layer - 1)
72 |
73 | losses['disc_gan_loss'] /= num_D
74 | losses['disc_feat_loss'] /= num_D
75 |
76 | return losses
77 |
78 | def compute_D_loss(self, x, mode: str):
79 | """
80 | The eq.(7) loss
81 | """
82 | assert mode in ['fake', 'real']
83 | sign = 1 if mode == 'fake' else -1
84 |
85 | out = self(x, return_feature=False)
86 |
87 | num_D = len(self.model)
88 | losses = {'loss': 0.}
89 |
90 | for i_d in range(num_D):
91 | # Hinge loss
92 | losses['loss'] += (1 + sign * out[i_d][-1]).relu().mean()
93 |
94 | losses['loss'] /= num_D
95 |
96 | return losses
97 |
98 |
99 | class MSDBlock(nn.Module):
100 | def __init__(self, ndf, n_layers, downsampling_factor):
101 | super().__init__()
102 | model = nn.ModuleDict()
103 |
104 | model["layer_0"] = nn.Sequential(
105 | nn.ReflectionPad1d(7),
106 | WNConv1d(1, ndf, kernel_size=15),
107 | nn.LeakyReLU(0.2, True),
108 | )
109 |
110 | nf = ndf
111 | stride = downsampling_factor
112 | for n in range(1, n_layers + 1):
113 | nf_prev = nf
114 | nf = min(nf * stride, 1024)
115 |
116 | model["layer_%d" % n] = nn.Sequential(
117 | WNConv1d(
118 | nf_prev,
119 | nf,
120 | kernel_size=stride * 10 + 1,
121 | stride=stride,
122 | padding=stride * 5,
123 | groups=nf_prev // 4,
124 | ),
125 | nn.LeakyReLU(0.2, True),
126 | )
127 |
128 | nf = min(nf * 2, 1024)
129 | model["layer_%d" % (n_layers + 1)] = nn.Sequential(
130 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
131 | nn.LeakyReLU(0.2, True),
132 | )
133 |
134 | model["layer_%d" % (n_layers + 2)] = WNConv1d(
135 | nf, 1, kernel_size=3, stride=1, padding=1
136 | )
137 |
138 | self.model = model
139 |
140 | def forward(self, x: torch.Tensor, return_feature: bool = True):
141 | """
142 | Args:
143 | x: input audio, (bs, 1, L)
144 | """
145 | n_layer = len(self.model)
146 | results = []
147 | for idx, (key, layer) in enumerate(self.model.items()):
148 | x = layer(x)
149 | if return_feature or (idx == n_layer - 1):
150 | results.append(x)
151 |
152 | return results
153 |
154 |
155 | class MSD(MultiDiscriminator):
156 | """ Multi-scale discriminator """
157 |
158 | def __init__(
159 | self,
160 | num_D: int = 3,
161 | ndf: int = 16,
162 | n_layers: int = 4,
163 | downsampling_factor: int = 4
164 | ):
165 | super().__init__()
166 | self.model = nn.ModuleDict()
167 | for i in range(num_D):
168 | self.model[f"disc_{i}"] = MSDBlock(
169 | ndf, n_layers, downsampling_factor
170 | )
171 |
172 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
173 | self.apply(weights_init)
174 |
175 | def forward(self, x: torch.Tensor, return_feature: bool = True):
176 | """
177 | Args:
178 | x: input audio, (bs, 1, L)
179 | """
180 | results = []
181 | for key, disc in self.model.items():
182 | results.append(disc(x, return_feature))
183 | x = self.downsample(x)
184 |
185 | return results
186 |
187 |
188 | class MPDBlock(nn.Module):
189 | def __init__(self, period: int):
190 | super().__init__()
191 | self.period = period
192 | self.convs = nn.ModuleList(
193 | [
194 | WNConv2dWithLeakyReLU(1, 64, (5, 1), (3, 1), padding=(2, 0)),
195 | WNConv2dWithLeakyReLU(64, 128, (5, 1), (3, 1), padding=(2, 0)),
196 | WNConv2dWithLeakyReLU(128, 256, (5, 1), (3, 1), padding=(2, 0)),
197 | WNConv2dWithLeakyReLU(256, 512, (5, 1), (3, 1), padding=(2, 0)),
198 | WNConv2dWithLeakyReLU(512, 1024, (5, 1), 1, padding=(2, 0)),
199 | ]
200 | )
201 | self.conv_post = WNConv2dWithLeakyReLU(
202 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
203 | )
204 |
205 | def pad_to_period(self, x):
206 | t = x.shape[-1]
207 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
208 | return x
209 |
210 | def forward(self, x, return_feature: bool = True):
211 | results = []
212 |
213 | x = self.pad_to_period(x)
214 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
215 |
216 | for layer in self.convs:
217 | x = layer(x)
218 | if return_feature:
219 | results.append(x)
220 |
221 | x = self.conv_post(x)
222 | results.append(x)
223 |
224 | return results
225 |
226 |
227 | class MPD(MultiDiscriminator):
228 | """ Multi-period discriminator """
229 |
230 | def __init__(self, periods=[2, 3, 5, 7, 11, 13, 17, 19]):
231 | super().__init__()
232 | self.model = nn.ModuleDict()
233 | for p in periods:
234 | self.model[f"disc_{p}"] = MPDBlock(p)
235 |
236 | self.apply(weights_init)
237 |
238 | def forward(self, x: torch.Tensor, return_feature: bool = True):
239 | """
240 | Args:
241 | x: input audio, (bs, 1, L)
242 | """
243 | results = []
244 | for key, disc in self.model.items():
245 | results.append(disc(x, return_feature))
246 |
247 | return results
248 |
249 |
250 | class Discriminator(nn.Module):
251 | def __init__(
252 | self,
253 | msd_kwargs: dict = {
254 | "num_D": 3,
255 | "ndf": 16,
256 | "n_layers": 4,
257 | "downsampling_factor": 4
258 | },
259 | mpd_kwargs: dict = {
260 | "periods": [2, 3, 5, 7, 11, 13, 17, 19]
261 | }
262 | ):
263 | super().__init__()
264 | self.msd = MSD(**msd_kwargs)
265 | self.mpd = MPD(**mpd_kwargs)
266 |
267 | def compute_G_loss(self, x_fake, x_real):
268 | losses = {}
269 | losses_msd = self.msd.compute_G_loss(x_fake, x_real)
270 | losses_mpd = self.mpd.compute_G_loss(x_fake, x_real)
271 | for k in losses_msd.keys():
272 | losses[k] = (losses_msd[k] + losses_mpd[k]) / 2
273 | losses[f"mpd-{k}"] = losses_mpd[k].detach()
274 | losses[f"msd-{k}"] = losses_msd[k].detach()
275 |
276 | return losses
277 |
278 | def compute_D_loss(self, x, mode: str):
279 | losses = {}
280 | losses_msd = self.msd.compute_D_loss(x, mode)
281 | losses_mpd = self.mpd.compute_D_loss(x, mode)
282 | for k in losses_msd.keys():
283 | losses[k] = (losses_msd[k] + losses_mpd[k]) / 2
284 | losses[f"mpd-{k}"] = losses_mpd[k].detach()
285 | losses[f"msd-{k}"] = losses_msd[k].detach()
286 |
287 | return losses
288 |
--------------------------------------------------------------------------------
/src/data/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 |
5 | import os
6 | import random
7 | import typing as tp
8 | import csv
9 |
10 | import torch
11 | import numpy as np
12 |
13 | from utils.torch_common import print_once, exists
14 | from .audio_io import get_audio_metadata, load_audio_with_pad
15 | from .modification import Mono, PhaseFlipper, VolumeChanger
16 | from .pretransform import GemmaAudioFeature
17 | from .degradation import AudioClipping, NoiseAddition, RIRReverb, AudioLowpass
18 |
19 |
20 | def fast_scandir(dir: str, ext: tp.List[str]):
21 | """ Very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243
22 |
23 | fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
24 |
25 | Args:
26 | dir (str): top-level directory at which to begin scanning.
27 | ext (tp.List[str]): list of allowed file extensions.
28 | """
29 | subfolders, files = [], []
30 | # add starting period to extensions if needed
31 | ext = ['.' + x if x[0] != '.' else x for x in ext]
32 |
33 | try: # hope to avoid 'permission denied' by this try
34 | for f in os.scandir(dir):
35 | try: # 'hope to avoid too many levels of symbolic links' error
36 | if f.is_dir():
37 | subfolders.append(f.path)
38 | elif f.is_file():
39 | is_hidden = os.path.basename(f.path).startswith(".")
40 | has_ext = os.path.splitext(f.name)[1].lower() in ext
41 |
42 | if has_ext and (not is_hidden):
43 | files.append(f.path)
44 | except Exception:
45 | pass
46 | except Exception:
47 | pass
48 |
49 | for dir in list(subfolders):
50 | sf, f = fast_scandir(dir, ext)
51 | subfolders.extend(sf)
52 | files.extend(f)
53 |
54 | return subfolders, files
55 |
56 |
57 | def get_info_from_csv(csv_path: str, filepath_tag: str = 'file_path',
58 | other_info_tags: tp.List[str] = ['sample_rate', 'num_frames', 'num_channels']):
59 | file_paths = []
60 | meta_dicts = []
61 | with open(csv_path, 'r', newline='') as f:
62 | reader = csv.DictReader(f) # 各行を dict として読み込む
63 | for row in reader:
64 | file_paths.append(row[filepath_tag])
65 | meta = {k: int(row[k]) for k in other_info_tags}
66 | meta_dicts.append(meta)
67 |
68 | # sort by file path
69 | sorted_indices = np.argsort(file_paths)
70 | file_paths = [file_paths[i] for i in sorted_indices]
71 | meta_dicts = [meta_dicts[i] for i in sorted_indices]
72 |
73 | return file_paths, meta_dicts
74 |
75 |
76 | def get_audio_info(
77 | paths: tp.List[str], # directories in which to search
78 | exts: tp.List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
79 | ):
80 | """recursively get a list of audio filenames"""
81 | if isinstance(paths, str):
82 | paths = [paths]
83 |
84 | # get a list of relevant filenames
85 | filepaths = []
86 | metas = []
87 | for p in paths:
88 | metadata_csv_path = f"{p}/metadata.csv"
89 | if os.path.exists(metadata_csv_path):
90 | # If metadata.csv exists, it's faster to get info
91 | filepaths_, metas_ = get_info_from_csv(metadata_csv_path)
92 | else:
93 | _, filepaths_ = fast_scandir(p, exts)
94 | filepaths_.sort()
95 | metas_ = []
96 | for filepath in filepaths_:
97 | info = get_audio_metadata(filepath, cache=True)
98 | metas_.append(info)
99 |
100 | filepaths_ = [os.path.join(p, f) for f in filepaths_]
101 | filepaths.extend(filepaths_)
102 | metas.extend(metas_)
103 |
104 | return filepaths, metas
105 |
106 |
107 | class DegradedAudioDataset(torch.utils.data.Dataset):
108 | """
109 | A dataset class for loading a pair of a target audio and a degraded audio.
110 | NOTE: This dataset supports only monoral output.
111 | """
112 |
113 | def __init__(
114 | self,
115 | dirs_audio: tp.List[str],
116 | dirs_noise: tp.Optional[tp.List[str]] = None,
117 | sample_size: int = 120000,
118 | sample_rate: int = 24000,
119 | pretransform: tp.Optional[str] = None,
120 | exts: tp.List[str] = ['wav', 'flac'],
121 | # augmentation
122 | augment_shift: bool = True,
123 | augment_flip: bool = True,
124 | augment_volume: bool = True,
125 | volume_range: tp.Tuple[float, float] = (0.25, 1.0),
126 | # degradation
127 | deg_types: tp.List[str] = ['clipping', 'noise', 'reverb', 'lowpass'],
128 | n_deg_comb: int = 3, # maximum number of combined degradations
129 | prob_no_deg: float = 0.05, # probability of no degradation samples
130 | clean_only: bool = False, # If true, only clean audio is returned.
131 | ):
132 |
133 | super().__init__()
134 | self.sample_size = sample_size
135 | self.sr = sample_rate
136 | self.augment_shift = augment_shift
137 |
138 | self.deg_types = deg_types
139 | self.n_deg_comb = min(n_deg_comb, len(deg_types))
140 | self.prob_no_deg = prob_no_deg
141 | self.clean_only = clean_only
142 |
143 | print_once('[Dataset instantiation]')
144 |
145 | # Degradation modules
146 | self.degradations = {
147 | 'clipping': AudioClipping(sample_rate=sample_rate),
148 | 'noise': NoiseAddition(sample_rate=sample_rate),
149 | 'reverb': RIRReverb(sample_rate=sample_rate),
150 | 'lowpass': AudioLowpass(sample_rate=sample_rate)
151 | }
152 |
153 | # Audio augmentations
154 | self.ch_encoding = torch.nn.Sequential(Mono())
155 | self.augs = torch.nn.Sequential(
156 | PhaseFlipper() if augment_flip else torch.nn.Identity(),
157 | VolumeChanger(*volume_range) if augment_volume else torch.nn.Identity()
158 | )
159 |
160 | # Pre-transform
161 | if pretransform == "gemma":
162 | self.pretransform = GemmaAudioFeature()
163 | else:
164 | self.pretransform = None
165 |
166 | # find all audio files
167 | print_once('\t->-> Searching AUDIO files...')
168 | self.filepaths, self.metas = get_audio_info(dirs_audio, exts=exts)
169 | print_once(f'\t->-> Found {len(self.filepaths)} AUDIO files.')
170 |
171 | if exists(dirs_noise) and ('noise' in self.deg_types) and (not self.clean_only):
172 | print_once('\t->-> Searching NOISE files...')
173 | self.filepaths_noise, self.metas_noise = get_audio_info(dirs_noise, exts=exts)
174 | self.n_noise_files = len(self.filepaths_noise)
175 | print_once(f'\t->-> Found {self.n_noise_files} NOISE files.')
176 |
177 | def get_track_info(self, idx, noise: bool = False):
178 | if not noise:
179 | filepath = self.filepaths[idx]
180 | info = self.metas[idx]
181 | else:
182 | filepath = self.filepaths_noise[idx]
183 | info = self.metas_noise[idx]
184 |
185 | max_ofs = max(0, info['num_frames'] - self.sample_size)
186 | offset = random.randint(0, max_ofs) if (self.augment_shift and max_ofs) else 0
187 | return filepath, offset, info
188 |
189 | def __len__(self):
190 | return len(self.filepaths)
191 |
192 | def __getitem__(self, idx):
193 |
194 | filename, offset, info = self.get_track_info(idx)
195 | # Load audio
196 | audio = load_audio_with_pad(filename, info, self.sr, self.sample_size, offset)
197 | # To mono
198 | audio = self.ch_encoding(audio).squeeze(0) # (L,)
199 | # Audio augmentations
200 | audio = self.augs(audio)
201 | target = self.pretransform(audio, self.sr) if exists(self.pretransform) else audio
202 |
203 | # Degradation
204 | if self.clean_only:
205 | deg = deg_audio = 0.
206 | else:
207 | num_deg = random.randint(1, self.n_deg_comb) if random.random() > self.prob_no_deg else 0
208 | # randomly sample degradations
209 | deg_types = random.sample(self.deg_types, num_deg)
210 | deg_audio = audio.clone()
211 | # apply degradations
212 | for deg_type in deg_types:
213 | if deg_type == 'noise':
214 | # randomly sample a noise file
215 | if self.n_noise_files > 0:
216 | noise_idx = random.randint(0, self.n_noise_files - 1)
217 | filename, offset, info = self.get_track_info(noise_idx, noise=True)
218 | noise = load_audio_with_pad(filename, info, self.sr, self.sample_size, offset)
219 | noise = self.ch_encoding(noise).squeeze(0) # (L,)
220 | else:
221 | # if no noise files are found, use Gaussian noise
222 | print_once("No noise files found. Using Gaussian noise.")
223 | noise = torch.randn_like(deg_audio)
224 | deg_audio = self.degradations[deg_type](deg_audio, noise)
225 | else:
226 | deg_audio = self.degradations[deg_type](deg_audio)
227 |
228 | # print(f"Deg ({deg_type}): {deg_audio.abs().max().item()}")
229 |
230 | deg = self.pretransform(deg_audio, self.sr) if exists(self.pretransform) else deg_audio
231 | info['deg_types'] = deg_types + ['none'] * (self.n_deg_comb - len(deg_types))
232 |
233 | return target, deg, audio, deg_audio, info
234 |
--------------------------------------------------------------------------------
/src/model/wavefit/generator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | Adapted from the following repo's code under Apache License 2.0.
5 | https://github.com/lmnt-com/wavegrad/
6 | """
7 |
8 | import typing as tp
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | class Conv1d(nn.Conv1d):
17 | def __init__(self, *args, **kwargs):
18 | super().__init__(*args, **kwargs)
19 | self.reset_parameters()
20 |
21 | def reset_parameters(self):
22 | nn.init.orthogonal_(self.weight)
23 | nn.init.zeros_(self.bias)
24 |
25 |
26 | class SinusoidalPositionalEncoding(nn.Module):
27 | def __init__(self, dim: int, max_iter: int, use_conv: bool = True):
28 | super().__init__()
29 | self.dim = dim
30 | self.max_iter = max_iter
31 | self.use_conv = use_conv
32 | assert dim % 2 == 0
33 |
34 | if use_conv:
35 | # 1x1 conv
36 | self.conv = nn.Conv1d(dim, dim, 1)
37 | nn.init.xavier_uniform_(self.conv.weight)
38 | nn.init.zeros_(self.conv.bias)
39 |
40 | # pre-compute positional embedding
41 | pos_embs = self.prepare_embedding() # (max_iter, dim)
42 | self.register_buffer('pos_embs', pos_embs)
43 |
44 | def forward(self, x, t: int):
45 | """
46 | Args:
47 | x: (bs, dim, T)
48 | t: Step index
49 |
50 | Returns:
51 | x_with_pos: (bs, dim, T)
52 | """
53 | assert 0 <= t < self.max_iter, f"Invalid step index {t}. It must be 0 <= t < {self.max_iter} = max_iter."
54 | pos_emb = self.pos_embs[t][None, :, None]
55 | if self.use_conv:
56 | pos_emb = self.conv(pos_emb)
57 |
58 | return x + pos_emb
59 |
60 | def prepare_embedding(self, scale: float = 5000.):
61 | dim_h = self.dim // 2
62 | pos = torch.linspace(0., scale, self.max_iter)
63 | div_term = torch.exp(- math.log(10000.0) * torch.arange(dim_h) / dim_h)
64 | pos = pos[:, None] @ div_term[None, :] # (max_iter, dim_h)
65 | pos_embs = torch.cat([torch.sin(pos), torch.cos(pos)], dim=-1) # (max_iter, dim)
66 | return pos_embs
67 |
68 |
69 | class MemEfficientFiLM(nn.Module):
70 | def __init__(self, input_size: int, output_size: int, max_iter: int):
71 | super().__init__()
72 | self.step_condition = SinusoidalPositionalEncoding(input_size, max_iter, use_conv=True)
73 | self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
74 | self.output_conv_1 = nn.Conv1d(input_size, output_size, 3, padding=1)
75 | self.reset_parameters()
76 |
77 | def reset_parameters(self):
78 | nn.init.xavier_uniform_(self.input_conv.weight)
79 | nn.init.zeros_(self.input_conv.bias)
80 | nn.init.xavier_uniform_(self.output_conv_1.weight)
81 | nn.init.zeros_(self.output_conv_1.bias)
82 |
83 | def forward(self, x, t: int):
84 | x = self.input_conv(x)
85 | x = F.leaky_relu(x, 0.2)
86 | x = self.step_condition(x, t)
87 | shift = self.output_conv_1(x)
88 |
89 | return shift, None
90 |
91 |
92 | class EmptyFiLM(nn.Module):
93 | def __init__(self):
94 | super().__init__()
95 |
96 | def forward(self, x, t):
97 | return 0, 1
98 |
99 |
100 | class UBlock(nn.Module):
101 | def __init__(self, input_size, hidden_size, factor, dilation):
102 | super().__init__()
103 | assert isinstance(dilation, (list, tuple))
104 | assert len(dilation) == 4
105 |
106 | self.factor = factor
107 | self.block1 = Conv1d(input_size, hidden_size, 1)
108 | self.block2 = nn.ModuleList([
109 | Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
110 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1])
111 | ])
112 | self.block3 = nn.ModuleList([
113 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
114 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3])
115 | ])
116 |
117 | def forward(self, x, film_shift, film_scale: tp.Optional[torch.Tensor]):
118 | if film_scale is None:
119 | film_scale = 1.0
120 |
121 | block1 = F.interpolate(x, size=x.shape[-1] * self.factor)
122 | block1 = self.block1(block1)
123 |
124 | block2 = F.leaky_relu(x, 0.2)
125 | block2 = F.interpolate(block2, size=x.shape[-1] * self.factor)
126 | block2 = self.block2[0](block2)
127 | block2 = film_shift + film_scale * block2
128 | block2 = F.leaky_relu(block2, 0.2)
129 | block2 = self.block2[1](block2)
130 |
131 | x = block1 + block2
132 |
133 | block3 = film_shift + film_scale * x
134 | block3 = F.leaky_relu(block3, 0.2)
135 | block3 = self.block3[0](block3)
136 | block3 = film_shift + film_scale * block3
137 | block3 = F.leaky_relu(block3, 0.2)
138 | block3 = self.block3[1](block3)
139 |
140 | x = x + block3
141 |
142 | return x
143 |
144 |
145 | class DBlock(nn.Module):
146 | def __init__(self, input_size, hidden_size, factor):
147 | super().__init__()
148 | self.factor = factor
149 |
150 | # self.residual_dense = Conv1d(input_size, hidden_size, 1)
151 | # self.conv = nn.ModuleList([
152 | # Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
153 | # Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
154 | # Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
155 | # ])
156 |
157 | # NOTE : This might be the correct architecture rather than the above one
158 | # since parameter size is quite closer to the reported size in the WaveGrad paper (15M).
159 | self.residual_dense = Conv1d(input_size, input_size, 1)
160 | self.conv = nn.ModuleList([
161 | Conv1d(input_size, input_size, 3, dilation=1, padding=1),
162 | Conv1d(input_size, input_size, 3, dilation=2, padding=2),
163 | Conv1d(input_size, hidden_size, 3, dilation=4, padding=4),
164 | ])
165 |
166 | # downsampling module using Conv1d
167 | # NOTE: When using kernel_size=3 for all downsampling factors,
168 | # the parameter size of generator is 15.12 millions.
169 | kernel_size = factor // 2 * 2 + 1
170 | padding = kernel_size // 2
171 | self.down1 = Conv1d(input_size, hidden_size, kernel_size, padding=padding, stride=factor)
172 | self.down2 = Conv1d(input_size, input_size, kernel_size, padding=padding, stride=factor)
173 |
174 | def forward(self, x):
175 | residual = self.residual_dense(x)
176 | residual = self.down1(residual)
177 |
178 | x = self.down2(x)
179 | for layer in self.conv:
180 | x = F.leaky_relu(x, 0.2)
181 | x = layer(x)
182 |
183 | return x + residual
184 |
185 |
186 | class WaveFitGenerator(nn.Module):
187 | """
188 | WaveFit generator module based on WaveGrad.
189 | See https://arxiv.org/abs/2009.00713 for details.
190 | """
191 |
192 | def __init__(
193 | self,
194 | num_iteration: int,
195 | dim_feat: int = 1536,
196 | upsample_factors: tp.List[int] = [5, 4, 3, 2, 2],
197 | upsample_channels: tp.List[int] = [512, 512, 256, 128, 128],
198 | downsample_channels: tp.List[int] = [128, 128, 256, 512]
199 | ):
200 | super().__init__()
201 |
202 | self.dim_feat = dim_feat
203 | self.upsample_factors = upsample_factors
204 | self.upsample_channels = upsample_channels
205 | self.downsample_factors = upsample_factors[1:][::-1] # e.g. [2, 2, 3, 4]
206 | self.downsample_channels = downsample_channels
207 | assert len(upsample_factors) == len(upsample_channels) == len(downsample_channels) + 1
208 | self.upsample_rate = math.prod(upsample_factors)
209 |
210 | # Downsampling blocks
211 | ch_first_down = 32
212 | self.downsample = nn.ModuleList([Conv1d(1, ch_first_down, 5, padding=2)])
213 | for i, (factor, ch_out) in enumerate(zip(self.downsample_factors, self.downsample_channels)):
214 | ch_in = ch_first_down if i == 0 else self.downsample_channels[i - 1]
215 | self.downsample.append(DBlock(ch_in, ch_out, factor))
216 |
217 | # FiLM layers
218 | self.film = nn.ModuleList([EmptyFiLM()])
219 | for i in range(len(self.downsample_channels)):
220 | ch_in, ch_out = self.downsample_channels[i], self.upsample_channels[-(i + 2)]
221 | self.film.append(MemEfficientFiLM(ch_in, ch_out, num_iteration))
222 |
223 | # Upsampling blocks
224 | # NOTE: Dilation factors in a 5-block case follows an implementation in the WaveGrad paper.
225 | # Cases other than 5-block are not verified.
226 | self.upsample = nn.ModuleList()
227 | for i, (factor, ch_out) in enumerate(zip(self.upsample_factors, self.upsample_channels)):
228 | ch_in = self.dim_feat if i == 0 else self.upsample_channels[i - 1]
229 | dilations = [1, 2, 4, 8] if i < 3 else [1, 2, 1, 2]
230 | self.upsample.append(UBlock(ch_in, ch_out, factor, dilations))
231 |
232 | self.last_conv = Conv1d(128, 1, 3, padding=1)
233 |
234 | def forward(self, y_t: torch.Tensor, audio_feats: torch.Tensor, t: int):
235 | """
236 | Args:
237 | y_t: Noisy input, (bs, 1, L)
238 | audio_feats: Audio features, (bs, dim_feat, num_frame)
239 | t: Step index
240 | Returns:
241 | n_hat: Estimated noise, (bs, 1, L)
242 | """
243 | bs, ch, L = y_t.size()
244 | bs_, dim, num_frame = audio_feats.size()
245 | assert bs == bs_ and dim == self.dim_feat and ch == 1
246 | assert L == num_frame * self.upsample_rate, f"Length mismatch: {L} != {num_frame} * {self.upsample_rate}"
247 |
248 | x = y_t
249 |
250 | # downsampling
251 | downsampled = []
252 | for film, layer in zip(self.film, self.downsample):
253 | x = layer(x)
254 | downsampled.append(film(x, t))
255 |
256 | # upsampling and FiLM
257 | x = audio_feats
258 | for layer, (film_shift, film_scale) in zip(self.upsample, reversed(downsampled)):
259 | x = layer(x, film_shift, film_scale)
260 |
261 | # to monoral
262 | x = self.last_conv(x)
263 |
264 | return x
265 |
--------------------------------------------------------------------------------
/src/model/feature_cleaner/google_usm.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | Adapted from the following repo's code under Apache-2.0 License.
5 | https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py
6 |
7 | -----------------------------------------------------
8 | Universal Speech Model (USM) from Google.
9 | """
10 | import typing as tp
11 | import gc
12 | from pathlib import Path
13 |
14 | import torch
15 | import torch.nn as nn
16 | from transformers import Gemma3nAudioEncoder, Gemma3nAudioFeatureExtractor
17 | from transformers.models.gemma3n.modeling_gemma3n import Gemma3nRMSNorm
18 |
19 | from .parallel_adapter import AdapterLayer
20 | from .base import AudioEncoderAdapter
21 | from utils.torch_common import exists
22 |
23 |
24 | class GoogleUSMAdapter(AudioEncoderAdapter):
25 | """
26 | Parallel adapter for Google USM described in Fig.1 of the Miipher-2 paper.
27 |
28 | NOTE: The shared layer norm before the adapter layers seems to be missing in the Gemma3n implementation.
29 | Instead, I tentatively introduce a pre-layer norm here for each adapter layer.
30 | """
31 |
32 | def __init__(
33 | self,
34 | n_adaptive_layers: tp.Optional[int] = None,
35 | model_id: str = "google/gemma-3n-e2b-it",
36 | encoder_id: str = "Atotti/google-usm",
37 | adapter_config: dict = {
38 | "dim_bottleneck": 1024,
39 | "init_option": "bert",
40 | "adapter_scalar": 1.0,
41 | "pre_ln_class": Gemma3nRMSNorm
42 | }
43 | ):
44 | super().__init__()
45 | self.model_id = model_id
46 | self.encoder_id = encoder_id
47 |
48 | # Feature extractor (mel-spectrogram)
49 | # NOTE: This feature extractor is not used when an input is a mel-spectrogram (e.g. forward function).
50 | # This could be used at an inference time when the input is a waveform.
51 | self.feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model_id)
52 |
53 | # Main audio encoder
54 | self.audio_encoder = Gemma3nAudioEncoder.from_pretrained(encoder_id)
55 | self.n_adaptive_layers = min(n_adaptive_layers, self.n_layers) if exists(n_adaptive_layers) else self.n_layers
56 | self.dim = self.audio_encoder.config.hidden_size
57 | self.adapter_config = adapter_config
58 | self.adapter_config["dim_in"] = self.dim
59 | self.adapter_config["pre_ln_class"] = globals()[adapter_config["pre_ln_class"]] \
60 | if isinstance(adapter_config["pre_ln_class"], str) else adapter_config["pre_ln_class"]
61 |
62 | # Remove unused layers
63 | layers = self.audio_encoder.conformer
64 | self.audio_encoder.conformer = nn.ModuleList(layers[:self.n_adaptive_layers])
65 | del layers # Free memory
66 | gc.collect()
67 | torch.cuda.empty_cache()
68 |
69 | # Exclude encoder from learnable modules
70 | for param in self.audio_encoder.parameters():
71 | param.requires_grad = False
72 | self.audio_encoder.eval()
73 |
74 | # Adapter layers
75 | self.adapter_layers = nn.ModuleList()
76 | self.adapter_norms = nn.ModuleList()
77 | for _ in range(self.n_adaptive_layers):
78 | self.adapter_layers.append(AdapterLayer(**self.adapter_config))
79 | self.adapter_norms.append(Gemma3nRMSNorm(self.dim))
80 |
81 | @property
82 | def n_layers(self) -> int:
83 | return len(self.audio_encoder.conformer)
84 |
85 | def eval(self):
86 | super().eval()
87 | return self
88 |
89 | def train(self, mode=True):
90 | super().train(mode)
91 | self.audio_encoder.eval() # Keep the encoder in eval mode
92 | return self
93 |
94 | def forward(
95 | self,
96 | audio_mel: torch.Tensor,
97 | encoder_only: bool = False
98 | ) -> torch.Tensor:
99 | """
100 | Args:
101 | audio_mel (torch.Tensor): (bs, n_frame, mel_bins)
102 | encoder_only (bool): If True, only the encoder is used without adaptation.
103 | Returns:
104 | (torch.Tensor): (bs, n_frame', dim)
105 | """
106 | bs, n_frame, _ = audio_mel.shape
107 | # NOTE: 'False' for valid frames, 'True' for padded frames
108 | audio_mel_mask = torch.zeros((bs, n_frame), dtype=torch.bool, device=audio_mel.device)
109 |
110 | feats = self._encoder_forward(audio_mel, audio_mel_mask, encoder_only)
111 |
112 | return feats
113 |
114 | def _encoder_forward(
115 | self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, encoder_only: bool = False
116 | ) -> tuple[torch.Tensor, torch.BoolTensor]:
117 | """Encodes a batch of MELs.
118 |
119 | Args:
120 | audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels, mel_bins].
121 | audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
122 |
123 | Returns:
124 | feats: a torch.Tensor of shape `[batch_size, frame_length, self.dim]`
125 | """
126 | audio_encodings = self.audio_encoder.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
127 |
128 | # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
129 | t_sub = audio_encodings.shape[1]
130 |
131 | time_stride_product = 1
132 | for stride_pair_idx in range(len(self.audio_encoder.config.sscp_conv_stride_size)):
133 | time_stride_product *= self.audio_encoder.config.sscp_conv_stride_size[stride_pair_idx][0]
134 |
135 | # Create indices for gathering from the original mask.
136 | # These indices map to original time steps corresponding to the start of each
137 | # receptive field in the subsampled output.
138 | indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
139 | indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
140 |
141 | # Expand indices for batch compatibility if B > 1 and indices is 1D.
142 | if audio_mel_mask.ndim > 1 and indices.ndim == 1:
143 | indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
144 | elif (
145 | audio_mel_mask.ndim == indices.ndim
146 | and audio_mel_mask.shape[0] == 1
147 | and indices.shape[0] != 1
148 | and t_sub == indices.shape[0]
149 | ):
150 | # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
151 | indices = indices.unsqueeze(0)
152 |
153 | current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
154 |
155 | # Adaptation
156 | feats = audio_encodings
157 | for i, (conformer_block, adapter_layer, adapter_norm) \
158 | in enumerate(zip(self.audio_encoder.conformer, self.adapter_layers, self.adapter_norms)):
159 | feats = self._block_forward(
160 | conformer_block,
161 | adapter_layer,
162 | adapter_norm,
163 | feats,
164 | current_mask,
165 | encoder_only
166 | )
167 |
168 | return feats
169 |
170 | def _block_forward(
171 | self,
172 | c_block,
173 | adapter_layer,
174 | adapter_norm,
175 | feats: torch.Tensor,
176 | mask: torch.BoolTensor,
177 | encoder_only: bool = False
178 | ) -> torch.Tensor:
179 | """ Gemma3nAudioConformerBlock forward pass with adaptation (Fig.1) """
180 |
181 | feats = c_block.ffw_layer_start(feats)
182 | feats = c_block.attention(feats, mask)
183 | validity_mask_for_lconv = ~mask # True for valid
184 | feats_for_lconv_input = feats * validity_mask_for_lconv.unsqueeze(-1).to(feats.dtype)
185 | feats_conv = c_block.lconv1d(feats_for_lconv_input)
186 | feats = c_block.ffw_layer_end(feats_conv) # feats_conv + feats_mlp
187 | feats = torch.clamp(feats, -c_block.gradient_clipping, c_block.gradient_clipping)
188 | # NOTE: This layer norm doesn't exist in the Fig.1 of the paper.
189 | feats = c_block.norm(feats)
190 |
191 | if not encoder_only:
192 | # Adapter layer
193 | feats_adapt = adapter_layer(feats_conv)
194 | feats = feats + feats_adapt
195 | # Post layer norm
196 | feats = adapter_norm(feats)
197 |
198 | return feats
199 |
200 | @torch.no_grad()
201 | def forward_waveform(
202 | self,
203 | audio: torch.Tensor,
204 | encoder_only: bool = False,
205 | device: tp.Optional[torch.device] = None
206 | ) -> torch.Tensor:
207 | """
208 | Args:
209 | audio (torch.Tensor): Monoral audio, (bs, n_sample)
210 | encoder_only (bool): If True, only the encoder is used without adaptation.
211 | Returns:
212 | (torch.Tensor): (bs, n_frame', dim)
213 | """
214 | # Feature extraction
215 | audio_np = audio.cpu().numpy()
216 | output = self.feature_extractor(audio_np, return_tensors="pt")
217 | audio_mel = output["input_features"]
218 |
219 | device = audio.device if device is None else device
220 | audio_mel = audio_mel.to(device)
221 |
222 | # Forward
223 | feats = self(audio_mel, encoder_only)
224 |
225 | return feats
226 |
227 | def save_state_dict(self, dir_save: str):
228 | """
229 | Save only the adapter layer and adaptive norm parameters.
230 | """
231 | state = {
232 | 'adapter_layers': self.adapter_layers.state_dict(),
233 | 'adapter_norms': self.adapter_norms.state_dict()
234 | }
235 | torch.save(state, Path(dir_save) / "model.pth")
236 |
237 | def load_state_dict(self, dir_load: tp.Optional[str] = None, state: tp.Optional[dict] = None):
238 | """
239 | Load only the adapter layer and adaptive norm parameters.
240 | """
241 | assert exists(dir_load) or exists(state), "Either dir_load or state must be provided."
242 | state = state if exists(state) else torch.load(Path(dir_load) / "model.pth")
243 | self.adapter_layers.load_state_dict(state['adapter_layers'])
244 | self.adapter_norms.load_state_dict(state['adapter_norms'])
245 |
246 | def get_state_dict(self):
247 | state = {
248 | 'adapter_layers': self.adapter_layers.state_dict(),
249 | 'adapter_norms': self.adapter_norms.state_dict()
250 | }
251 |
252 | return state
253 |
--------------------------------------------------------------------------------
/src/model/miipher_2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 |
4 | --------------
5 | Miipher-2 model
6 | """
7 | import typing as tp
8 | from enum import Enum
9 | from pathlib import Path
10 |
11 | import torch
12 | from torch import nn
13 | from torch.nn import functional as F
14 |
15 | from utils.torch_common import exists, print_once
16 |
17 |
18 | class MiipherMode(Enum):
19 | """
20 | Miipher-2 processing modes.
21 |
22 | CLEAN_INPUT: Clean audio is processed by a non-adaptive feature cleaner (WaveFit pretraining).
23 | NOISY_INPUT: Noisy audio is processed by an adaptive feature cleaner (WaveFit finetuning / Inference).
24 | """
25 | CLEAN_INPUT = 'clean_input' # Clean input waveform
26 | NOISY_INPUT = 'noisy_input' # Noisy input waveform
27 |
28 |
29 | class Miipher2(nn.Module):
30 | """
31 | Miipher-2 model consists of a feature cleaner and a WaveFit vocoder.
32 |
33 | Args:
34 | feature_cleaner (nn.Module): Feature cleaner model (e.g., Google-USM with adapter layers).
35 | vocoder (nn.Module): Vocoder model (e.g., WaveFit-5).
36 | mode (MiipherMode): Processing mode. See MiipherMode for details.
37 | """
38 |
39 | def __init__(
40 | self,
41 | feature_cleaner: nn.Module,
42 | vocoder: nn.Module,
43 | mode: str = 'noisy_input',
44 | # modules for vocoder training
45 | discriminator: tp.Optional[nn.Module] = None,
46 | mrstft: tp.Optional[nn.Module] = None,
47 | loss_lambdas: dict = {},
48 | # upsampling before vocoder
49 | upsample_factor: int = 4,
50 | upsample_mode: str = 'linear',
51 | # pretrained checkpoints are required for finetuning
52 | feature_cleaner_ckpt_dir: tp.Optional[str] = None,
53 | vocoder_ckpt_dir: tp.Optional[str] = None,
54 | # training config
55 | gradient_checkpointing: bool = False
56 | ):
57 | super().__init__()
58 | self.feature_cleaner = feature_cleaner
59 | self.vocoder = vocoder
60 | self.discriminator = discriminator
61 | self.mrstft = mrstft
62 | self.loss_lambdas = loss_lambdas
63 | self._mode = MiipherMode(mode)
64 | self.upsample_factor = upsample_factor
65 | self.upsample_mode = upsample_mode
66 | self.gradient_checkpointing = gradient_checkpointing
67 |
68 | if exists(feature_cleaner_ckpt_dir):
69 | self.feature_cleaner.load_state_dict(feature_cleaner_ckpt_dir)
70 | print_once(f"[Miipher-2 class] Loaded feature cleaner state from {feature_cleaner_ckpt_dir}")
71 |
72 | if exists(vocoder_ckpt_dir):
73 | state_vocoder = torch.load(Path(vocoder_ckpt_dir) / "vocoder.pth")
74 | self.vocoder.load_state_dict(state_vocoder)
75 | print_once(f"[Miipher-2 class] Loaded vocoder state from {vocoder_ckpt_dir}")
76 | if exists(self.discriminator):
77 | state_discriminator = torch.load(Path(vocoder_ckpt_dir) / "discriminator.pth")
78 | self.discriminator.load_state_dict(state_discriminator)
79 | print_once(f"[Miipher-2 class] Loaded discriminator state from {vocoder_ckpt_dir}")
80 |
81 | # no gradient for feature cleaner
82 | for param in self.feature_cleaner.parameters():
83 | param.requires_grad = False
84 |
85 | def eval(self):
86 | super().eval()
87 | return self
88 |
89 | def train(self, mode=True):
90 | super().train(mode)
91 | self.feature_cleaner.eval() # Keep the feature cleaner in eval mode
92 | return self
93 |
94 | def set_mode(self, mode: tp.Union[MiipherMode, str]):
95 | mode = MiipherMode(mode) if isinstance(mode, str) else mode
96 | self._mode = mode
97 |
98 | @property
99 | def mode(self) -> MiipherMode:
100 | return self._mode
101 |
102 | def forward(self, mel_spec: torch.Tensor, initial_noise: torch.Tensor) -> torch.Tensor:
103 | """
104 | Args:
105 | mel_spec (torch.Tensor): Input mel-spectrogram, (B, num_frames, feature_dim)
106 | initial_noise (torch.Tensor): Initial noise for the vocoder, (B, L)
107 | Returns:
108 | vocoder_output (List[torch.Tensor]): Output waveform, (B, L')
109 | """
110 | # Mel-spectrogram to audio encoder feature
111 | encoder_only = (self._mode == MiipherMode.CLEAN_INPUT)
112 | with torch.no_grad():
113 | feats = self.feature_cleaner(mel_spec, encoder_only=encoder_only) # (bs, num_frames, dim)
114 |
115 | # Upsample features to match the vocoder's input frame rate (Sec.2.3)
116 | feats = feats.transpose(1, 2) # (bs, dim, num_frames)
117 | feats = F.interpolate(feats, scale_factor=self.upsample_factor, mode=self.upsample_mode)
118 | feats = feats.transpose(1, 2) # (bs, num_frames, dim)
119 |
120 | # Audio encoder feature to waveform
121 | vocoder_output = self.vocoder(initial_noise, feats,
122 | gradient_checkpointing=self.gradient_checkpointing) # List of (B, L)
123 |
124 | return vocoder_output
125 |
126 | @torch.no_grad()
127 | def inference(
128 | self,
129 | input_waveform: torch.Tensor,
130 | initial_noise: tp.Optional[torch.Tensor] = None
131 | ) -> torch.Tensor:
132 | """
133 | Inference with waveform input.
134 |
135 | Args:
136 | input_waveform (torch.Tensor): Input waveform, 16kh, (B, L)
137 | Returns:
138 | decoded_waveform (torch.Tensor): Output waveform, 24khz, (B, L')
139 | """
140 |
141 | # 16khz -> 24khz
142 | if exists(initial_noise):
143 | assert initial_noise.shape[-1] == int(input_waveform.shape[-1] * 1.5)
144 | else:
145 | initial_noise = torch.randn(input_waveform.shape[0], int(input_waveform.shape[-1] * 1.5), device=input_waveform.device)
146 |
147 | # Waveform to audio encoder feature
148 | feats = self.feature_cleaner.forward_waveform(input_waveform, encoder_only=False)
149 |
150 | # Upsample features to match the vocoder's input frame rate (Sec.2.3)
151 | feats = feats.transpose(1, 2) # (bs, dim, num_frames)
152 | feats = F.interpolate(feats, scale_factor=self.upsample_factor, mode=self.upsample_mode)
153 | feats = feats.transpose(1, 2) # (bs, num_frames, dim)
154 |
155 | # Audio encoder feature to waveform
156 | decoded_waveform = self.vocoder(initial_noise, feats, return_only_last=True)[-1] # (B, L)
157 |
158 | return decoded_waveform
159 |
160 | def train_step(
161 | self,
162 | target_audio: torch.Tensor,
163 | input_mel_spec: torch.Tensor,
164 | train: bool = True
165 | ) -> dict:
166 | """
167 | Training step for the Miipher-2 model.
168 |
169 | Args:
170 | target_audio (torch.Tensor): Target audio waveform, (B, L)
171 | input_mel_spec (torch.Tensor): Input mel-spectrogram, (B, num_frames, feature_dim)
172 | train (bool): Whether in training mode.
173 |
174 | Returns:
175 | dict: Losses and metrics for the training step.
176 | """
177 | assert exists(self.discriminator) and exists(self.mrstft), "Discriminator and MRSTFT must be provided for training."
178 |
179 | self.train() if train else self.eval()
180 |
181 | # Fix the gain of target audio
182 | # NOTE: Note that the gain of target audio is specified by the WaveFit model
183 | # due to the gain normalization.
184 | scale = self.vocoder.target_gain / (target_audio.abs().max(dim=1, keepdim=True)[0] + 1e-8)
185 | target_audio = target_audio * scale
186 |
187 | # initial noise
188 | initial_noise = torch.randn_like(target_audio)
189 |
190 | target_audio = target_audio.unsqueeze(1) # (bs, 1, L)
191 | assert target_audio.dim() == 3 and input_mel_spec.dim() == 3
192 | assert target_audio.size(0) == input_mel_spec.size(0)
193 |
194 | # Forward pass
195 | preds = self(input_mel_spec, initial_noise)
196 |
197 | # Vocoder losses
198 | losses = {}
199 | for pred in preds:
200 | pred = pred.unsqueeze(1) # (bs, 1, L)
201 | losses_i = {}
202 | losses_i.update(self.mrstft(pred, target_audio))
203 | losses_i.update(self.discriminator.compute_G_loss(pred, target_audio))
204 | for k, v in losses_i.items():
205 | losses[k] = losses.get(k, 0.) + v / len(preds)
206 |
207 | loss = 0.
208 | for k in self.loss_lambdas.keys():
209 | losses[k] = losses[k] * self.loss_lambdas[k]
210 | loss += losses[k]
211 |
212 | # Discriminator loss
213 | out_real = self.discriminator.compute_D_loss(target_audio, mode='real')
214 | loss_d_real = out_real.pop('loss')
215 | losses.update({f"D/{k}_real": v for k, v in out_real.items()})
216 | # NOTE: Discriminator loss is also computed for all intermediate predictions (Sec.4.2)
217 | loss_d_fake = 0.
218 | out_fake = {}
219 | for pred in preds:
220 | pred = pred.unsqueeze(1)
221 | out_fake_ = self.discriminator.compute_D_loss(pred.detach(), mode='fake')
222 | loss_d_fake += out_fake_.pop('loss') / len(preds)
223 | for k, v in out_fake_.items():
224 | out_fake[f"{k}_fake"] = out_fake.get(f"{k}_fake", 0.) + v / len(preds)
225 |
226 | loss_d = loss_d_real + loss_d_fake
227 |
228 | losses.update({f"D/{k}": v for k, v in out_fake.items()})
229 | output = {'loss': loss}
230 | output.update({k: v.detach() for k, v in losses.items()})
231 | output['D/loss_d'] = loss_d
232 | output['D/loss_d_real'] = loss_d_real.detach()
233 | output['D/loss_d_fake'] = loss_d_fake.detach()
234 |
235 | return output
236 |
237 | def save_state_dict(self, dir_save: str):
238 | state_feature_cleaner = self.feature_cleaner.get_state_dict()
239 | state_vocoder = self.vocoder.state_dict()
240 | torch.save(state_feature_cleaner, Path(dir_save) / "feature_cleaner.pth")
241 | torch.save(state_vocoder, Path(dir_save) / "vocoder.pth")
242 |
243 | if exists(self.discriminator):
244 | state_discriminator = self.discriminator.state_dict()
245 | torch.save(state_discriminator, Path(dir_save) / "discriminator.pth")
246 |
247 | def load_state_dict(self, dir_load: str):
248 | state_feature_cleaner = torch.load(Path(dir_load) / "feature_cleaner.pth")
249 | state_vocoder = torch.load(Path(dir_load) / "vocoder.pth")
250 | self.feature_cleaner.load_state_dict(state=state_feature_cleaner)
251 | self.vocoder.load_state_dict(state_vocoder)
252 |
253 | if exists(self.discriminator):
254 | state_discriminator = torch.load(Path(dir_load) / "discriminator.pth")
255 | self.discriminator.load_state_dict(state_discriminator)
256 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2025 Yukara Ikemiya
3 | """
4 |
5 | import os
6 |
7 | import torch
8 | import wandb
9 | import hydra
10 | from einops import rearrange
11 |
12 | from utils.logging import MetricsLogger
13 | from utils.torch_common import exists, sort_dict, print_once
14 |
15 | from model import AudioEncoderAdapter, Miipher2, MiipherMode
16 |
17 |
18 | class Trainer:
19 | def __init__(
20 | self,
21 | model, # model
22 | ema, # exponential moving average
23 | optimizer, # optimizer
24 | scheduler, # scheduler
25 | train_dataloader,
26 | accel, # Accelerator object
27 | cfg, # Configurations
28 | # Discriminator options (for WaveFit training)
29 | optimizer_d=None,
30 | scheduler_d=None,
31 | # Resume training from a checkpoint directory
32 | ckpt_dir=None
33 | ):
34 | self.model = accel.unwrap_model(model)
35 | self.ema = ema
36 | self.opt = optimizer
37 | self.sche = scheduler
38 | self.train_dataloader = train_dataloader
39 | self.accel = accel
40 | self.cfg = cfg
41 | self.cfg_t = cfg.trainer
42 | self.EPS = 1e-8
43 |
44 | # discriminator
45 | self.have_disc = exists(optimizer_d) and exists(scheduler_d)
46 | self.opt_d = optimizer_d
47 | self.sche_d = scheduler_d
48 |
49 | self.logger = MetricsLogger() # Logger for WandB
50 | self.logger_print = MetricsLogger() # Logger for printing
51 | self.logger_test = MetricsLogger() # Logger for test
52 |
53 | self.states = {'global_step': 0, 'best_metrics': float('inf'), 'latest_metrics': float('inf')}
54 |
55 | # time measurement
56 | self.s_event = torch.cuda.Event(enable_timing=True)
57 | self.e_event = torch.cuda.Event(enable_timing=True)
58 |
59 | # resume training
60 | if ckpt_dir is not None:
61 | self.__load_ckpt(ckpt_dir)
62 |
63 | def start_training(self):
64 | """
65 | Start training with infinite loops
66 | """
67 | self.model.train()
68 | self.s_event.record()
69 |
70 | print_once("\n[ Started training ]\n")
71 |
72 | while True:
73 | for batch in self.train_dataloader:
74 | # Update
75 | metrics = self.run_step(batch)
76 |
77 | if self.accel.is_main_process:
78 | self.logger.add(metrics)
79 | self.logger_print.add(metrics)
80 |
81 | # Log
82 | if self.__its_time(self.cfg_t.logging.n_step_log):
83 | self.__log_metrics()
84 |
85 | # Print
86 | if self.__its_time(self.cfg_t.logging.n_step_print):
87 | self.__print_metrics()
88 |
89 | # Save checkpoint
90 | if self.__its_time(self.cfg_t.logging.n_step_ckpt):
91 | self.__save_ckpt()
92 |
93 | # Sample
94 | if not isinstance(self.model, AudioEncoderAdapter):
95 | if self.__its_time(self.cfg_t.logging.n_step_sample):
96 | self.__sampling()
97 |
98 | self.states['global_step'] += 1
99 |
100 | def run_step(self, batch, train: bool = True):
101 | """ One training step """
102 |
103 | # target and degraded audios: (bs, ch, sample_length)
104 | x_tgt, x_deg, clean_audio, noisy_audio, _ = batch
105 |
106 | # Update
107 |
108 | if train:
109 | self.opt.zero_grad()
110 | if self.have_disc:
111 | self.opt_d.zero_grad()
112 |
113 | if isinstance(self.model, AudioEncoderAdapter):
114 | output = self.model.train_step(x_tgt, x_deg, train=train)
115 | elif isinstance(self.model, Miipher2):
116 | input = x_tgt if self.model.mode == MiipherMode.CLEAN_INPUT else x_deg
117 | output = self.model.train_step(clean_audio, input, train=train)
118 | else:
119 | raise NotImplementedError(f"Model class '{self.model.__class__.__name__}' is not supported.")
120 |
121 | if train:
122 | self.accel.backward(output['loss'])
123 | if self.accel.sync_gradients:
124 | self.accel.clip_grad_norm_(self.model.parameters(), self.cfg_t.max_grad_norm)
125 | self.opt.step()
126 | self.sche.step()
127 |
128 | if self.have_disc:
129 | self.accel.backward(output['D/loss_d'])
130 | if self.accel.sync_gradients:
131 | self.accel.clip_grad_norm_(self.model.discriminator.parameters(), self.cfg_t.max_grad_norm)
132 | self.opt_d.step()
133 | self.sche_d.step()
134 |
135 | # EMA
136 | if exists(self.ema):
137 | self.ema.update()
138 |
139 | return {k: v.detach() for k, v in output.items()}
140 |
141 | @torch.no_grad()
142 | def __sampling(self):
143 | # Restoration / Reconstruction samples from Miipher-2
144 | self.model.eval()
145 |
146 | n_sample: int = self.cfg_t.logging.n_samples_per_step
147 |
148 | # randomly select samples
149 | dataset = self.train_dataloader.dataset
150 | idxs = torch.randint(len(dataset), size=(n_sample,))
151 | x_tgt, x_deg, clean_audio, noisy_audio, deg_types = [], [], [], [], []
152 | for idx in idxs:
153 | x_tgt_, x_deg_, clean_audio_, noisy_audio_, info_ = dataset[idx]
154 | x_tgt.append(x_tgt_)
155 | x_deg.append(x_deg_)
156 | clean_audio.append(clean_audio_)
157 | noisy_audio.append(noisy_audio_)
158 | deg_types.append(info_.get('deg_types', ['none']))
159 |
160 | x_tgt = torch.stack(x_tgt, dim=0).to(self.accel.device)
161 | x_deg = torch.stack(x_deg, dim=0).to(self.accel.device) if self.model.mode != MiipherMode.CLEAN_INPUT else None
162 | clean_audio = torch.stack(clean_audio, dim=0).to(self.accel.device)
163 | noisy_audio = torch.stack(noisy_audio, dim=0).to(self.accel.device) if self.model.mode != MiipherMode.CLEAN_INPUT else None
164 |
165 | columns = ['clean (audio)', 'decoded (audio)'] if self.model.mode == MiipherMode.CLEAN_INPUT \
166 | else ['clean (audio)', 'degraded (audio)', 'Degrations', 'restored (audio)']
167 | table_audio = wandb.Table(columns=columns)
168 |
169 | # sampling
170 | x_input = x_tgt if self.model.mode == MiipherMode.CLEAN_INPUT else x_deg
171 | initial_noise = torch.randn_like(clean_audio)
172 | with torch.no_grad():
173 | x_preds = self.model(x_input, initial_noise)
174 | x_pred = x_preds[-1] # (n_sample, L)
175 |
176 | for idx in range(n_sample):
177 | # clean audio
178 | data = [wandb.Audio(clean_audio[idx].cpu().numpy(), sample_rate=dataset.sr)]
179 |
180 | # degraded audio
181 | if self.model.mode == MiipherMode.NOISY_INPUT:
182 | data += [wandb.Audio(noisy_audio[idx].cpu().numpy(), sample_rate=dataset.sr)]
183 | data += ['/'.join([d for d in deg_types[idx] if d != 'none'])]
184 |
185 | # decoded audio
186 | data += [wandb.Audio(x_pred[idx].cpu().numpy().T, sample_rate=dataset.sr)]
187 |
188 | table_audio.add_data(*data)
189 |
190 | self.accel.log({'Samples': table_audio}, step=self.states['global_step'])
191 |
192 | self.model.train()
193 |
194 | print("\t->->-> Sampled.")
195 |
196 | def __save_ckpt(self):
197 | import shutil
198 | import json
199 | from omegaconf import OmegaConf
200 |
201 | out_dir = self.cfg_t.output_dir + '/ckpt'
202 |
203 | # save latest ckpt
204 | latest_dir = out_dir + '/latest'
205 | os.makedirs(latest_dir, exist_ok=True)
206 |
207 | # save optimizer/scheduler states
208 | torch.save(self.opt.state_dict(), f"{latest_dir}/optimizer.pth")
209 | torch.save(self.sche.state_dict(), f"{latest_dir}/scheduler.pth")
210 |
211 | if self.have_disc:
212 | torch.save(self.opt_d.state_dict(), f"{latest_dir}/optimizer_d.pth")
213 | torch.save(self.sche_d.state_dict(), f"{latest_dir}/scheduler_d.pth")
214 |
215 | # save model states
216 | self.model.save_state_dict(latest_dir)
217 |
218 | # save states and configuration
219 | OmegaConf.save(self.cfg, f"{latest_dir}/config.yaml")
220 | with open(f"{latest_dir}/states.json", mode="wt", encoding="utf-8") as f:
221 | json.dump(self.states, f, indent=2)
222 |
223 | # save best ckpt
224 | if self.states['latest_metrics'] == self.states['best_metrics']:
225 | shutil.copytree(latest_dir, out_dir + '/best', dirs_exist_ok=True)
226 |
227 | print("\t->->-> Saved checkpoints.")
228 |
229 | def __load_ckpt(self, dir: str):
230 | import json
231 |
232 | print_once(f"\n[Resuming training from the checkpoint directory] -> {dir}")
233 |
234 | self.opt.load_state_dict(torch.load(f"{dir}/optimizer.pth", weights_only=False))
235 | self.sche.load_state_dict(torch.load(f"{dir}/scheduler.pth", weights_only=False))
236 |
237 | if self.have_disc:
238 | self.opt_d.load_state_dict(torch.load(f"{dir}/optimizer_d.pth", weights_only=False))
239 | self.sche_d.load_state_dict(torch.load(f"{dir}/scheduler_d.pth", weights_only=False))
240 |
241 | self.model.load_state_dict(dir)
242 |
243 | with open(f"{dir}/states.json", mode="rt", encoding="utf-8") as f:
244 | self.states.update(json.load(f))
245 |
246 | def __log_metrics(self, sort_by_key: bool = False):
247 | metrics = self.logger.pop()
248 | # learning rate
249 | metrics['lr'] = self.sche.get_last_lr()[0]
250 | if sort_by_key:
251 | metrics = sort_dict(metrics)
252 |
253 | self.accel.log(metrics, step=self.states['global_step'])
254 |
255 | # update states
256 | m_for_ckpt = self.cfg_t.logging.metrics_for_best_ckpt
257 | m_latest = float(sum([metrics[k].detach() for k in m_for_ckpt]))
258 | self.states['latest_metrics'] = m_latest
259 | if m_latest < self.states['best_metrics']:
260 | self.states['best_metrics'] = m_latest
261 |
262 | def __print_metrics(self, sort_by_key: bool = False):
263 | self.e_event.record()
264 | torch.cuda.synchronize()
265 | p_time = self.s_event.elapsed_time(self.e_event) / 1000. # [sec]
266 |
267 | metrics = self.logger_print.pop()
268 | # tensor to scalar
269 | metrics = {k: v.item() for k, v in metrics.items()}
270 | if sort_by_key:
271 | metrics = sort_dict(metrics)
272 |
273 | step = self.states['global_step']
274 | s = f"Step {step} ({p_time:.1e} [sec]): " + ' / '.join([f"[{k}] - {v:.3e}" for k, v in metrics.items()])
275 | print(s)
276 |
277 | self.s_event.record()
278 |
279 | def __its_time(self, itv: int):
280 | return (self.states['global_step'] - 1) % itv == 0
281 |
--------------------------------------------------------------------------------
/_LICENSES/transformers.LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018- The Hugging Face team. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/_LICENSES/unify-parameter-efficient-tuning.LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018- The Hugging Face team. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🎵 Open-Miipher-2 | A Universal Speech Restoration Model
2 |
3 | This is an unofficial implementation of **`Miipher-2`**[1]
4 | which is **a state-of-the-art universal speech restoration model** from Google Research.
5 |
6 | 
7 |
8 | This repository supports:
9 | - 🔥 Full implementation and training code for the `Miipher-2` model
10 | - 🔥 Distributed training with multiple GPUs / multiple Nodes
11 |
12 | ## What Google Can Do vs. What You Can Do
13 |
14 | Google's Miipher-2 leverages proprietary datasets and large-scale infrastructure that are not publicly accessible. While you cannot use Google's internal data or resources, this repository enables you to train and experiment with open datasets and models. The implementation is designed to be flexible and reproducible for academic and general-purpose use.
15 |
16 | ### Google USM
17 |
18 | Google uses the Universal Speech Model (USM) [4] for both feature extraction and as the base (pretrained) model for the feature cleaner in Miipher-2. Specifically, Miipher-2 is trained using the first 13 layers of a 32-layer Conformer with 2 billion parameters. However, this pretrained model is not publicly available.
19 |
20 | On the other hand, Google has open-sourced [Gemma 3](https://huggingface.co/docs/transformers/main/model_doc/gemma3), a multimodal LLM that includes a 0.6 billion parameter (12-layer) USM module. In this repository, the default configuration uses up to the 6th layer of this model as the base for the feature cleaner. Naturally, differences in base model size may lead to variations in restoration performance between the Google version and this repository.
21 | And please note that the USM (0.6B) in Gemma 3 differs in architectural configuration from the 0.6B model described in the USM paper [4].
22 |
23 | The key differences between the Google version and this repository are summarized below:
24 |
25 | | | Base model | Model size | Conformer layers | Dimension | Frame rate | Parallel Adapter size |
26 | |:-----------|:----------:|:----------:|:----------------:|:---------:|:----------:|:---:|
27 | | **Google** | USM | 2B | 13th / 32 layers | 1536 | 25 | 41M |
28 | | **Open-Miipher-2** | USM | 0.6B | 6th / 12 layers | 1536 | 25 | 19M |
29 |
30 | Currently, Gemma 3's USM is loaded from the following model card on HuggingFace.
31 | - [Atotti/Google-USM](https://huggingface.co/Atotti/Google-USM)
32 |
33 | For more details on the selection of Conformer layers and related considerations, please refer to the [Tips](#-tips) section.
34 |
35 | ### Audio dataset
36 | According to the paper, Google used `3,195 hours of speech from 1,642 speakers across 44 languages` as speech data and `internally collected audio snippets from environments such as cafes, kitchens, and automobiles` as noise data for training Miipher-2. These datasets are internal to Google and are not publicly available.
37 |
38 | For general-purpose use, it is preferable to utilize larger and more diverse speech/noise datasets. However, for experiments or academic purposes, you can use open datasets such as those listed below.
39 |
40 | | Type | Dataset name | Link | Hours |
41 | |-------|-------------------------|------|---|
42 | | Speech | LibriTTS-R [5] | [https://www.openslr.org/141/](https://www.openslr.org/141/) | 585 |
43 | | Noise | TAU Urban Acoustic Scenes 2020 Mobile, Development dataset | [https://zenodo.org/records/3670167](https://zenodo.org/records/3670167) | 64 |
44 | | Noise | TAU Urban Audio-Visual Scenes 2021, Development dataset | [https://zenodo.org/records/4477542](https://zenodo.org/records/4477542) | 34 |
45 |
46 |
47 | # Requirements
48 |
49 | - Python 3.8.10 or later
50 | - PyTorch 2.1 or later
51 | - transformers>=4.53 (NOTE: `transformers` must contain Gemma implementations from Google.)
52 |
53 | ## Building a training environment
54 |
55 | To simplify setting up the training environment, I recommend to use container systems like `Docker` or `Singularity` instead of installing dependencies on each GPU machine. Below are the steps for creating `Singularity` containers.
56 |
57 | All example scripts are stored at the [container](container/) folder.
58 |
59 | ### 1. Install Singularity
60 |
61 | Install the latest Singularity by following the official instruction.
62 | - https://docs.sylabs.io/guides/main/user-guide/quick_start.html#quick-installation-steps
63 |
64 | ### 2. Create a Singularity image file
65 |
66 | Create (build) a Singularity image file with a definition file.
67 | ```bash
68 | singularity build --fakeroot Open-Miipher-2.sif Open-Miipher-2.def
69 | ```
70 |
71 | ** NOTE: You might need to change NVIDIA base image in the definition file to match your GPU machine.
72 |
73 | Now, you obtained a container file for training and inference of Open-Miipher-2.
74 |
75 | ## Setting a WandB account for logging
76 |
77 | The training code also requires a Weights & Biases account to log the training outputs and demos.
78 | Please create an account and follow the instruction.
79 |
80 | Once you create your WandB account,
81 | you can obtain the API key from https://wandb.ai/authorize after logging in to your account.
82 | And then, the API key can be passed as an environment variable `WANDB_API_KEY` to a training job
83 | for logging training information.
84 |
85 | ```bash
86 | $ WANDB_API_KEY="12345x6789y..."
87 | ```
88 |
89 | ## Authentication of HuggingFace
90 |
91 | This repository uses the `google/gemma-3n-e2b-it` model from HuggingFace.
92 | To load this model, you must authenticate with HuggingFace in advance.
93 |
94 | - https://huggingface.co/google/gemma-3n-E2B-it
95 |
96 | ### Login setting
97 |
98 | 1. Create a HuggingFace token for accessing the weights from https://huggingface.co/settings/tokens.
99 | 2. Set the token to `HF_TOKEN` variable when running your script.
100 |
101 |
102 | # Data preparation
103 |
104 | By default, Miipher-2 performs decoding on 24kHz audio signals. Therefore, if audio with a sample rate other than 24kHz is loaded, it will be automatically resampled to 24kHz within the Dataset class. To avoid this extra computation, it is recommended to pre-process all audio files in each dataset to 24kHz before training.
105 |
106 | For fast initialization of the dataset class, this repository uses a metadata file (`metadata.csv`) placed in the root directory of each audio dataset. You can generate this file by running the following script for each dataset directory:
107 |
108 | ```bash
109 | AUDIO_DIR=/path/to/audio-root-directory/
110 | python dataset/script/make_metadata_csv.py --root-dir ${AUDIO_DIR}
111 | ```
112 |
113 | ## Pre-computation of degraded speech signals [Not supported]
114 |
115 | For clarity and simplicity, this repository applies random degradations to audio samples within a Dataset class during loading (online processing). However, some degradation methods (e.g. convolution filtering) can be computationally intensive on the CPU, potentially becoming a bottleneck and preventing full utilization of GPU resources during training. In such cases, consider pre-computing and saving multiple degraded versions of each clean speech sample before training, so that the Dataset can load these pre-processed files directly.
116 |
117 | # Training
118 |
119 | Miipher-2 training consists of the following three stages:
120 |
121 | 1. Training of the feature cleaner module
122 | 2. Pretraining of the WaveFit module
123 | 3. Finetuning of the WaveFit module
124 |
125 | Stage 1 trains a module that converts noisy SSL features into clean ones, serving as the main part of audio restoration. Stage 2 pre-trains the WaveFit speech vocoder using only clean speech. Stage 3 fine-tunes the WaveFit module using features restored by the feature cleaner from Stage 1.
126 |
127 | For Stage 2, short audio signals of 0.6 seconds are used to accelerate training (see Sec.2.4 in [2]). In contrast, Stages 1 and 3 use longer audio signals (typically 10–30 seconds) to capture more global information. The original paper suggests that 30 seconds, which is the maximum input length of the USM Conformer block is used for the training, but this requires significant memory and may be impractical for most users. Therefore, the sample code in this repository defaults to a 10-second input length.
128 |
129 | In summary, the input signal lengths for each training stage are as follows:
130 |
131 | | Stage | Module | Input length | Input audio | Purpose |
132 | |-------|-----------------------|--------------|------|----------------------------------------------|
133 | | 1 | Feature Cleaner | 10 sec | Noisy speech | Restore noisy SSL features |
134 | | 2 | WaveFit Pretraining | 0.6 sec | Clean speech | Pretrain vocoder with clean features |
135 | | 3 | WaveFit Finetuning | 10 sec | Noisy speech | Finetune vocoder with restored features |
136 |
137 |
138 | ## Stage 1: Feature Cleaner Training
139 |
140 | Sample script for Feature Cleaner Training
141 |
142 | ```bash
143 | ROOT_DIR="/path/to/this/repository/"
144 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif"
145 | DATASET_ROOT="/path/to/dataset/root/"
146 | JOB_ID="your_job_id"
147 |
148 | # Dataset
149 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/
150 | DIRS_NOISE=${DATASET_ROOT}/TAU-urban-audio-visual-scenes-2021-development_24k-mono/,${DATASET_ROOT}/TAU-urban-acoustic-scenes-2020-mobile-development_24k-mono/
151 |
152 | # Configuration
153 | PROJECT_NAME="cleaner_google-usm"
154 | MODEL="feature_cleaner/google-usm"
155 | DATA="deg_gemma_24khz_10sec"
156 | OPTIMIZER="feature_cleaner"
157 | BS_PER_GPU=20
158 | NUM_WORKERS=4
159 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}"
160 |
161 | WANDB_API_KEY="your_wandb_api_key"
162 | HF_TOKEN="your_huggingface_token"
163 | PORT=50000
164 |
165 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID}
166 | mkdir -p ${OUTPUT_DIR}
167 |
168 | # Calculate total batch size based on number of GPUs
169 | NUM_GPUS=2 # Adjust based on your setup
170 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU}))
171 |
172 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \
173 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \
174 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \
175 | ${CONTAINER_PATH} \
176 | torchrun --nproc_per_node gpu --master_port ${PORT} \
177 | ${ROOT_DIR}/src/train.py \
178 | data.train.dirs_audio=[${DIRS_AUDIO}] \
179 | data.train.dirs_noise=[${DIRS_NOISE}] \
180 | trainer.output_dir=${OUTPUT_DIR} \
181 | trainer.batch_size=${BATCH_SIZE} \
182 | trainer.num_workers=${NUM_WORKERS} \
183 | trainer.logger.project_name=${PROJECT_NAME} \
184 | trainer.logger.run_name=job-${JOB_ID} \
185 | ${EXTRA_ARGS}
186 | ```
187 |
188 |
189 |
190 | ## Stage 2: WaveFit Pretraining
191 |
192 | Sample script for WaveFit Pretraining
193 |
194 | ```bash
195 | ROOT_DIR="/path/to/this/repository/"
196 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif"
197 | DATASET_ROOT="/path/to/dataset/root/"
198 | JOB_ID="your_job_id"
199 |
200 | # Dataset
201 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/
202 |
203 | # Configuration
204 | PROJECT_NAME="wavefit_pretrain"
205 | MODEL="miipher-2_google-usm_wavefit-5_clean-input"
206 | DATA="deg_gemma_24khz_06sec_clean-only"
207 | OPTIMIZER="wavefit"
208 | BS_PER_GPU=30
209 | NUM_WORKERS=4
210 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}"
211 |
212 | WANDB_API_KEY="your_wandb_api_key"
213 | HF_TOKEN="your_huggingface_token"
214 | PORT=50000
215 |
216 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID}
217 | mkdir -p ${OUTPUT_DIR}
218 |
219 | # Calculate total batch size based on number of GPUs
220 | NUM_GPUS=2 # Adjust based on your setup
221 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU}))
222 |
223 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \
224 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \
225 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \
226 | ${CONTAINER_PATH} \
227 | torchrun --nproc_per_node gpu --master_port ${PORT} \
228 | ${ROOT_DIR}/src/train.py \
229 | data.train.dirs_audio=[${DIRS_AUDIO}] \
230 | trainer.output_dir=${OUTPUT_DIR} \
231 | trainer.batch_size=${BATCH_SIZE} \
232 | trainer.num_workers=${NUM_WORKERS} \
233 | trainer.logger.project_name=${PROJECT_NAME} \
234 | trainer.logger.run_name=job-${JOB_ID} \
235 | ${EXTRA_ARGS}
236 | ```
237 |
238 |
239 |
240 | ## Stage 3: WaveFit Finetuning
241 |
242 | Please note that for WaveFit finetuning, you must specify the checkpoint directories for the modules trained in Stage 1 and Stage 2.
243 |
244 | Sample script for WaveFit Finetuning
245 |
246 | ```bash
247 | ROOT_DIR="/path/to/this/repository/"
248 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif"
249 | DATASET_ROOT="/path/to/dataset/root/"
250 | JOB_ID="your_job_id"
251 |
252 | # Dataset
253 | DIRS_AUDIO=${DATASET_ROOT}/LibriTTS_R/train-clean-100/
254 | DIRS_NOISE=${DATASET_ROOT}/TAU-urban-audio-visual-scenes-2021-development_24k-mono/,${DATASET_ROOT}/TAU-urban-acoustic-scenes-2020-mobile-development_24k-mono/
255 |
256 | # Configuration
257 | PROJECT_NAME="wavefit_finetune"
258 | MODEL="miipher-2_google-usm_wavefit-5_noisy-input"
259 | DATA="deg_gemma_24khz_10sec"
260 | OPTIMIZER="wavefit"
261 | BS_PER_GPU=5
262 | NUM_WORKERS=4
263 |
264 | # Pre-trained model checkpoints
265 | FEATURE_CLEANER_CKPT_DIR=/path/to/feature_cleaner_ckpt_dir/
266 | VOCODER_CKPT_DIR=/path/to/vocoder_ckpt_dir/
267 |
268 | EXTRA_ARGS="model=${MODEL} data=${DATA} optimizer=${OPTIMIZER}"
269 | EXTRA_ARGS="${EXTRA_ARGS} model.feature_cleaner_ckpt_dir=${FEATURE_CLEANER_CKPT_DIR} model.vocoder_ckpt_dir=${VOCODER_CKPT_DIR}"
270 |
271 | WANDB_API_KEY="your_wandb_api_key"
272 | HF_TOKEN="your_huggingface_token"
273 | PORT=50000
274 |
275 | OUTPUT_DIR=${ROOT_DIR}/runs/train/${MODEL}/${JOB_ID}
276 | mkdir -p ${OUTPUT_DIR}
277 |
278 | # Calculate total batch size based on number of GPUs
279 | NUM_GPUS=2 # Adjust based on your setup
280 | BATCH_SIZE=$((${NUM_GPUS}*${BS_PER_GPU}))
281 |
282 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \
283 | --env HYDRA_FULL_ERROR=1 --env MASTER_PORT=${PORT} \
284 | --env WANDB_API_KEY=$WANDB_API_KEY --env HF_TOKEN=$HF_TOKEN \
285 | ${CONTAINER_PATH} \
286 | torchrun --nproc_per_node gpu --master_port ${PORT} \
287 | ${ROOT_DIR}/src/train.py \
288 | data.train.dirs_audio=[${DIRS_AUDIO}] \
289 | data.train.dirs_noise=[${DIRS_NOISE}] \
290 | trainer.output_dir=${OUTPUT_DIR} \
291 | trainer.batch_size=${BATCH_SIZE} \
292 | trainer.num_workers=${NUM_WORKERS} \
293 | trainer.logger.project_name=${PROJECT_NAME} \
294 | trainer.logger.run_name=job-${JOB_ID} \
295 | trainer.debug=0 \
296 | ${EXTRA_ARGS}
297 | ```
298 |
299 |
300 |
301 | ## Resume training from a checkpoint
302 |
303 | While training, checkpoints (state_dict) of models, optimizers and schedulers are saved under the output directory specified in the configuration as follows.
304 | ```
305 | output_dir/
306 | ├─ ckpt/
307 | │ ├─ latest/
308 | │ │ ├─ model.pth
309 | │ │ ├─ optimizer.pth
310 | │ │ ├─ scheduler.pth
311 | │ │ ├─ ...
312 | ```
313 |
314 | By specifying the checkpoint directory, you can easily resume your training from the checkpoint.
315 |
316 | Sample script for training resumption from saved checkpoints
317 |
318 | ```bash
319 | CKPT_DIR="output_dir/ckpt/latest/"
320 | OUTPUT_DIR="another/directory/"
321 |
322 | # Execution
323 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_ROOT \
324 | --env MASTER_PORT=${PORT} --env WANDB_API_KEY=$WANDB_API_KEY \
325 | ${CONTAINER_PATH} \
326 | torchrun --nproc_per_node gpu ${ROOT_DIR}/src/train.py \
327 | trainer.ckpt_dir=${CKPT_DIR} \
328 | trainer.output_dir=${OUTPUT_DIR}
329 | ```
330 |
331 |
332 |
333 | # Inference
334 |
335 | Using pre-trained Open-Miipher-2 models, you can perform inference with audio signals as input
336 | (e.g. for speech restoration evaluation).
337 |
338 | The [`inference.py`](src/inference.py) performs inference for all of audio files in a target directory.
339 | To check other options for the script, please use `-h` option.
340 |
341 | Sample script for Inference (speech restoration)
342 |
343 | ```bash
344 | ROOT_DIR="/path/to/this/repository/"
345 | CONTAINER_PATH="/path/to/Open-Miipher-2.sif"
346 | JOB_ID="your_job_id"
347 |
348 | CKPT_DIR="/path/to/feature_cleaner_ckpt_dir/"
349 | AUDIO_DIR="/path/to/target/speech/directory/"
350 | OUTPUT_DIR="${ROOT_DIR}/runs/inference/${JOB_ID}"
351 |
352 | HF_TOKEN="your_huggingface_token"
353 | PORT=50000
354 |
355 | mkdir -p ${OUTPUT_DIR}
356 |
357 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $AUDIO_DIR \
358 | --env MASTER_PORT=${PORT} --env HF_TOKEN=$HF_TOKEN \
359 | ${CONTAINER_PATH} \
360 | torchrun --nproc_per_node=1 --master_port=${PORT} \
361 | ${ROOT_DIR}/src/inference.py \
362 | --ckpt-dir ${CKPT_DIR} \
363 | --input-audio-dir ${AUDIO_DIR} \
364 | --output-dir ${OUTPUT_DIR} \
365 | --sample-size 160000 \
366 | --sr-in 16000 \
367 | --sr-out 24000
368 | ```
369 |
370 |
371 |
372 | # 💡 Tips
373 |
374 | ## Number of Conformer layers
375 |
376 | The number of Conformer layers used from USM is a crucial parameter that affects the overall performance of Miipher-2. Since Conformer layers are repeatedly applied to Mel-spectrogram inputs, using fewer layers may make it easier for WaveFit to restore the speech signal. However, it is also necessary to ensure a sufficient number of layers for effective feature cleaning. In the original Miipher-2, 13 out of 32 layers are utilized.
377 |
378 | To determine the optimal number of layers for this repository, small-scale WaveFit training experiments were conducted using SSL features obtained from different numbers of layers. The figure below shows the progression of `STFT spectral convergence loss`, `STFT magnitude loss`, and `GAN loss` when using up to 1, 2, 5, 7, 9, and 11 layers.
379 |
380 | 
381 |
382 | As expected, using fewer Conformer layers generally results in lower STFT loss, indicating that decoding to speech is easier. Notably, when using 9 or more layers, performance drops sharply, and using all 12 layers fails entirely. Based on these results, this repository defaults to using `6 layers`.
383 |
384 | This setting can be easily changed via the configuration.
385 |
386 | ## Degradation types
387 |
388 | The methods for degrading speech used for model training differ between those described in the paper and those implemented in this repository, as summarized below.
389 |
390 | | | Background noise | Room reverb | Codec (MP3, Vorbis, A-law, AMR-WB, OPUS) | Soft/Hard clipping | Lowpass |
391 | |:--------------:|:---------------:|:-----------:|:----------------------------------------:|:------------------:|:-------:|
392 | | **Google** | ✓ | ✓ | ✓ | | |
393 | | **Open-Miipher-2** | ✓ | ✓ | | ✓ | ✓ |
394 |
395 | Codec processing is considered too computationally intensive for online processing, so it is excluded from this repository.
396 |
397 |
400 |
401 | # 🤔 Unclear points in the implementation
402 |
403 | ## 1. Parameter size
404 |
405 | According to the paper, the Parallel Adapter (PA) is described as having `20 million` learnable parameters (Sec.2.2). However, the PA used in Miipher-2 takes a 1536-dimensional input and has a 1024-dimensional bottleneck structure, which is applied to 13 layers of Conformers. The approximate parameter size of the PA can be calculated from two linear layers, resulting in a total parameter count of about `1536 x 1024 x 2 x 13 = 40.9M`. This is likely a typo in the paper.
406 |
407 | ## 2. Upsampling method of USM feature
408 |
409 | When inputting SSL features into WaveFit in Miipher-2, upsampling is performed along the time axis to fit the frame rate to the appropriate input length (Sec.2.3). The specific upsampling method (e.g., 'nearest', 'linear', or 'bilinear') is not described in the paper, but it is likely that the choice does not significantly affect performance. Therefore, this repository uses `linear` interpolation as the default.
410 |
411 | ## 3. Loss function of feature cleaner
412 |
413 | The loss function for training feature cleaner is defined in the Miipher paper [2] as below.
414 |
415 | ```math
416 | \mathcal{L} = \| S - \hat{S} \|_{1}
417 | + \| S - \hat{S} \|_{2}^{2}
418 | + \frac{\| S - \hat{S} \|_{2}^{2}}{\| S \|_{2}^{2}},
419 | \quad \text{where }
420 | \| S \|_{p} = \left( \sum_{k} \sum_{d} | S_{k,d} |^{p} \right)^{1/p}.
421 | ```
422 |
423 | In the paper, the first term is referred to as "mean-absolute-error" and the second term as "mean-squared-error," but the formulas do not actually compute the mean. In practice, when calculating this loss, the first and second terms become disproportionately large compared to the third term (spectral convergence loss). Therefore, it is reasonable to compute the loss as follows:
424 |
425 | ```math
426 | \mathcal{L} = \frac{1}{KD}\| S - \hat{S} \|_{1}
427 | + \frac{1}{KD}\| S - \hat{S} \|_{2}^{2}
428 | + \frac{\| S - \hat{S} \|_{2}^{2}}{\| S \|_{2}^{2}}
429 | ```
430 |
431 | # References
432 |
433 | 1. "Miipher-2: A Universal Speech Restoration Model for Million-Hour Scale Data Restoration", S. Karita et al., WASPAA 2025
434 | 1. "Miipher: A Robust Speech Restoration Model Integrating Self-Supervised Speech and Text Representations", Y. Koizumi, WASPAA 2021
435 | 1. "WaveFit: An Iterative and Non-autoregressive Neural Vocoder based on Fixed-Point Iteration", Y. Koizumi et al., IEEE SLT, 2022
436 | 1. "Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages", Y. Zhang et al., Arxiv, 2023
437 | 1. "LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus", Yuma Koizumi et al., INTERSPEECH 2023.
--------------------------------------------------------------------------------