├── docs ├── README.ja.md ├── note.md ├── image.png ├── 評価指標.md ├── LIGHTNING_VOCODER_MIGRATION.md └── checkpoint_guide.md ├── configs ├── evaluate.yaml ├── wavlm_large_l2.yaml ├── infer_file.yaml ├── infer_mhubert_l6.yaml ├── infer_wavlm_base_l2.yaml ├── infer_wav2vec2_base_l2.yaml ├── infer_hubert_large_l2.yaml ├── adapter_l2.yaml ├── wavlm_base_l2.yaml ├── wav2vec2_base_l2.yaml ├── wav2vec2_large_l2.yaml ├── ums.yaml ├── preprocess_fleurs_r.yaml ├── adapter_layer_6_mhubert_147.yaml ├── preprocess_jvs.yaml └── preprocess_libritts_r.yaml ├── scripts ├── __init__.py ├── check_webdataset.py ├── decompress_fleurs_r.sh ├── check_hf_datasets_head.py ├── count_webdataset_samples.py ├── download_libri_tts_r.sh ├── check_hf_datasets_stats.py ├── fix_checkpoint.py ├── eval_figure.py ├── upload_to_hf.py ├── generate_latex_table.py └── aggregate_results.py ├── .python-version ├── cmds ├── pre_train_vocoder.py ├── __init__.py ├── train_adapter.py ├── inference_dir.py ├── inference.py ├── preprocess.py ├── degrade.py ├── check_adapter_arch.py ├── evaluate.py └── evaluate_simple.py ├── src └── miipher_2 │ ├── data │ ├── __init__.py │ ├── dataloader.py │ └── webdataset_loader.py │ ├── dataset │ ├── __init__.py │ ├── jvs_corpus.py │ ├── libritts_r.py │ └── fleurs_r.py │ ├── model │ ├── __init__.py │ └── feature_cleaner.py │ ├── train │ ├── __init__.py │ └── adapter.py │ ├── adapters │ ├── __init__.py │ └── parallel_adapter.py │ ├── extractors │ ├── __init__.py │ ├── hubert.py │ └── ssl_extractor.py │ ├── lightning_vocoders │ ├── __init__.py │ ├── lightning_module.py │ └── _hifigan.py │ ├── utils │ ├── __init__.py │ ├── audio_utils.py │ ├── audio.py │ ├── ema.py │ ├── infer.py │ ├── checkpoint.py │ └── eval_utils.py │ └── preprocess │ ├── __init__.py │ ├── preprocessor.py │ └── noise_augmentation.py ├── demo ├── .gitignore ├── requirements.txt ├── pyproject.toml ├── models │ └── miipher2 │ │ ├── config.json │ │ └── README.md ├── README_spaces.md ├── README.md └── app.py ├── samples ├── sample.wav ├── miipher1.wav ├── miipher2_p.wav ├── miipher2_hubert_base_l6.wav ├── miipher2_wavlm_base_l2.wav ├── miipher2_hubert_large_l2.wav └── miipher2_wav2vec2_base_l2.wav ├── .gitmodules ├── .gitignore ├── results ├── summary_original.csv ├── summary_8khz.csv ├── summary_degrade.csv └── results_table.tex ├── CLAUDE.md ├── pyproject.toml ├── auto_resume_training.sh └── README.md /docs/README.ja.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/evaluate.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /cmds/pre_train_vocoder.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/wavlm_large_l2.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cmds/__init__.py: -------------------------------------------------------------------------------- 1 | # noqa: A005 2 | -------------------------------------------------------------------------------- /src/miipher_2/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miipher_2/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miipher_2/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miipher_2/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miipher_2/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miipher_2/extractors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /docs/note.md: -------------------------------------------------------------------------------- 1 | ``` 2 | CUDA_VISIBLE_DEVICES=1 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/docs/image.png -------------------------------------------------------------------------------- /src/miipher_2/lightning_vocoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Lightning vocoders package.""" 2 | -------------------------------------------------------------------------------- /src/miipher_2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_utils import DataCollatorAudioPad 2 | -------------------------------------------------------------------------------- /samples/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/sample.wav -------------------------------------------------------------------------------- /samples/miipher1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher1.wav -------------------------------------------------------------------------------- /samples/miipher2_p.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher2_p.wav -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "miipher2-hf"] 2 | path = miipher2-hf 3 | url = https://github.com/Atotti/miipher2-hf.git 4 | -------------------------------------------------------------------------------- /samples/miipher2_hubert_base_l6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher2_hubert_base_l6.wav -------------------------------------------------------------------------------- /samples/miipher2_wavlm_base_l2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher2_wavlm_base_l2.wav -------------------------------------------------------------------------------- /samples/miipher2_hubert_large_l2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher2_hubert_large_l2.wav -------------------------------------------------------------------------------- /samples/miipher2_wav2vec2_base_l2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atotti/miipher-2/HEAD/samples/miipher2_wav2vec2_base_l2.wav -------------------------------------------------------------------------------- /src/miipher_2/extractors/hubert.py: -------------------------------------------------------------------------------- 1 | from .ssl_extractor import HubertExtractor, SSLExtractor 2 | 3 | __all__ = ["HubertExtractor", "SSLExtractor"] 4 | -------------------------------------------------------------------------------- /demo/requirements.txt: -------------------------------------------------------------------------------- 1 | # UI framework 2 | gradio>=4.0.0 3 | # Hugging Face Hub 4 | huggingface_hub>=0.16.0 5 | # miipher-2 implementation 6 | git+https://github.com/Atotti/miipher-2.git 7 | 8 | -------------------------------------------------------------------------------- /demo/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "miipher-demo" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /src/miipher_2/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from miipher_2.preprocess.noise_augmentation import DegradationApplier 2 | from miipher_2.preprocess.preprocessor import Preprocessor 3 | 4 | __all__ = [ 5 | "DegradationApplier", 6 | "Preprocessor", 7 | ] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | .ruff_cache/ 8 | *.egg-info 9 | 10 | # Virtual environments 11 | .venv 12 | 13 | 14 | wavefit-pytorch/ 15 | tmp/ 16 | download/ 17 | datasets/ 18 | 19 | wandb/ 20 | exp/ 21 | assets/ 22 | outputs/ 23 | models/ 24 | -------------------------------------------------------------------------------- /cmds/train_adapter.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from miipher_2.train.adapter import train_adapter 5 | 6 | 7 | @hydra.main(version_base=None, config_path="../configs", config_name="adapter") 8 | def main(cfg: DictConfig) -> None: # pylint:disable=no-value-for-parameter 9 | train_adapter(cfg) 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /results/summary_original.csv: -------------------------------------------------------------------------------- 1 | name,ecapa_cos_mean,ecapa_cos_std,count,dnsmos_p808_mean,dnsmos_p808_std,dnsmos_sig_mean,dnsmos_sig_std,dnsmos_bak_mean,dnsmos_bak_std,dnsmos_ovr_mean,dnsmos_ovr_std 2 | original,0.9999999868869781,7.436952758635848e-08,300,3.050243421100435,0.23513381104495876,3.2309267350593416,0.3713587652836165,3.2253586782567814,0.4590828290237858,2.596128871731708,0.37026601122320457 3 | -------------------------------------------------------------------------------- /cmds/inference_dir.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | # 新しく作成する一括処理用の関数をインポート 5 | from miipher_2.utils.infer import run_inference_dir 6 | 7 | 8 | @hydra.main(version_base=None, config_path="../configs", config_name="infer_dir") 9 | def main(cfg: DictConfig) -> None: 10 | """ 11 | Hydra経由で設定を読み込み、ディレクトリ単位の音声修復を実行する 12 | """ 13 | run_inference_dir(cfg) 14 | 15 | 16 | if __name__ == "__main__": 17 | main() 18 | -------------------------------------------------------------------------------- /scripts/check_webdataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import webdataset as wds 4 | 5 | 6 | def main(path: str) -> None: 7 | ds = wds.WebDataset( 8 | path, 9 | ) 10 | data = next(iter(ds)) 11 | print(data) 12 | 13 | 14 | if __name__ == "__main__": 15 | ap = argparse.ArgumentParser() 16 | ap.add_argument("--data", required=True, help="path used in wds.WebDataset()") 17 | args = ap.parse_args() 18 | main(args.data) 19 | -------------------------------------------------------------------------------- /cmds/inference.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | # `main`関数を`run_inference`にリネームして、より責務を明確にします 5 | from miipher_2.utils.infer import run_inference 6 | 7 | 8 | @hydra.main(version_base=None, config_path="../configs", config_name="infer") 9 | def main(cfg: DictConfig) -> None: 10 | """ 11 | Hydra経由で設定を読み込み、推論処理を実行するエントリーポイント 12 | """ 13 | run_inference(cfg) 14 | 15 | 16 | if __name__ == "__main__": 17 | main() 18 | -------------------------------------------------------------------------------- /cmds/preprocess.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from lightning.pytorch import seed_everything 3 | from omegaconf import DictConfig 4 | 5 | from miipher_2.preprocess import Preprocessor 6 | 7 | 8 | @hydra.main(version_base=None, config_path="../configs/", config_name="preprocess") # type: ignore 9 | def main(cfg: DictConfig) -> None: 10 | seed_everything(172957) 11 | preprocessor = Preprocessor(cfg=cfg) 12 | preprocessor.build_from_path() 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /configs/infer_file.yaml: -------------------------------------------------------------------------------- 1 | input_wav: "samples/sample.wav" 2 | output_wav: "samples/miipher2_p.wav" 3 | device: "cuda" 4 | 5 | 6 | adapter_ckpt: "/home/ayu/GitHub/open-miipher-2/exp/adapter_layer_6_mhubert_147/checkpoint_199k_fixed.pt" 7 | 8 | vocoder_ckpt: /home/ayu/GitHub/open-miipher-2/exp/ssl-vocoder/epoch=77-step=137108.ckpt 9 | 10 | model: 11 | hubert_model_name: "utter-project/mHuBERT-147" 12 | hubert_layer: 6 13 | adapter_hidden_dim: 768 14 | 15 | 16 | # --- 出力設定 --- 17 | output_sampling_rate: 22050 18 | -------------------------------------------------------------------------------- /scripts/decompress_fleurs_r.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | cd /home/ayu/datasets 6 | cd fleurs-r/data 7 | 8 | find . -type f -path "*/audio/*.tar.gz" -print0 | 9 | while IFS= read -r -d '' tgz; do 10 | aud_dir=$(dirname "$tgz") # .../audio 11 | split=$(basename "$tgz" .tar.gz) # train / dev / test 12 | out_dir="$aud_dir/$split" # .../audio/train など 13 | mkdir -p "$out_dir" 14 | echo "⇢ extracting $tgz → $out_dir" 15 | tar -xzf "$tgz" -C "$out_dir" # 展開 16 | done 17 | -------------------------------------------------------------------------------- /scripts/check_hf_datasets_head.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import datasets 4 | 5 | 6 | def main(path: str, n_head: int) -> None: 7 | ds = datasets.load_from_disk(path) 8 | 9 | first_examples = ds.take(n_head) 10 | for i, ex in enumerate(first_examples, 1): 11 | print(f"[{i}] {ex}\n") 12 | 13 | 14 | if __name__ == "__main__": 15 | ap = argparse.ArgumentParser() 16 | ap.add_argument("--data", required=True, help="path used in datasets.load_from_disk()") 17 | ap.add_argument("--head", type=int, default=10, help="show this many examples and exit (default: 10)") 18 | args = ap.parse_args() 19 | main(args.data, args.head) 20 | -------------------------------------------------------------------------------- /src/miipher_2/adapters/parallel_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ParallelAdapter(nn.Module): 6 | def __init__(self, dim: int, hidden: int = 1024) -> None: 7 | super().__init__() 8 | self.ff = nn.Sequential( 9 | nn.LayerNorm(dim), 10 | nn.Linear(dim, hidden), 11 | nn.GELU(), 12 | nn.Linear(hidden, dim), 13 | ) 14 | # Xavier 初期化を 0.01 スケールで 15 | for m in self.ff: 16 | if isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight, gain=0.01) 18 | if m.bias is not None: 19 | nn.init.zeros_(m.bias) 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, C) 22 | return self.ff(x) 23 | -------------------------------------------------------------------------------- /src/miipher_2/utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F # noqa: N812 3 | 4 | 5 | class DataCollatorAudioPad: 6 | """ 7 | Pad 1-D waveforms to the max length in batch (zero-padding). 8 | Returns dict ready for Trainer. 9 | """ 10 | 11 | def __call__(self, features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 12 | inputs = [f["input_values"] for f in features] 13 | targets = [f["labels"] for f in features] 14 | 15 | max_len = max(x.size(-1) for x in inputs) 16 | batch_in = torch.stack([F.pad(x, (0, max_len - x.size(-1))) for x in inputs]) 17 | batch_lbl = torch.stack([F.pad(y, (0, max_len - y.size(-1))) for y in targets]) 18 | 19 | return {"input_values": batch_in, "labels": batch_lbl} 20 | -------------------------------------------------------------------------------- /scripts/count_webdataset_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import webdataset as wds 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser(description="Count samples in a WebDataset") 8 | parser.add_argument( 9 | "path_pattern", 10 | type=str, 11 | help="Path pattern for the WebDataset shards (e.g., 'data/jvs-train-{000000..000020}.tar.gz')", 12 | ) 13 | args = parser.parse_args() 14 | 15 | # resampled=False にして、1周でループが終了するように設定 16 | dataset = wds.WebDataset(args.path_pattern, resampled=False) 17 | 18 | count = 0 19 | for _ in dataset: 20 | count += 1 21 | if count % 1000 == 0: 22 | print(f"Counted {count} samples...") 23 | 24 | print(f"Total samples found: {count}") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /configs/infer_mhubert_l6.yaml: -------------------------------------------------------------------------------- 1 | # --- 一括処理の基本設定 --- 2 | input_dir: "/home/ayu/GitHub/miipher-plaoground/PA_E3" # PA_E3, samples_8khz_16khz, degrade_samples 3 | output_dir: "/home/ayu/GitHub/miipher-plaoground/mhubert_l6/PA_E3" 4 | 5 | # 処理対象とする音声ファイルの拡張子リスト 6 | extensions: 7 | - ".wav" 8 | 9 | # --- 推論の基本設定 --- 10 | device: "cuda" # "cuda" または "cpu" 11 | 12 | # --- チェックポイントのパス --- 13 | # Adapter学習で生成された最終モデル 14 | adapter_ckpt: "/home/ayu/GitHub/open-miipher-2/exp/adapter_layer_6_mhubert_147/checkpoint_199k_fixed.pt" 15 | 16 | vocoder_ckpt: /home/ayu/GitHub/open-miipher-2/exp/ssl-vocoder/epoch=77-step=137108.ckpt 17 | 18 | 19 | # --- モデル設定 (Fine-tuning時と一致させる) --- 20 | model: 21 | hubert_model_name: "utter-project/mHuBERT-147" 22 | hubert_layer: 6 23 | adapter_hidden_dim: 768 24 | 25 | 26 | # --- 出力設定 --- 27 | output_sampling_rate: 22050 28 | -------------------------------------------------------------------------------- /configs/infer_wavlm_base_l2.yaml: -------------------------------------------------------------------------------- 1 | # --- 一括処理の基本設定 --- 2 | input_dir: "/home/ayu/GitHub/miipher-plaoground/PA_E3" # PA_E3, samples_8khz_16khz, degrade_samples 3 | output_dir: "/home/ayu/GitHub/miipher-plaoground/wavlm_base_l2/PA_E3" 4 | 5 | # 処理対象とする音声ファイルの拡張子リスト 6 | extensions: 7 | - ".wav" 8 | 9 | # --- 推論の基本設定 --- 10 | device: "cuda" # "cuda" または "cpu" 11 | 12 | # --- チェックポイントのパス --- 13 | # Adapter学習で生成された最終モデル 14 | adapter_ckpt: "/home/ayu/GitHub/open-miipher-2/exp/wavlm_base_l2/checkpoint_199k.pt" 15 | 16 | vocoder_ckpt: /home/ayu/GitHub/ssl-vocoders/tb_logs/lightning_logs/version_28/checkpoints/epoch=219-step=404920.ckpt 17 | 18 | 19 | # --- モデル設定 (Fine-tuning時と一致させる) --- 20 | model: 21 | hubert_model_name: "microsoft/wavlm-base" 22 | hubert_layer: 2 23 | adapter_hidden_dim: 768 24 | 25 | 26 | # --- 出力設定 --- 27 | output_sampling_rate: 22050 28 | -------------------------------------------------------------------------------- /configs/infer_wav2vec2_base_l2.yaml: -------------------------------------------------------------------------------- 1 | # --- 一括処理の基本設定 --- 2 | input_dir: "/home/ayu/GitHub/miipher-plaoground/PA_E3" # PA_E3, samples_8khz_16khz, degrade_samples 3 | output_dir: "/home/ayu/GitHub/miipher-plaoground/wav2vec2_base_l2/PA_E3" 4 | 5 | # 処理対象とする音声ファイルの拡張子リスト 6 | extensions: 7 | - ".wav" 8 | 9 | # --- 推論の基本設定 --- 10 | device: "cuda" # "cuda" または "cpu" 11 | 12 | # --- チェックポイントのパス --- 13 | # Adapter学習で生成された最終モデル 14 | adapter_ckpt: "/home/ayu/GitHub/open-miipher-2/exp/wav2vec2_base_l2/checkpoint_199k.pt" 15 | 16 | vocoder_ckpt: /home/ayu/GitHub/ssl-vocoders/tb_logs/lightning_logs/version_29/checkpoints/epoch=125-step=227636.ckpt 17 | 18 | 19 | # --- モデル設定 (Fine-tuning時と一致させる) --- 20 | model: 21 | hubert_model_name: "rinna/japanese-wav2vec2-base" 22 | hubert_layer: 2 23 | adapter_hidden_dim: 768 24 | 25 | 26 | # --- 出力設定 --- 27 | output_sampling_rate: 22050 28 | -------------------------------------------------------------------------------- /configs/infer_hubert_large_l2.yaml: -------------------------------------------------------------------------------- 1 | # --- 一括処理の基本設定 --- 2 | input_dir: "/home/ayu/GitHub/miipher-plaoground/degrade_samples" # PA_E3, samples_8khz_16khz, degrade_samples 3 | output_dir: "/home/ayu/GitHub/miipher-plaoground/hubert_large_l2/degrade_samples" 4 | 5 | # 処理対象とする音声ファイルの拡張子リスト 6 | extensions: 7 | - ".wav" 8 | 9 | # --- 推論の基本設定 --- 10 | device: "cuda" # "cuda" または "cpu" 11 | 12 | # --- チェックポイントのパス --- 13 | # Adapter学習で生成された最終モデル 14 | adapter_ckpt: "/home/ayu/GitHub/open-miipher-2/exp/adapter_l2/checkpoint_199k.pt" 15 | 16 | vocoder_ckpt: /home/ayu/GitHub/ssl-vocoders/tb_logs/lightning_logs/version_22/checkpoints/epoch=240-step=445008.ckpt 17 | 18 | 19 | # --- モデル設定 (Fine-tuning時と一致させる) --- 20 | model: 21 | hubert_model_name: "rinna/japanese-hubert-large" 22 | hubert_layer: 2 23 | adapter_hidden_dim: 1024 24 | 25 | 26 | # --- 出力設定 --- 27 | output_sampling_rate: 22050 28 | -------------------------------------------------------------------------------- /demo/models/miipher2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "miipher2", 3 | "architecture": "speech_enhancement", 4 | "components": { 5 | "adapter": { 6 | "architecture": "parallel_adapter", 7 | "base_model": "utter-project/mHuBERT-147", 8 | "hubert_layer": 6, 9 | "adapter_hidden_dim": 768, 10 | "checkpoint_file": "checkpoint_199k_fixed.pt", 11 | "training_steps": "199k" 12 | }, 13 | "vocoder": { 14 | "architecture": "lightning_ssl_vocoder", 15 | "base_architecture": "hifigan", 16 | "checkpoint_file": "epoch=77-step=137108.ckpt", 17 | "training_epoch": 77, 18 | "training_step": 137108 19 | } 20 | }, 21 | "model_description": "Miipher-2: Complete speech enhancement system with Parallel Adapter and SSL-Vocoder", 22 | "output_sampling_rate": 22050, 23 | "version": "1.0.0", 24 | "paper": "Miipher-2: Speech Enhancement with Parallel Adapters", 25 | "license": "Apache-2.0", 26 | "authors": "Miipher-2 Team" 27 | } -------------------------------------------------------------------------------- /scripts/download_libri_tts_r.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # エラー時に即停止 4 | 5 | mkdir -p data/download 6 | cd data/download 7 | 8 | BASE_URL="https://www.openslr.org/resources/141" 9 | FILES=( 10 | "doc.tar.gz" 11 | "dev_clean.tar.gz" 12 | "dev_other.tar.gz" 13 | "test_clean.tar.gz" 14 | "test_other.tar.gz" 15 | "train_clean_100.tar.gz" 16 | "train_clean_360.tar.gz" 17 | "train_other_500.tar.gz" 18 | "libritts_r_failed_speech_restoration_examples.tar.gz" 19 | "md5sum.txt" 20 | ) 21 | 22 | # 保存先ディレクトリを作成 23 | mkdir -p libritts_r 24 | cd libritts_r 25 | 26 | # ダウンロード 27 | for FILE in "${FILES[@]}"; do 28 | echo "Downloading $FILE..." 29 | curl -L -O "$BASE_URL/$FILE" 30 | done 31 | 32 | # チェックサム検証 33 | echo "Verifying checksums..." 34 | if command -v md5sum &> /dev/null; then 35 | md5sum -c md5sum.txt 36 | elif command -v md5 &> /dev/null; then 37 | while read -r CHECKSUM FILE; do 38 | echo "$CHECKSUM $FILE" | md5 -r -c - 39 | done < md5sum.txt 40 | else 41 | echo "Error: md5sum Not found." 42 | exit 1 43 | fi 44 | 45 | echo "All files downloaded and verified successfully." 46 | -------------------------------------------------------------------------------- /scripts/check_hf_datasets_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | 4 | import datasets 5 | import tqdm 6 | 7 | SR = 24_000 8 | 9 | 10 | def main(path: str) -> None: 11 | ds = datasets.load_from_disk(path).cast_column("audio", datasets.Audio(sampling_rate=SR)) 12 | utterances = len(ds) 13 | langs = collections.defaultdict(float) 14 | tot_samples = 0 15 | 16 | for ex in tqdm.tqdm(ds, desc="scanning"): 17 | n = len(ex["audio"]["array"]) 18 | tot_samples += n 19 | langs[ex["language"]] += n / SR 20 | 21 | tot_hours = tot_samples / SR / 3600 22 | 23 | print("=== Dataset Stats ===") 24 | print(f" utterances : {utterances:,}") 25 | print(f" total duration : {tot_hours:,.2f} h") 26 | print(" duration by language (h):") 27 | for lg, sec in sorted(langs.items(), key=lambda x: -x[1]): 28 | print(f" {lg:<5} : {sec / 3600:6.1f}") 29 | 30 | 31 | if __name__ == "__main__": 32 | ap = argparse.ArgumentParser() 33 | ap.add_argument("--data", required=True, help="path passed to datasets.load_from_disk()") 34 | main(ap.parse_args().data) 35 | -------------------------------------------------------------------------------- /scripts/fix_checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # 元のチェックポイントファイルを指定 4 | checkpoint_path = 'exp/adapter_layer_6_mhubert_147/checkpoint_199k.pt' 5 | 6 | # weights_only=False を追加して読み込む 7 | print(f"Loading checkpoint: {checkpoint_path}") 8 | checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) 9 | print("Checkpoint loaded successfully.") 10 | 11 | # 新しいstate_dictを作成 12 | model_state_dict = checkpoint['model_state_dict'] 13 | new_state_dict = {} 14 | 15 | # 'extractor.hubert.' を 'extractor.model.' に置換する 16 | for key, value in model_state_dict.items(): 17 | if key.startswith('extractor.hubert.'): 18 | new_key = key.replace('extractor.hubert.', 'extractor.model.', 1) 19 | new_state_dict[new_key] = value 20 | else: 21 | new_state_dict[key] = value # その他のキーはそのままコピー 22 | 23 | # 新しいstate_dictをチェックポイントにセットする 24 | checkpoint['model_state_dict'] = new_state_dict 25 | 26 | # 修正したチェックポイントを新しいファイルとして保存する 27 | new_checkpoint_path = 'exp/adapter_layer_6_mhubert_147/checkpoint_199k_fixed.pt' 28 | torch.save(checkpoint, new_checkpoint_path) 29 | 30 | print(f"Fixed checkpoint saved to: {new_checkpoint_path}") 31 | -------------------------------------------------------------------------------- /docs/評価指標.md: -------------------------------------------------------------------------------- 1 | ### 1. MCD (Mel-Cepstral Distortion) 2 | メルケプストラム歪み 3 | - 評価する内容: 音声の音質や音色が、元のクリーン音声と比べてどれだけ歪んでいる(異なっている)か。スペクトル包絡(声の個性を特徴づける周波数特性)の類似度を評価します。 4 | - 解釈:値が低いほど良いです。0に近いほど、音色が元の音声に近いことを意味します。音声変換や音声合成の品質評価で最も一般的に使われる指標の一つです。 5 | 6 | 7 | ### 2. XvecCos / ECAPACos 8 | 話者ベクトル・コサイン類似度 9 | - 評価する内容: 音声の話者性が、元のクリーン音声と比べてどれだけ維持されているか。話者埋め込み同士がどれだけ似ているかを計算します。 10 | - 解釈: 値が高いほど良いです。コサイン類似度で計算され、最大値は1.0です。1.0に近いほど、声の主が同一人物であるとモデルが判断したことを意味します。 11 | 12 | 13 | ### 3. WER (Word Error Rate) 14 | 日本語名 15 | - 評価する内容: 音声の明瞭度・聞き取りやすさ。音声認識モデル(ASR)を使って音声を文字起こしし、その結果が元のクリーン音声の正しいテキストと比べてどれだけ間違っているかを評価します。 16 | - 解釈: 値が低いほど良いです。0%に近いほど、発話内容がクリアで正確に聞き取れることを示します。計算式は `(置換単語数 + 削除単語数 + 挿入単語数) / 正解の総単語数` です。 17 | 18 | 19 | ### 4. logF0-RMSE (log F0 Root Mean Square Error) 20 | 対数基本周波数・二乗平均平方根誤差 21 | - 評価する内容: 声の高さ(ピッチ)や抑揚(イントネーション)が、元のクリーン音声と比べてどれだけ正確に再現されているか。F0(基本周波数)は声の高さの物理的な指標です。 22 | - 解釈**: 値が低いほど良いです。0に近いほど、抑揚が元の音声と完全に一致していることを意味します。人間の聴覚はピッチを対数スケールで知覚するため、単純なF0ではなく対数(log)スケールで誤差を計算するのが一般的です。 23 | 24 | 25 | ### まとめ 26 | 27 | | 指標 | 評価項目 | 良いスコア | 説明 | 28 | | :--- | :--- | :--- | :--- | 29 | | **MCD** | 音質・音色 | **低い** | スペクトル包絡の歪み。 | 30 | | **XvecCos・ECAPACos** | 話者性 | **高い** | 話者性が維持されているか。 | 31 | | **WER** | 明瞭度 | **低い** | 内容が正しく聞き取れるか。 | 32 | | **logF0-RMSE** | 抑揚・ピッチ | **低い** | 声の高さの自然さが再現されているか。 | 33 | -------------------------------------------------------------------------------- /src/miipher_2/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import torch 5 | import torchaudio 6 | from torch.utils.data import Dataset 7 | 8 | from miipher_2.utils.audio import SR, add_noise, add_reverb, codec 9 | 10 | # Constants for magic numbers 11 | REVERB_PROBABILITY = 0.5 12 | CODEC_PROBABILITY = 0.8 13 | 14 | 15 | class CleanNoisyDataset(Dataset): 16 | """ 17 | 戻り値: noisy (B,1,T), clean (B,1,T) 18 | """ 19 | 20 | def __init__(self, wav_files: list[str | Path]) -> None: 21 | self.wav_files = [Path(f) for f in wav_files] 22 | 23 | def __len__(self) -> int: 24 | return len(self.wav_files) 25 | 26 | def _degrade(self, wav: torch.Tensor) -> torch.Tensor: 27 | if random.random() < REVERB_PROBABILITY: 28 | wav = add_reverb(wav, rt60=random.uniform(0.2, 0.5)) 29 | wav = add_noise(wav, snr_db=random.uniform(5, 30)) 30 | if random.random() < CODEC_PROBABILITY: 31 | wav = codec(wav, random.choice(["mp3", "opus", "vorbis", "alaw", "amr"])) 32 | return wav 33 | 34 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 35 | wav, sr = torchaudio.load(self.wav_files[idx]) 36 | if sr != SR: 37 | wav = torchaudio.functional.resample(wav, sr, SR) 38 | wav = wav.mean(0, keepdim=True) # mono 39 | noisy = self._degrade(wav.clone()) 40 | return noisy, wav # (1,T), (1,T) 41 | -------------------------------------------------------------------------------- /configs/adapter_l2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "rinna/japanese-hubert-large" # HuBERTベースモデル名 3 | hubert_layer: 2 4 | adapter_hidden_dim: 1024 # Adapterの中間層の次元数 5 | 6 | dataset: 7 | path_pattern: 8 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 9 | val_path_pattern: 10 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 11 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 12 | batch_size: 8 13 | steps: 200000 14 | validation_interval: 500 15 | validation_batches: 1 16 | 17 | training: 18 | gradient_accumulation_steps: 1 19 | mixed_precision: "no" 20 | dataloader_drop_last: true 21 | 22 | optim: 23 | lr: 2.0e-4 24 | weight_decay: 0.01 25 | betas: [0.9, 0.95] 26 | max_grad_norm: 1.0 27 | scheduler: 28 | name: "constant_with_warmup" 29 | warmup_steps: 100 30 | 31 | loader: 32 | num_workers: 8 33 | pin_memory: true 34 | 35 | save_dir: exp/adapter_l2 36 | log_interval: 100 # iter ごとに損失表示 37 | 38 | # Checkpoint configuration 39 | checkpoint: 40 | save_interval: 1000 # 1kステップごとにチェックポイント保存 41 | resume_from: null # 再開用チェックポイントパス 42 | keep_last_n: 500 # 最新N個のチェックポイントを保持 43 | save_wandb_metadata: true # wandb情報も保存 44 | 45 | # Wandb logging configuration 46 | wandb: 47 | enabled: true 48 | project: "miipher-2-adapter" 49 | entity: null # デフォルトのwandbエンティティを使用 50 | name: null # 実行名を自動生成 51 | tags: ["hubert", "adapter", "training"] 52 | notes: "Parallel Adapter training for Miipher-2" 53 | log_model: false 54 | -------------------------------------------------------------------------------- /configs/wavlm_base_l2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "microsoft/wavlm-base" # SSLモデル名 3 | hubert_layer: 2 4 | adapter_hidden_dim: 768 # Adapterの中間層の次元数 5 | 6 | 7 | dataset: 8 | path_pattern: 9 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 10 | val_path_pattern: 11 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 12 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 13 | batch_size: 16 14 | steps: 200000 15 | validation_interval: 500 16 | validation_batches: 1 17 | 18 | training: 19 | gradient_accumulation_steps: 1 20 | mixed_precision: "no" 21 | dataloader_drop_last: true 22 | 23 | optim: 24 | lr: 2.0e-4 25 | weight_decay: 0.01 26 | betas: [0.9, 0.95] 27 | max_grad_norm: 1.0 28 | scheduler: 29 | name: "constant_with_warmup" 30 | warmup_steps: 100 31 | 32 | loader: 33 | num_workers: 8 34 | pin_memory: true 35 | 36 | save_dir: exp/wavlm_base_l2 37 | log_interval: 100 # iter ごとに損失表示 38 | 39 | # Checkpoint configuration 40 | checkpoint: 41 | save_interval: 1000 # 1kステップごとにチェックポイント保存 42 | resume_from: null # 再開用チェックポイントパス 43 | keep_last_n: 500 # 最新N個のチェックポイントを保持 44 | save_wandb_metadata: true # wandb情報も保存 45 | 46 | # Wandb logging configuration 47 | wandb: 48 | enabled: true 49 | project: "miipher-2-adapter" 50 | entity: null # デフォルトのwandbエンティティを使用 51 | name: null # 実行名を自動生成 52 | tags: ["wavlm", "adapter", "training"] 53 | notes: "Parallel Adapter training for Miipher-2" 54 | log_model: false 55 | -------------------------------------------------------------------------------- /configs/wav2vec2_base_l2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "rinna/japanese-wav2vec2-base" # SSLモデル名 3 | hubert_layer: 2 4 | adapter_hidden_dim: 768 # Adapterの中間層の次元数 5 | model_type: "wav2vec2" # モデルタイプを明示的に指定 6 | 7 | dataset: 8 | path_pattern: 9 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 10 | val_path_pattern: 11 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 12 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 13 | batch_size: 16 14 | steps: 200000 15 | validation_interval: 500 16 | validation_batches: 1 17 | 18 | training: 19 | gradient_accumulation_steps: 1 20 | mixed_precision: "no" 21 | dataloader_drop_last: true 22 | 23 | optim: 24 | lr: 2.0e-4 25 | weight_decay: 0.01 26 | betas: [0.9, 0.95] 27 | max_grad_norm: 1.0 28 | scheduler: 29 | name: "constant_with_warmup" 30 | warmup_steps: 100 31 | 32 | loader: 33 | num_workers: 8 34 | pin_memory: true 35 | 36 | save_dir: exp/wav2vec2_base_l2 37 | log_interval: 100 # iter ごとに損失表示 38 | 39 | # Checkpoint configuration 40 | checkpoint: 41 | save_interval: 1000 # 1kステップごとにチェックポイント保存 42 | resume_from: null # 再開用チェックポイントパス 43 | keep_last_n: 500 # 最新N個のチェックポイントを保持 44 | save_wandb_metadata: true # wandb情報も保存 45 | 46 | # Wandb logging configuration 47 | wandb: 48 | enabled: true 49 | project: "miipher-2-adapter" 50 | entity: null # デフォルトのwandbエンティティを使用 51 | name: null # 実行名を自動生成 52 | tags: ["wav2vec2", "adapter", "training"] 53 | notes: "Parallel Adapter training for Miipher-2" 54 | log_model: false 55 | -------------------------------------------------------------------------------- /configs/wav2vec2_large_l2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "facebook/wav2vec2-large-xlsr-53" # SSLモデル名 3 | hubert_layer: 2 4 | adapter_hidden_dim: 1024 # Adapterの中間層の次元数 5 | model_type: "wav2vec2" # モデルタイプを明示的に指定 6 | 7 | dataset: 8 | path_pattern: 9 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 10 | val_path_pattern: 11 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 12 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 13 | batch_size: 8 14 | steps: 200000 15 | validation_interval: 500 16 | validation_batches: 1 17 | 18 | training: 19 | gradient_accumulation_steps: 1 20 | mixed_precision: "no" 21 | dataloader_drop_last: true 22 | 23 | optim: 24 | lr: 2.0e-4 25 | weight_decay: 0.01 26 | betas: [0.9, 0.95] 27 | max_grad_norm: 1.0 28 | scheduler: 29 | name: "constant_with_warmup" 30 | warmup_steps: 100 31 | 32 | loader: 33 | num_workers: 8 34 | pin_memory: true 35 | 36 | save_dir: exp/wav2vec2_large_l2 37 | log_interval: 100 # iter ごとに損失表示 38 | 39 | # Checkpoint configuration 40 | checkpoint: 41 | save_interval: 1000 # 1kステップごとにチェックポイント保存 42 | resume_from: null # 再開用チェックポイントパス 43 | keep_last_n: 500 # 最新N個のチェックポイントを保持 44 | save_wandb_metadata: true # wandb情報も保存 45 | 46 | # Wandb logging configuration 47 | wandb: 48 | enabled: true 49 | project: "miipher-2-adapter" 50 | entity: null # デフォルトのwandbエンティティを使用 51 | name: null # 実行名を自動生成 52 | tags: ["wav2vec2", "adapter", "training"] 53 | notes: "Parallel Adapter training for Miipher-2" 54 | log_model: false 55 | -------------------------------------------------------------------------------- /results/summary_8khz.csv: -------------------------------------------------------------------------------- 1 | name,ecapa_cos_mean,ecapa_cos_std,count,dnsmos_p808_mean,dnsmos_p808_std,dnsmos_sig_mean,dnsmos_sig_std,dnsmos_bak_mean,dnsmos_bak_std,dnsmos_ovr_mean,dnsmos_ovr_std 2 | original,0.9999999868869781,7.436952758635848e-08,300,3.050243421100435,0.23513381104495876,3.2309267350593416,0.3713587652836165,3.2253586782567814,0.4590828290237858,2.596128871731708,0.37026601122320457 3 | 8khz_degraded,0.6724535860617955,0.07182601741817676,300,2.754626776322486,0.20492790223668916,3.2194568280633713,0.3530187739853043,3.2523069589040765,0.4412417081563691,2.5905245107841663,0.36210046582205996 4 | hubert_large_l2,0.4939083201686541,0.09100778810102289,300,2.9762201866006097,0.17697652104010989,3.1332352828245598,0.24435322035339524,3.5224471241749407,0.30528709624211,2.550861301299079,0.22267570896207334 5 | mhubert_l6,0.4328003499408563,0.10876108810551353,300,3.3253866619723182,0.2802464441122725,3.31210966479506,0.20738430355982668,4.089147848566308,0.08552702153371265,3.0442597620828993,0.21342142880950118 6 | miipher_1,0.4645548557738463,0.09016346312598403,300,3.4354899323062287,0.2720519377461226,3.2815594052001313,0.19911308807159914,4.044542159916082,0.10586306360835754,3.0086758017910626,0.210636941585021 7 | wav2vec2_base_l2,0.5179871685306231,0.10191337619534478,300,3.262616547255289,0.22771706425077864,3.190366639959415,0.19636426279724933,4.043022268836082,0.09618723276048831,2.885088914715393,0.21082752772753016 8 | wavlm_base_l2,0.5220717957615852,0.11150853319185897,300,3.312712347473417,0.22931823836750326,3.211558293290149,0.20107913191611443,4.053447031675322,0.09105232063963782,2.9114680497813272,0.218296744025802 9 | -------------------------------------------------------------------------------- /results/summary_degrade.csv: -------------------------------------------------------------------------------- 1 | name,ecapa_cos_mean,ecapa_cos_std,count,dnsmos_p808_mean,dnsmos_p808_std,dnsmos_sig_mean,dnsmos_sig_std,dnsmos_bak_mean,dnsmos_bak_std,dnsmos_ovr_mean,dnsmos_ovr_std 2 | original,0.9999999868869781,7.436952758635848e-08,300,3.050243421100435,0.23513381104495876,3.2309267350593416,0.3713587652836165,3.2253586782567814,0.4590828290237858,2.596128871731708,0.37026601122320457 3 | noise_degraded,0.9494669326146443,0.041764857755093844,300,2.803015498628692,0.2353215734550551,2.4490575133884236,0.8251526660236204,1.9792567965151198,0.6208523921044735,1.7907272296298289,0.5290055705886365 4 | hubert_large_l2,0.688731447160244,0.06276563002826571,300,3.1072856800385886,0.23204310756646057,3.271121402580063,0.1886189890298814,3.64059973449956,0.24070088856824107,2.7577867637388573,0.2206539901929696 5 | mhubert_l6,0.5240665201842785,0.10163385066011302,300,3.4740742364478487,0.27810250410296766,3.348599887545861,0.19448651716509424,4.095195120179198,0.10021498835544398,3.0956964281955064,0.2067577011406275 6 | miipher_1,0.6960392619172732,0.08371855199385392,300,3.536531817578134,0.2930578353680913,3.3452072112208695,0.1907791842371167,4.061753196933847,0.10214940153058852,3.0796673145208087,0.20281000017931688 7 | wav2vec2_base_l2,0.6112834508220355,0.08663966863767417,300,3.518550822131217,0.25897106099315137,3.3521183915677053,0.19190458214226785,4.087553531975515,0.0939524196407561,3.0846790690758463,0.20412992137199953 8 | wavlm_base_l2,0.6478396042188008,0.08288519060456478,300,3.5249499833281077,0.2503511193584741,3.3355331494241525,0.18966846982122368,4.080146309112477,0.1001820208404826,3.063547710308511,0.20543776670348513 9 | -------------------------------------------------------------------------------- /demo/README_spaces.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Miipher-2 Speech Enhancement Demo 3 | emoji: 🎵 4 | colorFrom: blue 5 | colorTo: purple 6 | sdk: gradio 7 | sdk_version: 4.0.0 8 | app_file: app.py 9 | pinned: false 10 | license: apache-2.0 11 | --- 12 | 13 | # Miipher-2 Speech Enhancement Demo 14 | 15 | Miipher-2 is a speech enhancement system that uses Parallel Adapters inserted into mHuBERT layers to improve audio quality. 16 | 17 | ## Features 18 | 19 | - **Real-time speech enhancement** from noisy or degraded audio 20 | - **Parallel Adapter architecture** for efficient fine-tuning 21 | - **Lightning SSL-Vocoder** for high-quality audio synthesis 22 | - **Easy-to-use Gradio interface** 23 | 24 | ## Model Architecture 25 | 26 | 1. **SSL Feature Extractor**: mHuBERT-147 (Layer 6) 27 | 2. **Parallel Adapter**: Lightweight feedforward network 28 | 3. **Lightning SSL-Vocoder**: HiFi-GAN based vocoder 29 | 30 | ## Usage 31 | 32 | 1. Upload an audio file or record using your microphone 33 | 2. Click "音声を修復" (Enhance Audio) 34 | 3. Listen to the enhanced audio output 35 | 36 | ## Models 37 | 38 | The demo automatically downloads the unified model from: 39 | - Complete Model: `Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1` (includes both Adapter and Vocoder) 40 | 41 | ## Technical Details 42 | 43 | - **Input**: Audio files (WAV, MP3, FLAC) 44 | - **Output**: Enhanced audio at 22050Hz 45 | - **Supported Languages**: Primarily trained on Japanese but works with other languages 46 | - **Processing**: Real-time inference on CPU/GPU 47 | 48 | ## License 49 | 50 | Apache-2.0 51 | 52 | ## Citation 53 | 54 | If you use Miipher-2 in your research, please cite: 55 | 56 | ```bibtex 57 | @article{miipher2, 58 | title={Miipher-2: Speech Enhancement with Parallel Adapters}, 59 | author={Your Name}, 60 | year={2024} 61 | } 62 | ``` -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Miipher 2 HuBERT HiFi GAN V0.1 3 | emoji: 🎤 4 | colorFrom: blue 5 | colorTo: purple 6 | sdk: gradio 7 | sdk_version: 5.38.0 8 | app_file: app.py 9 | pinned: false 10 | license: apache-2.0 11 | models: 12 | - Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1 13 | --- 14 | 15 | # 🎤 Miipher-2 Speech Enhancement Demo 16 | 17 | This is a Gradio demo for **Miipher-2**, a high-quality speech enhancement model that combines HuBERT, Parallel Adapters, and HiFi-GAN vocoder. 18 | 19 | ## Features 20 | 21 | - **Real-time speech enhancement** - Remove noise, reverb, and other degradations 22 | - **Multilingual support** - Built on mHuBERT-147 for 147 languages 23 | - **High-quality output** - 22.05kHz audio output 24 | - **Easy to use** - Simple drag-and-drop or microphone input 25 | 26 | ## Model Details 27 | 28 | - **Paper**: [Miipher-2: High-Quality Speech Enhancement](https://arxiv.org/abs/2505.04457) 29 | - **Model**: [Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1](https://huggingface.co/Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1) 30 | - **GitHub**: [open-miipher-2](https://github.com/your-repo/open-miipher-2) 31 | 32 | ## How to Use 33 | 34 | 1. **Upload** an audio file or record using microphone 35 | 2. Click **"Enhance Audio"** button 36 | 3. **Download** the enhanced result 37 | 38 | ## Technical Details 39 | 40 | The model uses: 41 | - **SSL Backbone**: mHuBERT-147 (multilingual) 42 | - **Adapter**: Parallel adapters inserted at layer 6 43 | - **Vocoder**: HiFi-GAN trained on SSL features 44 | - **Input**: Any sample rate (auto-resampled to 16kHz) 45 | - **Output**: 22.05kHz enhanced audio 46 | 47 | ## Citation 48 | 49 | ```bibtex 50 | @article{miipher2024, 51 | title={Miipher-2: High-Quality Speech Enhancement via Self-Supervised Learning}, 52 | author={Your Name and Others}, 53 | journal={arXiv preprint arXiv:2505.04457}, 54 | year={2024} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /configs/ums.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "Atotti/Google-USM" 3 | hubert_layer: 12 4 | adapter_hidden_dim: 1536 # Adapterの中間層の次元数 5 | 6 | dataset: 7 | path_pattern: 8 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 9 | - /home/ayu/datasets/fleurs-r_preprocessed/fleurs-r-train-{000000..000910}.tar.gz 10 | - /home/ayu/datasets/libritts_r_preprocessed/libritts_r-train-{000000..000504}.tar.gz 11 | val_path_pattern: 12 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 13 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 14 | batch_size: 16 15 | steps: 2000000 16 | validation_interval: 10000 17 | validation_batches: 16 18 | 19 | training: 20 | gradient_accumulation_steps: 1 21 | mixed_precision: "no" 22 | dataloader_drop_last: true 23 | 24 | optim: 25 | lr: 2.0e-4 26 | weight_decay: 0.01 27 | betas: [0.9, 0.95] 28 | max_grad_norm: 1.0 29 | scheduler: 30 | name: "cosine_with_restarts" 31 | warmup_steps: 10000 32 | first_cycle_steps: 100000 # 最初のサイクルの長さ 33 | cycle_mult: 1.0 # サイクル長の倍率 34 | max_lr: 2.0e-4 # 最大学習率 35 | min_lr: 1.0e-6 # 最小学習率 36 | 37 | loader: 38 | num_workers: 8 39 | pin_memory: true 40 | 41 | save_dir: exp/usm # チェックポイント保存先 42 | log_interval: 1000 # iter ごとに損失表示 43 | 44 | # Checkpoint configuration 45 | checkpoint: 46 | save_interval: 10000 # 1kステップごとにチェックポイント保存 47 | resume_from: null # 再開用チェックポイントパス 48 | keep_last_n: 500 # 最新N個のチェックポイントを保持 49 | save_wandb_metadata: true # wandb情報も保存 50 | 51 | # Wandb logging configuration 52 | wandb: 53 | enabled: true 54 | project: "miipher-2-adapter" 55 | entity: null # デフォルトのwandbエンティティを使用 56 | name: null # 実行名を自動生成 57 | tags: ["usm", "adapter", "training"] 58 | notes: "Parallel Adapter training for Miipher-2" 59 | log_model: false 60 | -------------------------------------------------------------------------------- /configs/preprocess_fleurs_r.yaml: -------------------------------------------------------------------------------- 1 | preprocess: 2 | preprocess_dataset: 3 | _target_: miipher_2.dataset.fleurs_r.FleursRCorpus 4 | root: /home/ayu/datasets/fleurs-r/ 5 | degradation: 6 | format_encoding_pairs: 7 | - format: mp3 8 | compression: 16 9 | - format: mp3 10 | compression: 32 11 | - format: mp3 12 | compression: 64 13 | - format: mp3 14 | compression: 128 15 | - format: ogg 16 | compression: -1 17 | - format: ogg 18 | compression: 0 19 | - format: ogg 20 | compression: 1 21 | - format: wav 22 | encoding: ALAW 23 | bits_per_sample: 8 24 | reverb_conditions: 25 | p: 0.5 26 | reverbation_times: 27 | max: 0.5 28 | min: 0.2 29 | room_xy: 30 | max: 10.0 31 | min: 2.0 32 | room_z: 33 | max: 5.0 34 | min: 2.0 35 | room_params: 36 | fs: 22050 37 | max_order: 10 38 | absorption: 0.2 39 | source_pos: 40 | - 1.0 41 | - 1.0 42 | - 1.0 43 | mic_pos: 44 | - 1.0 45 | - 0.7 46 | - 1.2 47 | n_rirs: 1000 48 | background_noise: 49 | snr: 50 | max: 30.0 51 | min: 5.0 52 | patterns: 53 | - 54 | - /home/audio/TAU2023/dataset/TAU-urban-acoustic-scenes-2022-mobile-development/audio/ 55 | - '**/*.wav' 56 | - 57 | - /home/audio/TAU2021/datasets/TAU-urban-acoustic-scenes-2020-mobile-development/audio/ 58 | - '**/*.wav' 59 | train_tar_sink: 60 | _target_: webdataset.ShardWriter 61 | pattern: /home/ayu/datasets/fleurs-r_preprocessed/fleurs-r-train-%06d.tar.gz 62 | val_tar_sink: 63 | _target_: webdataset.ShardWriter 64 | pattern: /home/ayu/datasets/fleurs-r_preprocessed/fleurs-r-val-%06d.tar.gz 65 | val_size: 600 66 | n_repeats: 4 67 | sampling_rate: 22050 68 | 69 | -------------------------------------------------------------------------------- /src/miipher_2/extractors/ssl_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import HubertModel, Wav2Vec2Model 3 | 4 | 5 | class SSLExtractor(torch.nn.Module): 6 | def __init__(self, model_name: str, layer: int, model_type: str = "auto") -> None: 7 | super().__init__() 8 | self.layer = layer 9 | self.model_type = model_type 10 | 11 | if model_type == "auto": 12 | if "hubert" in model_name.lower(): 13 | model_type = "hubert" 14 | elif "wav2vec2" in model_name.lower(): 15 | model_type = "wav2vec2" 16 | elif "wavlm" in model_name.lower(): 17 | model_type = "hubert" # WavLM is based on HuBERT 18 | else: 19 | msg = f"Cannot auto-detect model type for {model_name}" 20 | raise ValueError(msg) 21 | 22 | if model_type == "hubert": 23 | self.model = HubertModel.from_pretrained(model_name, output_hidden_states=True) 24 | elif model_type == "wav2vec2": 25 | self.model = Wav2Vec2Model.from_pretrained(model_name, output_hidden_states=True) 26 | else: 27 | msg = f"Unsupported model type: {model_type}" 28 | raise ValueError(msg) 29 | 30 | self.model_type = model_type 31 | 32 | def forward(self, wav: torch.Tensor) -> torch.Tensor: 33 | """ 34 | Args: 35 | wav: (B, T) float32, 16 kHz, -1 ~ 1 36 | Returns: 37 | feat: (B, C, T/320) 50 Hz 38 | """ 39 | outputs = self.model(wav, output_hidden_states=True, return_dict=True) 40 | hs: list[torch.Tensor] = outputs.hidden_states 41 | # 指定された層を転置して返す 42 | return hs[self.layer + 1].transpose(1, 2).contiguous() 43 | 44 | 45 | # 後方互換性のためのエイリアス 46 | class HubertExtractor(SSLExtractor): 47 | def __init__(self, model_name: str, layer: int) -> None: 48 | super().__init__(model_name, layer, model_type="hubert") 49 | -------------------------------------------------------------------------------- /configs/adapter_layer_6_mhubert_147.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hubert_model_name: "utter-project/mHuBERT-147" # HuBERTベースモデル名 3 | hubert_layer: 6 4 | adapter_hidden_dim: 768 # Adapterの中間層の次元数 5 | 6 | dataset: 7 | path_pattern: 8 | - /home/ayu/datasets/jvs_preprocessed/jvs-train-{000000..000025}.tar.gz 9 | - /home/ayu/datasets/fleurs-r_preprocessed/fleurs-r-train-{000000..000918}.tar.gz 10 | - /home/ayu/datasets/libritts_r_preprocessed/libritts_r-train-{000000..000504}.tar.gz 11 | val_path_pattern: 12 | - /home/ayu/datasets/jvs_preprocessed/jvs-val-{000000..000001}.tar.gz 13 | shuffle: 1000 # WebDataset 内部 shuffle バッファ 14 | batch_size: 16 15 | steps: 2000000 16 | validation_interval: 10000 17 | validation_batches: 1 18 | 19 | training: 20 | gradient_accumulation_steps: 1 21 | mixed_precision: "no" 22 | dataloader_drop_last: true 23 | 24 | optim: 25 | lr: 2.0e-4 26 | weight_decay: 0.01 27 | betas: [0.9, 0.95] 28 | max_grad_norm: 1.0 29 | scheduler: 30 | name: "cosine_with_restarts" 31 | warmup_steps: 10000 32 | first_cycle_steps: 100000 # 最初のサイクルの長さ 33 | cycle_mult: 1.0 # サイクル長の倍率 34 | max_lr: 2.0e-4 # 最大学習率 35 | min_lr: 1.0e-6 # 最小学習率 36 | 37 | loader: 38 | num_workers: 8 39 | pin_memory: true 40 | 41 | save_dir: exp/adapter_layer_6_mhubert_147_multi_lang # チェックポイント保存先 42 | log_interval: 1000 # iter ごとに損失表示 43 | 44 | # Checkpoint configuration 45 | checkpoint: 46 | save_interval: 10000 # 1kステップごとにチェックポイント保存 47 | resume_from: null # 再開用チェックポイントパス 48 | keep_last_n: 500 # 最新N個のチェックポイントを保持 49 | save_wandb_metadata: true # wandb情報も保存 50 | 51 | # Wandb logging configuration 52 | wandb: 53 | enabled: true 54 | project: "miipher-2-adapter" 55 | entity: null # デフォルトのwandbエンティティを使用 56 | name: null # 実行名を自動生成 57 | tags: ["hubert", "adapter", "training"] 58 | notes: "Parallel Adapter training for Miipher-2" 59 | log_model: false 60 | -------------------------------------------------------------------------------- /configs/preprocess_jvs.yaml: -------------------------------------------------------------------------------- 1 | preprocess: 2 | preprocess_dataset: 3 | _target_: torch.utils.data.ConcatDataset 4 | datasets: 5 | - 6 | _target_: miipher_2.dataset.jvs_corpus.JVSCorpus 7 | root: /home/ayu/datasets/jvs_ver1/ 8 | degradation: 9 | format_encoding_pairs: 10 | - format: mp3 11 | compression: 16 12 | - format: mp3 13 | compression: 32 14 | - format: mp3 15 | compression: 64 16 | - format: mp3 17 | compression: 128 18 | - format: ogg 19 | compression: -1 20 | - format: ogg 21 | compression: 0 22 | - format: ogg 23 | compression: 1 24 | - format: wav 25 | encoding: ALAW 26 | bits_per_sample: 8 27 | reverb_conditions: 28 | p: 0.5 29 | reverbation_times: 30 | max: 0.5 31 | min: 0.2 32 | room_xy: 33 | max: 10.0 34 | min: 2.0 35 | room_z: 36 | max: 5.0 37 | min: 2.0 38 | room_params: 39 | fs: 22050 40 | max_order: 10 41 | absorption: 0.2 42 | source_pos: 43 | - 1.0 44 | - 1.0 45 | - 1.0 46 | mic_pos: 47 | - 1.0 48 | - 0.7 49 | - 1.2 50 | n_rirs: 1000 51 | background_noise: 52 | snr: 53 | max: 30.0 54 | min: 5.0 55 | patterns: 56 | - 57 | - /home/audio/TAU2023/dataset/TAU-urban-acoustic-scenes-2022-mobile-development/audio/ 58 | - '**/*.wav' 59 | - 60 | - /home/audio/TAU2021/datasets/TAU-urban-acoustic-scenes-2020-mobile-development/audio/ 61 | - '**/*.wav' 62 | train_tar_sink: 63 | _target_: webdataset.ShardWriter 64 | pattern: /home/ayu/datasets/jvs_preprocessed/jvs-train-%06d.tar.gz 65 | val_tar_sink: 66 | _target_: webdataset.ShardWriter 67 | pattern: /home/ayu/datasets/jvs_preprocessed/jvs-val-%06d.tar.gz 68 | val_size: 600 69 | n_repeats: 4 70 | sampling_rate: 22050 71 | 72 | -------------------------------------------------------------------------------- /configs/preprocess_libritts_r.yaml: -------------------------------------------------------------------------------- 1 | preprocess: 2 | preprocess_dataset: 3 | _target_: torch.utils.data.ConcatDataset 4 | datasets: 5 | - 6 | _target_: miipher_2.dataset.libritts_r.LibriTTSRCorpus 7 | root: /home/ayu/datasets/libritts_r/LibriTTS_R/ 8 | degradation: 9 | format_encoding_pairs: 10 | - format: mp3 11 | compression: 16 12 | - format: mp3 13 | compression: 32 14 | - format: mp3 15 | compression: 64 16 | - format: mp3 17 | compression: 128 18 | - format: ogg 19 | compression: -1 20 | - format: ogg 21 | compression: 0 22 | - format: ogg 23 | compression: 1 24 | - format: wav 25 | encoding: ALAW 26 | bits_per_sample: 8 27 | reverb_conditions: 28 | p: 0.5 29 | reverbation_times: 30 | max: 0.5 31 | min: 0.2 32 | room_xy: 33 | max: 10.0 34 | min: 2.0 35 | room_z: 36 | max: 5.0 37 | min: 2.0 38 | room_params: 39 | fs: 22050 40 | max_order: 10 41 | absorption: 0.2 42 | source_pos: 43 | - 1.0 44 | - 1.0 45 | - 1.0 46 | mic_pos: 47 | - 1.0 48 | - 0.7 49 | - 1.2 50 | n_rirs: 1000 51 | background_noise: 52 | snr: 53 | max: 30.0 54 | min: 5.0 55 | patterns: 56 | - 57 | - /home/audio/TAU2023/dataset/TAU-urban-acoustic-scenes-2022-mobile-development/audio/ 58 | - '**/*.wav' 59 | - 60 | - /home/audio/TAU2021/datasets/TAU-urban-acoustic-scenes-2020-mobile-development/audio/ 61 | - '**/*.wav' 62 | train_tar_sink: 63 | _target_: webdataset.ShardWriter 64 | pattern: /home/ayu/datasets/libritts_r_preprocessed/libritts_r-train-%06d.tar.gz 65 | val_tar_sink: 66 | _target_: webdataset.ShardWriter 67 | pattern: /home/ayu/datasets/libritts_r_preprocessed/libritts_r-val-%06d.tar.gz 68 | val_size: 600 69 | n_repeats: 4 70 | sampling_rate: 22050 -------------------------------------------------------------------------------- /src/miipher_2/utils/audio.py: -------------------------------------------------------------------------------- 1 | import io 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | 8 | SR = 16000 9 | 10 | 11 | def load(path: str | Path) -> torch.Tensor: 12 | wav, sr = torchaudio.load(path) 13 | if sr != SR: 14 | wav = torchaudio.functional.resample(wav, sr, SR) 15 | return wav.mean(0, keepdim=True) # mono 16 | 17 | 18 | def save(path: str | Path, wav: torch.Tensor, sr: int) -> None: 19 | torchaudio.save(path, wav, sr) 20 | 21 | 22 | # --- 劣化関数 ----------------------------------------------- 23 | def add_noise(wav: torch.Tensor, snr_db: float = 20) -> torch.Tensor: 24 | noise = torch.randn_like(wav) 25 | sig_pow = wav.pow(2).mean() 26 | noise_pow = noise.pow(2).mean() 27 | factor = (sig_pow / noise_pow / (10 ** (snr_db / 10))) ** 0.5 28 | return wav + factor * noise 29 | 30 | 31 | def add_reverb(wav: torch.Tensor, rt60: float = 0.3) -> torch.Tensor: 32 | # sox コンボリューションを手軽に呼ぶ 33 | cmd = ["sox", "-t", "wav", "-", "-t", "wav", "-", "reverb", f"{rt60 * 1000:.1f}"] 34 | with subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as p: # noqa: S603 35 | out, _ = p.communicate(wav.cpu().numpy().astype("float32").tobytes()) 36 | wav_r, _ = torchaudio.load(io.BytesIO(out)) 37 | return wav_r.to(wav.device) 38 | 39 | 40 | def codec(wav: torch.Tensor, codec_name: str = "mp3") -> torch.Tensor: 41 | if codec_name == "mp3": 42 | enc = ["-ar", "16000", "-ac", "1", "-codec:a", "libmp3lame", "-b:a", "64k"] 43 | elif codec_name == "opus": 44 | enc = ["-ar", "16000", "-ac", "1", "-codec:a", "libopus", "-b:a", "32k"] 45 | elif codec_name == "vorbis": 46 | enc = ["-codec:a", "libvorbis", "-qscale:a", "3"] 47 | elif codec_name == "alaw": 48 | enc = ["-codec:a", "alaw"] 49 | else: # amr 50 | enc = ["-codec:a", "libopencore_amrwb", "-b:a", "16k"] 51 | cmd_enc = ["ffmpeg", "-f", "wav", "-i", "-", *enc, "-f", "wav", "-"] 52 | proc = subprocess.run( # noqa: S603 53 | cmd_enc, input=wav.cpu().numpy().astype("float32").tobytes(), capture_output=True, check=False 54 | ) 55 | wav_d, _ = torchaudio.load(io.BytesIO(proc.stdout)) 56 | return wav_d.to(wav.device) 57 | -------------------------------------------------------------------------------- /scripts/eval_figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import seaborn as sns 4 | 5 | noise_modes = ["degrade", "8khz", "PA_E3"] 6 | 7 | for noise_mode in noise_modes: 8 | 9 | df_v2 = pd.read_csv(f"results/{noise_mode}_miipher_2.csv") 10 | df_v1 = pd.read_csv(f"results/{noise_mode}_miipher.csv") 11 | 12 | # データを整形 13 | metrics = ["MCD", "XvecCos", "ECAPACos", "WER", "logF0_RMSE"] 14 | 15 | # Miipher v2のデータ 16 | df_v2_restored = df_v2[metrics] 17 | df_v2_restored["Condition"] = "Miipher2" 18 | 19 | # Miipher v1のデータ 20 | df_v1_restored = df_v1[metrics] 21 | df_v1_restored["Condition"] = "Miipher" 22 | 23 | # Degradedのデータ (どちらのファイルも同じはずなのでv2から使用) 24 | df_degraded = df_v2[[f"Deg_{m}" for m in metrics]] 25 | df_degraded.columns = metrics 26 | df_degraded["Condition"] = "Degraded" 27 | 28 | # 3つの条件を結合 29 | df_combined = pd.concat([df_degraded, df_v1_restored, df_v2_restored]) 30 | 31 | # プロットしやすいようにさらに整形 (Melt) 32 | df_melted = df_combined.melt(id_vars=["Condition"], var_name="Metric", value_name="Value") 33 | 34 | # 指標ごとに「高いほど良い」か「低いほど良い」かを定義 35 | lower_is_better = ["MCD", "WER", "logF0_RMSE"] 36 | 37 | # 描画 38 | sns.set_theme(style="whitegrid") 39 | g = sns.catplot( 40 | data=df_melted, 41 | x="Condition", 42 | y="Value", 43 | col="Metric", 44 | kind="box", 45 | order=["Degraded", "Miipher", "Miipher2"], 46 | palette={"Degraded": "lightcoral", "Miipher": "skyblue", "Miipher2": "mediumseagreen"}, 47 | height=4.5, 48 | aspect=0.8, 49 | sharey=False, 50 | ) 51 | 52 | # タイトルとラベルを調整 53 | g.fig.suptitle("Performance Comparison: Miipher vs Miipher2", y=1.03, size=18, weight="bold") 54 | g.set_xlabels("") 55 | g.set_ylabels("Metric Value") 56 | 57 | for _i, ax in enumerate(g.axes.flat): 58 | metric_name = ax.get_title().split("= ")[1] 59 | if metric_name in lower_is_better: 60 | ax.set_title(f"{metric_name}\n(Lower is better)") 61 | else: 62 | ax.set_title(f"{metric_name}\n(Higher is better)") 63 | 64 | plt.tight_layout(rect=[0, 0, 1, 0.97]) 65 | plt.savefig(f"results/{noise_mode}_miipher_vs_miipher2.png") 66 | -------------------------------------------------------------------------------- /src/miipher_2/model/feature_cleaner.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torch import nn 6 | 7 | from miipher_2.adapters.parallel_adapter import ParallelAdapter 8 | from miipher_2.extractors.hubert import SSLExtractor 9 | 10 | 11 | class FeatureCleaner(nn.Module): 12 | def __init__(self, cfg_model: DictConfig) -> None: 13 | super().__init__() 14 | model_type = cfg_model.get("model_type", "auto") 15 | self.extractor = SSLExtractor( 16 | model_name=cfg_model.hubert_model_name, 17 | layer=cfg_model.hubert_layer - 1, 18 | model_type=model_type, 19 | ) 20 | 21 | # ベースとなるSSLモデルの全パラメータを凍結 22 | self.extractor.model.eval() 23 | for param in self.extractor.model.parameters(): 24 | param.requires_grad = False 25 | 26 | hubert_dim = self.extractor.model.config.hidden_size 27 | 28 | num_layers_to_patch = cfg_model.hubert_layer 29 | 30 | self.adapters = nn.ModuleList( 31 | [ParallelAdapter(dim=hubert_dim, hidden=cfg_model.adapter_hidden_dim) for _ in range(num_layers_to_patch)] 32 | ) 33 | 34 | for i, blk in enumerate(self.extractor.model.encoder.layers[:num_layers_to_patch]): 35 | original_ff_forward = blk.feed_forward.forward 36 | adapter_module = self.adapters[i] 37 | 38 | # Adapterを挿入 39 | def patched_forward( 40 | hidden_states: torch.Tensor, 41 | _orig_ff: Callable = original_ff_forward, 42 | _ad: ParallelAdapter = adapter_module, 43 | ) -> torch.Tensor: 44 | # 元のFeedForward(MLP)の出力を計算 45 | ff_output = _orig_ff(hidden_states) 46 | 47 | # 同じ入力からAdapterの出力を計算 48 | adapter_output = _ad(hidden_states) 49 | 50 | # 元のMLP出力にアダプターの出力を加算する 51 | return ff_output + adapter_output 52 | 53 | # feed_forwardモジュールのforwardメソッドを新しい関数で上書き 54 | blk.feed_forward.forward = patched_forward 55 | 56 | for param in blk.final_layer_norm.parameters(): 57 | param.requires_grad = True 58 | 59 | def forward(self, wav16: torch.Tensor) -> torch.Tensor: 60 | return self.extractor(wav16) 61 | -------------------------------------------------------------------------------- /cmds/degrade.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import torch 8 | import torchaudio 9 | from tqdm.auto import tqdm 10 | 11 | from miipher_2.utils.eval_utils import degrade_waveform, get_logger 12 | 13 | log = get_logger("degrade") 14 | 15 | 16 | def load_noises(noise_dir: Path, sr: int = 16000): 17 | pool = [] 18 | for p in itertools.islice(noise_dir.glob("**/*.*"), 300): 19 | try: 20 | wav, s = torchaudio.load(p) 21 | if s != sr: 22 | wav = torchaudio.functional.resample(wav, s, sr) 23 | pool.append(wav) 24 | except Exception: 25 | continue 26 | if not pool: 27 | msg = "noise_dir から有効な wav が読み込めませんでした" 28 | raise RuntimeError(msg) 29 | log.info(f"Loaded {len(pool)} noise files") 30 | return pool 31 | 32 | 33 | def main() -> None: 34 | ap = argparse.ArgumentParser() 35 | ap.add_argument("--clean_dir", required=True, type=Path) 36 | ap.add_argument("--noise_dir", required=True, type=Path) 37 | ap.add_argument("--out_dir", required=True, type=Path) 38 | ap.add_argument("--seed", type=int, default=1234) 39 | ap.add_argument("--sr", type=int, default=16000) 40 | args = ap.parse_args() 41 | 42 | random.seed(args.seed) 43 | torch.manual_seed(args.seed) 44 | 45 | noises = load_noises(args.noise_dir, args.sr) 46 | 47 | clean_files = sorted(args.clean_dir.glob("*.wav")) 48 | tot = len(clean_files) 49 | log.info(f"Start degrading {tot} wav files") 50 | 51 | t0 = time.time() 52 | for i, wav_path in enumerate(clean_files, 1): 53 | wav, sr_ = torchaudio.load(wav_path) 54 | if sr_ != args.sr: 55 | wav = torchaudio.functional.resample(wav, sr_, args.sr) 56 | 57 | degraded = degrade_waveform(wav, args.sr, noises) 58 | 59 | out_path = args.out_dir / wav_path.relative_to(args.clean_dir) 60 | out_path.parent.mkdir(parents=True, exist_ok=True) 61 | torchaudio.save(out_path, degraded, args.sr, encoding="PCM_S", bits_per_sample=16) 62 | 63 | if i % 10 == 0 or i == tot: 64 | elapsed = time.time() - t0 65 | log.info(f"{i}/{tot} files done | elapsed {elapsed / 60:.1f} min") 66 | 67 | log.info("All files degraded successfully") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /cmds/check_adapter_arch.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | from omegaconf import DictConfig 4 | from torchinfo import summary 5 | 6 | from miipher_2.model.feature_cleaner import FeatureCleaner 7 | 8 | 9 | @hydra.main(version_base=None, config_path="../configs", config_name="adapter_layer_6_mhubert_147") 10 | def main(cfg: DictConfig) -> None: 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | print(f"Using device: {device}") 13 | 14 | print("Creating FeatureCleaner model...") 15 | model = FeatureCleaner(cfg.model) 16 | model.to(device) 17 | model.eval() 18 | 19 | print("=" * 40, "Model architecture", "=" * 40) 20 | print(model) 21 | print("=" * 80) 22 | 23 | # Sample input (16kHz audio, 3 seconds) 24 | batch_size = 2 25 | sequence_length = 48000 # 3 seconds at 16kHz 26 | input_shape = (batch_size, sequence_length) 27 | 28 | print("=" * 40, f"\nModel architecture summary (input shape: {input_shape})", "=" * 40) 29 | 30 | summary( 31 | model, 32 | input_size=input_shape, 33 | device=device, 34 | dtypes=[torch.float32], 35 | depth=3, 36 | col_names=["input_size", "output_size", "num_params", "trainable"], 37 | row_settings=["var_names"], 38 | ) 39 | 40 | print("\n" + "=" * 80) 41 | print("Parameter counts:") 42 | 43 | total_params = sum(p.numel() for p in model.parameters()) 44 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 45 | frozen_params = total_params - trainable_params 46 | 47 | print(f"Total parameters: {total_params:,}") 48 | print(f"Trainable parameters: {trainable_params:,}") 49 | print(f"Frozen parameters: {frozen_params:,}") 50 | print(f"Trainable ratio: {trainable_params / total_params:.2%}") 51 | 52 | # Show adapter-specific parameters 53 | adapter_params = 0 54 | for name, param in model.named_parameters(): 55 | if "adapter" in name and param.requires_grad: 56 | adapter_params += param.numel() 57 | 58 | print(f"Adapter parameters: {adapter_params:,}") 59 | 60 | print("\n" + "=" * 80) 61 | print("Trainable modules:") 62 | for name, param in model.named_parameters(): 63 | if param.requires_grad: 64 | print(f" {name}: {param.shape} ({param.numel():,} params)") 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /src/miipher_2/dataset/jvs_corpus.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal 3 | 4 | import torchaudio 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class JVSCorpus(Dataset): 9 | def __init__(self, root: str, exclude_speakers: tuple = ()) -> None: 10 | super().__init__() 11 | self.root = Path(root) 12 | self.speakers = [f.stem for f in self.root.glob("jvs*") if f.is_dir() and f.stem not in exclude_speakers] 13 | self.clean_texts = {} 14 | self.wav_files = [] 15 | for speaker in self.speakers: 16 | transcript_files = (self.root / speaker).glob("**/transcripts_utf8.txt") 17 | for transcript_file in transcript_files: 18 | subset = transcript_file.parent.name 19 | with transcript_file.open() as f: 20 | lines = f.readlines() 21 | for line in lines: 22 | wav_name, text = line.strip().split(":") 23 | self.clean_texts[f"{speaker}/{subset}/{wav_name}"] = text 24 | wav_path = self.root / Path(f"{speaker}/{subset}/wav24kHz16bit/{wav_name}.wav") 25 | if wav_path.exists(): 26 | self.wav_files.append(wav_path) 27 | 28 | def __getitem__(self, index: int) -> dict[str, str]: 29 | wav_path = self.wav_files[index] 30 | wav_tensor, sr = torchaudio.load(wav_path) 31 | wav_path = wav_path.resolve() 32 | speaker = wav_path.parent.parent.parent.stem 33 | subset = wav_path.parent.parent.stem 34 | wav_name = wav_path.stem 35 | 36 | clean_text = self.clean_texts[f"{speaker}/{subset}/{wav_name}"] 37 | 38 | basename = f"{subset}_{speaker}_{wav_name}" 39 | return { 40 | "wav_path": str(wav_path), 41 | "speaker": speaker, 42 | "clean_text": clean_text, 43 | "basename": basename, 44 | "lang_code": "jpn", 45 | } 46 | 47 | def __len__(self) -> int: 48 | return len(self.wav_files) 49 | 50 | @property 51 | def speaker_dict(self) -> dict[str, int]: 52 | speakers = set() 53 | for wav_path in self.wav_files: 54 | speakers.add(wav_path.parent.parent.parent.stem) 55 | return {x: idx for idx, x in enumerate(speakers)} 56 | 57 | @property 58 | def lang_code(self) -> Literal["jpn"]: 59 | return "jpn" 60 | -------------------------------------------------------------------------------- /docs/LIGHTNING_VOCODER_MIGRATION.md: -------------------------------------------------------------------------------- 1 | # Lightning SSL-Vocoder Migration Guide 2 | 3 | ## 変更点 4 | 5 | Miipher-2 の推論機能を HiFiGAN から Lightning SSL-Vocoder に移行しました。 6 | 7 | ### 主な変更 8 | 9 | 1. **HiFiGAN + PreNet** → **Lightning SSL-Vocoder** (Conformer + HiFiGAN統合) 10 | 2. 設定ファイルの簡素化(PreNet設定が不要) 11 | 3. チェックポイント形式の変更(`.pt` → `.ckpt`) 12 | 13 | ## 必要な準備 14 | 15 | ### 1. SSL-Vocoder モデルの準備 16 | 17 | `/home/ayu/GitHub/ssl-vocoders` で学習したモデルを使用してください。 18 | 19 | ```bash 20 | # ssl-vocodersでの学習例 21 | cd /home/ayu/GitHub/ssl-vocoders 22 | python train.py --config config/wavlm_large.yaml 23 | ``` 24 | 25 | ### 2. 設定ファイルの更新 26 | 27 | #### Before (旧HiFiGAN設定): 28 | ```yaml 29 | vocoder_ckpt: "exp/hifigan_pretrain_layer_4/checkpoint_96k.pt" 30 | prenet: 31 | in_dim: 768 32 | n_layers: 4 33 | mel_dim: 80 34 | src_fps: 50.0 35 | tgt_hop: 256 36 | sr: 22050 37 | ``` 38 | 39 | #### After (Lightning SSL-Vocoder設定): 40 | ```yaml 41 | vocoder_ckpt: "/path/to/ssl-vocoder-checkpoint.ckpt" 42 | # PreNet設定は不要(Lightning SSL-Vocoderが内部で処理) 43 | ``` 44 | 45 | ## 使用方法 46 | 47 | ### 単一ファイルの推論 48 | 49 | ```bash 50 | # 設定ファイルのvocoder_ckptを更新 51 | vim configs/infer.yaml 52 | 53 | # 推論実行 54 | uv run python -m miipher_2.infer --config configs/infer.yaml 55 | ``` 56 | 57 | ### バッチ推論 58 | 59 | ```bash 60 | # 設定ファイルのvocoder_ckptを更新 61 | vim configs/infer_dir.yaml 62 | 63 | # バッチ推論実行 64 | uv run python -m miipher_2.infer_dir --config configs/infer_dir.yaml 65 | ``` 66 | 67 | ## 推奨モデル 68 | 69 | SSL-Vocodersで利用可能な事前学習済みモデル: 70 | 71 | | モデル | SSL特徴量 | 品質 | 用途 | 72 | |--------|-----------|------|------| 73 | | `wavlm-large` | WavLM Large | 最高 | 高品質推論 | 74 | | `hubert-base` | HuBERT Base | 良好 | 軽量推論 | 75 | | `wav2vec2-base` | Wav2Vec2.0 Base | 良好 | 汎用 | 76 | 77 | ## 注意事項 78 | 79 | 1. **チェックポイント形式**: `.pt` → `.ckpt` に変更 80 | 2. **PreNet設定削除**: Lightning SSL-Vocoderが内部で処理 81 | 3. **依存関係**: `lightning` パッケージが必要 82 | 4. **SSL特徴量の一致**: FeatureCleanerとSSL-Vocoderで同じSSL特徴量を使用 83 | 84 | ## トラブルシューティング 85 | 86 | ### ImportError: lightning_vocoders 87 | ```bash 88 | # __init__.py と hifigan.py が正しくコピーされているか確認 89 | ls src/miipher_2/lightning_vocoders/ 90 | ``` 91 | 92 | ### CheckpointNotFound 93 | ```bash 94 | # パスが正しいか確認 95 | ls -la /path/to/ssl-vocoder-checkpoint.ckpt 96 | ``` 97 | 98 | ### SSL特徴量の不一致 99 | ```bash 100 | # FeatureCleanerとSSL-Vocoderで同じHuBERTレイヤを使用しているか確認 101 | ``` -------------------------------------------------------------------------------- /results/results_table.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | \usepackage{booktabs} 3 | \usepackage{multirow} 4 | \usepackage{siunitx} 5 | \usepackage[margin=1in]{geometry} 6 | \usepackage{adjustbox} 7 | 8 | \begin{document} 9 | 10 | \begin{table}[htbp] 11 | \centering 12 | \sisetup{ 13 | table-format=1.2, 14 | round-mode=places, 15 | round-precision=2 16 | } 17 | 18 | \caption{音声復元の種類とDNSMOSスコア比較} 19 | 20 | \setlength{\tabcolsep}{4pt} 21 | 22 | \begin{adjustbox}{max width=\linewidth} 23 | \begin{tabular}{ 24 | l 25 | l 26 | *{5}{S[table-format=1.2]} 27 | } 28 | \toprule 29 | \textbf{音声復元の種類} & \textbf{劣化手法} 30 | & \textbf{ecapa cos} & \textbf{dnsmos p808} 31 | & \textbf{dnsmos sig} & \textbf{dnsmos bak} 32 | & \textbf{dnsmos} \\ 33 | \midrule 34 | \multirow{7}{*}{} 35 | original & \multirow{7}{*}{8kHzに変換} 36 | & 1.00 & 3.05 & 3.23 & 3.23 & 2.60 \\ 37 | degraded & 38 | & 0.67 & 2.75 & 3.22 & 3.25 & 2.59 \\ 39 | miipher-1 & 40 | & 0.46 & 3.44 & 3.28 & 4.04 & 3.01 \\ 41 | hubert\_large\_l2 & 42 | & 0.49 & 2.98 & 3.13 & 3.52 & 2.55 \\ 43 | mhubert\_l6 & 44 | & 0.43 & 3.33 & 3.31 & 4.09 & 3.04 \\ 45 | wav2vec2\_base\_l2 & 46 | & 0.52 & 3.26 & 3.19 & 4.04 & 2.89 \\ 47 | wavlm\_base\_l2 & 48 | & 0.52 & 3.31 & 3.21 & 4.05 & 2.91 \\ 49 | \midrule 50 | \multirow{7}{*}{} 51 | original & \multirow{7}{*}{残響・背景雑音} 52 | & 1.00 & 3.05 & 3.23 & 3.23 & 2.60 \\ 53 | degraded & 54 | & 0.95 & 2.80 & 2.45 & 1.98 & 1.79 \\ 55 | miipher-1 & 56 | & 0.70 & 3.54 & 3.35 & 4.06 & 3.08 \\ 57 | hubert\_large\_l2 & 58 | & 0.69 & 3.11 & 3.27 & 3.64 & 2.76 \\ 59 | mhubert\_l6 & 60 | & 0.52 & 3.47 & 3.35 & 4.10 & 3.10 \\ 61 | wav2vec2\_base\_l2 & 62 | & 0.61 & 3.52 & 3.35 & 4.09 & 3.08 \\ 63 | wavlm\_base\_l2 & 64 | & 0.65 & 3.52 & 3.34 & 4.08 & 3.06 \\ 65 | \bottomrule 66 | \end{tabular} 67 | \end{adjustbox} 68 | \end{table} 69 | 70 | \end{document} 71 | -------------------------------------------------------------------------------- /src/miipher_2/utils/ema.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class EMA: 8 | """ 9 | Exponential Moving Average for model parameters. 10 | This class maintains a shadow copy of the model's parameters and updates them 11 | with a decaying average of the current parameters. 12 | """ 13 | 14 | def __init__(self, model: nn.Module, decay: float) -> None: 15 | """ 16 | Args: 17 | model (nn.Module): The model to apply EMA to. 18 | decay (float): The decay factor for the moving average. 19 | """ 20 | self.model = model 21 | self.decay = decay 22 | self.shadow = OrderedDict() 23 | self.backup = OrderedDict() 24 | 25 | def register(self) -> None: 26 | """Register the EMA parameters by creating a shadow copy.""" 27 | for name, param in self.model.named_parameters(): 28 | if param.requires_grad: 29 | self.shadow[name] = param.data.clone() 30 | print(f"[INFO] Registered {len(self.shadow)} parameters for EMA.") 31 | 32 | def update(self) -> None: 33 | """Update the shadow parameters with the current model parameters.""" 34 | for name, param in self.model.named_parameters(): 35 | if param.requires_grad and name in self.shadow: 36 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 37 | self.shadow[name] = new_average 38 | 39 | def apply_shadow(self) -> None: 40 | """Apply the shadow parameters to the model for evaluation.""" 41 | self.backup = OrderedDict() 42 | for name, param in self.model.named_parameters(): 43 | if param.requires_grad and name in self.shadow: 44 | self.backup[name] = param.data 45 | param.data = self.shadow[name] 46 | 47 | def restore(self) -> None: 48 | """Restore the original parameters from the backup.""" 49 | for name, param in self.model.named_parameters(): 50 | if param.requires_grad and name in self.backup: 51 | param.data = self.backup[name] 52 | self.backup = OrderedDict() 53 | 54 | def state_dict(self) -> OrderedDict: 55 | """Return the state dictionary of the shadow parameters.""" 56 | return self.shadow 57 | 58 | def load_state_dict(self, state_dict: OrderedDict) -> None: 59 | """Load the shadow parameters from a state dictionary.""" 60 | self.shadow = state_dict 61 | print("[INFO] Loaded EMA shadow parameters.") 62 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | このファイルは、このリポジトリでコードを作業する際にClaude Code (claude.ai/code) にガイダンスを提供します。 4 | 5 | ## 開発コマンド 6 | 7 | ### 環境セットアップ 8 | ```bash 9 | uv sync 10 | ``` 11 | 12 | ### リンター・フォーマッター 13 | ```bash 14 | # コードの品質チェック 15 | uv run ruff check . 16 | 17 | # コードフォーマット 18 | uv run ruff format . 19 | 20 | # 型チェック 21 | uv run mypy . 22 | ``` 23 | 24 | ### データ前処理 25 | ```bash 26 | # JVSコーパスの前処理 27 | uv run cmd/preprocess.py --config-name preprocess_jvs 28 | 29 | # LibriTTSの前処理 30 | uv run cmd/preprocess.py --config-name preprocess_libritts_r 31 | 32 | # FLEURSの前処理 33 | uv run cmd/preprocess.py --config-name preprocess_fleurs_r 34 | ``` 35 | 36 | ### モデル学習 37 | ```bash 38 | # Parallel Adapter学習 39 | uv run cmd/train_adapter.py --config-name adapter_layer_6_mhubert_147 40 | 41 | # 学習再開(特定のチェックポイントから) 42 | uv run cmd/train_adapter.py checkpoint.resume_from="exp/adapter_layer_6_mhubert_147/checkpoint_199k.pt" --config-name adapter_layer_6_mhubert_147 43 | 44 | # Lightning SSL-Vocoder事前学習 45 | uv run cmd/pre_train_vocoder.py --config-name hifigan_pretrain_layer_6_mhubert_147 46 | ``` 47 | 48 | ### 推論・評価 49 | ```bash 50 | # バッチ推論 51 | uv run cmd/inference_dir.py --config-name infer_dir 52 | 53 | # 評価用劣化音声生成 54 | uv run cmd/degrade.py --clean_dir --noise_dir --out_dir 55 | 56 | # 評価実行 57 | uv run cmd/evaluate.py --clean_dir --degraded_dir --restored_dir --outfile 58 | ``` 59 | 60 | ## アーキテクチャ概要 61 | 62 | ### プロジェクト構造 63 | - `src/miipher_2/`: メインのPythonモジュール 64 | - `cmd/`: CLI エントリーポイント 65 | - `configs/`: Hydra設定ファイル (YAML) 66 | - `exp/`: 学習チェックポイント出力先 67 | 68 | ### 主要コンポーネント 69 | 70 | #### 1. Parallel Adapter (`src/miipher_2/adapters/parallel_adapter.py`) 71 | - HuBERTの特定層に挿入される軽量なフィードフォワードネットワーク 72 | - LayerNorm + Linear + GELU + Linear の構成 73 | - Xavier初期化を0.01スケールで適用 74 | 75 | #### 2. HuBERT Feature Extractor (`src/miipher_2/extractors/hubert.py`) 76 | - 事前学習済みmHuBERT-147モデルを使用 77 | - 指定した層(デフォルト6層)の特徴量を抽出 78 | - Parallel Adapterが特定層に挿入される 79 | 80 | #### 3. Lightning SSL-Vocoder (`src/miipher_2/lightning_vocoders/`) 81 | - HiFi-GANベースの音声合成モデル 82 | - PyTorch Lightningでの実装 83 | - SSL特徴量からメルスペクトログラムを生成 84 | 85 | #### 4. WebDataset Loader (`src/miipher_2/data/webdataset_loader.py`) 86 | - 大規模データセットの効率的な読み込み 87 | - WebDataset形式での並列処理 88 | - 動的なバッチ生成とシャッフル 89 | 90 | #### 5. 学習・推論パイプライン 91 | - `src/miipher_2/train/adapter.py`: Adapter学習ロジック 92 | - `src/miipher_2/utils/infer.py`: 推論実行ロジック 93 | - Wandb統合による学習監視 94 | 95 | ### 設定管理 96 | - Hydra設定を使用してハイパーパラメータを管理 97 | - 主要設定: `configs/adapter_layer_6_mhubert_147.yaml` 98 | - チェックポイント機能による学習再開サポート 99 | 100 | ### データフロー 101 | 1. 音声データをWebDataset形式で前処理 102 | 2. HuBERTで特徴抽出(Parallel Adapter適用) 103 | 3. 抽出した特徴量でFeature Cleanerを学習 104 | 4. Lightning SSL-Vocoderで音声合成 105 | 5. 複数の評価指標で性能測定 -------------------------------------------------------------------------------- /demo/models/miipher2/README.md: -------------------------------------------------------------------------------- 1 | # Miipher-2: Speech Enhancement Model 2 | 3 | Complete speech enhancement system consisting of a Parallel Adapter and Lightning SSL-Vocoder. 4 | 5 | ## Model Components 6 | 7 | ### 1. Parallel Adapter 8 | - **Architecture**: Lightweight feedforward network inserted into mHuBERT-147 9 | - **Target Layer**: Layer 6 10 | - **Hidden Dimension**: 768 11 | - **Training Steps**: 199k 12 | - **File**: `checkpoint_199k_fixed.pt` 13 | 14 | ### 2. Lightning SSL-Vocoder 15 | - **Architecture**: HiFi-GAN based vocoder with PyTorch Lightning 16 | - **Input**: SSL features from enhanced mHuBERT 17 | - **Output**: High-quality audio at 22050Hz 18 | - **Training**: 77 epochs, 137108 steps 19 | - **File**: `epoch=77-step=137108.ckpt` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | from omegaconf import DictConfig 26 | from miipher_2.model.feature_cleaner import FeatureCleaner 27 | from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule 28 | from huggingface_hub import hf_hub_download 29 | 30 | # Download model files 31 | adapter_path = hf_hub_download( 32 | repo_id="YOUR_USERNAME/miipher2", 33 | filename="checkpoint_199k_fixed.pt" 34 | ) 35 | vocoder_path = hf_hub_download( 36 | repo_id="YOUR_USERNAME/miipher2", 37 | filename="epoch=77-step=137108.ckpt" 38 | ) 39 | 40 | # Load models 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | # Feature Cleaner (Adapter) 44 | config = DictConfig({ 45 | "hubert_model_name": "utter-project/mHuBERT-147", 46 | "hubert_layer": 6, 47 | "adapter_hidden_dim": 768 48 | }) 49 | 50 | cleaner = FeatureCleaner(config).to(device).eval() 51 | checkpoint = torch.load(adapter_path, map_location=device, weights_only=False) 52 | cleaner.load_state_dict(checkpoint["model_state_dict"]) 53 | 54 | # Vocoder 55 | vocoder = HiFiGANLightningModule.load_from_checkpoint( 56 | vocoder_path, map_location=device 57 | ).to(device).eval() 58 | 59 | # Inference 60 | with torch.inference_mode(): 61 | # Extract and clean features 62 | enhanced_features = cleaner(input_audio) 63 | 64 | # Generate audio 65 | batch = {"input_feature": enhanced_features.transpose(1, 2)} 66 | restored_audio = vocoder.generator_forward(batch) 67 | ``` 68 | 69 | ## Model Performance 70 | 71 | - **Target**: Speech enhancement from noisy/degraded audio 72 | - **Training Data**: Japanese Voice Speech corpus (JVS) and multilingual datasets 73 | - **Evaluation**: Improved speech quality metrics (STOI, PESQ, etc.) 74 | 75 | ## Files 76 | 77 | - `checkpoint_199k_fixed.pt` (442MB) - Parallel Adapter weights 78 | - `epoch=77-step=137108.ckpt` (1.2GB) - Lightning SSL-Vocoder weights 79 | - `config.json` - Model configuration and metadata 80 | 81 | ## Citation 82 | 83 | ```bibtex 84 | @article{miipher2, 85 | title={Miipher-2: Speech Enhancement with Parallel Adapters}, 86 | author={Miipher-2 Team}, 87 | year={2024} 88 | } 89 | ``` 90 | 91 | ## License 92 | 93 | Apache-2.0 -------------------------------------------------------------------------------- /src/miipher_2/dataset/libritts_r.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class LibriTTSRCorpus: 5 | def __init__(self, root: str): 6 | self.root = Path(root) 7 | self.samples = [] 8 | self._load_samples() 9 | 10 | def _load_samples(self): 11 | splits = [ 12 | "train-clean-100", 13 | "train-clean-360", 14 | "train-other-500", 15 | "dev-clean", 16 | "dev-other", 17 | "test-clean", 18 | "test-other", 19 | ] 20 | 21 | for split in splits: 22 | split_path = self.root / split 23 | if not split_path.exists(): 24 | continue 25 | 26 | for speaker_dir in split_path.glob("*"): 27 | if not speaker_dir.is_dir(): 28 | continue 29 | 30 | speaker_id = speaker_dir.name 31 | 32 | for chapter_dir in speaker_dir.glob("*"): 33 | if not chapter_dir.is_dir(): 34 | continue 35 | 36 | chapter_id = chapter_dir.name 37 | 38 | # Load transcription file 39 | trans_file = chapter_dir / f"{speaker_id}_{chapter_id}.trans.tsv" 40 | if not trans_file.exists(): 41 | continue 42 | 43 | # Parse transcription file 44 | transcriptions = {} 45 | with trans_file.open(encoding="utf-8") as f: 46 | for line in f: 47 | line = line.strip() 48 | if not line: 49 | continue 50 | parts = line.split("\t") 51 | if len(parts) >= 3: 52 | utterance_id = parts[0] 53 | normalized_text = parts[2] # Use normalized text (3rd column) 54 | transcriptions[utterance_id] = normalized_text 55 | 56 | # Find corresponding wav files 57 | for utterance_id, text in transcriptions.items(): 58 | wav_path = chapter_dir / f"{utterance_id}.wav" 59 | if wav_path.exists(): 60 | self.samples.append( 61 | { 62 | "wav_path": str(wav_path), 63 | "speaker": speaker_id, 64 | "clean_text": text, 65 | "basename": utterance_id, 66 | "lang_code": "eng", 67 | } 68 | ) 69 | 70 | def __getitem__(self, index: int) -> dict[str, str]: 71 | return self.samples[index] 72 | 73 | def __len__(self) -> int: 74 | return len(self.samples) 75 | 76 | @property 77 | def speaker_dict(self) -> dict[str, int]: 78 | speakers = sorted(set(sample["speaker"] for sample in self.samples)) 79 | return {speaker: idx for idx, speaker in enumerate(speakers)} 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "miipher-2" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "accelerate>=1.7.0", 9 | "datasets>=3.6.0", 10 | "hydra-core>=1.3.2", 11 | "jiwer>=4.0.0", 12 | "librosa>=0.11.0", 13 | "lightning>=2.5.1.post0", 14 | "matplotlib>=3.10.3", 15 | "numpy>=2.2.6", 16 | "peft>=0.15.2", 17 | "pillow>=11.2.1", 18 | "pyroomacoustics>=0.8.4", 19 | "pysptk>=1.0.1", 20 | "pyworld>=0.3.5", 21 | "seaborn>=0.13.2", 22 | "soundfile>=0.13.1", 23 | "speechbrain>=1.0.0", 24 | "torch>=2.7.0", 25 | "torchaudio>=2.7.0", 26 | "torchinfo>=1.8.0", 27 | "torchmetrics[audio]>=1.0.0", 28 | "tqdm>=4.67.1", 29 | "transformers>=4.53.1", 30 | "trl[peft]>=0.18.1", 31 | "wandb>=0.19.11", 32 | "webdataset>=0.2.111", 33 | "numba>=0.61.2", 34 | ] 35 | 36 | [build-system] 37 | requires = ["setuptools>=68", "wheel"] 38 | build-backend = "setuptools.build_meta" 39 | 40 | [tool.setuptools.package-dir] 41 | "" = "src" 42 | 43 | [tool.setuptools.packages.find] 44 | where = ["src"] 45 | 46 | [dependency-groups] 47 | dev = [ 48 | "ruff>=0.11.0", 49 | "ty>=0.0.1a10", 50 | ] 51 | 52 | [tool.ruff] 53 | line-length = 120 54 | 55 | [tool.ruff.format] 56 | docstring-code-format = true 57 | 58 | [tool.ruff.lint] 59 | select = ["ALL"] 60 | ignore = [ 61 | "T201", 62 | "COM812", 63 | "ISC001", 64 | "PGH003", 65 | "FBT003", 66 | "C901", 67 | "PLR0915", 68 | "PLR0913", 69 | ] 70 | unfixable = [ 71 | "F401", 72 | "F841", 73 | ] 74 | 75 | pydocstyle.convention = "google" 76 | 77 | [tool.ruff.lint.per-file-ignores] 78 | "*.py" = [ 79 | "D", 80 | "S101", 81 | "N802", 82 | "ARG", 83 | "S311", 84 | "S301", 85 | ] 86 | "__init__.py" = [ 87 | "F401", 88 | ] 89 | 90 | [tool.ruff.lint.pylint] 91 | max-args = 6 92 | 93 | 94 | [tool.mypy] 95 | python_version = "3.11" 96 | warn_return_any = true 97 | warn_unused_configs = true 98 | disallow_untyped_defs = true 99 | disallow_incomplete_defs = true 100 | check_untyped_defs = true 101 | disallow_untyped_decorators = true 102 | no_implicit_optional = true 103 | warn_redundant_casts = true 104 | warn_unused_ignores = true 105 | warn_no_return = true 106 | warn_unreachable = true 107 | ignore_missing_imports = true 108 | 109 | [tool.pyrefly] 110 | 111 | #### configuring what to type check and where to import from 112 | project_includes = ["."] 113 | project_excludes = ["**/.[!/.]*", "**/tests"] 114 | search_path = ["."] 115 | import_root = ["."] 116 | site_package_path = [".venv/lib/python3.12/site-packages"] 117 | 118 | #### configuring your python environment 119 | python_platform = "linux" 120 | python_version = "3.12" 121 | python_interpreter = ".venv/bin/python3" 122 | 123 | #### configuring your type check settings 124 | ignore_errors_in_generated_code = true 125 | use_untyped_imports = true 126 | ignore_missing_source = true 127 | 128 | [tool.pyrefly.errors] 129 | bad-assignment = false 130 | invalid-argument = false 131 | 132 | [tool.uv.workspace] 133 | members = [ 134 | "miipher2-hf", 135 | ] 136 | 137 | -------------------------------------------------------------------------------- /auto_resume_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Configuration 4 | CONFIG_NAME="${1:-adapter_l2}" # Default to adapter_l2 if not provided 5 | CHECKPOINT_DIR="exp/${CONFIG_NAME}" 6 | LOG_FILE="${CHECKPOINT_DIR}/training.log" 7 | 8 | # Colors for output 9 | RED='\033[0;31m' 10 | GREEN='\033[0;32m' 11 | YELLOW='\033[1;33m' 12 | NC='\033[0m' # No Color 13 | 14 | echo -e "${GREEN}Auto-resume training script for config: ${CONFIG_NAME}${NC}" 15 | echo -e "${GREEN}Checkpoint directory: ${CHECKPOINT_DIR}${NC}" 16 | 17 | # Create checkpoint directory if it doesn't exist 18 | mkdir -p "${CHECKPOINT_DIR}" 19 | 20 | # Function to find the latest checkpoint 21 | find_latest_checkpoint() { 22 | local latest_checkpoint="" 23 | local latest_step=0 24 | 25 | # Find all checkpoint files and extract the one with highest step number 26 | for checkpoint in "${CHECKPOINT_DIR}"/checkpoint_*.pt; do 27 | if [ -f "$checkpoint" ]; then 28 | # Extract step number from filename (e.g., checkpoint_5k.pt -> 5) 29 | step=$(basename "$checkpoint" | sed -n 's/checkpoint_\([0-9]*\)k\.pt/\1/p') 30 | if [ -n "$step" ] && [ "$step" -gt "$latest_step" ]; then 31 | latest_step=$step 32 | latest_checkpoint=$checkpoint 33 | fi 34 | fi 35 | done 36 | 37 | echo "$latest_checkpoint" 38 | } 39 | 40 | # Function to run training 41 | run_training() { 42 | local checkpoint_arg="" 43 | local latest_checkpoint=$(find_latest_checkpoint) 44 | 45 | if [ -n "$latest_checkpoint" ]; then 46 | echo -e "${YELLOW}Found checkpoint: ${latest_checkpoint}${NC}" 47 | checkpoint_arg="checkpoint.resume_from=\"${latest_checkpoint}\"" 48 | else 49 | echo -e "${YELLOW}No checkpoint found, starting from scratch${NC}" 50 | fi 51 | 52 | # Construct and run the training command 53 | local cmd="uv run cmds/train_adapter.py ${checkpoint_arg} --config-name ${CONFIG_NAME}" 54 | echo -e "${GREEN}Running: ${cmd}${NC}" 55 | 56 | # Execute the command and capture exit status 57 | eval $cmd 2>&1 | tee -a "$LOG_FILE" 58 | return ${PIPESTATUS[0]} 59 | } 60 | 61 | # Main loop with retry logic 62 | MAX_RETRIES=100 # Prevent infinite loops 63 | retry_count=0 64 | wait_time=30 # Wait time in seconds before retry 65 | 66 | while [ $retry_count -lt $MAX_RETRIES ]; do 67 | echo -e "\n${GREEN}[Attempt $((retry_count + 1))/${MAX_RETRIES}] Starting training...${NC}" 68 | 69 | # Run training 70 | run_training 71 | exit_code=$? 72 | 73 | if [ $exit_code -eq 0 ]; then 74 | echo -e "\n${GREEN}Training completed successfully!${NC}" 75 | break 76 | else 77 | echo -e "\n${RED}Training crashed with exit code: ${exit_code}${NC}" 78 | echo -e "${YELLOW}Waiting ${wait_time} seconds before retry...${NC}" 79 | 80 | # Log crash information 81 | echo "[$(date)] Training crashed, exit code: ${exit_code}" >> "$LOG_FILE" 82 | 83 | # Wait before retry 84 | sleep $wait_time 85 | 86 | # Increment retry counter 87 | ((retry_count++)) 88 | 89 | # Check if we've reached max retries 90 | if [ $retry_count -ge $MAX_RETRIES ]; then 91 | echo -e "${RED}Maximum retries (${MAX_RETRIES}) reached. Exiting.${NC}" 92 | exit 1 93 | fi 94 | fi 95 | done 96 | 97 | echo -e "${GREEN}Script finished.${NC}" -------------------------------------------------------------------------------- /src/miipher_2/preprocess/preprocessor.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import pathlib 4 | 5 | import hydra 6 | import torch 7 | import torchaudio 8 | import tqdm 9 | import webdataset 10 | from omegaconf import DictConfig 11 | from torch.utils.data import DataLoader 12 | 13 | from miipher_2.preprocess import DegradationApplier 14 | 15 | 16 | class Preprocessor: 17 | """ 18 | Preprocess dataset 19 | """ 20 | 21 | def __init__(self, cfg: DictConfig) -> None: 22 | """ 23 | Args: 24 | cfg: hydra config 25 | """ 26 | self.cfg = cfg 27 | self.dataset = hydra.utils.instantiate(cfg.preprocess.preprocess_dataset) 28 | self.sampling_rate = self.cfg.sampling_rate 29 | self.degradation_model = DegradationApplier(cfg.preprocess.degradation) 30 | self.text2phone_dict: dict[str, str] = {} 31 | self.n_repeats = cfg.preprocess.n_repeats 32 | 33 | @torch.inference_mode() # type: ignore 34 | def process_utterance( 35 | self, 36 | basename: str, 37 | audio_file_path: pathlib.Path, 38 | lang_code: str, 39 | ) -> list[dict[str, bytes | str]]: 40 | orig_waveform, orig_sample_rate = torchaudio.load(audio_file_path) 41 | 42 | waveform: torch.Tensor = torchaudio.functional.resample( 43 | orig_waveform, orig_sample_rate, new_freq=self.sampling_rate 44 | )[0] # remove channel dimension only support mono 45 | 46 | with audio_file_path.open(mode="rb") as f: 47 | wav_bytes = f.read() 48 | samples: list[dict[str, bytes | str]] = [] 49 | for i in range(self.n_repeats): 50 | degraded_speech = self.apply_noise(waveform) 51 | buff = io.BytesIO() 52 | torchaudio.save( 53 | buff, 54 | src=degraded_speech.unsqueeze(0), 55 | sample_rate=self.sampling_rate, 56 | format="wav", 57 | ) 58 | buff.seek(0) 59 | 60 | sample = { 61 | "__key__": basename + f"_{i}", 62 | "speech.wav": wav_bytes, 63 | "degraded_speech.wav": buff.read(), 64 | "resampled_speech.pth": webdataset.torch_dumps(waveform), 65 | } 66 | samples.append(sample) 67 | return samples 68 | 69 | def apply_noise(self, waveform: torch.Tensor) -> torch.Tensor: 70 | return self.degradation_model.process(waveform, self.sampling_rate) 71 | 72 | def build_from_path(self) -> None: 73 | pathlib.Path("/".join(self.cfg.preprocess.train_tar_sink.pattern.split("/")[:-1])).mkdir(exist_ok=True) 74 | train_sink = hydra.utils.instantiate(self.cfg.preprocess.train_tar_sink) 75 | val_sink = hydra.utils.instantiate(self.cfg.preprocess.val_tar_sink) 76 | cpu_count = os.cpu_count() 77 | num_workers: int = cpu_count if cpu_count is not None else 8 78 | dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=num_workers) 79 | for idx, data in enumerate(tqdm.tqdm(dataloader)): 80 | basename = data["basename"][0] 81 | wav_path = data["wav_path"][0] 82 | lang_code = data["lang_code"][0] 83 | result = self.process_utterance(basename, pathlib.Path(wav_path), lang_code) 84 | sink = train_sink if idx >= self.cfg.preprocess.val_size else val_sink 85 | for sample in result: 86 | sink.write(sample) 87 | train_sink.close() 88 | val_sink.close() 89 | -------------------------------------------------------------------------------- /cmds/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate_speech_restoration.py 4 | ============================== 5 | Clean / Degraded / Restored から指標を算出し CSV 出力 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import argparse 11 | import time 12 | import warnings 13 | from pathlib import Path 14 | 15 | import pandas as pd 16 | import torch 17 | import torchaudio 18 | from tqdm.auto import tqdm 19 | 20 | import miipher_2.utils.eval_utils as U 21 | 22 | log = U.get_logger("eval") 23 | 24 | 25 | def load_wav(path: Path, sr: int): 26 | wav, s = torchaudio.load(path) 27 | if s != sr: 28 | wav = torchaudio.functional.resample(wav, s, sr) 29 | return wav 30 | 31 | 32 | def main() -> None: 33 | ap = argparse.ArgumentParser() 34 | ap.add_argument("--clean_dir", required=True, type=Path) 35 | ap.add_argument("--degraded_dir", required=True, type=Path) 36 | ap.add_argument("--restored_dir", required=True, type=Path) 37 | ap.add_argument("--outfile", default=Path("eval_results.csv"), type=Path) 38 | ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") 39 | ap.add_argument("--sr", type=int, default=16000) 40 | ap.add_argument("--log_every", type=int, default=10, help="log every N files") 41 | args = ap.parse_args() 42 | 43 | t0 = time.time() 44 | log.info("Loading ASR & speaker models (this may take a while)…") 45 | asr_model, asr_proc = U.load_asr(args.device) 46 | xvec, ecapa = U.load_spk_models(args.device) 47 | log.info(f"Models loaded in {time.time() - t0:.1f} s") 48 | 49 | rows, start = [], time.time() 50 | clean_files = sorted(args.clean_dir.glob("**/*.wav")) 51 | tot = len(clean_files) 52 | log.info(f"Start evaluation on {tot} files") 53 | 54 | for i, cl_path in enumerate(clean_files, 1): 55 | rel = cl_path.relative_to(args.clean_dir) 56 | deg_path = args.degraded_dir / rel 57 | res_path = args.restored_dir / rel 58 | if not deg_path.exists() or not res_path.exists(): 59 | warnings.warn(f"skip {rel} (missing degraded/restored)", stacklevel=2) 60 | continue 61 | 62 | cl, deg, res = (load_wav(p, args.sr) for p in (cl_path, deg_path, res_path)) 63 | 64 | rows.append( 65 | { 66 | "file": str(rel), 67 | "MCD": U.mcd(cl, res, args.sr), 68 | "XvecCos": U.speaker_cos(cl.to(args.device), res.to(args.device), args.sr, xvec), 69 | "ECAPACos": U.speaker_cos(cl.to(args.device), res.to(args.device), args.sr, ecapa), 70 | "WER": U.asr_wer(cl.to(args.device), res.to(args.device), args.sr, asr_model, asr_proc, args.device), 71 | "logF0_RMSE": U.logf0_rmse(cl, res, args.sr), 72 | # 劣化比較 73 | "Deg_MCD": U.mcd(cl, deg, args.sr), 74 | "Deg_XvecCos": U.speaker_cos(cl.to(args.device), deg.to(args.device), args.sr, xvec), 75 | "Deg_ECAPACos": U.speaker_cos(cl.to(args.device), deg.to(args.device), args.sr, ecapa), 76 | "Deg_WER": U.asr_wer( 77 | cl.to(args.device), deg.to(args.device), args.sr, asr_model, asr_proc, args.device 78 | ), 79 | "Deg_logF0_RMSE": U.logf0_rmse(cl, deg, args.sr), 80 | } 81 | ) 82 | 83 | if i % args.log_every == 0 or i == tot: 84 | elapsed = time.time() - start 85 | log.info(f"{i}/{tot} files evaluated | elapsed {elapsed / 60:.1f} min") 86 | 87 | df = pd.DataFrame(rows) 88 | df.to_csv(args.outfile, index=False) 89 | log.info(f"CSV saved to {args.outfile}") 90 | log.info(df.describe().loc[["mean", "std", "min", "max"]].to_string()) 91 | log.info("Done") 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /scripts/upload_to_hf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Hugging Face Hubにmiipher-2モデルをアップロードするスクリプト 4 | """ 5 | 6 | import os 7 | import shutil 8 | from pathlib import Path 9 | from huggingface_hub import HfApi, create_repo, upload_folder 10 | import argparse 11 | 12 | def setup_model_repo(model_dir: Path, repo_id: str): 13 | """モデルリポジトリをセットアップしてアップロード""" 14 | 15 | # Hugging Face APIを初期化 16 | api = HfApi() 17 | 18 | print(f"Creating repository: {repo_id}") 19 | try: 20 | create_repo( 21 | repo_id=repo_id, 22 | repo_type="model", 23 | exist_ok=True, 24 | private=False 25 | ) 26 | print("✅ Repository created/verified") 27 | except Exception as e: 28 | print(f"Repository creation: {e}") 29 | 30 | # モデルフォルダをアップロード 31 | print(f"Uploading model files from {model_dir}") 32 | try: 33 | upload_folder( 34 | folder_path=str(model_dir), 35 | repo_id=repo_id, 36 | repo_type="model", 37 | commit_message="Upload Miipher-2 complete model (Adapter + Vocoder)" 38 | ) 39 | print("✅ Model uploaded successfully!") 40 | except Exception as e: 41 | print(f"❌ Upload failed: {e}") 42 | return False 43 | 44 | return True 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser(description="Upload Miipher-2 model to Hugging Face Hub") 48 | parser.add_argument( 49 | "--model-dir", 50 | type=str, 51 | default="models/miipher2", 52 | help="Path to model directory" 53 | ) 54 | parser.add_argument( 55 | "--repo-id", 56 | type=str, 57 | default="Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1", 58 | help="Hugging Face repository ID" 59 | ) 60 | parser.add_argument( 61 | "--check-files", 62 | action="store_true", 63 | help="Check if required files exist before upload" 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | model_dir = Path(args.model_dir) 69 | 70 | # モデルディレクトリの存在確認 71 | if not model_dir.exists(): 72 | print(f"❌ Model directory not found: {model_dir}") 73 | return 74 | 75 | # 必要なファイルの確認 76 | required_files = [ 77 | "config.json", 78 | "README.md", 79 | "checkpoint_199k_fixed.pt", 80 | "epoch=77-step=137108.ckpt" 81 | ] 82 | 83 | print("📋 Checking required files...") 84 | missing_files = [] 85 | for file in required_files: 86 | file_path = model_dir / file 87 | if file_path.exists(): 88 | size_mb = file_path.stat().st_size / (1024 * 1024) 89 | print(f" ✅ {file} ({size_mb:.1f}MB)") 90 | else: 91 | print(f" ❌ {file} (missing)") 92 | missing_files.append(file) 93 | 94 | if missing_files: 95 | print(f"❌ Missing files: {missing_files}") 96 | if not args.check_files: 97 | print("Use --check-files to only check without uploading") 98 | return 99 | 100 | if args.check_files: 101 | print("📋 File check completed") 102 | return 103 | 104 | # Hugging Face tokenの確認 105 | hf_token = os.getenv("HF_TOKEN") 106 | if not hf_token: 107 | print("❌ HF_TOKEN environment variable not set") 108 | print("Please run: export HF_TOKEN=your_token_here") 109 | return 110 | 111 | print(f"🚀 Starting upload to {args.repo_id}") 112 | success = setup_model_repo(model_dir, args.repo_id) 113 | 114 | if success: 115 | print(f"🎉 Model successfully uploaded to: https://huggingface.co/{args.repo_id}") 116 | else: 117 | print("❌ Upload failed") 118 | 119 | if __name__ == "__main__": 120 | main() -------------------------------------------------------------------------------- /src/miipher_2/dataset/fleurs_r.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FleursRCorpus(Dataset): 9 | """ 10 | fleurs-r コーパス用のPyTorch Datasetクラス。 11 | 複数の言語やサブセット(train, dev, test)を横断してデータをロードできます。 12 | """ 13 | 14 | def __init__(self, root: str, subset: str | list[str] = "all", language: str | list[str] = "all") -> None: 15 | """ 16 | Args: 17 | root (str): fleurs-r データセットのルートディレクトリ。 18 | subset (str | List[str]): 使用するサブセット。"train", "dev", "test"のいずれか、 19 | そのリスト、または"all"を指定。デフォルトは"all"。 20 | language (str | List[str]): 使用する言語 (例: "ja_jp")。 21 | そのリスト、または"all"を指定。デフォルトは"all"。 22 | """ 23 | super().__init__() 24 | self.root = Path(root).resolve() 25 | self.samples: list[dict[str, str]] = [] 26 | data_path = self.root / "data" 27 | 28 | # 1. 処理対象の言語を決定 29 | if language == "all": 30 | # xx_xx または xxx_xxxx 形式のディレクトリ名を持つ言語ディレクトリをすべて取得 31 | lang_dirs = [ 32 | p for p in data_path.iterdir() if p.is_dir() and re.match(r"^[a-z]{2,3}(_[a-z]{2,4}){1,2}$", p.name) 33 | ] 34 | target_languages = sorted([p.name for p in lang_dirs]) 35 | elif isinstance(language, list): 36 | target_languages = language 37 | else: 38 | target_languages = [language] 39 | 40 | # 2. 処理対象のサブセットを決定 41 | if subset == "all": 42 | target_subsets = ["train", "dev", "test"] 43 | elif isinstance(subset, list): 44 | target_subsets = subset 45 | else: 46 | target_subsets = [subset] 47 | 48 | print(f"Target languages: {len(target_languages)}") 49 | print(f"Target subsets: {target_subsets}") 50 | 51 | # 3. 全ての対象データをループで読み込み 52 | for lang in target_languages: 53 | for sub in target_subsets: 54 | lang_dir = data_path / lang 55 | tsv_path = lang_dir / f"{sub}.tsv" 56 | audio_dir = lang_dir / "audio" / sub / sub 57 | 58 | if not (tsv_path.exists() and audio_dir.exists()): 59 | continue 60 | 61 | print(f"Processing: {lang}/{sub}") 62 | try: 63 | metadata = pd.read_csv( 64 | tsv_path, 65 | sep="\t", 66 | header=None, 67 | usecols=[0, 1, 2], 68 | names=["speaker_id", "wav_name", "raw_text"], 69 | on_bad_lines="warn", 70 | ) 71 | except Exception as e: 72 | print(f"Warning: Could not read TSV {tsv_path}: {e}") 73 | continue 74 | 75 | lang_code_out = "jpn" if lang == "ja_jp" else lang.split("_")[0] 76 | 77 | for _, row in metadata.iterrows(): 78 | speaker = str(row["speaker_id"]) 79 | wav_name = row["wav_name"] 80 | wav_path = audio_dir / wav_name 81 | 82 | if wav_path.is_file(): 83 | self.samples.append( 84 | { 85 | "wav_path": str(wav_path), 86 | "speaker": speaker, 87 | "clean_text": row["raw_text"], 88 | "basename": f"{lang}_{sub}_{speaker}_{wav_path.stem}", 89 | "lang_code": lang_code_out, 90 | } 91 | ) 92 | 93 | print(f"Total samples found: {len(self.samples)}") 94 | 95 | def __getitem__(self, index: int) -> dict[str, str]: 96 | return self.samples[index] 97 | 98 | def __len__(self) -> int: 99 | return len(self.samples) 100 | 101 | @property 102 | def speaker_dict(self) -> dict[str, int]: 103 | """一意の話者IDと整数のインデックスをマッピングした辞書を返します。""" 104 | if not hasattr(self, "_speaker_dict"): 105 | speakers = sorted({str(sample["speaker"]) for sample in self.samples}) 106 | self._speaker_dict = {speaker: idx for idx, speaker in enumerate(speakers)} 107 | return self._speaker_dict 108 | -------------------------------------------------------------------------------- /docs/checkpoint_guide.md: -------------------------------------------------------------------------------- 1 | # Checkpoint機能 2 | 3 | Miipher-2の学習におけるチェックポイント機能の使用方法について説明します。 4 | 5 | ## 概要 6 | 7 | チェックポイント機能により、以下が可能になります: 8 | 9 | - **自動保存**: 1,000ステップごとに学習状態を自動保存 10 | - **学習再開**: 中断された学習を正確な状態から再開 11 | - **Wandb連携**: 学習履歴の連続性を保持 12 | - **自動クリーンアップ**: 古いチェックポイントの自動削除 13 | 14 | ## 設定 15 | 16 | ### Adapter学習の設定 (`configs/adapter.yaml`) 17 | 18 | ```yaml 19 | # Checkpoint configuration 20 | checkpoint: 21 | save_interval: 1000 # 1kステップごとにチェックポイント保存 22 | resume_from: null # 再開用チェックポイントパス 23 | keep_last_n: 5 # 最新N個のチェックポイントを保持 24 | save_wandb_metadata: true # wandb情報も保存 25 | 26 | # Wandb logging configuration 27 | wandb: 28 | enabled: true 29 | project: "miipher-2-adapter" 30 | entity: null 31 | name: null 32 | tags: ["adapter", "training"] 33 | notes: "Parallel Adapter training for Miipher-2" 34 | log_model: false 35 | ``` 36 | 37 | ### HiFi-GAN学習の設定 (`configs/hifigan_finetune.yaml`) 38 | 39 | ```yaml 40 | # Checkpoint configuration 41 | checkpoint: 42 | save_interval: 1000 # 1kステップごとにチェックポイント保存 43 | resume_from: null # 再開用チェックポイントパス 44 | keep_last_n: 5 # 最新N個のチェックポイントを保持 45 | save_wandb_metadata: true # wandb情報も保存 46 | 47 | # Wandb logging configuration 48 | wandb: 49 | enabled: true 50 | project: "miipher-2-hifigan" 51 | entity: null 52 | name: null 53 | tags: ["hifigan", "vocoder", "finetune"] 54 | notes: "HiFi-GAN fine-tuning for Miipher-2" 55 | log_model: true 56 | log_audio: true 57 | ``` 58 | 59 | ## 使用方法 60 | 61 | ### 1. 通常の学習開始 62 | 63 | ```bash 64 | # Adapter学習 65 | uv run cmd/train_adapter.py 66 | 67 | # HiFi-GAN学習 68 | uv run cmd/train_vocoder.py 69 | ``` 70 | 71 | ### 2. 特定のチェックポイントから再開 72 | 73 | ```bash 74 | # Adapter学習 75 | uv run cmd/train_adapter.py checkpoint.resume_from="exp/adapter/checkpoint_5k.pt" 76 | 77 | # HiFi-GAN学習 78 | uv run cmd/train_vocoder.py checkpoint.resume_from="exp/hifigan_ft/checkpoint_10k.pt" 79 | ``` 80 | 81 | ### 3. 自動再開スクリプトの使用 82 | 83 | 最新のチェックポイントを自動で見つけて再開: 84 | 85 | ```bash 86 | # Adapter学習の自動再開 87 | bash scripts/auto_resume_adapter.sh 88 | 89 | # HiFi-GAN学習の自動再開 90 | bash scripts/auto_resume_hifigan.sh 91 | ``` 92 | 93 | 94 | 95 | ## チェックポイントファイルの構造 96 | 97 | ### Adapter学習のチェックポイント 98 | 99 | ```python 100 | { 101 | 'step': 5000, # 現在のステップ数 102 | 'model_state_dict': {...}, # モデルの状態 103 | 'optimizer_state_dict': {...}, # オプティマイザの状態 104 | 'scheduler_state_dict': {...}, # スケジューラの状態 105 | 'wandb_run_id': 'abc123', # Wandb Run ID 106 | 'wandb_run_name': 'adapter_run_1', # Wandb Run名 107 | 'wandb_project': 'miipher-2-adapter', # Wandbプロジェクト名 108 | 'config': {...}, # 学習設定 109 | 'random_states': { # 乱数状態(再現性のため) 110 | 'python': ..., 111 | 'numpy': ..., 112 | 'torch': ..., 113 | 'torch_cuda': ... 114 | } 115 | } 116 | ``` 117 | 118 | ### HiFi-GAN学習のチェックポイント 119 | 120 | ```python 121 | { 122 | 'step': 10000, # 現在のステップ数 123 | 'model_state_dict': {...}, # Generator状態 124 | 'optimizer_state_dict': {...}, # Generator Optimizer状態 125 | 'mpd_state_dict': {...}, # Multi-Period Discriminator状態 126 | 'msd_state_dict': {...}, # Multi-Scale Discriminator状態 127 | 'opt_d_state_dict': {...}, # Discriminator Optimizer状態 128 | 'scaler_state_dict': {...}, # AMP Scaler状態 129 | 'wandb_run_id': 'def456', # Wandb Run ID 130 | 'wandb_run_name': 'hifigan_run_1', # Wandb Run名 131 | 'config': {...}, # 学習設定 132 | 'random_states': {...} # 乱数状態 133 | } 134 | ``` 135 | 136 | ## Wandb連携 137 | 138 | ### 自動ID継承(デフォルト) 139 | 140 | チェックポイントを指定して学習を再開すると: 141 | 142 | - **自動的にWandb IDが継承**される 143 | - 同じWandb RunIDで学習を継続 144 | - 学習曲線が途切れない 145 | - メトリクスの連続性が保たれる 146 | 147 | ### 設定の互換性チェック 148 | 149 | 学習再開時に重要なパラメータの変更を自動検出: 150 | 151 | ``` 152 | [WARNING] Configuration changes detected: 153 | - optim.lr: 0.0002 -> 0.0001 154 | - batch_size: 16 -> 32 155 | These changes may affect training consistency. 156 | ``` 157 | 158 | チェックされるパラメータ: 159 | - `model`: モデル構造 160 | - `optim.lr`: 学習率 161 | - `batch_size`: バッチサイズ 162 | - `dataset.num_examples`: データセット例数 163 | - `steps`: ステップ数(HiFi-GAN) 164 | 165 | ## ファイル管理 166 | 167 | ### 自動クリーンアップ 168 | 169 | `keep_last_n: 5`の設定により: 170 | 171 | - 最新5個のチェックポイントのみ保持 172 | - 古いチェックポイントは自動削除 173 | - ディスク容量の節約 174 | 175 | ### ファイル命名規則 176 | 177 | ``` 178 | exp/adapter/checkpoint_1k.pt # 1,000ステップ 179 | exp/adapter/checkpoint_2k.pt # 2,000ステップ 180 | exp/adapter/checkpoint_3k.pt # 3,000ステップ 181 | ... 182 | ``` 183 | 184 | ## 高度な使用方法 185 | 186 | ### カスタムチェックポイント間隔 187 | 188 | ```yaml 189 | checkpoint: 190 | save_interval: 500 # 500ステップごとに保存 191 | ``` 192 | 193 | ### 特定のチェックポイントのみ保持 194 | 195 | ```yaml 196 | checkpoint: 197 | keep_last_n: 10 # 最新10個を保持 198 | ``` 199 | 200 | ### Wandbメタデータの無効化 201 | 202 | ```yaml 203 | checkpoint: 204 | save_wandb_metadata: false 205 | ``` 206 | -------------------------------------------------------------------------------- /src/miipher_2/utils/infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from tqdm import tqdm 6 | 7 | from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule 8 | from miipher_2.model.feature_cleaner import FeatureCleaner 9 | from miipher_2.utils.audio import load, save 10 | 11 | 12 | @torch.inference_mode() 13 | def run_inference(cfg: DictConfig) -> None: 14 | """ 15 | 設定ファイルに基づいてMiipher-2の音声修復推論を実行する 16 | 17 | Args: 18 | cfg (DictConfig): Hydraによって読み込まれた設定オブジェクト 19 | """ 20 | device = torch.device(cfg.device) 21 | 22 | # 1. FeatureCleaner 23 | print("Loading FeatureCleaner model...") 24 | cleaner = FeatureCleaner(cfg.model).to(device).eval() 25 | adapter_checkpoint = torch.load(cfg.adapter_ckpt, map_location=device, weights_only=False) 26 | cleaner.load_state_dict(adapter_checkpoint["model_state_dict"]) 27 | print("FeatureCleaner model loaded.") 28 | 29 | # 2. Vocoder (Lightning SSL-Vocoder) 30 | print("Loading Lightning SSL-Vocoder...") 31 | vocoder_ckpt_path = pathlib.Path(cfg.vocoder_ckpt) 32 | if not vocoder_ckpt_path.exists(): 33 | msg = f"Vocoder checkpoint not found at: {vocoder_ckpt_path}" 34 | raise FileNotFoundError(msg) 35 | 36 | vocoder = HiFiGANLightningModule.load_from_checkpoint( 37 | vocoder_ckpt_path, map_location=device 38 | ).to(device).eval() 39 | print("Lightning SSL-Vocoder loaded.") 40 | 41 | # 3. 音声ファイルを読み込み、推論実行 42 | print(f"Processing input file: {cfg.input_wav}") 43 | input_wav = load(cfg.input_wav).to(device) 44 | 45 | with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): 46 | cleaned_features = cleaner(input_wav) 47 | # Lightning SSL-Vocoderの入力形式に合わせる (batch, seq_len, input_channels) 48 | batch = {"input_feature": cleaned_features.transpose(1, 2)} 49 | restored_wav = vocoder.generator_forward(batch) 50 | 51 | # 4. 修復された音声を保存 52 | output_path = pathlib.Path(cfg.output_wav) 53 | output_path.parent.mkdir(parents=True, exist_ok=True) 54 | 55 | # 【修正点】 .squeeze(0) でバッチ次元を削除 56 | save(output_path, restored_wav.squeeze(0).cpu().to(torch.float32), sr=cfg.output_sampling_rate) 57 | 58 | print(f"Restored audio saved to: {output_path}") 59 | 60 | 61 | @torch.inference_mode() 62 | def run_inference_dir(cfg: DictConfig) -> None: 63 | """ 64 | ディレクトリ内の全ての音声ファイルに対して一括で音声修復推論を実行する 65 | 66 | Args: 67 | cfg (DictConfig): Hydraによって読み込まれた設定オブジェクト 68 | """ 69 | device = torch.device(cfg.device) 70 | 71 | # 1. モデルの読み込み (ループの外で一度だけ行います) 72 | print("Loading models...") 73 | cleaner = FeatureCleaner(cfg.model).to(device).eval() 74 | adapter_checkpoint = torch.load(cfg.adapter_ckpt, map_location=device, weights_only=False) 75 | cleaner.load_state_dict(adapter_checkpoint["model_state_dict"]) 76 | 77 | vocoder_ckpt_path = pathlib.Path(cfg.vocoder_ckpt) 78 | if not vocoder_ckpt_path.exists(): 79 | msg = f"Vocoder checkpoint not found at: {vocoder_ckpt_path}" 80 | raise FileNotFoundError(msg) 81 | 82 | vocoder = HiFiGANLightningModule.load_from_checkpoint( 83 | vocoder_ckpt_path, map_location=device 84 | ).to(device).eval() 85 | print("Models loaded successfully.") 86 | 87 | # 2. 入力ファイルリストを作成 88 | input_dir = pathlib.Path(cfg.input_dir) 89 | output_dir = pathlib.Path(cfg.output_dir) 90 | 91 | if not input_dir.is_dir(): 92 | print(f"Error: Input directory not found at '{input_dir}'") 93 | return 94 | 95 | audio_files = [] 96 | for ext in cfg.extensions: 97 | # rglobでサブディレクトリも再帰的に検索 98 | audio_files.extend(input_dir.rglob(f"*{ext}")) 99 | 100 | if not audio_files: 101 | print(f"No audio files found in '{input_dir}' with extensions {cfg.extensions}") 102 | return 103 | 104 | print(f"Found {len(audio_files)} files. Starting batch inference...") 105 | 106 | # 3. 各ファイルに対してループ処理を実行 107 | for input_path in tqdm(audio_files, desc="Processing files"): 108 | try: 109 | relative_path = input_path.relative_to(input_dir) 110 | output_path = output_dir / relative_path 111 | 112 | # 出力ディレクトリを作成 113 | output_path.parent.mkdir(parents=True, exist_ok=True) 114 | 115 | # 3a. 音声ファイルを読み込み 116 | input_wav = load(input_path).to(device) 117 | 118 | # 3b. 推論実行 119 | with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): 120 | cleaned_features = cleaner(input_wav) 121 | # Lightning SSL-Vocoderの入力形式に合わせる (batch, seq_len, input_channels) 122 | batch = {"input_feature": cleaned_features.transpose(1, 2)} 123 | restored_wav = vocoder.generator_forward(batch) 124 | 125 | # 3c. 修復された音声を保存 126 | save(output_path, restored_wav.squeeze(0).cpu().to(torch.float32), sr=cfg.output_sampling_rate) 127 | 128 | except Exception as e: # noqa: BLE001 129 | # エラーが発生しても処理を止めず、エラーメッセージを表示して次のファイルへ進みます 130 | tqdm.write(f"Failed to process {input_path}: {e}") 131 | 132 | print("\nBatch inference finished.") 133 | -------------------------------------------------------------------------------- /scripts/generate_latex_table.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | 4 | def format_value(value, precision=2): 5 | """数値を指定の精度でフォーマット""" 6 | if isinstance(value, (int, float)): 7 | return f"{value:.{precision}f}" 8 | return str(value) 9 | 10 | def generate_latex_table(): 11 | """CSVファイルからLaTeXテーブルを生成""" 12 | results_dir = Path("/home/ayu/GitHub/open-miipher-2/results") 13 | 14 | # CSVファイルを読み込む 15 | df_8khz = pd.read_csv(results_dir / "summary_8khz.csv") 16 | df_degrade = pd.read_csv(results_dir / "summary_degrade.csv") 17 | 18 | # LaTeX文書の開始 19 | latex_content = r"""\documentclass{article} 20 | \usepackage{booktabs} 21 | \usepackage{multirow} 22 | \usepackage{siunitx} 23 | \usepackage[margin=1in]{geometry} 24 | \usepackage{adjustbox} 25 | 26 | \begin{document} 27 | 28 | \begin{table}[htbp] 29 | \centering 30 | \sisetup{ 31 | table-format=1.2, 32 | round-mode=places, 33 | round-precision=2 34 | } 35 | 36 | \caption{音声復元の種類とDNSMOSスコア比較} 37 | 38 | \setlength{\tabcolsep}{4pt} 39 | 40 | \begin{adjustbox}{max width=\linewidth} 41 | \begin{tabular}{ 42 | l 43 | l 44 | *{5}{S[table-format=1.2]} 45 | } 46 | \toprule 47 | \textbf{音声復元の種類} & \textbf{劣化手法} 48 | & \textbf{ecapa cos} & \textbf{dnsmos p808} 49 | & \textbf{dnsmos sig} & \textbf{dnsmos bak} 50 | & \textbf{dnsmos} \\ 51 | \midrule 52 | """ 53 | 54 | # モデル名のマッピング 55 | name_mapping = { 56 | 'original': 'original', 57 | '8khz_degraded': 'degraded', 58 | 'noise_degraded': 'degraded', 59 | 'miipher_1': 'miipher-1', 60 | 'hubert_large_l2': 'hubert\\_large\\_l2', 61 | 'mhubert_l6': 'mhubert\\_l6', 62 | 'wav2vec2_base_l2': 'wav2vec2\\_base\\_l2', 63 | 'wavlm_base_l2': 'wavlm\\_base\\_l2' 64 | } 65 | 66 | # 表示する行の順序 67 | row_order = ['original', '8khz_degraded', 'miipher_1', 'hubert_large_l2', 'mhubert_l6', 'wav2vec2_base_l2', 'wavlm_base_l2'] 68 | row_order_degrade = ['original', 'noise_degraded', 'miipher_1', 'hubert_large_l2', 'mhubert_l6', 'wav2vec2_base_l2', 'wavlm_base_l2'] 69 | 70 | # 8kHz結果のセクション 71 | num_rows = len(row_order) 72 | for i, row_name in enumerate(row_order): 73 | if row_name in df_8khz['name'].values: 74 | row = df_8khz[df_8khz['name'] == row_name].iloc[0] 75 | 76 | # 最初の行でmultirowを開始 77 | if i == 0: 78 | latex_content += f" \\multirow{{{num_rows}}}{{*}}{{}}\n" 79 | 80 | # モデル名 81 | model_name = name_mapping.get(row_name, row_name) 82 | latex_content += f" {model_name:<17} " 83 | 84 | # 劣化手法(最初の行のみ) 85 | if i == 0: 86 | latex_content += f"& \\multirow{{{num_rows}}}{{*}}{{8kHzに変換}}" 87 | else: 88 | latex_content += "& " 89 | 90 | # 数値データ 91 | latex_content += f"\n & {format_value(row['ecapa_cos_mean'])}" 92 | latex_content += f" & {format_value(row['dnsmos_p808_mean'])}" 93 | latex_content += f" & {format_value(row['dnsmos_sig_mean'])}" 94 | latex_content += f" & {format_value(row['dnsmos_bak_mean'])}" 95 | latex_content += f" & {format_value(row['dnsmos_ovr_mean'])} \\\\\n" 96 | 97 | latex_content += " \\midrule\n" 98 | 99 | # Degrade結果のセクション 100 | num_rows = len(row_order_degrade) 101 | for i, row_name in enumerate(row_order_degrade): 102 | if row_name in df_degrade['name'].values: 103 | row = df_degrade[df_degrade['name'] == row_name].iloc[0] 104 | 105 | # 最初の行でmultirowを開始 106 | if i == 0: 107 | latex_content += f" \\multirow{{{num_rows}}}{{*}}{{}}\n" 108 | 109 | # モデル名 110 | model_name = name_mapping.get(row_name, row_name) 111 | latex_content += f" {model_name:<17} " 112 | 113 | # 劣化手法(最初の行のみ) 114 | if i == 0: 115 | latex_content += f"& \\multirow{{{num_rows}}}{{*}}{{残響・背景雑音}}" 116 | else: 117 | latex_content += "& " 118 | 119 | # 数値データ 120 | latex_content += f"\n & {format_value(row['ecapa_cos_mean'])}" 121 | latex_content += f" & {format_value(row['dnsmos_p808_mean'])}" 122 | latex_content += f" & {format_value(row['dnsmos_sig_mean'])}" 123 | latex_content += f" & {format_value(row['dnsmos_bak_mean'])}" 124 | latex_content += f" & {format_value(row['dnsmos_ovr_mean'])} \\\\\n" 125 | 126 | # LaTeX文書の終了 127 | latex_content += r""" \bottomrule 128 | \end{tabular} 129 | \end{adjustbox} 130 | \end{table} 131 | 132 | \end{document} 133 | """ 134 | 135 | # ファイルに保存 136 | output_path = results_dir / "results_table.tex" 137 | with open(output_path, 'w', encoding='utf-8') as f: 138 | f.write(latex_content) 139 | 140 | print(f"LaTeX table generated: {output_path}") 141 | 142 | # 画面にも出力 143 | print("\n" + "="*80) 144 | print("Generated LaTeX code:") 145 | print("="*80) 146 | print(latex_content) 147 | 148 | if __name__ == "__main__": 149 | generate_latex_table() 150 | -------------------------------------------------------------------------------- /src/miipher_2/data/webdataset_loader.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterator 2 | 3 | import torch 4 | import torchaudio 5 | import webdataset as wds 6 | from braceexpand import braceexpand 7 | from torch.utils.data import IterableDataset 8 | 9 | 10 | def _ensure_2d(tensor: torch.Tensor) -> torch.Tensor: 11 | """音声テンソルが必ず [channels, length] の2次元になるように保証する""" 12 | if tensor.dim() == 1: 13 | # テンソルが1次元の場合、チャンネル次元を追加する 14 | return tensor.unsqueeze(0) 15 | return tensor 16 | 17 | 18 | class AdapterDataset(IterableDataset): 19 | """Adapter学習用: 全て16kHzに変換する""" 20 | 21 | def __init__(self, pattern: str | list[str], shuffle: int = 1000) -> None: 22 | # 複数のパターンに対応 23 | if isinstance(pattern, str): 24 | patterns = [pattern] 25 | else: 26 | patterns = pattern 27 | 28 | # ブレース展開を適用 29 | expanded_patterns = [] 30 | for p in patterns: 31 | expanded_patterns.extend(list(braceexpand(p))) 32 | 33 | self.dataset = ( 34 | wds.WebDataset( 35 | expanded_patterns, 36 | resampled=True, 37 | shardshuffle=True, 38 | ) 39 | .shuffle(shuffle) 40 | .decode(wds.torch_audio) 41 | ) 42 | self.target_sr = 16000 43 | 44 | def __iter__(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: 45 | for sample in self.dataset: 46 | clean_wav, clean_sr = sample["speech.wav"] 47 | noisy_wav, noisy_sr = sample["degraded_speech.wav"] 48 | 49 | # ロードした直後に次元数を2Dに統一する 50 | clean_wav = _ensure_2d(clean_wav) 51 | noisy_wav = _ensure_2d(noisy_wav) 52 | 53 | # それぞれの正しいsrを使って16kHzにリサンプリング 54 | clean_16k = torchaudio.functional.resample(clean_wav, orig_freq=clean_sr, new_freq=self.target_sr) 55 | noisy_16k = torchaudio.functional.resample(noisy_wav, orig_freq=noisy_sr, new_freq=self.target_sr) 56 | 57 | # .mean(0, keepdim=True)はステレオ音声をモノラルに変換する安全策として残しておく 58 | yield noisy_16k.mean(0, keepdim=True), clean_16k.mean(0, keepdim=True) 59 | 60 | 61 | class VocoderDataset(IterableDataset): 62 | """Vocoder学習用: 劣化音声は16kHz、クリーン音声は22.05kHzで出力 63 | 64 | Args: 65 | IterableDataset (_type_): _description_ 66 | """ 67 | 68 | def __init__(self, pattern: str | list[str], shuffle: int = 1000) -> None: 69 | # 複数のパターンに対応 70 | if isinstance(pattern, str): 71 | patterns = [pattern] 72 | else: 73 | patterns = pattern 74 | 75 | # ブレース展開を適用 76 | expanded_patterns = [] 77 | for p in patterns: 78 | expanded_patterns.extend(list(braceexpand(p))) 79 | 80 | self.dataset = wds.WebDataset(expanded_patterns, resampled=True, shardshuffle=True).shuffle(shuffle).decode(wds.torch_audio) 81 | self.input_sr = 16000 82 | self.target_sr = 22050 83 | 84 | def __iter__(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: 85 | for sample in self.dataset: 86 | clean_wav, clean_sr = sample["speech.wav"] 87 | noisy_wav, noisy_sr = sample["degraded_speech.wav"] 88 | 89 | # ロードした直後に次元数を2Dに統一する 90 | clean_wav = _ensure_2d(clean_wav) 91 | noisy_wav = _ensure_2d(noisy_wav) 92 | 93 | # 劣化音声はHuBERTに入力するため16kHzにリサンプリング 94 | noisy_16k = torchaudio.functional.resample(noisy_wav, orig_freq=noisy_sr, new_freq=self.input_sr) 95 | 96 | # クリーン音声は教師信号なので22.05kHzのまま 97 | if clean_sr != self.target_sr: 98 | clean_22k = torchaudio.functional.resample(clean_wav, orig_freq=clean_sr, new_freq=self.target_sr) 99 | else: 100 | clean_22k = clean_wav 101 | 102 | # .mean(0, keepdim=True)はステレオ音声をモノラルに変換する安全策として残しておく 103 | yield noisy_16k.mean(0, keepdim=True), clean_22k.mean(0, keepdim=True) 104 | 105 | 106 | class CleanVocoderDataset(IterableDataset): 107 | """Vocoder事前学習用: クリーン音声を16kHzと22.05kHzの両方で出力""" 108 | 109 | def __init__(self, pattern: str | list[str], shuffle: int = 1000) -> None: 110 | # 複数のパターンに対応 111 | if isinstance(pattern, str): 112 | patterns = [pattern] 113 | else: 114 | patterns = pattern 115 | 116 | # ブレース展開を適用 117 | expanded_patterns = [] 118 | for p in patterns: 119 | expanded_patterns.extend(list(braceexpand(p))) 120 | 121 | self.dataset = wds.WebDataset(expanded_patterns, resampled=True, shardshuffle=True).shuffle(shuffle).decode(wds.torch_audio) 122 | self.input_sr = 16000 123 | self.target_sr = 22050 124 | 125 | def __iter__(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: 126 | for sample in self.dataset: 127 | # noisy_speech.wav の代わりに speech.wav を使う 128 | clean_wav, clean_sr = sample["speech.wav"] 129 | 130 | # ロードした直後に次元数を2Dに統一する 131 | clean_wav = _ensure_2d(clean_wav) 132 | 133 | # HuBERTに入力するため16kHzにリサンプリング 134 | clean_16k = torchaudio.functional.resample(clean_wav, orig_freq=clean_sr, new_freq=self.input_sr) 135 | 136 | # 教師信号なので22.05kHzのまま 137 | if clean_sr != self.target_sr: 138 | clean_22k = torchaudio.functional.resample(clean_wav, orig_freq=clean_sr, new_freq=self.target_sr) 139 | else: 140 | clean_22k = clean_wav 141 | 142 | # .mean(0, keepdim=True)はステレオ音声をモノラルに変換する安全策 143 | yield clean_16k.mean(0, keepdim=True), clean_22k.mean(0, keepdim=True) 144 | -------------------------------------------------------------------------------- /src/miipher_2/preprocess/noise_augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pyroomacoustics as pra 6 | import torch 7 | import torchaudio 8 | from omegaconf import DictConfig 9 | from tqdm import tqdm 10 | 11 | 12 | def align_waveform(wav1: torch.Tensor, wav2: torch.Tensor) -> tuple[int, torch.Tensor]: 13 | assert wav2.size(1) >= wav1.size(1) 14 | diff = wav2.size(1) - wav1.size(1) 15 | min_mse = float("inf") 16 | best_i = -1 17 | 18 | for i in range(diff): 19 | segment = wav2[:, i : i + wav1.size(1)] 20 | mse: float = torch.mean((wav1 - segment).pow(2)).item() 21 | if mse < min_mse: 22 | min_mse = mse 23 | best_i = i 24 | 25 | return best_i, wav2[:, best_i : best_i + wav1.size(1)] 26 | 27 | 28 | class DegradationApplier: 29 | def __init__(self, cfg: DictConfig) -> None: 30 | self.format_encoding_pairs = cfg.format_encoding_pairs 31 | self.reverb_conditions = cfg.reverb_conditions 32 | self.background_noise = cfg.background_noise 33 | self.cfg = cfg 34 | self.rirs: list[torch.Tensor] = [] 35 | self.prepare_rir(cfg.n_rirs) 36 | self.noise_audio_paths = [] 37 | for root, pattern in self.cfg.background_noise.patterns: 38 | self.noise_audio_paths.extend(list(Path(root).glob(pattern))) 39 | 40 | def applyCodec(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: 41 | if len(self.format_encoding_pairs) == 0: 42 | return waveform 43 | param: dict = random.choice(self.format_encoding_pairs) 44 | audio_format: str = param["format"] 45 | compression: int | None = param.get("compression") 46 | codec_config = torchaudio.io.CodecConfig(compression_level=compression) if compression else None 47 | eff = torchaudio.io.AudioEffector(format=audio_format, codec_config=codec_config) 48 | wav_tc = waveform.transpose(0, 1) 49 | aug_tc = eff.apply(wav_tc, sample_rate) 50 | augmented = aug_tc.transpose(0, 1).contiguous() 51 | # mp3 encoding may increase the length of the waveform by zero-padding 52 | if waveform.size(1) != augmented.size(1): 53 | best_idx, augmented = align_waveform(waveform, augmented) 54 | return augmented.float() 55 | 56 | def applyReverb(self, waveform: torch.Tensor) -> torch.Tensor: 57 | if len(self.rirs) == 0: 58 | raise RuntimeError 59 | rir = random.choice(self.rirs) 60 | augmented = torchaudio.functional.fftconvolve(waveform, rir) 61 | # rir convolution may increase the length of the waveform 62 | if waveform.size(1) != augmented.size(1): 63 | augmented = augmented[:, : waveform.size(1)] 64 | return augmented.float() 65 | 66 | def prepare_rir(self, n_rirs: int) -> None: 67 | for _ in tqdm(range(n_rirs)): 68 | xy_min_max = self.reverb_conditions.room_xy 69 | z_min_max = self.reverb_conditions.room_z 70 | x = random.uniform(xy_min_max.min, xy_min_max.max) 71 | y = random.uniform(xy_min_max.min, xy_min_max.max) 72 | z = random.uniform(z_min_max.min, z_min_max.max) 73 | corners = np.array([[0, 0], [0, y], [x, y], [x, 0]]).T 74 | room = pra.Room.from_corners(corners, **self.reverb_conditions.room_params) 75 | room.extrude(z) 76 | room.add_source(self.cfg.reverb_conditions.source_pos) 77 | room.add_microphone(self.cfg.reverb_conditions.mic_pos) 78 | 79 | room.compute_rir() 80 | rir = torch.tensor(np.array(room.rir[0])) 81 | rir = rir / rir.norm(p=2) 82 | self.rirs.append(rir) 83 | 84 | def applyBackgroundNoise(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: 85 | snr_max, snr_min = self.background_noise.snr.max, self.background_noise.snr.min 86 | snr = random.uniform(snr_min, snr_max) 87 | 88 | noise_path = random.choice(self.noise_audio_paths) 89 | noise, noise_sr = torchaudio.load(noise_path) 90 | noise /= noise.norm(p=2) 91 | if noise.size(0) > 1: 92 | noise = noise[0].unsqueeze(0) 93 | noise = torchaudio.functional.resample(noise, noise_sr, sample_rate) 94 | if not noise.size(1) < waveform.size(1): 95 | start_idx = random.randint(0, noise.size(1) - waveform.size(1)) 96 | end_idx = start_idx + waveform.size(1) 97 | noise = noise[:, start_idx:end_idx] 98 | else: 99 | noise = noise.repeat(1, waveform.size(1) // noise.size(1) + 1)[:, : waveform.size(1)] 100 | if noise.abs().max() > 0: 101 | augmented = torchaudio.functional.add_noise(waveform=waveform, noise=noise, snr=torch.tensor([snr])) 102 | else: 103 | augmented = waveform 104 | return augmented 105 | 106 | def process(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: 107 | if len(waveform.shape) == 1: 108 | waveform = waveform.unsqueeze(0) 109 | org_len = waveform.size(1) 110 | waveform = self.applyBackgroundNoise(waveform, sample_rate) 111 | if random.random() > self.cfg.reverb_conditions.p: 112 | waveform = self.applyReverb(waveform) 113 | waveform = self.applyCodec(waveform, sample_rate) 114 | assert org_len == waveform.size(1), f"{org_len}, {waveform.size(1)}" 115 | return waveform.squeeze() 116 | 117 | def __call__(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: 118 | return self.process(waveform, sample_rate) 119 | -------------------------------------------------------------------------------- /cmds/evaluate_simple.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate_simple.py 4 | ================== 5 | ECAPAの話者性とDNSMOSの2指標でのみ評価 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import argparse 11 | import logging 12 | import sys 13 | import time 14 | import warnings 15 | from pathlib import Path 16 | 17 | import pandas as pd 18 | import torch 19 | import torchaudio 20 | from speechbrain.pretrained import SpeakerRecognition 21 | from tqdm.auto import tqdm 22 | 23 | # DNSMOS 24 | from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore 25 | 26 | 27 | def get_logger(name: str = "eval", level: int = logging.INFO) -> logging.Logger: 28 | logger = logging.getLogger(name) 29 | if not logger.handlers: 30 | fmt = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S") 31 | h = logging.StreamHandler(sys.stdout) 32 | h.setFormatter(fmt) 33 | logger.addHandler(h) 34 | logger.setLevel(level) 35 | return logger 36 | 37 | 38 | log = get_logger() 39 | 40 | 41 | def load_wav(path: Path, sr: int): 42 | wav, s = torchaudio.load(path) 43 | if s != sr: 44 | wav = torchaudio.functional.resample(wav, s, sr) 45 | return wav 46 | 47 | 48 | def speaker_cos_ecapa(ref: torch.Tensor, syn: torch.Tensor, ecapa_model) -> float: 49 | """ECAPAモデルを使用して話者類似度を計算""" 50 | if ref.dim() == 1: 51 | ref = ref.unsqueeze(0) 52 | if syn.dim() == 1: 53 | syn = syn.unsqueeze(0) 54 | 55 | device = next(ecapa_model.parameters()).device 56 | ref, syn = ref.to(device), syn.to(device) 57 | 58 | with torch.no_grad(): 59 | emb_ref = ecapa_model.encode_batch(ref) 60 | emb_syn = ecapa_model.encode_batch(syn) 61 | 62 | # 次元調整 63 | while emb_ref.dim() > 1: 64 | emb_ref = emb_ref.mean(dim=0) 65 | while emb_syn.dim() > 1: 66 | emb_syn = emb_syn.mean(dim=0) 67 | 68 | sim = torch.nn.functional.cosine_similarity(emb_ref, emb_syn, dim=0, eps=1e-8) 69 | return float(sim) 70 | 71 | 72 | def compute_dnsmos(wav: torch.Tensor, model) -> dict: 73 | """DNSMOSスコアを計算""" 74 | # DNSMOSは音声テンソルを受け取る(1次元で入力) 75 | if wav.dim() > 1: 76 | wav = wav.squeeze() 77 | 78 | with torch.no_grad(): 79 | scores = model(wav) 80 | 81 | # DNSMOSは4つのスコアを返す: [p808_mos, mos_sig, mos_bak, mos_ovr] 82 | return { 83 | 'p808_mos': float(scores[0]), 84 | 'mos_sig': float(scores[1]), 85 | 'mos_bak': float(scores[2]), 86 | 'mos_ovr': float(scores[3]) 87 | } 88 | 89 | 90 | def main() -> None: 91 | ap = argparse.ArgumentParser() 92 | ap.add_argument("--clean_dir", required=True, type=Path) 93 | ap.add_argument("--restored_dir", required=True, type=Path) 94 | ap.add_argument("--outfile", default=Path("eval_results_simple.csv"), type=Path) 95 | ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") 96 | ap.add_argument("--sr", type=int, default=16000) 97 | args = ap.parse_args() 98 | 99 | t0 = time.time() 100 | log.info("Loading models...") 101 | 102 | # ECAPAモデル 103 | ecapa = SpeakerRecognition.from_hparams( 104 | "speechbrain/spkrec-ecapa-voxceleb", 105 | run_opts={"device": args.device}, 106 | savedir=str(Path.home() / ".cache/speechbrain/ecapa"), 107 | ) 108 | 109 | # DNSMOSモデル 110 | dnsmos_model = DeepNoiseSuppressionMeanOpinionScore(fs=args.sr, personalized=False) 111 | dnsmos_model = dnsmos_model.to(args.device) 112 | 113 | log.info(f"Models loaded in {time.time() - t0:.1f} s") 114 | 115 | rows = [] 116 | clean_files = sorted(args.clean_dir.glob("**/*.wav")) 117 | tot = len(clean_files) 118 | log.info(f"Start evaluation on {tot} files") 119 | 120 | ecapa_scores = [] 121 | dnsmos_scores = [] 122 | 123 | for cl_path in tqdm(clean_files, desc="Evaluating"): 124 | rel = cl_path.relative_to(args.clean_dir) 125 | res_path = args.restored_dir / rel 126 | 127 | if not res_path.exists(): 128 | warnings.warn(f"skip {rel} (missing restored file)", stacklevel=2) 129 | continue 130 | 131 | cl = load_wav(cl_path, args.sr) 132 | res = load_wav(res_path, args.sr) 133 | 134 | # ECAPA話者類似度 135 | ecapa_score = speaker_cos_ecapa(cl, res, ecapa) 136 | ecapa_scores.append(ecapa_score) 137 | 138 | # DNSMOSスコア 139 | res_device = res.to(args.device) 140 | dnsmos_score = compute_dnsmos(res_device, dnsmos_model) 141 | dnsmos_scores.append(dnsmos_score) 142 | 143 | rows.append({ 144 | "file": str(rel), 145 | "ECAPA_cos": ecapa_score, 146 | "DNSMOS_p808": dnsmos_score['p808_mos'], 147 | "DNSMOS_sig": dnsmos_score['mos_sig'], 148 | "DNSMOS_bak": dnsmos_score['mos_bak'], 149 | "DNSMOS_ovr": dnsmos_score['mos_ovr'], 150 | }) 151 | 152 | # 結果をCSVに保存 153 | df = pd.DataFrame(rows) 154 | df.to_csv(args.outfile, index=False) 155 | log.info(f"CSV saved to {args.outfile}") 156 | 157 | # 平均値の表示 158 | log.info("\n=== Evaluation Results (Mean) ===") 159 | log.info(f"ECAPA Speaker Similarity: {sum(ecapa_scores) / len(ecapa_scores):.4f}") 160 | 161 | # DNSMOS各スコアの平均値を計算 162 | avg_p808 = sum(s['p808_mos'] for s in dnsmos_scores) / len(dnsmos_scores) 163 | avg_sig = sum(s['mos_sig'] for s in dnsmos_scores) / len(dnsmos_scores) 164 | avg_bak = sum(s['mos_bak'] for s in dnsmos_scores) / len(dnsmos_scores) 165 | avg_ovr = sum(s['mos_ovr'] for s in dnsmos_scores) / len(dnsmos_scores) 166 | 167 | log.info(f"DNSMOS P808: {avg_p808:.4f}") 168 | log.info(f"DNSMOS Signal: {avg_sig:.4f}") 169 | log.info(f"DNSMOS Background: {avg_bak:.4f}") 170 | log.info(f"DNSMOS Overall: {avg_ovr:.4f}") 171 | log.info("=================================") 172 | 173 | 174 | if __name__ == "__main__": 175 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Miipher-2 2 | 3 |

4 | Python 5 | PyTorch 6 | arXiv 7 | Hugging Face 8 |

9 | 10 |

11 | Unofficial implementation of Miipher-2: High-quality speech enhancement via HuBERT + Parallel Adapter 12 |

13 | 14 |

15 | Key Features • 16 | Demo • 17 | Quick Start • 18 | Model Zoo • 19 | Training • 20 | Evaluation • 21 | Citation 22 |

23 | 24 | --- 25 | 26 | ## 🚀 Key Features 27 | 28 | - **Speech enhancement** based on [Miipher-2](https://arxiv.org/abs/2505.04457) architecture 29 | - **Lightweight Parallel Adapter** design for efficient feature adaptation 30 | - **Pre-trained models** available on [🤗 Hugging Face](https://huggingface.co/Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1) 31 | - **Comprehensive evaluation pipeline** with multiple metrics 32 | 33 | ## 🎧 Demo 34 | 35 | Experience the power of our model speech enhancement: 36 | 37 |
38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 52 | 57 | 58 | 59 |
🔊 Degraded Audio✨ Enhanced Audio
48 | Degraded spectrogram 49 |
50 | Noisy input 51 |
53 | Enhanced spectrogram 54 |
55 | Clean output 56 |
60 |
61 | 62 | ## 🛠️ Quick Start 63 | 64 | ### Prerequisites 65 | 66 | ```bash 67 | # Install dependencies using uv 68 | uv sync 69 | ``` 70 | 71 | ### 📁 Project Structure 72 | 73 | ``` 74 | open-miipher-2/ 75 | ├── configs/ # Hydra configuration files 76 | ├── src/miipher_2/ # Core Python modules 77 | ├── cmd/ # CLI entry points 78 | ├── exp/ # Model checkpoints 79 | └── docs/ # Documentation 80 | ``` 81 | 82 | ### 🚀 Quick Inference 83 | 84 | Use our pre-trained model for instant speech enhancement: 85 | 86 | ```bash 87 | # Download pre-trained model from Hugging Face 88 | # Model: miipher-2-HuBERT-HiFi-GAN-v0.1 89 | 90 | # Run inference on your audio files 91 | uv run cmd/inference_dir.py --config-name infer_dir 92 | ``` 93 | 94 | ## 🤗 Model Zoo 95 | 96 | | Model | SSL Backbone | Adapter Layers | Vocoder | Download | 97 | |-------|--------------|----------------|---------|----------| 98 | | miipher-2 HuBERT HiFi-GAN v0.1 | mHuBERT-147 | Layer 6 | HiFi-GAN | [🤗 HuggingFace](https://huggingface.co/Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1) | 99 | 100 | ## 📚 Training 101 | 102 | ### Step 1: Data Preprocessing 103 | 104 | Generate pseudo-degraded dataset from clean speech: 105 | 106 | ```bash 107 | # Process JVS corpus (Japanese) 108 | uv run cmd/preprocess.py --config-name preprocess_jvs 109 | 110 | # Process LibriTTS (English) 111 | uv run cmd/preprocess.py --config-name preprocess_libritts_r 112 | 113 | # Process FLEURS (Multilingual) 114 | uv run cmd/preprocess.py --config-name preprocess_fleurs_r 115 | ``` 116 | 117 | Output is saved in WebDataset format for efficient data loading. 118 | 119 | ### Step 2: Train Parallel Adapter 120 | 121 | ```bash 122 | # Train adapter module 123 | uv run cmd/train_adapter.py --config-name adapter_layer_6_mhubert_147 124 | 125 | # Resume from checkpoint 126 | uv run cmd/train_adapter.py \ 127 | checkpoint.resume_from="exp/adapter_layer_6_mhubert_147/checkpoint_199k.pt" \ 128 | --config-name adapter_layer_6_mhubert_147 129 | ``` 130 | 131 | ### Step 3: Train SSL-Vocoder 132 | 133 | ```bash 134 | # Pre-train Lightning SSL-Vocoder 135 | uv run cmd/pre_train_vocoder.py --config-name hifigan_pretrain_layer_6_mhubert_147 136 | ``` 137 | 138 | > 💡 **Note**: Configuration is automatically inherited from checkpoint unless explicitly overridden. 139 | 140 | 141 | ## 📊 Evaluation 142 | 143 | ### Step 1: Generate Degraded Test Data 144 | 145 | Create evaluation dataset with various noise conditions: 146 | 147 | ```bash 148 | uv run cmd/degrade.py \ 149 | --clean_dir \ 150 | --noise_dir \ 151 | --out_dir 152 | ``` 153 | 154 | ### Step 2: Run Enhancement 155 | 156 | Process degraded audio through the model: 157 | 158 | ```bash 159 | uv run cmd/inference_dir.py --config-name infer_dir 160 | ``` 161 | 162 | ### Step 3: Compute Metrics 163 | 164 | Evaluate enhancement quality with multiple metrics: 165 | 166 | ```bash 167 | uv run cmd/evaluate.py \ 168 | --clean_dir \ 169 | --degraded_dir \ 170 | --restored_dir \ 171 | --outfile results.csv 172 | ``` 173 | 174 | Metrics include: 175 | - **PESQ** (Perceptual Evaluation of Speech Quality) 176 | - **STOI** (Short-Time Objective Intelligibility) 177 | - **SI-SDR** (Scale-Invariant Signal-to-Distortion Ratio) 178 | - **MOS-LQO** (Mean Opinion Score) 179 | 180 | ## 🏗️ Architecture 181 | 182 | ![alt text](docs/image.png) 183 | 184 | ### Key Components 185 | 186 | 1. **HuBERT Feature Extractor**: Multilingual HuBERT (mHuBERT-147) for robust speech representations 187 | 2. **Parallel Adapter**: Lightweight feed-forward network inserted at specific layers 188 | 3. **Feature Cleaner**: Denoising module operating on SSL features 189 | 4. **Lightning SSL-Vocoder**: HiFi-GAN-based vocoder 190 | 191 | ## 🔧 Configuration 192 | 193 | All configurations are managed through Hydra. Key config files: 194 | 195 | - `configs/adapter_layer_6_mhubert_147.yaml` - Adapter training 196 | - `configs/infer_dir.yaml` - Inference settings 197 | - `configs/preprocess_*.yaml` - Data preprocessing 198 | -------------------------------------------------------------------------------- /demo/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import torchaudio 4 | import numpy as np 5 | from pathlib import Path 6 | from huggingface_hub import hf_hub_download 7 | from omegaconf import DictConfig 8 | 9 | from miipher_2.model.feature_cleaner import FeatureCleaner 10 | from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule 11 | 12 | # Model configuration 13 | MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1" 14 | ADAPTER_FILENAME = "checkpoint_199k_fixed.pt" 15 | VOCODER_FILENAME = "epoch=77-step=137108.ckpt" 16 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | SAMPLE_RATE_INPUT = 16000 18 | SAMPLE_RATE_OUTPUT = 22050 19 | 20 | # Cache for models 21 | models_cache = {} 22 | 23 | def download_models(): 24 | """Download models from Hugging Face Hub""" 25 | print("Downloading models from Hugging Face Hub...") 26 | 27 | adapter_path = hf_hub_download( 28 | repo_id=MODEL_REPO_ID, 29 | filename=ADAPTER_FILENAME, 30 | cache_dir="./models" 31 | ) 32 | 33 | vocoder_path = hf_hub_download( 34 | repo_id=MODEL_REPO_ID, 35 | filename=VOCODER_FILENAME, 36 | cache_dir="./models" 37 | ) 38 | 39 | return adapter_path, vocoder_path 40 | 41 | def load_models(): 42 | """Load models into memory""" 43 | if "cleaner" in models_cache and "vocoder" in models_cache: 44 | return models_cache["cleaner"], models_cache["vocoder"] 45 | 46 | adapter_path, vocoder_path = download_models() 47 | 48 | # Model configuration 49 | model_config = DictConfig({ 50 | "hubert_model_name": "utter-project/mHuBERT-147", 51 | "hubert_layer": 6, 52 | "adapter_hidden_dim": 768 53 | }) 54 | 55 | # Initialize FeatureCleaner 56 | print("Loading FeatureCleaner...") 57 | cleaner = FeatureCleaner(model_config).to(DEVICE).eval() 58 | 59 | # Load adapter weights 60 | adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False) 61 | cleaner.load_state_dict(adapter_checkpoint["model_state_dict"]) 62 | 63 | # Load vocoder 64 | print("Loading vocoder...") 65 | vocoder = HiFiGANLightningModule.load_from_checkpoint( 66 | vocoder_path, map_location=DEVICE 67 | ).to(DEVICE).eval() 68 | 69 | # Cache models 70 | models_cache["cleaner"] = cleaner 71 | models_cache["vocoder"] = vocoder 72 | 73 | return cleaner, vocoder 74 | 75 | @torch.inference_mode() 76 | def enhance_audio(audio_path, progress=gr.Progress()): 77 | """Enhance audio using Miipher-2 model""" 78 | try: 79 | progress(0, desc="Loading models...") 80 | cleaner, vocoder = load_models() 81 | 82 | progress(0.2, desc="Loading audio...") 83 | # Load audio 84 | waveform, sr = torchaudio.load(audio_path) 85 | 86 | # Resample to 16kHz if needed 87 | if sr != SAMPLE_RATE_INPUT: 88 | waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT) 89 | 90 | # Convert to mono if stereo 91 | waveform = waveform.mean(0, keepdim=True) 92 | 93 | # Move to device 94 | waveform = waveform.to(DEVICE) 95 | 96 | progress(0.4, desc="Extracting features...") 97 | # Extract features using FeatureCleaner 98 | with torch.no_grad(), torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=(DEVICE.type == "cuda")): 99 | features = cleaner(waveform) 100 | 101 | # Ensure correct shape for vocoder 102 | if features.dim() == 2: 103 | features = features.unsqueeze(0) 104 | 105 | progress(0.7, desc="Generating enhanced audio...") 106 | # Generate audio using vocoder 107 | # Lightning SSL-Vocoderの入力形式に合わせる (batch, seq_len, input_channels) 108 | batch = {"input_feature": features.transpose(1, 2)} 109 | enhanced_audio = vocoder.generator_forward(batch) 110 | 111 | # Convert to numpy 112 | enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy() 113 | 114 | progress(1.0, desc="Enhancement complete!") 115 | 116 | # Save audio using torchaudio to avoid Gradio format issues 117 | enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0) 118 | enhanced_audio_tensor = torch.from_numpy(enhanced_audio) 119 | 120 | # Ensure 2D tensor: (channels, samples) 121 | if enhanced_audio_tensor.dim() == 1: 122 | enhanced_audio_tensor = enhanced_audio_tensor.unsqueeze(0) 123 | 124 | # Save to temporary file using torchaudio 125 | import tempfile 126 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: 127 | torchaudio.save(tmp_file.name, enhanced_audio_tensor, SAMPLE_RATE_OUTPUT) 128 | return tmp_file.name 129 | 130 | except Exception as e: 131 | raise gr.Error(f"Error during enhancement: {str(e)}") 132 | 133 | # Create Gradio interface 134 | def create_interface(): 135 | title = "🎤 Miipher-2 Speech Enhancement" 136 | 137 | description = """ 138 |
139 |

High-quality speech enhancement using Miipher-2 (HuBERT + Parallel Adapter + HiFi-GAN)

140 |

📄 Paper | 141 | 🤗 Model | 142 | 💻 GitHub

143 |
144 | """ 145 | 146 | article = """ 147 | ## How it works 148 | 149 | 1. **Upload** a noisy or degraded audio file 150 | 2. **Process** using Miipher-2 model 151 | 3. **Download** the enhanced audio 152 | 153 | ### Model Details 154 | - **SSL Backbone**: mHuBERT-147 (Multilingual) 155 | - **Adapter**: Parallel adapters at layer 6 156 | - **Vocoder**: HiFi-GAN trained on SSL features 157 | - **Input**: Any sample rate (automatically resampled to 16kHz) 158 | - **Output**: 22.05kHz high-quality audio 159 | 160 | ### Tips 161 | - Works best with speech audio 162 | - Supports various noise types (background noise, reverb, etc.) 163 | - Processing time depends on audio length and hardware 164 | """ 165 | 166 | examples = [ 167 | ["examples/noisy_speech_1.wav"], 168 | ["examples/noisy_speech_2.wav"], 169 | ["examples/reverb_speech.wav"], 170 | ] 171 | 172 | with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo: 173 | gr.Markdown(f"# {title}") 174 | gr.Markdown(description) 175 | 176 | with gr.Row(): 177 | with gr.Column(): 178 | input_audio = gr.Audio( 179 | label="Input Audio (Noisy/Degraded)", 180 | type="filepath", 181 | sources=["upload", "microphone"] 182 | ) 183 | 184 | enhance_btn = gr.Button("🚀 Enhance Audio", variant="primary") 185 | 186 | with gr.Column(): 187 | output_audio = gr.Audio( 188 | label="Enhanced Audio", 189 | type="filepath", 190 | interactive=False 191 | ) 192 | 193 | # Add examples if they exist 194 | examples_dir = Path("examples") 195 | if examples_dir.exists(): 196 | example_files = list(examples_dir.glob("*.wav")) + list(examples_dir.glob("*.mp3")) 197 | if example_files: 198 | gr.Examples( 199 | examples=[[str(f)] for f in example_files[:3]], 200 | inputs=input_audio, 201 | outputs=output_audio, 202 | fn=enhance_audio, 203 | cache_examples=True 204 | ) 205 | 206 | gr.Markdown(article) 207 | 208 | # Connect the enhancement function 209 | enhance_btn.click( 210 | fn=enhance_audio, 211 | inputs=input_audio, 212 | outputs=output_audio, 213 | show_progress=True 214 | ) 215 | 216 | return demo 217 | 218 | # Launch the app 219 | if __name__ == "__main__": 220 | # Pre-load models 221 | print("Pre-loading models...") 222 | load_models() 223 | print("Models loaded successfully!") 224 | 225 | # Create and launch interface 226 | demo = create_interface() 227 | demo.launch() 228 | -------------------------------------------------------------------------------- /src/miipher_2/train/adapter.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torch import nn, optim 6 | from torch.nn.utils.rnn import pad_sequence 7 | from torch.utils.data import DataLoader 8 | from transformers.optimization import get_scheduler 9 | 10 | import wandb 11 | from miipher_2.data.webdataset_loader import AdapterDataset 12 | from miipher_2.extractors.hubert import SSLExtractor 13 | from miipher_2.model.feature_cleaner import FeatureCleaner 14 | from miipher_2.utils.checkpoint import ( 15 | get_resume_checkpoint_path, 16 | load_checkpoint, 17 | restore_random_states, 18 | save_checkpoint, 19 | setup_wandb_resume, 20 | ) 21 | 22 | 23 | def collate_tensors(batch: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: 24 | noisy_tensors, clean_tensors = zip(*batch, strict=False) 25 | noisy_tensors = [x.squeeze(0) for x in noisy_tensors] 26 | clean_tensors = [x.squeeze(0) for x in clean_tensors] 27 | noisy = pad_sequence(noisy_tensors, batch_first=True, padding_value=0.0) 28 | clean = pad_sequence(clean_tensors, batch_first=True, padding_value=0.0) 29 | return noisy, clean 30 | 31 | 32 | @torch.no_grad() 33 | def validate( 34 | model: nn.Module, target_model: nn.Module, val_dl: DataLoader, loss_fns: dict, limit_batches: int | None = None 35 | ) -> dict: 36 | model.eval() 37 | total_losses = {"total": 0.0, "mae": 0.0, "mse": 0.0, "sc": 0.0} 38 | total_count = 0 39 | 40 | for i, (noisy, clean) in enumerate(val_dl): 41 | if limit_batches is not None and i >= limit_batches: 42 | break 43 | noisy, clean = noisy.cuda(), clean.cuda() # noqa: PLW2901 44 | 45 | target = target_model(clean) 46 | pred = model(noisy) 47 | 48 | min_len = min(pred.size(2), target.size(2)) 49 | pred, target = pred[:, :, :min_len], target[:, :, :min_len] 50 | 51 | mae_loss = loss_fns["mae"](pred, target) 52 | mse_loss = loss_fns["mse"](pred, target) 53 | sc_loss = (pred - target).pow(2).sum() / (target.pow(2).sum() + 1e-9) 54 | loss = mae_loss + mse_loss + sc_loss 55 | 56 | total_losses["total"] += loss.item() * noisy.size(0) 57 | total_losses["mae"] += mae_loss.item() * noisy.size(0) 58 | total_losses["mse"] += mse_loss.item() * noisy.size(0) 59 | total_losses["sc"] += sc_loss.item() * noisy.size(0) 60 | total_count += noisy.size(0) 61 | 62 | avg_losses = {key: val / total_count for key, val in total_losses.items()} 63 | model.train() 64 | return avg_losses 65 | 66 | 67 | def train_adapter(cfg: DictConfig) -> None: 68 | resume_checkpoint_path = get_resume_checkpoint_path(cfg) 69 | resumed_checkpoint = None 70 | if resume_checkpoint_path: 71 | resumed_checkpoint = load_checkpoint(str(resume_checkpoint_path)) 72 | restore_random_states(resumed_checkpoint) 73 | print(f"[INFO] Resuming from step {resumed_checkpoint['step']}") 74 | 75 | setup_wandb_resume(cfg, resumed_checkpoint) 76 | 77 | dl = DataLoader( 78 | AdapterDataset(cfg.dataset.path_pattern, shuffle=cfg.dataset.shuffle), 79 | batch_size=cfg.batch_size, 80 | num_workers=cfg.loader.num_workers, 81 | pin_memory=cfg.loader.pin_memory, 82 | collate_fn=collate_tensors, 83 | drop_last=True, 84 | ) 85 | 86 | val_dl = DataLoader( 87 | AdapterDataset(cfg.dataset.val_path_pattern, shuffle=False), # シャッフルは不要 88 | batch_size=cfg.batch_size, 89 | num_workers=cfg.loader.num_workers, 90 | pin_memory=cfg.loader.pin_memory, 91 | collate_fn=collate_tensors, 92 | drop_last=False, 93 | ) 94 | 95 | model = FeatureCleaner(cfg.model).cuda().float() 96 | 97 | # SSLExtractorを使用してモデルタイプを自動判定またはconfigから取得 98 | model_type = cfg.model.get("model_type", "auto") 99 | target_model = ( 100 | SSLExtractor( 101 | model_name=cfg.model.hubert_model_name, 102 | layer=cfg.model.hubert_layer - 1, 103 | model_type=model_type, 104 | ) 105 | .cuda() 106 | .float() 107 | .eval() 108 | ) 109 | for param in target_model.parameters(): 110 | param.requires_grad = False 111 | 112 | opt = optim.AdamW( 113 | filter(lambda p: p.requires_grad, model.parameters()), 114 | lr=cfg.optim.lr, 115 | weight_decay=cfg.optim.weight_decay, 116 | betas=(cfg.optim.betas[0], cfg.optim.betas[1]), 117 | ) 118 | 119 | # スケジューラの設定 120 | if cfg.optim.scheduler.name == "cosine_with_restarts": 121 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 122 | 123 | scheduler = CosineAnnealingWarmRestarts( 124 | optimizer=opt, 125 | T_0=cfg.optim.scheduler.first_cycle_steps, 126 | T_mult=int(cfg.optim.scheduler.cycle_mult), 127 | eta_min=cfg.optim.scheduler.min_lr, 128 | ) 129 | else: 130 | scheduler = get_scheduler( 131 | name=cfg.optim.scheduler.name, 132 | optimizer=opt, 133 | num_warmup_steps=cfg.optim.scheduler.warmup_steps, 134 | num_training_steps=cfg.steps, 135 | ) 136 | 137 | start_it = 0 138 | if resumed_checkpoint: 139 | model.load_state_dict(resumed_checkpoint["model_state_dict"]) 140 | opt.load_state_dict(resumed_checkpoint["optimizer_state_dict"]) 141 | if "scheduler_state_dict" in resumed_checkpoint: 142 | scheduler.load_state_dict(resumed_checkpoint["scheduler_state_dict"]) 143 | start_it = resumed_checkpoint.get("step", 0) + 1 144 | print("[INFO] Restored model, optimizer, and scheduler states") 145 | 146 | mae_loss_fn = nn.L1Loss() 147 | mse_loss_fn = nn.MSELoss() 148 | 149 | dl_iter = iter(dl) 150 | 151 | for it in range(start_it, cfg.steps): 152 | if it > 0 and it % cfg.validation_interval == 0: 153 | val_losses = validate( 154 | model, 155 | target_model, 156 | val_dl, 157 | {"mae": mae_loss_fn, "mse": mse_loss_fn}, 158 | limit_batches=cfg.get("validation_batches"), 159 | ) 160 | wandb.log({f"val_loss/{key}": val for key, val in val_losses.items()}, step=it) 161 | print(f"[Adapter] it:{it:>7} | Validation Loss: {val_losses['total']:.4f}") 162 | 163 | try: 164 | noisy, clean = next(dl_iter) 165 | except StopIteration: 166 | dl_iter = iter(dl) 167 | noisy, clean = next(dl_iter) 168 | 169 | noisy, clean = noisy.cuda(), clean.cuda() 170 | 171 | with torch.no_grad(): 172 | target = target_model(clean) 173 | 174 | pred = model(noisy) 175 | 176 | min_len = min(pred.size(2), target.size(2)) 177 | pred = pred[:, :, :min_len] 178 | target = target[:, :, :min_len] 179 | 180 | mae_loss = mae_loss_fn(pred, target) 181 | mse_loss = mse_loss_fn(pred, target) 182 | sc_loss = (pred - target).pow(2).sum() / (target.pow(2).sum() + 1e-9) 183 | loss = mae_loss + mse_loss + sc_loss 184 | 185 | opt.zero_grad(set_to_none=True) 186 | loss.backward() 187 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg.optim.max_grad_norm) 188 | opt.step() 189 | scheduler.step() 190 | 191 | if it % cfg.log_interval == 0: 192 | log_data = { 193 | "iteration": it, 194 | "loss/total": loss.item(), 195 | "loss/mae": mae_loss.item(), 196 | "loss/mse": mse_loss.item(), 197 | "loss/sc": sc_loss.item(), 198 | "learning_rate": scheduler.get_last_lr()[0], 199 | } 200 | print( 201 | f"[Adapter] it:{it:>7} | " 202 | f"Total Loss={loss.item():.4f} | " 203 | f"MAE={mae_loss.item():.4f} | " 204 | f"MSE={mse_loss.item():.4f} | " 205 | f"SC={sc_loss.item():.4f}" 206 | ) 207 | if cfg.wandb.enabled: 208 | wandb.log(log_data, step=it) 209 | 210 | if hasattr(cfg, "checkpoint") and it > 0 and it % cfg.checkpoint.save_interval == 0: 211 | save_checkpoint( 212 | checkpoint_dir=cfg.save_dir, 213 | step=it, 214 | model_state=model.state_dict(), 215 | optimizer_state=opt.state_dict(), 216 | scheduler_state=scheduler.state_dict(), 217 | cfg=cfg, 218 | keep_last_n=cfg.checkpoint.keep_last_n, 219 | ) 220 | 221 | # 最終モデル保存 222 | sd = pathlib.Path(cfg.save_dir) 223 | sd.mkdir(parents=True, exist_ok=True) 224 | model_path = sd / "adapter_final.pt" 225 | torch.save(model.state_dict(), model_path) 226 | 227 | if cfg.wandb.enabled and cfg.wandb.log_model: 228 | artifact = wandb.Artifact("adapter_model", type="model") 229 | artifact.add_file(str(model_path)) 230 | wandb.log_artifact(artifact) 231 | print("[INFO] Model saved as wandb artifact") 232 | 233 | if cfg.wandb.enabled: 234 | wandb.finish() 235 | -------------------------------------------------------------------------------- /src/miipher_2/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import numpy as np 7 | import torch 8 | from omegaconf import DictConfig 9 | 10 | import wandb 11 | 12 | 13 | def save_checkpoint( 14 | checkpoint_dir: str, 15 | step: int, 16 | model_state: dict[str, Any], 17 | optimizer_state: dict[str, Any], 18 | scheduler_state: dict[str, Any] | None = None, 19 | additional_states: dict[str, Any] | None = None, 20 | cfg: DictConfig | None = None, 21 | keep_last_n: int = 5, 22 | ) -> str: 23 | """ 24 | チェックポイントを保存する 25 | 26 | Args: 27 | checkpoint_dir: チェックポイント保存ディレクトリ 28 | step: 現在のステップ数 29 | model_state: モデルの状態辞書 30 | optimizer_state: オプティマイザの状態辞書 31 | scheduler_state: スケジューラの状態辞書 32 | additional_states: 追加の状態辞書 33 | cfg: 設定 34 | keep_last_n: 保持するチェックポイント数 35 | 36 | Returns: 37 | 保存されたチェックポイントのパス 38 | """ 39 | checkpoint_dir_path = Path(checkpoint_dir) 40 | checkpoint_dir_path.mkdir(parents=True, exist_ok=True) 41 | 42 | # チェックポイントファイル名 43 | checkpoint_name = f"checkpoint_{step // 1000}k.pt" 44 | checkpoint_path = checkpoint_dir_path / checkpoint_name 45 | 46 | # 保存するデータ 47 | checkpoint_data = { 48 | "step": step, 49 | "model_state_dict": model_state, 50 | "optimizer_state_dict": optimizer_state, 51 | } 52 | 53 | if scheduler_state is not None: 54 | checkpoint_data["scheduler_state_dict"] = scheduler_state 55 | 56 | if additional_states is not None: 57 | checkpoint_data.update(additional_states) 58 | 59 | # Wandb情報を保存 60 | if wandb.run is not None: 61 | checkpoint_data["wandb_run_id"] = wandb.run.id 62 | checkpoint_data["wandb_run_name"] = wandb.run.name 63 | checkpoint_data["wandb_project"] = wandb.run.project 64 | 65 | # 設定を保存 66 | if cfg is not None: 67 | checkpoint_data["config"] = dict(cfg) 68 | 69 | # 乱数状態を保存 70 | checkpoint_data["random_states"] = { 71 | "python": random.getstate(), 72 | "numpy": np.random.get_state(), # noqa: NPY002 73 | "torch": torch.get_rng_state(), 74 | } 75 | 76 | if torch.cuda.is_available(): 77 | checkpoint_data["random_states"]["torch_cuda"] = torch.cuda.get_rng_state_all() 78 | 79 | # チェックポイント保存 80 | torch.save(checkpoint_data, checkpoint_path) 81 | print(f"[INFO] Checkpoint saved: {checkpoint_path}") 82 | 83 | # 古いチェックポイントを削除 84 | cleanup_old_checkpoints(checkpoint_dir_path, keep_last_n) 85 | 86 | return str(checkpoint_path) 87 | 88 | 89 | def load_checkpoint(checkpoint_path: str) -> dict[str, Any]: 90 | """ 91 | チェックポイントを読み込む 92 | 93 | Args: 94 | checkpoint_path: チェックポイントファイルのパス 95 | 96 | Returns: 97 | チェックポイントデータ 98 | """ 99 | if not Path(checkpoint_path).exists(): 100 | msg = f"Checkpoint not found: {checkpoint_path}" 101 | raise FileNotFoundError(msg) 102 | 103 | try: 104 | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) 105 | print(f"[INFO] Checkpoint loaded: {checkpoint_path}") 106 | return checkpoint # noqa: TRY300 107 | except Exception as e: 108 | msg = f"Failed to load checkpoint {checkpoint_path}: {e}" 109 | raise RuntimeError(msg) from e 110 | 111 | 112 | def validate_config_compatibility(checkpoint_config: dict[str, Any], current_config: dict[str, Any]) -> list[str]: 113 | """ 114 | 設定の互換性をチェックし、重要なパラメータの変更を検出 115 | 116 | Args: 117 | checkpoint_config: チェックポイントに保存された設定 118 | current_config: 現在の設定 119 | 120 | Returns: 121 | 変更された重要なパラメータのリスト 122 | """ 123 | critical_params = [ 124 | "model", 125 | "optim.lr", 126 | "batch_size", 127 | "dataset.num_examples", 128 | "steps", 129 | ] 130 | 131 | warnings = [] 132 | 133 | def get_nested_value(config: dict[str, Any], key: str) -> Any | None: # noqa: ANN401 134 | """ネストされた設定値を取得""" 135 | keys = key.split(".") 136 | value = config 137 | for k in keys: 138 | if isinstance(value, dict) and k in value: 139 | value = value[k] 140 | else: 141 | return None 142 | return value 143 | 144 | for param in critical_params: 145 | checkpoint_val = get_nested_value(checkpoint_config, param) 146 | current_val = get_nested_value(current_config, param) 147 | 148 | if checkpoint_val is not None and current_val is not None and checkpoint_val != current_val: 149 | warnings.append(f"{param}: {checkpoint_val} -> {current_val}") 150 | 151 | return warnings 152 | 153 | 154 | def setup_wandb_resume(cfg: DictConfig, checkpoint: dict[str, Any] | None = None) -> None: 155 | """ 156 | Wandbの初期化と再開設定 157 | 158 | Args: 159 | cfg: 設定 160 | checkpoint: チェックポイントデータ 161 | """ 162 | if not cfg.wandb.enabled: 163 | return 164 | 165 | if checkpoint is not None and "wandb_run_id" in checkpoint: 166 | # チェックポイントから自動的にWandb IDを継承 167 | wandb.init( 168 | project=cfg.wandb.project, 169 | entity=cfg.wandb.entity, 170 | id=checkpoint["wandb_run_id"], 171 | resume="must", 172 | config=dict(cfg), # 現在の設定を使用 173 | tags=cfg.wandb.tags, 174 | notes=cfg.wandb.notes, 175 | ) 176 | print(f"[INFO] Resumed wandb run: {checkpoint['wandb_run_id']}") 177 | 178 | # 設定の互換性をチェック 179 | if "config" in checkpoint: 180 | config_warnings = validate_config_compatibility(checkpoint["config"], dict(cfg)) 181 | if config_warnings: 182 | print("[WARNING] Configuration changes detected:") 183 | for warning in config_warnings: 184 | print(f" - {warning}") 185 | print("These changes may affect training consistency.") 186 | else: 187 | # 新規学習 188 | wandb.init( 189 | project=cfg.wandb.project, 190 | entity=cfg.wandb.entity, 191 | name=cfg.wandb.name, 192 | tags=cfg.wandb.tags, 193 | config=dict(cfg), 194 | notes=cfg.wandb.notes, 195 | ) 196 | print(f"[INFO] Started new wandb run: {wandb.run.id}") 197 | 198 | 199 | def restore_random_states(checkpoint: dict[str, Any]) -> None: 200 | """ 201 | 乱数状態を復元する 202 | 203 | Args: 204 | checkpoint: チェックポイントデータ 205 | """ 206 | if "random_states" not in checkpoint: 207 | print("[WARNING] No random states found in checkpoint") 208 | return 209 | 210 | random_states = checkpoint["random_states"] 211 | 212 | try: 213 | random.setstate(random_states["python"]) 214 | np.random.set_state(random_states["numpy"]) # noqa: NPY002 215 | torch.set_rng_state(random_states["torch"]) 216 | 217 | if torch.cuda.is_available() and "torch_cuda" in random_states: 218 | torch.cuda.set_rng_state_all(random_states["torch_cuda"]) 219 | 220 | print("[INFO] Random states restored") 221 | except Exception as e: # noqa: BLE001 222 | print(f"[WARNING] Failed to restore random states: {e}") 223 | 224 | 225 | def cleanup_old_checkpoints(checkpoint_dir: Path, keep_last_n: int) -> None: 226 | """ 227 | 古いチェックポイントを削除する 228 | 229 | Args: 230 | checkpoint_dir: チェックポイントディレクトリ 231 | keep_last_n: 保持するチェックポイント数 232 | """ 233 | if keep_last_n <= 0: 234 | return 235 | 236 | # チェックポイントファイルを取得 237 | checkpoint_files = list(checkpoint_dir.glob("checkpoint_*.pt")) 238 | 239 | if len(checkpoint_files) <= keep_last_n: 240 | return 241 | 242 | checkpoint_files.sort(key=os.path.getmtime, reverse=True) 243 | 244 | # 古いファイルを削除 245 | files_to_delete = checkpoint_files[keep_last_n:] 246 | for file_path in files_to_delete: 247 | try: 248 | Path(file_path).unlink() 249 | print(f"[INFO] Removed old checkpoint: {file_path}") 250 | except Exception as e: # noqa: BLE001 251 | print(f"[WARNING] Failed to remove checkpoint {file_path}: {e}") 252 | 253 | 254 | def find_latest_checkpoint(checkpoint_dir: str) -> Path | None: 255 | """ 256 | 最新のチェックポイントを見つける 257 | 258 | Args: 259 | checkpoint_dir: チェックポイントディレクトリ 260 | 261 | Returns: 262 | 最新のチェックポイントパス 263 | """ 264 | checkpoint_files = list(Path(checkpoint_dir).glob("checkpoint_*.pt")) 265 | if not checkpoint_files: 266 | return None 267 | 268 | # 最新のファイルを返す 269 | return max(checkpoint_files, key=os.path.getmtime) 270 | 271 | 272 | def get_resume_checkpoint_path(cfg: DictConfig) -> Path | None: 273 | """ 274 | 再開用チェックポイントパスを取得する 275 | コマンドライン引数で`checkpoint.resume_from`が指定された場合のみパスを返す 276 | """ 277 | if hasattr(cfg, "checkpoint") and cfg.checkpoint.resume_from: 278 | checkpoint_path = Path(cfg.checkpoint.resume_from) 279 | if checkpoint_path.exists(): 280 | print(f"[INFO] Resuming from specified checkpoint: {checkpoint_path}") 281 | return checkpoint_path 282 | 283 | # 指定されたファイルが見つからない場合は、警告ではなくエラーを発生させて処理を停止 284 | msg = f"Checkpoint specified in `resume_from` not found: {checkpoint_path}" 285 | raise FileNotFoundError(msg) 286 | 287 | # `resume_from`が指定されていない場合は、常にNoneを返し、新規学習を開始する 288 | return None 289 | -------------------------------------------------------------------------------- /src/miipher_2/lightning_vocoders/lightning_module.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | import hydra 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | import transformers 10 | from lightning.pytorch import LightningModule, loggers 11 | from lightning.pytorch.utilities.types import STEP_OUTPUT 12 | from omegaconf import DictConfig 13 | from webdataset import resampled 14 | 15 | from .hifigan import ( 16 | Generator, 17 | MultiPeriodDiscriminator, 18 | MultiScaleDiscriminator, 19 | discriminator_loss, 20 | feature_loss, 21 | generator_loss, 22 | ) 23 | 24 | 25 | class Preprocessor(torch.nn.Module): 26 | def __init__(self, cfg:DictConfig) -> None: 27 | super().__init__() 28 | self.resampler = torchaudio.transforms.Resample(cfg.sample_rate,16_000) 29 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 30 | cfg.sample_rate, 31 | cfg.preprocess.stft.n_fft, 32 | cfg.preprocess.stft.win_length, 33 | cfg.preprocess.stft.hop_length, 34 | cfg.preprocess.mel.f_min, 35 | cfg.preprocess.mel.f_max, 36 | n_mels=cfg.preprocess.mel.n_mels, 37 | ) 38 | def get_logmelspec(self,waveform): 39 | melspec = self.mel_spec(waveform) 40 | logmelspec = torch.log(torch.clamp_min(melspec, 1.0e-5) * 1.0).to(torch.float32) 41 | return logmelspec 42 | 43 | class HiFiGANLightningModule(LightningModule): 44 | def __init__(self, cfg: DictConfig) -> None: 45 | super().__init__() 46 | self.generator = Generator(cfg.model.generator) 47 | self.multi_period_discriminator = MultiPeriodDiscriminator() 48 | self.multi_scale_discriminator = MultiScaleDiscriminator() 49 | self.automatic_optimization = False 50 | self.preprocessor = Preprocessor(cfg) 51 | self.cfg = cfg 52 | self.save_hyperparameters() 53 | 54 | def configure_optimizers(self) -> Any: 55 | opt_g = hydra.utils.instantiate( 56 | self.cfg.model.optim.opt_g, params=self.generator.parameters() 57 | ) 58 | opt_d = hydra.utils.instantiate( 59 | self.cfg.model.optim.opt_d, 60 | params=itertools.chain( 61 | self.multi_scale_discriminator.parameters(), 62 | self.multi_period_discriminator.parameters(), 63 | ), 64 | ) 65 | scheduler_g = hydra.utils.instantiate( 66 | self.cfg.model.optim.scheduler_g, optimizer=opt_g 67 | ) 68 | scheduler_d = hydra.utils.instantiate( 69 | self.cfg.model.optim.scheduler_d, optimizer=opt_d 70 | ) 71 | 72 | return [opt_g, opt_d], [ 73 | {"name": "scheduler_g", "scheduler": scheduler_g}, 74 | {"name": "scheduler_d", "scheduler": scheduler_d}, 75 | ] 76 | 77 | def generator_forward(self,batch): 78 | generator_input = batch["input_feature"] 79 | wav_generator_out = self.generator(generator_input) 80 | return wav_generator_out 81 | 82 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 83 | wav,generator_input, _ = ( 84 | batch["resampled_speech.pth"], 85 | batch["input_feature"], 86 | batch["filenames"], 87 | ) 88 | mel = self.preprocessor.get_logmelspec(wav) 89 | wav = wav.unsqueeze(1) 90 | wav_generator_out = self.generator_forward(batch) 91 | output_length = min(wav_generator_out.size(2),wav.size(2)) 92 | wav = wav[:,:,:output_length] 93 | wav_generator_out = wav_generator_out[:,:,:output_length] 94 | 95 | opt_g, opt_d = self.optimizers() 96 | sch_g, sch_d = self.lr_schedulers() 97 | if self.global_step >= self.cfg.model.adversarial_start_step: 98 | opt_d.zero_grad() 99 | 100 | # mpd 101 | mpd_out_real, mpd_out_fake, _, _ = self.multi_period_discriminator( 102 | wav, wav_generator_out.detach() 103 | ) 104 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 105 | mpd_out_real, mpd_out_fake 106 | ) 107 | 108 | # msd 109 | msd_out_real, msd_out_fake, _, _ = self.multi_scale_discriminator( 110 | wav, wav_generator_out.detach() 111 | ) 112 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 113 | msd_out_real, msd_out_fake 114 | ) 115 | 116 | loss_disc_all = loss_disc_s + loss_disc_f 117 | self.manual_backward(loss_disc_all) 118 | opt_d.step() 119 | sch_d.step() 120 | self.log("train/discriminator/loss_disc_f", loss_disc_f) 121 | self.log("train/discriminator/loss_disc_s", loss_disc_s) 122 | else: 123 | loss_disc_f = loss_disc_s = 0.0 124 | 125 | # generator 126 | opt_g.zero_grad() 127 | predicted_mel = self.preprocessor.get_logmelspec(wav_generator_out.squeeze(1)) 128 | loss_recons = self.reconstruction_loss(mel, predicted_mel) 129 | loss_g = loss_recons * self.cfg.model.loss.recons_coef 130 | if self.global_step >= self.cfg.model.adversarial_start_step: 131 | ( 132 | mpd_out_real, 133 | mpd_out_fake, 134 | fmap_f_real, 135 | fmap_f_generated, 136 | ) = self.multi_period_discriminator(wav, wav_generator_out) 137 | loss_fm_mpd = feature_loss(fmap_f_real, fmap_f_generated) 138 | 139 | # msd 140 | ( 141 | msd_out_real, 142 | msd_out_fake, 143 | fmap_scale_real, 144 | fmap_scale_generated, 145 | ) = self.multi_scale_discriminator(wav, wav_generator_out) 146 | loss_fm_msd = feature_loss(fmap_scale_real, fmap_scale_generated) 147 | 148 | loss_g_mpd, losses_gen_f = generator_loss(mpd_out_fake) 149 | loss_g_msd, losses_gen_s = generator_loss(msd_out_fake) 150 | loss_g += loss_fm_mpd * self.cfg.model.loss.fm_mpd_coef 151 | loss_g += loss_fm_msd * self.cfg.model.loss.fm_msd_coef 152 | loss_g += loss_g_mpd * self.cfg.model.loss.g_mpd_coef 153 | loss_g += loss_g_msd * self.cfg.model.loss.g_msd_coef 154 | self.log("train/generator/loss_fm_mpd", loss_fm_mpd) 155 | self.log("train/generator/loss_fm_msd", loss_fm_msd) 156 | self.log("train/generator/loss_g_mpd", loss_g_mpd) 157 | self.log("train/generator/loss_g_msd", loss_g_msd) 158 | self.manual_backward(loss_g) 159 | self.log("train/loss_reconstruction", loss_recons) 160 | self.log("train/generator/loss", loss_g) 161 | opt_g.step() 162 | sch_g.step() 163 | 164 | def validation_step(self, batch, batch_idx): 165 | wav, generator_input, filename, wav_lens = ( 166 | batch["resampled_speech.pth"], 167 | batch["input_feature"], 168 | batch["filenames"], 169 | batch["wav_lens"], 170 | ) 171 | mel = self.preprocessor.get_logmelspec(wav) 172 | wav_generator_out = self.generator_forward(batch) 173 | predicted_mel = self.preprocessor.get_logmelspec(wav_generator_out.squeeze(1)) 174 | loss_recons = self.reconstruction_loss(mel, predicted_mel) 175 | if ( 176 | batch_idx < self.cfg.model.logging_wav_samples 177 | and self.global_rank == 0 178 | and self.local_rank == 0 179 | ): 180 | self.log_audio( 181 | wav_generator_out[0] 182 | .squeeze()[: wav_lens[0]] 183 | .cpu() 184 | .numpy() 185 | .astype(np.float32), 186 | name=f"generated/{filename[0]}", 187 | sampling_rate=self.cfg.sample_rate, 188 | ) 189 | self.log_audio( 190 | wav[0].squeeze()[: wav_lens[0]].cpu().numpy().astype(np.float32), 191 | name=f"natural/{filename[0]}", 192 | sampling_rate=self.cfg.sample_rate, 193 | ) 194 | 195 | self.log("val/reconstruction", loss_recons) 196 | def on_test_start(self): 197 | Path(f"{self.output_path}").mkdir(exist_ok=True,parents=True) 198 | def test_step(self,batch,batch_idx): 199 | generator_input = batch["input_feature"] 200 | wav_generator_out = self.generator_forward(batch) 201 | return wav_generator_out 202 | def on_test_batch_end(self, outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0): 203 | for output,filename,resampled in zip(outputs,batch["filenames"],batch["resampled_speech.pth"], strict=False): 204 | torchaudio.save(filepath=f"{self.output_path}/{filename}.wav",src=output.cpu(),sample_rate=self.cfg.sample_rate) 205 | torchaudio.save(filepath=f"{self.output_path}/{filename}_gt.wav",src=resampled.unsqueeze(0).cpu(),sample_rate=self.cfg.sample_rate) 206 | 207 | 208 | def reconstruction_loss(self, mel_gt, mel_predicted): 209 | length = min(mel_gt.size(2), mel_predicted.size(2)) 210 | return torch.nn.L1Loss()( 211 | mel_gt[:, :, :length], 212 | mel_predicted[:, :, :length], 213 | ) 214 | 215 | def log_audio(self, audio, name, sampling_rate): 216 | for logger in self.loggers: 217 | if type(logger) == loggers.WandbLogger: 218 | import wandb 219 | 220 | wandb.log( 221 | {name: wandb.Audio(audio, sample_rate=sampling_rate)}, 222 | step=self.global_step, 223 | ) 224 | elif type(logger) == loggers.TensorBoardLogger: 225 | logger.experiment.add_audio( 226 | name, 227 | audio, 228 | self.global_step, 229 | sampling_rate, 230 | ) 231 | -------------------------------------------------------------------------------- /src/miipher_2/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | audio_eval_utils.py 3 | =================== 4 | * 劣化パイプライン 5 | * 指標計算(MCD / X‑vector / ECAPA / WER / log‑F0‑RMSE) 6 | └ ASR・話者モデルは **必要時にのみ動的インポート** してロード時間を短縮 7 | * 共通 Logger utility 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import io 13 | 14 | # ───────────────────────── logging helper ────────────────────────── 15 | import logging 16 | import math 17 | import os 18 | import random 19 | import subprocess 20 | import sys 21 | import tempfile 22 | import time 23 | from pathlib import Path 24 | from typing import Dict, List 25 | 26 | import numpy as np 27 | import pysptk 28 | import pyworld as pw 29 | import soundfile as sf 30 | import torch 31 | import torchaudio 32 | 33 | 34 | def get_logger(name: str = "audio_eval", level: int = logging.INFO) -> logging.Logger: 35 | logger = logging.getLogger(name) 36 | if not logger.handlers: 37 | fmt = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S") 38 | h = logging.StreamHandler(sys.stdout) 39 | h.setFormatter(fmt) 40 | logger.addHandler(h) 41 | logger.setLevel(level) 42 | return logger 43 | 44 | 45 | log = get_logger(__name__) 46 | 47 | # ────────────────────── 劣化パイプライン ↓ ────────────────────── 48 | 49 | 50 | def _encode_with_codec(wave: torch.Tensor, sr: int, fmt: str) -> torch.Tensor: 51 | assert fmt in ("mp3", "ogg") 52 | with tempfile.TemporaryDirectory() as d: 53 | raw = Path(d, "raw.wav") 54 | torchaudio.save(str(raw), wave, sr, encoding="PCM_S", bits_per_sample=16) 55 | coded = Path(d, f"coded.{fmt}") 56 | codec = "libmp3lame" if fmt == "mp3" else "libvorbis" 57 | subprocess.run( 58 | ["ffmpeg", "-loglevel", "quiet", "-y", "-i", raw, "-codec:a", codec, "-b:a", "64k", coded], 59 | check=True, 60 | ) 61 | rec, sr2 = torchaudio.load(str(coded)) 62 | assert sr2 == sr 63 | return rec 64 | 65 | 66 | def _apply_reverb(wave: torch.Tensor, sr: int) -> torch.Tensor: 67 | wav, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, [["reverb", "50", "50", "100"]]) 68 | return wav 69 | 70 | 71 | def _mix_noise(clean: torch.Tensor, noise: torch.Tensor, snr_db: float) -> torch.Tensor: 72 | if noise.size(1) < clean.size(1): 73 | r = math.ceil(clean.size(1) / noise.size(1)) 74 | noise = noise.repeat(1, r)[:, : clean.size(1)] 75 | else: 76 | s = random.randint(0, noise.size(1) - clean.size(1)) 77 | noise = noise[:, s : s + clean.size(1)] 78 | 79 | pow_c = clean.pow(2).mean() 80 | pow_n = noise.pow(2).mean() 81 | scale = math.sqrt((pow_c / (10 ** (snr_db / 10))) / (pow_n + 1e-12)) 82 | return clean + noise * scale 83 | 84 | 85 | def degrade_waveform(clean: torch.Tensor, sr: int, noise_pool: list[torch.Tensor]) -> torch.Tensor: 86 | wav = clean.clone() 87 | wav = _encode_with_codec(wav, sr, random.choice(["mp3", "ogg"])) 88 | if random.random() < 0.5: 89 | wav = _apply_reverb(wav, sr) 90 | wav = _mix_noise(wav, random.choice(noise_pool), random.uniform(5, 30)) 91 | peak = wav.abs().max() 92 | if peak > 0: 93 | wav *= 0.89 / peak 94 | return wav 95 | 96 | 97 | def _next_pow2(x: int) -> int: 98 | """x 以上の最小 2 の冪.""" 99 | return 1 << (x - 1).bit_length() 100 | 101 | 102 | def _mcep_feat( 103 | x: np.ndarray, 104 | sr: int, 105 | order: int = 24, 106 | frame_shift_ms: float = 5.0, 107 | frame_len_ms: float = 25.0, 108 | ) -> np.ndarray: 109 | """ 110 | WORLD 系 MCD 計算用メルケプストラム系列抽出. 111 | 112 | Parameters 113 | ---------- 114 | x : np.ndarray 115 | 1‑D waveform (float64 推奨, range ‑1…1) 116 | sr : int 117 | Sample rate (e.g. 16000) 118 | order : int, optional 119 | メルケプストラム次数 (default 24 → 25 次元; いわゆる MCD‑13 なら 12) 120 | frame_shift_ms, frame_len_ms : float 121 | フレーム長 / シフト (ms) 122 | 123 | Returns 124 | ------- 125 | mc : ndarray, shape (num_frames, order+1) 126 | """ 127 | 128 | # ---------- 定数 ---------- 129 | alpha = 0.42 if sr == 16000 else 0.455 # 22.05 kHz → 0.455 等 130 | hop = int(frame_shift_ms * 0.001 * sr) 131 | win = int(frame_len_ms * 0.001 * sr) 132 | nfft = _next_pow2(win * 2) # 2 の冪でないと SPTK が落ちる 133 | 134 | # ---------- 窓関数 ---------- 135 | try: 136 | from pysptk.window import hamming # ≥ 0.3 137 | 138 | w = hamming(win) 139 | except (ImportError, AttributeError): 140 | w = np.hamming(win) # NumPy fallback 141 | 142 | # ---------- フレーム逐次処理 ---------- 143 | feats: list[np.ndarray] = [] 144 | err_frames = 0 145 | 146 | for i in range(0, len(x) - win, hop): 147 | frame = x[i : i + win] * w 148 | # 低振幅時に log(0)→ -inf を避けるため +1e‑12 149 | psd = np.abs(np.fft.rfft(frame, n=nfft)) ** 2 + 1e-12 150 | 151 | try: 152 | mc = pysptk.mcep( 153 | psd, 154 | order=order, 155 | alpha=alpha, 156 | maxiter=100, 157 | etype=1, # log magnitude spectral distortion 158 | ) 159 | except RuntimeError: 160 | # 無音 or 発散で mcep 失敗 → ゼロベクトルを採用 161 | err_frames += 1 162 | mc = np.zeros(order + 1, dtype=np.float64) 163 | 164 | feats.append(mc) 165 | 166 | if err_frames: 167 | log.debug("mcep: %d frames fell back to zeros (%.2f%%)", err_frames, 100 * err_frames / max(len(feats), 1)) 168 | 169 | if not feats: 170 | # クリッピング等で 1 フレームも取れない場合の保険 171 | return np.zeros((1, order + 1), dtype=np.float64) 172 | 173 | return np.vstack(feats) 174 | 175 | 176 | def _to_mono_1d_numpy(wave: torch.Tensor) -> np.ndarray: 177 | """PyTorch テンソルを、pyworld が要求するモノラル/1D/float64/C-contiguous な NumPy 配列に変換する""" 178 | # チャンネル次元があれば平均化してモノラルにする 179 | if wave.dim() > 1 and wave.shape[0] > 1: 180 | wave = wave.mean(dim=0, keepdim=True) 181 | 182 | # NumPy 配列に変換し、型とメモリ配置を整える 183 | np_wave = wave.squeeze().cpu().numpy().astype(np.float64) 184 | return np.ascontiguousarray(np_wave) 185 | 186 | 187 | def mcd(ref: torch.Tensor, syn: torch.Tensor, sr: int) -> float: 188 | ref_np = _to_mono_1d_numpy(ref) 189 | syn_np = _to_mono_1d_numpy(syn) 190 | 191 | r, s = _mcep_feat(ref_np, sr), _mcep_feat(syn_np, sr) 192 | 193 | L = min(len(r), len(s)) 194 | diff = (r[:L] - s[:L]) ** 2 195 | return (10 / np.log(10)) * np.sqrt(2 * diff.sum(axis=1)).mean() 196 | 197 | 198 | def logf0_rmse(ref: torch.Tensor, syn: torch.Tensor, sr: int) -> float: 199 | ref_np = _to_mono_1d_numpy(ref) 200 | syn_np = _to_mono_1d_numpy(syn) 201 | 202 | fr, _ = pw.harvest(ref_np, sr) 203 | fs, _ = pw.harvest(syn_np, sr) 204 | 205 | L = min(len(fr), len(fs)) 206 | mask = np.logical_and(fr[:L] > 0, fs[:L] > 0) 207 | if mask.sum() == 0: 208 | return float("nan") 209 | return float(np.sqrt(np.mean((np.log(fr[:L][mask]) - np.log(fs[:L][mask])) ** 2))) 210 | 211 | 212 | # ─── 以下は必要時ロード ────────────────────────────────────────── 213 | def _lazy_import_speechbrain(): 214 | global SpeakerRecognition 215 | from speechbrain.pretrained import SpeakerRecognition # type: ignore 216 | 217 | return SpeakerRecognition 218 | 219 | 220 | def _lazy_import_transformers(): 221 | global AutoProcessor, AutoModelForSpeechSeq2Seq 222 | from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor # type: ignore 223 | 224 | return AutoProcessor, AutoModelForSpeechSeq2Seq 225 | 226 | 227 | def speaker_cos( 228 | ref: torch.Tensor, 229 | syn: torch.Tensor, 230 | sr: int, 231 | recognizer, 232 | ) -> float: 233 | """ 234 | Robust cosine‑similarity (scalar) between two utterances. 235 | 236 | * 正常系: 返り値 ∈ [‑1, 1] 237 | * エンベディング形状が (B, T, D) / (B, D) / (D,) いずれでも動く 238 | """ 239 | 240 | # -------- waveform shape → (1, T) -------- 241 | if ref.dim() == 1: 242 | ref = ref.unsqueeze(0) 243 | if syn.dim() == 1: 244 | syn = syn.unsqueeze(0) 245 | 246 | device = next(recognizer.parameters()).device 247 | ref, syn = ref.to(device), syn.to(device) 248 | 249 | with torch.no_grad(): 250 | emb_ref = recognizer.encode_batch(ref) # shape: (⋯, D) 251 | emb_syn = recognizer.encode_batch(syn) 252 | 253 | # -------- すべての非最終次元を平均 -------- 254 | # (B, D) -> (D,) / (B, T, D) -> (D,) / (D,) -> (D,) 255 | while emb_ref.dim() > 1: 256 | emb_ref = emb_ref.mean(dim=0) 257 | while emb_syn.dim() > 1: 258 | emb_syn = emb_syn.mean(dim=0) 259 | 260 | # -------- cosine → scalar -------- 261 | sim = torch.nn.functional.cosine_similarity(emb_ref, emb_syn, dim=0, eps=1e-8) # ==> shape () 262 | return float(sim) 263 | 264 | 265 | def load_spk_models(device: str): 266 | SpeakerRecognition = _lazy_import_speechbrain() 267 | x = SpeakerRecognition.from_hparams( 268 | "speechbrain/spkrec-xvect-voxceleb", 269 | run_opts={"device": device}, 270 | savedir=str(Path.home() / ".cache/speechbrain/xvect"), 271 | ) 272 | e = SpeakerRecognition.from_hparams( 273 | "speechbrain/spkrec-ecapa-voxceleb", 274 | run_opts={"device": device}, 275 | savedir=str(Path.home() / ".cache/speechbrain/ecapa"), 276 | ) 277 | return x, e 278 | 279 | 280 | def load_asr(device: str): 281 | AutoProcessor, AutoModelForSpeechSeq2Seq = _lazy_import_transformers() 282 | proc = AutoProcessor.from_pretrained("openai/whisper-small") 283 | model = AutoModelForSpeechSeq2Seq.from_pretrained( 284 | "openai/whisper-small", 285 | torch_dtype=torch.float16 if "cuda" in device else torch.float32, 286 | ).to(device) 287 | model.eval() 288 | return model, proc 289 | 290 | 291 | def asr_wer(ref: torch.Tensor, syn: torch.Tensor, sr: int, model, proc, device: str) -> float: 292 | with torch.no_grad(): 293 | ir = proc(ref.squeeze().cpu().numpy(), sampling_rate=sr, return_tensors="pt") 294 | isyn = proc(syn.squeeze().cpu().numpy(), sampling_rate=sr, return_tensors="pt") 295 | 296 | model_dtype = next(model.parameters()).dtype 297 | ir = ir.to(device) 298 | ir["input_features"] = ir["input_features"].to(model_dtype) 299 | isyn = isyn.to(device) 300 | isyn["input_features"] = isyn["input_features"].to(model_dtype) 301 | 302 | tr = proc.decode(model.generate(**ir)[0]) 303 | ts = proc.decode(model.generate(**isyn)[0]) 304 | from jiwer import wer # local import for light start‑up 305 | 306 | return wer(tr, ts) 307 | -------------------------------------------------------------------------------- /scripts/aggregate_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import pandas as pd 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | def load_csv(csv_path): 7 | """CSVファイルを読み込む""" 8 | return pd.read_csv(csv_path) 9 | 10 | def calculate_stats(df, name): 11 | """統計情報を計算""" 12 | stats = { 13 | 'name': name, 14 | 'ecapa_cos_mean': df['ECAPA_cos'].mean(), 15 | 'ecapa_cos_std': df['ECAPA_cos'].std(), 16 | 'count': len(df) 17 | } 18 | 19 | # DNSMOSの4つのスコアを処理 20 | if 'DNSMOS_p808' in df.columns: 21 | stats.update({ 22 | 'dnsmos_p808_mean': df['DNSMOS_p808'].mean(), 23 | 'dnsmos_p808_std': df['DNSMOS_p808'].std(), 24 | 'dnsmos_sig_mean': df['DNSMOS_sig'].mean(), 25 | 'dnsmos_sig_std': df['DNSMOS_sig'].std(), 26 | 'dnsmos_bak_mean': df['DNSMOS_bak'].mean(), 27 | 'dnsmos_bak_std': df['DNSMOS_bak'].std(), 28 | 'dnsmos_ovr_mean': df['DNSMOS_ovr'].mean(), 29 | 'dnsmos_ovr_std': df['DNSMOS_ovr'].std(), 30 | }) 31 | # 旧DNSMOSv2フォーマットとの互換性 32 | elif 'DNSMOSv2' in df.columns: 33 | stats.update({ 34 | 'dnsmos_mean': df['DNSMOSv2'].mean(), 35 | 'dnsmos_std': df['DNSMOSv2'].std(), 36 | }) 37 | 38 | return stats 39 | 40 | def main(): 41 | # CSVファイルパス 42 | results_dir = Path("/home/ayu/GitHub/open-miipher-2/results") 43 | 44 | # モデルのサブディレクトリ 45 | model_dirs = ['hubert_large_l2', 'mhubert_l6', 'miipher_1', 'wav2vec2_base_l2', 'wavlm_base_l2'] 46 | 47 | # 元の音声の品質(samples.csv) 48 | results_original = [] 49 | original_csv_path = results_dir / 'samples.csv' 50 | if original_csv_path.exists(): 51 | df = load_csv(original_csv_path) 52 | results_original.append(calculate_stats(df, 'original')) 53 | 54 | # 8kHz劣化後の品質(samples_8khz_16khz.csv) 55 | results_8khz_baseline = [] 56 | baseline_8khz_csv_path = results_dir / 'samples_8khz_16khz.csv' 57 | if baseline_8khz_csv_path.exists(): 58 | df = load_csv(baseline_8khz_csv_path) 59 | results_8khz_baseline.append(calculate_stats(df, '8khz_degraded')) 60 | 61 | # ノイズ劣化後の品質(degrade_samples.csv) 62 | results_degrade_baseline = [] 63 | baseline_degrade_csv_path = results_dir / 'degrade_samples.csv' 64 | if baseline_degrade_csv_path.exists(): 65 | df = load_csv(baseline_degrade_csv_path) 66 | results_degrade_baseline.append(calculate_stats(df, 'noise_degraded')) 67 | 68 | # 8kHzの結果 69 | results_8khz = [] 70 | for model_dir in model_dirs: 71 | csv_path = results_dir / model_dir / 'samples_8khz_16khz.csv' 72 | if csv_path.exists(): 73 | df = load_csv(csv_path) 74 | results_8khz.append(calculate_stats(df, model_dir)) 75 | 76 | # degradeの結果 77 | results_degrade = [] 78 | for model_dir in model_dirs: 79 | csv_path = results_dir / model_dir / 'degrade_samples.csv' 80 | if csv_path.exists(): 81 | df = load_csv(csv_path) 82 | results_degrade.append(calculate_stats(df, model_dir)) 83 | 84 | # 結果を表示 85 | print("## 8kHz Results (samples_8khz_16khz.csv)") 86 | # DNSMOSとUTMOSの両方に対応した表示 87 | first_result = results_8khz[0] if results_8khz else {} 88 | if 'dnsmos_ovr_mean' in first_result: 89 | print("\n| Method | ECAPA-cos (mean±std) | DNSMOS Overall (mean±std) | DNSMOS SIG | DNSMOS BAK | Samples |") 90 | print("|--------|---------------------|---------------------------|------------|------------|---------|") 91 | # 元の音声品質を最初に表示 92 | if results_original: 93 | r = results_original[0] 94 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 95 | # 8kHz劣化後の品質を表示 96 | if results_8khz_baseline: 97 | r = results_8khz_baseline[0] 98 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 99 | for r in results_8khz: 100 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 101 | else: 102 | print("\n| Method | ECAPA-cos (mean±std) | DNSMOSv2 (mean±std) | Samples |") 103 | print("|--------|---------------------|---------------------|---------|") 104 | if results_original: 105 | r = results_original[0] 106 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 107 | if results_8khz_baseline: 108 | r = results_8khz_baseline[0] 109 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 110 | for r in results_8khz: 111 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 112 | 113 | print("\n## Degrade Results (degrade_samples.csv)") 114 | # DNSMOSとUTMOSの両方に対応した表示 115 | first_degrade = results_degrade[0] if results_degrade else {} 116 | if 'dnsmos_ovr_mean' in first_degrade: 117 | print("\n| Method | ECAPA-cos (mean±std) | DNSMOS Overall (mean±std) | DNSMOS SIG | DNSMOS BAK | Samples |") 118 | print("|--------|---------------------|---------------------------|------------|------------|---------|") 119 | # 元の音声品質を最初に表示 120 | if results_original: 121 | r = results_original[0] 122 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 123 | # ノイズ劣化後の品質を表示 124 | if results_degrade_baseline: 125 | r = results_degrade_baseline[0] 126 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 127 | for r in results_degrade: 128 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_ovr_mean']:.4f}±{r['dnsmos_ovr_std']:.4f} | {r['dnsmos_sig_mean']:.4f} | {r['dnsmos_bak_mean']:.4f} | {r['count']} |") 129 | else: 130 | print("\n| Method | ECAPA-cos (mean±std) | DNSMOSv2 (mean±std) | Samples |") 131 | print("|--------|---------------------|---------------------|---------|") 132 | if results_original: 133 | r = results_original[0] 134 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 135 | if results_degrade_baseline: 136 | r = results_degrade_baseline[0] 137 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 138 | for r in results_degrade: 139 | print(f"| {r['name']} | {r['ecapa_cos_mean']:.4f}±{r['ecapa_cos_std']:.4f} | {r['dnsmos_mean']:.4f}±{r['dnsmos_std']:.4f} | {r['count']} |") 140 | 141 | # 改善率の計算(miipher_1をベースラインとして使用) 142 | print("\n## Improvement Rates") 143 | 144 | # miipher_1をベースラインとして見つける 145 | baseline_8khz = None 146 | baseline_degrade = None 147 | 148 | for r in results_8khz: 149 | if r['name'] == 'miipher_1': 150 | baseline_8khz = r 151 | break 152 | 153 | for r in results_degrade: 154 | if r['name'] == 'miipher_1': 155 | baseline_degrade = r 156 | break 157 | 158 | print("\n### Relative to miipher_1") 159 | print("\n#### 8kHz") 160 | if baseline_8khz: 161 | for method in results_8khz: 162 | if method['name'] != 'miipher_1': 163 | ecapa_improve = ((method['ecapa_cos_mean'] - baseline_8khz['ecapa_cos_mean']) / baseline_8khz['ecapa_cos_mean']) * 100 164 | 165 | if 'dnsmos_ovr_mean' in method: 166 | dnsmos_improve = ((method['dnsmos_ovr_mean'] - baseline_8khz['dnsmos_ovr_mean']) / baseline_8khz['dnsmos_ovr_mean']) * 100 167 | print(f"{method['name']}: ECAPA-cos {ecapa_improve:+.2f}%, DNSMOS Overall {dnsmos_improve:+.2f}%") 168 | elif 'dnsmos_mean' in method: 169 | dnsmos_improve = ((method['dnsmos_mean'] - baseline_8khz['dnsmos_mean']) / baseline_8khz['dnsmos_mean']) * 100 170 | print(f"{method['name']}: ECAPA-cos {ecapa_improve:+.2f}%, DNSMOSv2 {dnsmos_improve:+.2f}%") 171 | 172 | print("\n#### Degrade") 173 | if baseline_degrade: 174 | for method in results_degrade: 175 | if method['name'] != 'miipher_1': 176 | ecapa_improve = ((method['ecapa_cos_mean'] - baseline_degrade['ecapa_cos_mean']) / baseline_degrade['ecapa_cos_mean']) * 100 177 | 178 | if 'dnsmos_ovr_mean' in method: 179 | dnsmos_improve = ((method['dnsmos_ovr_mean'] - baseline_degrade['dnsmos_ovr_mean']) / baseline_degrade['dnsmos_ovr_mean']) * 100 180 | print(f"{method['name']}: ECAPA-cos {ecapa_improve:+.2f}%, DNSMOS Overall {dnsmos_improve:+.2f}%") 181 | elif 'dnsmos_mean' in method: 182 | dnsmos_improve = ((method['dnsmos_mean'] - baseline_degrade['dnsmos_mean']) / baseline_degrade['dnsmos_mean']) * 100 183 | print(f"{method['name']}: ECAPA-cos {ecapa_improve:+.2f}%, DNSMOSv2 {dnsmos_improve:+.2f}%") 184 | 185 | # データフレームとして保存 186 | # 元の音声品質と劣化後の品質を各結果に含める 187 | combined_8khz = [] 188 | combined_degrade = [] 189 | 190 | if results_original: 191 | combined_8khz.append(results_original[0]) 192 | combined_degrade.append(results_original[0]) 193 | 194 | if results_8khz_baseline: 195 | combined_8khz.append(results_8khz_baseline[0]) 196 | 197 | if results_degrade_baseline: 198 | combined_degrade.append(results_degrade_baseline[0]) 199 | 200 | combined_8khz.extend(results_8khz) 201 | combined_degrade.extend(results_degrade) 202 | 203 | df_8khz = pd.DataFrame(combined_8khz) 204 | df_degrade = pd.DataFrame(combined_degrade) 205 | 206 | df_8khz.to_csv(results_dir / 'summary_8khz.csv', index=False) 207 | df_degrade.to_csv(results_dir / 'summary_degrade.csv', index=False) 208 | print("\n## Summary files saved") 209 | print(f"- {results_dir / 'summary_8khz.csv'}") 210 | print(f"- {results_dir / 'summary_degrade.csv'}") 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /src/miipher_2/lightning_vocoders/_hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock1(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock1, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class ResBlock2(torch.nn.Module): 113 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 114 | super(ResBlock2, self).__init__() 115 | self.h = h 116 | self.convs = nn.ModuleList( 117 | [ 118 | weight_norm( 119 | Conv1d( 120 | channels, 121 | channels, 122 | kernel_size, 123 | 1, 124 | dilation=dilation[0], 125 | padding=get_padding(kernel_size, dilation[0]), 126 | ) 127 | ), 128 | weight_norm( 129 | Conv1d( 130 | channels, 131 | channels, 132 | kernel_size, 133 | 1, 134 | dilation=dilation[1], 135 | padding=get_padding(kernel_size, dilation[1]), 136 | ) 137 | ), 138 | ] 139 | ) 140 | self.convs.apply(init_weights) 141 | 142 | def forward(self, x): 143 | for c in self.convs: 144 | xt = F.leaky_relu(x, LRELU_SLOPE) 145 | xt = c(xt) 146 | x = xt + x 147 | return x 148 | 149 | def remove_weight_norm(self): 150 | for l in self.convs: 151 | remove_weight_norm(l) 152 | 153 | 154 | class Generator(torch.nn.Module): 155 | def __init__(self, h): 156 | super(Generator, self).__init__() 157 | self.h = h 158 | self.num_kernels = len(h.resblock_kernel_sizes) 159 | self.num_upsamples = len(h.upsample_rates) 160 | self.conv_pre = weight_norm( 161 | Conv1d(h.num_input_channels, h.upsample_initial_channel, 7, 1, padding=3) 162 | ) 163 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 164 | 165 | self.ups = nn.ModuleList() 166 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 167 | self.ups.append( 168 | weight_norm( 169 | ConvTranspose1d( 170 | h.upsample_initial_channel // (2**i), 171 | h.upsample_initial_channel // (2 ** (i + 1)), 172 | k, 173 | u, 174 | padding=(k - u) // 2, 175 | ) 176 | ) 177 | ) 178 | 179 | self.resblocks = nn.ModuleList() 180 | for i in range(len(self.ups)): 181 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 182 | for j, (k, d) in enumerate( 183 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 184 | ): 185 | self.resblocks.append(resblock(h, ch, k, d)) 186 | 187 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 188 | self.ups.apply(init_weights) 189 | self.conv_post.apply(init_weights) 190 | 191 | def forward(self, x): 192 | x = self.conv_pre(x.transpose(1, 2)) 193 | for i in range(self.num_upsamples): 194 | x = F.leaky_relu(x, LRELU_SLOPE) 195 | x = self.ups[i](x) 196 | xs = None 197 | for j in range(self.num_kernels): 198 | if xs is None: 199 | xs = self.resblocks[i * self.num_kernels + j](x) 200 | else: 201 | xs += self.resblocks[i * self.num_kernels + j](x) 202 | x = xs / self.num_kernels 203 | x = F.leaky_relu(x) 204 | x = self.conv_post(x) 205 | x = torch.tanh(x) 206 | 207 | return x 208 | 209 | def remove_weight_norm(self): 210 | print("Removing weight norm...") 211 | for l in self.ups: 212 | remove_weight_norm(l) 213 | for l in self.resblocks: 214 | l.remove_weight_norm() 215 | remove_weight_norm(self.conv_pre) 216 | remove_weight_norm(self.conv_post) 217 | 218 | 219 | class DiscriminatorP(torch.nn.Module): 220 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 221 | super(DiscriminatorP, self).__init__() 222 | self.period = period 223 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 224 | self.convs = nn.ModuleList( 225 | [ 226 | norm_f( 227 | Conv2d( 228 | 1, 229 | 32, 230 | (kernel_size, 1), 231 | (stride, 1), 232 | padding=(get_padding(5, 1), 0), 233 | ) 234 | ), 235 | norm_f( 236 | Conv2d( 237 | 32, 238 | 128, 239 | (kernel_size, 1), 240 | (stride, 1), 241 | padding=(get_padding(5, 1), 0), 242 | ) 243 | ), 244 | norm_f( 245 | Conv2d( 246 | 128, 247 | 512, 248 | (kernel_size, 1), 249 | (stride, 1), 250 | padding=(get_padding(5, 1), 0), 251 | ) 252 | ), 253 | norm_f( 254 | Conv2d( 255 | 512, 256 | 1024, 257 | (kernel_size, 1), 258 | (stride, 1), 259 | padding=(get_padding(5, 1), 0), 260 | ) 261 | ), 262 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 263 | ] 264 | ) 265 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 266 | 267 | def forward(self, x): 268 | fmap = [] 269 | 270 | # 1d to 2d 271 | b, c, t = x.shape 272 | if t % self.period != 0: # pad first 273 | n_pad = self.period - (t % self.period) 274 | x = F.pad(x, (0, n_pad), "reflect") 275 | t = t + n_pad 276 | x = x.view(b, c, t // self.period, self.period) 277 | 278 | for l in self.convs: 279 | x = l(x) 280 | x = F.leaky_relu(x, LRELU_SLOPE) 281 | fmap.append(x) 282 | x = self.conv_post(x) 283 | fmap.append(x) 284 | x = torch.flatten(x, 1, -1) 285 | 286 | return x, fmap 287 | 288 | 289 | class MultiPeriodDiscriminator(torch.nn.Module): 290 | def __init__(self): 291 | super(MultiPeriodDiscriminator, self).__init__() 292 | self.discriminators = nn.ModuleList( 293 | [ 294 | DiscriminatorP(2), 295 | DiscriminatorP(3), 296 | DiscriminatorP(5), 297 | DiscriminatorP(7), 298 | DiscriminatorP(11), 299 | ] 300 | ) 301 | 302 | def forward(self, y, y_hat): 303 | y_d_rs = [] 304 | y_d_gs = [] 305 | fmap_rs = [] 306 | fmap_gs = [] 307 | for i, d in enumerate(self.discriminators): 308 | y_d_r, fmap_r = d(y) 309 | y_d_g, fmap_g = d(y_hat) 310 | y_d_rs.append(y_d_r) 311 | fmap_rs.append(fmap_r) 312 | y_d_gs.append(y_d_g) 313 | fmap_gs.append(fmap_g) 314 | 315 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 316 | 317 | 318 | class DiscriminatorS(torch.nn.Module): 319 | def __init__(self, use_spectral_norm=False): 320 | super(DiscriminatorS, self).__init__() 321 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 322 | self.convs = nn.ModuleList( 323 | [ 324 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 325 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 326 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 327 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 328 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 329 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 330 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 331 | ] 332 | ) 333 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 334 | 335 | def forward(self, x): 336 | fmap = [] 337 | for l in self.convs: 338 | x = l(x) 339 | x = F.leaky_relu(x, LRELU_SLOPE) 340 | fmap.append(x) 341 | x = self.conv_post(x) 342 | fmap.append(x) 343 | x = torch.flatten(x, 1, -1) 344 | 345 | return x, fmap 346 | 347 | 348 | class MultiScaleDiscriminator(torch.nn.Module): 349 | def __init__(self): 350 | super(MultiScaleDiscriminator, self).__init__() 351 | self.discriminators = nn.ModuleList( 352 | [ 353 | DiscriminatorS(use_spectral_norm=True), 354 | DiscriminatorS(), 355 | DiscriminatorS(), 356 | ] 357 | ) 358 | self.meanpools = nn.ModuleList( 359 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] 360 | ) 361 | 362 | def forward(self, y, y_hat): 363 | y_d_rs = [] 364 | y_d_gs = [] 365 | fmap_rs = [] 366 | fmap_gs = [] 367 | for i, d in enumerate(self.discriminators): 368 | if i != 0: 369 | y = self.meanpools[i - 1](y) 370 | y_hat = self.meanpools[i - 1](y_hat) 371 | y_d_r, fmap_r = d(y) 372 | y_d_g, fmap_g = d(y_hat) 373 | y_d_rs.append(y_d_r) 374 | fmap_rs.append(fmap_r) 375 | y_d_gs.append(y_d_g) 376 | fmap_gs.append(fmap_g) 377 | 378 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 379 | 380 | 381 | def feature_loss(fmap_r, fmap_g): 382 | loss = 0 383 | for dr, dg in zip(fmap_r, fmap_g): 384 | for rl, gl in zip(dr, dg): 385 | loss += torch.mean(torch.abs(rl - gl)) 386 | 387 | return loss * 2 388 | 389 | 390 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 391 | loss = 0 392 | r_losses = [] 393 | g_losses = [] 394 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 395 | r_loss = torch.mean((1 - dr) ** 2) 396 | g_loss = torch.mean(dg**2) 397 | loss += r_loss + g_loss 398 | r_losses.append(r_loss.item()) 399 | g_losses.append(g_loss.item()) 400 | 401 | return loss, r_losses, g_losses 402 | 403 | 404 | def generator_loss(disc_outputs): 405 | loss = 0 406 | gen_losses = [] 407 | for dg in disc_outputs: 408 | l = torch.mean((1 - dg) ** 2) 409 | gen_losses.append(l) 410 | loss += l 411 | 412 | return loss, gen_losses 413 | --------------------------------------------------------------------------------