├── .python-version ├── src └── miipher │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-310.pyc │ ├── dataset │ ├── libritts.py │ ├── preprocess_for_infer.py │ ├── jvs_corpus.py │ └── datamodule.py │ ├── model │ ├── miipher.py │ └── modules.py │ ├── preprocess │ ├── preprocessor.py │ └── noiseAugmentation.py │ └── lightning_module.py ├── .vscode ├── settings.json └── launch.json ├── examples ├── pretrained_models │ └── EncoderClassifier-8f6f7fdaa9628acf73e21ad1f99d5f83 │ │ └── hyperparams.yaml ├── preprocess.py ├── train.py ├── demo.py ├── prepare_finetune_vocoder_dataset.py └── configs │ └── config.yaml ├── scripts └── run_training.sh ├── tests └── test_models.py ├── pyproject.toml ├── LICENSE ├── README.md ├── .gitignore ├── requirements.lock └── requirements-dev.lock /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.11 2 | -------------------------------------------------------------------------------- /src/miipher/__init__.py: -------------------------------------------------------------------------------- 1 | def hello(): 2 | return "Hello from miipher!" 3 | -------------------------------------------------------------------------------- /src/miipher/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wataru-Nakata/miipher/HEAD/src/miipher/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true 7 | } -------------------------------------------------------------------------------- /examples/pretrained_models/EncoderClassifier-8f6f7fdaa9628acf73e21ad1f99d5f83/hyperparams.yaml: -------------------------------------------------------------------------------- 1 | /home/wnakata/.cache/huggingface/hub/models--speechbrain--spkrec-ecapa-voxceleb/snapshots/5c0be3875fda05e81f3c004ed8c7c06be308de1e/hyperparams.yaml -------------------------------------------------------------------------------- /scripts/run_training.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AF=1 3 | #$ -l h_rt=168:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | cd examples 13 | python3 train.py 14 | -------------------------------------------------------------------------------- /examples/preprocess.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | from lightning.pytorch import seed_everything 4 | 5 | from miipher.preprocess.preprocessor import Preprocessor 6 | 7 | 8 | @hydra.main(version_base="1.3", config_name="config", config_path="./configs") 9 | def main(cfg: DictConfig): 10 | seed_everything(1234) 11 | preprocssor = Preprocessor(cfg=cfg) 12 | preprocssor.build_from_path() 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from miipher.model.miipher import Miipher 4 | 5 | 6 | class TestMiipher(unittest.TestCase): 7 | def setUp(self) -> None: 8 | self.miipher = Miipher(512, 256, 1024, 1024, 4, 2) 9 | 10 | def test_miipher(self): 11 | phone_feature = torch.rand(2, 129, 512) 12 | speaker_feature = torch.rand(2, 256) 13 | ssl_feature = torch.rand(2, 121, 1024) 14 | output = self.miipher.forward( 15 | phone_feature, speaker_feature, ssl_feature, torch.tensor([121, 121]) 16 | ) 17 | self.assertTrue(output[0].size() == ssl_feature.size()) 18 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from lightning.pytorch.core import datamodule 3 | from omegaconf import DictConfig 4 | from lightning.pytorch import Trainer, seed_everything 5 | import torch 6 | from miipher import lightning_module 7 | from miipher.lightning_module import MiipherLightningModule 8 | from miipher.dataset.datamodule import MiipherDataModule 9 | 10 | 11 | @hydra.main(version_base="1.3", config_name="config", config_path="./configs") 12 | def main(cfg: DictConfig): 13 | torch.set_float32_matmul_precision("medium") 14 | lightning_module = MiipherLightningModule(cfg) 15 | datamodule = MiipherDataModule(cfg) 16 | loggers = hydra.utils.instantiate(cfg.train.loggers) 17 | trainer = hydra.utils.instantiate(cfg.train.trainer, logger=loggers) 18 | trainer.fit(lightning_module, datamodule) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "name": "Python: prprocess", 17 | "type": "python", 18 | "request": "launch", 19 | "program": "preprocess.py", 20 | "console": "integratedTerminal", 21 | "cwd": "${workspaceFolder}/examples", 22 | "justMyCode": true 23 | } 24 | ] 25 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "miipher" 3 | version = "0.1.0" 4 | description = "Add a short description here" 5 | authors = [ 6 | { name = "Wataru Nakata", email = "wataru9871@gmail.com" } 7 | ] 8 | dependencies = ["lightning~=2.0.5", "torchaudio~=2.0.2", "speechbrain~=0.5.14", "matplotlib~=3.7.2", "pyroomacoustics~=0.7.3", "hydra-core~=1.3.2", "webdataset~=0.2.48", "text2phonemesequence~=0.1.4", "mecab-python3~=1.0.6", "unidic~=1.1.0", "wandb~=0.15.7", "lightning_vocoders @ git+https://github.com/Wataru-Nakata/ssl-vocoders", "llvmlite~=0.40.1", "gradio~=3.45.2"] 9 | readme = "README.md" 10 | requires-python = ">= 3.10" 11 | 12 | [build-system] 13 | requires = ["hatchling"] 14 | build-backend = "hatchling.build" 15 | 16 | [tool.rye] 17 | managed = true 18 | dev-dependencies = ["ipykernel~=6.24.0", "pytest~=7.4.0", "black~=23.7.0"] 19 | [tool.hatch.metadata] 20 | allow-direct-references = true 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wataru-Nakata 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # miipher 2 | This repository proviedes unofficial implementation of speech restoration model Miipher. 3 | Miipher is originally proposed by Koizumi et. al. [arxiv](https://arxiv.org/abs/2303.01664) 4 | Please note that the model provided in this repository doesn't represent the performance of the original model proposed by Koizumi et. al. as this implementation differs in many ways from the paper. 5 | 6 | # Installation 7 | Install with pip. The installation is confirmed on Python 3.10.11 8 | ```python 9 | pip install git+https://github.com/Wataru-Nakata/miipher 10 | ``` 11 | 12 | # Pretrained model 13 | The pretrained model is trained on [LibriTTS-R](http://www.openslr.org/141/) and [JVS corpus](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus), 14 | and provided in **CC-BY-NC-2.0 license**. 15 | 16 | The models are hosted on [huggingface](https://huggingface.co/spaces/Wataru/Miipher/) 17 | 18 | To use pretrained model, please refere to `examples/demo.py` 19 | 20 | # Differences from the original paper 21 | | | [original paper](https://arxiv.org/abs/2303.01664) | This repo | 22 | |---|---|---| 23 | | Clean speech dataset | proprietary | [LibriTTS-R](http://www.openslr.org/141/) and [JVS corpus](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus) | 24 | | Noise dataset | TAU Urban Audio-Visual Scenes 2021 dataset | TAU Urban Audio-Visual Scenes 2021 dataset and Slakh2100 | 25 | | Speech SSL model | [W2v-BERT XL](https://arxiv.org/abs/2108.06209) | [WavLM-large](https://arxiv.org/abs/2110.13900) | 26 | | Language SSL model | [PnG BERT](https://arxiv.org/abs/2103.15060) | [XPhoneBERT](https://github.com/VinAIResearch/XPhoneBERT) | 27 | | Feature cleaner building block | [DF-Conformer](https://arxiv.org/abs/2106.15813) | [Conformer](https://arxiv.org/abs/2005.08100) | 28 | | Vocoder | [WaveFit](https://arxiv.org/abs/2210.01029) | [HiFi-GAN](https://arxiv.org/abs/2010.05646) | 29 | | X-Vector model | Streaming Conformer-based speaker encoding model | [speechbrain/spkrec-xvect-voxceleb](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) | 30 | 31 | # LICENSE 32 | Code in this repo: MIT License 33 | 34 | Weights on huggingface: CC-BY-NC-2.0 license 35 | 36 | -------------------------------------------------------------------------------- /examples/demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from miipher.dataset.preprocess_for_infer import PreprocessForInfer 3 | from miipher.lightning_module import MiipherLightningModule 4 | from lightning_vocoders.models.hifigan.xvector_lightning_module import HiFiGANXvectorLightningModule 5 | import torch 6 | import torchaudio 7 | import hydra 8 | import tempfile 9 | 10 | miipher_path = "https://huggingface.co/spaces/Wataru/Miipher/resolve/main/miipher.ckpt" 11 | miipher = MiipherLightningModule.load_from_checkpoint(miipher_path,map_location='cpu') 12 | vocoder = HiFiGANXvectorLightningModule.load_from_checkpoint("https://huggingface.co/spaces/Wataru/Miipher/resolve/main/vocoder_finetuned.ckpt",map_location='cpu') 13 | xvector_model = hydra.utils.instantiate(vocoder.cfg.data.xvector.model) 14 | xvector_model = xvector_model.to('cpu') 15 | preprocessor = PreprocessForInfer(miipher.cfg) 16 | 17 | @torch.inference_mode() 18 | def main(wav_path,transcript,lang_code): 19 | wav,sr =torchaudio.load(wav_path) 20 | wav = wav[0].unsqueeze(0) 21 | batch = preprocessor.process( 22 | 'test', 23 | (torch.tensor(wav),sr), 24 | word_segmented_text=transcript, 25 | lang_code=lang_code 26 | ) 27 | 28 | miipher.feature_extractor(batch) 29 | ( 30 | phone_feature, 31 | speaker_feature, 32 | degraded_ssl_feature, 33 | _, 34 | ) = miipher.feature_extractor(batch) 35 | cleaned_ssl_feature, _ = miipher(phone_feature,speaker_feature,degraded_ssl_feature) 36 | vocoder_xvector = xvector_model.encode_batch(batch['degraded_wav_16k'].view(1,-1).cpu()).squeeze(1) 37 | cleaned_wav = vocoder.generator_forward({"input_feature": cleaned_ssl_feature, "xvector": vocoder_xvector})[0].T 38 | with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as fp: 39 | torchaudio.save(fp,cleaned_wav.view(1,-1), sample_rate=22050,format='wav') 40 | return fp.name 41 | 42 | inputs = [gr.Audio(label="noisy audio",type='filepath'),gr.Textbox(label="Transcript", value="Your transcript here", max_lines=1), 43 | gr.Radio(label="Language", choices=["eng-us", "jpn"], value="eng-us")] 44 | outputs = gr.Audio(label="Output") 45 | 46 | demo = gr.Interface(fn=main, inputs=inputs, outputs=outputs) 47 | 48 | demo.launch() 49 | -------------------------------------------------------------------------------- /src/miipher/dataset/libritts.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torchaudio 3 | from pathlib import Path 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class LibriTTSCorpus(Dataset): 8 | def __init__(self, root) -> None: 9 | super().__init__() 10 | self.root = Path(root) 11 | self.wav_files = list(self.root.glob("**/*.wav")) 12 | self.wav_files = [ 13 | x 14 | for x in self.wav_files 15 | if ( 16 | Path(str(x).replace(".wav", ".normalized.txt")).exists() 17 | and Path(str(x).replace(".wav", ".original.txt")).exists() 18 | ) 19 | ] 20 | 21 | def __getitem__(self, index): 22 | wav_path = self.wav_files[index] 23 | wav_path = wav_path.resolve() 24 | basename = wav_path.stem 25 | m = re.search(r"^(\d+?)\_(\d+?)\_(\d+?\_\d+?)$", basename) 26 | speaker, chapter, utt_id = m.group(1), m.group(2), m.group(3) 27 | with wav_path.with_suffix(".normalized.txt").open() as f: 28 | lines = f.readlines() 29 | line = " ".join(lines) 30 | line = line.strip() 31 | clean_text = line 32 | with wav_path.with_suffix(".original.txt").open() as f: 33 | lines = f.readlines() 34 | line = " ".join(lines) 35 | line = line.strip() 36 | punc_text = line 37 | 38 | output = { 39 | "wav_path": str(wav_path), 40 | "speaker": speaker, 41 | "chapter": chapter, 42 | "utt_id": utt_id, 43 | "clean_text": clean_text, 44 | "word_segmented_text": clean_text, 45 | "punc_text": punc_text, 46 | "basename": basename, 47 | "lang_code": "eng-us" 48 | # "phones": phones 49 | } 50 | 51 | return output 52 | 53 | def __len__(self): 54 | return len(self.wav_files) 55 | 56 | @property 57 | def speaker_dict(self): 58 | speakers = set() 59 | for wav_file in self.wav_files: 60 | basename = wav_file.stem 61 | m = re.search(r"^(\d+?)\_(\d+?)\_(\d+?\_\d+?)$", basename) 62 | speaker, chapter, utt_id = m.group(1), m.group(2), m.group(3) 63 | speakers.add(speaker) 64 | speaker_dict = {x: idx for idx, x in enumerate(speakers)} 65 | return speaker_dict 66 | 67 | @property 68 | def lang_code(self): 69 | return "eng-us" 70 | -------------------------------------------------------------------------------- /src/miipher/dataset/preprocess_for_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import torchaudio 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | class PreprocessForInfer(torch.nn.Module): 7 | def __init__(self,cfg): 8 | super().__init__() 9 | self.phoneme_tokenizer = hydra.utils.instantiate( 10 | cfg.preprocess.phoneme_tokenizer 11 | ) 12 | self.speech_ssl_processor = hydra.utils.instantiate( 13 | cfg.data.speech_ssl_processor.processor 14 | ) 15 | self.speech_ssl_sr = cfg.data.speech_ssl_processor.sr 16 | self.cfg = cfg 17 | self.text2phone_dict = dict() 18 | 19 | @torch.inference_mode() 20 | def get_phonemes_input_ids(self, word_segmented_text, lang_code): 21 | if lang_code not in self.text2phone_dict.keys(): 22 | self.text2phone_dict[lang_code] = hydra.utils.instantiate( 23 | self.cfg.preprocess.text2phone_model, language=lang_code, is_cuda=False 24 | ) 25 | input_phonemes = self.text2phone_dict[lang_code].infer_sentence( 26 | word_segmented_text 27 | ) 28 | input_ids = self.phoneme_tokenizer(input_phonemes, return_tensors="pt") 29 | return input_ids, input_phonemes 30 | def process(self,basename, degraded_audio,word_segmented_text=None,lang_code=None, phoneme_text=None): 31 | degraded_audio,sr = degraded_audio 32 | output = dict() 33 | 34 | if word_segmented_text != None and lang_code != None: 35 | input_ids, input_phonems = self.get_phonemes_input_ids( 36 | word_segmented_text, lang_code 37 | ) 38 | output['phoneme_input_ids'] = input_ids 39 | elif phoneme_text == None: 40 | raise ValueError 41 | else: 42 | output["phoneme_input_ids"] = self.phoneme_tokenizer( 43 | phoneme_text, return_tensors="pt", padding=True 44 | ) 45 | 46 | degraded_16k = torchaudio.functional.resample( 47 | degraded_audio, sr, new_freq=16000 48 | ).squeeze() 49 | degraded_wav_16ks = [degraded_16k] 50 | 51 | output["degraded_ssl_input"] = self.speech_ssl_processor( 52 | [x.cpu().numpy() for x in degraded_wav_16ks], 53 | return_tensors="pt", 54 | sampling_rate=16000, 55 | padding=True, 56 | ) 57 | output["degraded_wav_16k"] = pad_sequence(degraded_wav_16ks, batch_first=True) 58 | output["degraded_wav_16k_lengths"] = torch.tensor( 59 | [degraded_wav_16k.size(0) for degraded_wav_16k in degraded_wav_16ks] 60 | ) 61 | return output 62 | 63 | -------------------------------------------------------------------------------- /src/miipher/dataset/jvs_corpus.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from torch.utils.data import Dataset 3 | from pathlib import Path 4 | import MeCab 5 | 6 | 7 | class JVSCorpus(Dataset): 8 | def __init__(self, root, exclude_speakers=[]) -> None: 9 | super().__init__() 10 | self.root = Path(root) 11 | self.speakers = [ 12 | f.stem 13 | for f in self.root.glob("jvs*") 14 | if f.is_dir() and f.stem not in exclude_speakers 15 | ] 16 | self.clean_texts = dict() 17 | self.wav_files = [] 18 | for speaker in self.speakers: 19 | transcript_files = (self.root / speaker).glob("**/transcripts_utf8.txt") 20 | for transcript_file in transcript_files: 21 | subset = transcript_file.parent.name 22 | with transcript_file.open() as f: 23 | lines = f.readlines() 24 | for line in lines: 25 | wav_name, text = line.strip().split(":") 26 | self.clean_texts[f"{speaker}/{subset}/{wav_name}"] = text 27 | wav_path = self.root / Path( 28 | f"{speaker}/{subset}/wav24kHz16bit/{wav_name}.wav" 29 | ) 30 | if wav_path.exists(): 31 | self.wav_files.append(wav_path) 32 | 33 | self.tokenizer = MeCab.Tagger("-Owakati") 34 | 35 | def __getitem__(self, index): 36 | wav_path = self.wav_files[index] 37 | wav_tensor, sr = torchaudio.load(wav_path) 38 | wav_path = wav_path.resolve() 39 | speaker = wav_path.parent.parent.parent.stem 40 | subset = wav_path.parent.parent.stem 41 | wav_name = wav_path.stem 42 | 43 | clean_text = self.clean_texts[f"{speaker}/{subset}/{wav_name}"] 44 | 45 | basename = f"{subset}_{speaker}_{wav_name}" 46 | tokenized = self.tokenizer.parse(clean_text) 47 | output = { 48 | "wav_path": str(wav_path), 49 | "speaker": speaker, 50 | "clean_text": clean_text, 51 | "word_segmented_text": tokenized, 52 | "basename": basename, 53 | "lang_code": "jpn", 54 | } 55 | 56 | return output 57 | 58 | def __len__(self): 59 | return len(self.wav_files) 60 | 61 | @property 62 | def speaker_dict(self): 63 | speakers = set() 64 | for wav_path in self.wav_files: 65 | speakers.add(wav_path.parent.parent.parent.stem) 66 | speaker_dict = {x: idx for idx, x in enumerate(speakers)} 67 | return speaker_dict 68 | 69 | @property 70 | def lang_code(self): 71 | return "jpn" 72 | -------------------------------------------------------------------------------- /examples/prepare_finetune_vocoder_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from miipher.lightning_module import MiipherLightningModule 4 | import webdataset 5 | from pathlib import Path 6 | from miipher.dataset.preprocess_for_infer import PreprocessForInfer 7 | from lightning_vocoders.models.hifigan.xvector_lightning_module import HiFiGANXvectorLightningModule 8 | import hydra 9 | from tqdm import tqdm 10 | 11 | @torch.inference_mode() 12 | def main(miipher_path: Path): 13 | torch.set_float32_matmul_precision("medium") 14 | miipher = MiipherLightningModule.load_from_checkpoint(miipher_path) 15 | train_dataset = webdataset.WebDataset(miipher.cfg.data.train_dataset_path).decode(webdataset.torch_audio) 16 | val_dataset = webdataset.WebDataset(miipher.cfg.data.val_dataset_path).decode(webdataset.torch_audio) 17 | train_dl = DataLoader(train_dataset,num_workers=8) 18 | val_dl = DataLoader(val_dataset,num_workers=8) 19 | preprocessor = PreprocessForInfer(miipher.cfg) 20 | train_sink = webdataset.TarWriter("/mnt/hdd/finetune_train.tar.gz") 21 | val_sink = webdataset.TarWriter("/mnt/hdd/finetune_val.tar.gz") 22 | device = miipher.device 23 | miipher.feature_extractor.to(device) 24 | vocoder = HiFiGANXvectorLightningModule.load_from_checkpoint("https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wavlm-large-l8-xvector/wavlm-large-l8-xvector.ckpt",map_location='cpu') 25 | vocoder.eval() 26 | xvector_model = hydra.utils.instantiate(vocoder.cfg.data.xvector.model) 27 | xvector_model = xvector_model.to('cpu') 28 | del(vocoder) 29 | xvector_model.eval() 30 | 31 | for dl,sink in zip([val_dl,train_dl],[val_sink,train_sink]): 32 | for sample in tqdm(dl): 33 | batch = preprocessor.process( 34 | sample['__key__'], 35 | sample['degraded_speech.wav'], 36 | phoneme_text= sample['phoneme.txt'] 37 | ) 38 | for k,v in batch.items(): 39 | batch[k] = v.to(device) 40 | ( 41 | phone_feature, 42 | speaker_feature, 43 | degraded_ssl_feature, 44 | _, 45 | ) = miipher.feature_extractor(batch) 46 | cleaned_ssl_feature, _ = miipher(phone_feature.to(device),speaker_feature.to(device),degraded_ssl_feature.to(device)) 47 | vocoder_xvector = xvector_model.encode_batch(batch['degraded_wav_16k'].view(1,-1).cpu()).squeeze(1) 48 | 49 | sample_to_write = { 50 | "__key__": sample['__key__'][0], 51 | "resampled_speech.pth": webdataset.torch_dumps(sample['resampled_speech.pth'][0].cpu()), 52 | "miipher_cleaned_feature.pth": webdataset.torch_dumps(cleaned_ssl_feature[0].cpu()), 53 | "xvector.pth": webdataset.torch_dumps(vocoder_xvector.view(-1).cpu()) 54 | } 55 | sink.write(sample_to_write) 56 | sink.close() 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | ckpt_path = Path("miipher/0kt6hnn2/checkpoints/epoch=19-step=400000.ckpt") 62 | main(ckpt_path) 63 | -------------------------------------------------------------------------------- /src/miipher/dataset/datamodule.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch import LightningDataModule 2 | from torch.utils.data import DataLoader 3 | from torch.nn.utils.rnn import pad_sequence 4 | import webdataset as wds 5 | import torch 6 | import torchaudio 7 | import hydra 8 | 9 | 10 | class MiipherDataModule(LightningDataModule): 11 | def __init__(self, cfg) -> None: 12 | super().__init__() 13 | self.speech_ssl_processor = hydra.utils.instantiate( 14 | cfg.data.speech_ssl_processor.processor 15 | ) 16 | self.speech_ssl_sr = cfg.data.speech_ssl_processor.sr 17 | self.phoneme_tokenizer = hydra.utils.instantiate(cfg.data.phoneme_tokenizer) 18 | self.cfg = cfg 19 | 20 | def setup(self, stage: str): 21 | self.train_dataset = ( 22 | wds.WebDataset( 23 | self.cfg.data.train_dataset_path, 24 | resampled=True, 25 | nodesplitter=wds.split_by_node, 26 | ) 27 | .shuffle(1000) 28 | .decode(wds.torch_audio) 29 | # .decode(self.decode_phoneme_input) 30 | .repeat(2) 31 | .with_length(20000 * self.cfg.data.train_batch_size) 32 | ) 33 | self.val_dataset = ( 34 | wds.WebDataset( 35 | self.cfg.data.val_dataset_path, nodesplitter=wds.split_by_node 36 | ) 37 | .decode(wds.torch_audio) 38 | # .decode(self.decode_phoneme_input) 39 | .repeat(2) 40 | .with_length(3000 * 4 // self.cfg.data.val_batch_size) 41 | ) 42 | 43 | def train_dataloader(self): 44 | return DataLoader( 45 | self.train_dataset, 46 | batch_size=self.cfg.data.train_batch_size, 47 | collate_fn=self.collate_fn, 48 | num_workers=8, 49 | ) 50 | 51 | def val_dataloader(self): 52 | return DataLoader( 53 | self.val_dataset, 54 | batch_size=self.cfg.data.val_batch_size, 55 | collate_fn=self.collate_fn, 56 | num_workers=8, 57 | ) 58 | 59 | @torch.no_grad() 60 | def collate_fn(self, batch): 61 | output = dict() 62 | degraded_wav_16ks = [] 63 | clean_wav_16ks = [] 64 | 65 | for sample in batch: 66 | clean_wav, sr = sample["speech.wav"] 67 | clean_wav_16ks.append( 68 | torchaudio.functional.resample(clean_wav, sr, new_freq=16000).squeeze()[:16000*20] 69 | ) 70 | degraded_wav, sr = sample["degraded_speech.wav"] 71 | degraded_wav_16ks.append( 72 | torchaudio.functional.resample( 73 | degraded_wav, sr, new_freq=16000 74 | ).squeeze()[:16000*20] 75 | ) 76 | output["degraded_wav_16k"] = pad_sequence(degraded_wav_16ks, batch_first=True) 77 | output["degraded_wav_16k_lengths"] = torch.tensor( 78 | [degraded_wav_16k.size(0) for degraded_wav_16k in degraded_wav_16ks] 79 | ) 80 | output["clean_ssl_input"] = self.speech_ssl_processor( 81 | [x.numpy() for x in clean_wav_16ks], 82 | return_tensors="pt", 83 | sampling_rate=16000, 84 | padding=True, 85 | ) 86 | output["degraded_ssl_input"] = self.speech_ssl_processor( 87 | [x.numpy() for x in degraded_wav_16ks], 88 | return_tensors="pt", 89 | sampling_rate=16000, 90 | padding=True, 91 | ) 92 | output["phoneme_input_ids"] = self.phoneme_tokenizer( 93 | [b["phoneme.txt"] for b in batch], return_tensors="pt", padding=True 94 | ) 95 | return output 96 | -------------------------------------------------------------------------------- /examples/configs/config.yaml: -------------------------------------------------------------------------------- 1 | preprocess: 2 | preprocess_dataset: 3 | _target_: torch.utils.data.ConcatDataset 4 | datasets: 5 | - 6 | _target_: miipher.dataset.jvs_corpus.JVSCorpus 7 | root: /mnt/hdd/datasets/jvs_ver1/ 8 | - 9 | _target_: miipher.dataset.libritts.LibriTTSCorpus 10 | root: /mnt/hdd/datasets/libritts-r/LibriTTS_R/ 11 | phoneme_tokenizer: 12 | _target_: transformers.AutoTokenizer.from_pretrained 13 | pretrained_model_name_or_path: "vinai/xphonebert-base" 14 | text2phone_model: 15 | _target_: text2phonemesequence.Text2PhonemeSequence 16 | is_cuda: True 17 | degration: 18 | format_encoding_pairs: 19 | - format: mp3 20 | compression: 16 21 | - format: mp3 22 | compression: 32 23 | - format: mp3 24 | compression: 64 25 | - format: mp3 26 | compression: 128 27 | - format: vorbis 28 | compression: -1 29 | - format: vorbis 30 | compression: 0 31 | - format: vorbis 32 | compression: 1 33 | - format: wav 34 | encoding: ALAW 35 | bits_per_sample: 8 36 | reverb_conditions: 37 | p: 0.5 38 | reverbation_times: 39 | max: 0.5 40 | min: 0.2 41 | room_xy: 42 | max: 10.0 43 | min: 2.0 44 | room_z: 45 | max: 5.0 46 | min: 2.0 47 | room_params: 48 | fs: 22050 49 | max_order: 10 50 | absorption: 0.2 51 | source_pos: 52 | - 1.0 53 | - 1.0 54 | - 1.0 55 | mic_pos: 56 | - 1.0 57 | - 0.7 58 | - 1.2 59 | n_rirs: 1000 60 | background_noise: 61 | snr: 62 | max: 30.0 63 | min: 5.0 64 | patterns: 65 | - 66 | - /mnt/hdd/datasets/slakh2100_flac_redux 67 | - '**/mix.flac' 68 | - 69 | - /mnt/hdd/datasets/TAU_urban/audio/ 70 | - '**/*.wav' 71 | train_tar_sink: 72 | _target_: webdataset.ShardWriter 73 | pattern: "/mnt/hdd/miipher/miipher-train-%06d.tar.gz" 74 | val_tar_sink: 75 | _target_: webdataset.ShardWriter 76 | pattern: "/mnt/hdd/miipher/miipher-val-%06d.tar.gz" 77 | val_size: 6000 78 | n_repeats: 4 79 | sample_rate: 22050 80 | 81 | data: 82 | train_dataset_path: /mnt/hdd/miipher/miipher-train-{000000..000663}.tar.gz 83 | val_dataset_path: /mnt/hdd/miipher/miipher-val-{000000..000007}.tar.gz 84 | train_batch_size: 8 85 | val_batch_size: 8 86 | speech_ssl_processor: 87 | processor: 88 | _target_: transformers.AutoFeatureExtractor.from_pretrained 89 | pretrained_model_name_or_path: "microsoft/wavlm-large" 90 | sr: 16_000 91 | phoneme_padding_idx: 1 92 | phoneme_tokenizer: 93 | _target_: transformers.AutoTokenizer.from_pretrained 94 | pretrained_model_name_or_path: "vinai/xphonebert-base" 95 | train: 96 | loggers: 97 | - _target_: lightning.pytorch.loggers.WandbLogger 98 | project: "miipher" 99 | trainer: 100 | _target_: lightning.Trainer 101 | accelerator: "gpu" 102 | devices: -1 103 | check_val_every_n_epoch: 1 104 | max_epochs: 3300 105 | model: 106 | ssl_models: 107 | model: 108 | _target_: transformers.AutoModel.from_pretrained 109 | pretrained_model_name_or_path: "microsoft/wavlm-large" 110 | sr: 16_000 111 | layer: 8 112 | phoneme_model: 113 | _target_: transformers.AutoModel.from_pretrained 114 | pretrained_model_name_or_path: "vinai/xphonebert-base" 115 | xvector_model: 116 | _target_: speechbrain.pretrained.EncoderClassifier.from_hparams 117 | source: speechbrain/spkrec-ecapa-voxceleb 118 | miipher: 119 | n_phone_feature: 768 120 | n_speaker_embedding: 192 121 | n_ssl_feature: 1024 122 | n_hidden_dim: 1024 123 | n_conformer_blocks: 4 124 | n_iters: 2 125 | optimizers: 126 | _target_: torch.optim.AdamW 127 | lr: 2e-5 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/outputs/ 2 | examples/preprocessed_data/ 3 | *.pyc 4 | *.tsv 5 | *.ipynb 6 | *.ckpt 7 | # Created by https://www.toptal.com/developers/gitignore/api/python 8 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 9 | 10 | ### Python ### 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | ### Python Patch ### 173 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 174 | poetry.toml 175 | 176 | # ruff 177 | .ruff_cache/ 178 | 179 | # LSP config files 180 | pyrightconfig.json 181 | 182 | # End of https://www.toptal.com/developers/gitignore/api/python 183 | 184 | wandb/ -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | absl-py==1.4.0 11 | aiofiles==23.2.1 12 | aiohttp==3.8.4 13 | aiosignal==1.3.1 14 | altair==5.1.1 15 | antlr4-python3-runtime==4.9.3 16 | anyio==3.7.1 17 | appdirs==1.4.4 18 | arrow==1.2.3 19 | async-timeout==4.0.2 20 | attrs==23.1.0 21 | babel==2.12.1 22 | backoff==2.2.1 23 | beautifulsoup4==4.12.2 24 | blessed==1.20.0 25 | braceexpand==0.1.7 26 | cachetools==5.3.1 27 | certifi==2023.5.7 28 | cffi==1.15.1 29 | charset-normalizer==3.2.0 30 | click==8.1.6 31 | clldutils==3.19.0 32 | cmake==3.26.4 33 | colorama==0.4.6 34 | colorlog==6.7.0 35 | contourpy==1.1.0 36 | croniter==1.4.1 37 | csvw==3.1.3 38 | cycler==0.11.0 39 | cython==3.0.0 40 | dateutils==0.6.12 41 | deepdiff==6.3.1 42 | dill==0.3.7 43 | docker-pycreds==0.4.0 44 | exceptiongroup==1.1.2 45 | fastapi==0.100.0 46 | ffmpy==0.3.1 47 | filelock==3.12.2 48 | fonttools==4.41.0 49 | frozenlist==1.4.0 50 | fsspec==2023.6.0 51 | gitdb==4.0.10 52 | gitpython==3.1.32 53 | google-auth==2.17.3 54 | google-auth-oauthlib==1.0.0 55 | gradio==3.45.2 56 | gradio-client==0.5.3 57 | grpcio==1.57.0 58 | h11==0.14.0 59 | httpcore==0.18.0 60 | httpx==0.25.0 61 | huggingface-hub==0.16.4 62 | hydra-core==1.3.2 63 | hyperpyyaml==1.2.1 64 | idna==3.4 65 | importlib-resources==6.1.0 66 | inquirer==3.1.3 67 | isodate==0.6.1 68 | itsdangerous==2.1.2 69 | jinja2==3.1.2 70 | joblib==1.3.1 71 | jsonschema==4.18.4 72 | jsonschema-specifications==2023.7.1 73 | kiwisolver==1.4.4 74 | language-tags==1.2.0 75 | lightning==2.0.5 76 | lightning-cloud==0.5.37 77 | lightning-utilities==0.9.0 78 | lightning-vocoders @ git+https://github.com/Wataru-Nakata/ssl-vocoders 79 | lit==16.0.6 80 | llvmlite==0.40.1 81 | lxml==4.9.3 82 | markdown==3.4.3 83 | markdown-it-py==3.0.0 84 | markupsafe==2.1.3 85 | matplotlib==3.7.2 86 | mdurl==0.1.2 87 | mecab-python3==1.0.6 88 | mpmath==1.3.0 89 | multidict==6.0.4 90 | networkx==3.1 91 | numpy==1.25.1 92 | nvidia-cublas-cu11==11.10.3.66 93 | nvidia-cuda-cupti-cu11==11.7.101 94 | nvidia-cuda-nvrtc-cu11==11.7.99 95 | nvidia-cuda-runtime-cu11==11.7.99 96 | nvidia-cudnn-cu11==8.5.0.96 97 | nvidia-cufft-cu11==10.9.0.58 98 | nvidia-curand-cu11==10.2.10.91 99 | nvidia-cusolver-cu11==11.4.0.1 100 | nvidia-cusparse-cu11==11.7.4.91 101 | nvidia-nccl-cu11==2.14.3 102 | nvidia-nvtx-cu11==11.7.91 103 | oauthlib==3.2.2 104 | omegaconf==2.3.0 105 | ordered-set==4.1.0 106 | orjson==3.9.7 107 | packaging==23.1 108 | pandarallel==1.6.5 109 | pandas==2.0.3 110 | pathtools==0.1.2 111 | pillow==10.0.0 112 | plac==1.3.5 113 | protobuf==4.23.4 114 | psutil==5.9.5 115 | pyasn1==0.5.0 116 | pyasn1-modules==0.3.0 117 | pybind11==2.11.1 118 | pycparser==2.21 119 | pydantic==1.10.11 120 | pydub==0.25.1 121 | pygments==2.15.1 122 | pyjwt==2.8.0 123 | pylatexenc==2.10 124 | pyparsing==3.0.9 125 | pyroomacoustics==0.7.3 126 | pyrootutils==1.0.4 127 | python-dateutil==2.8.2 128 | python-dotenv==1.0.0 129 | python-editor==1.0.4 130 | python-multipart==0.0.6 131 | pytorch-lightning==2.0.5 132 | pytz==2023.3 133 | pyyaml==6.0.1 134 | rdflib==6.3.2 135 | readchar==4.0.5 136 | referencing==0.30.0 137 | regex==2023.6.3 138 | requests==2.31.0 139 | requests-oauthlib==1.3.1 140 | rfc3986==1.5.0 141 | rich==13.4.2 142 | rpds-py==0.9.2 143 | rsa==4.9 144 | ruamel-yaml==0.17.28 145 | ruamel-yaml-clib==0.2.7 146 | scipy==1.11.1 147 | segments==2.2.1 148 | semantic-version==2.10.0 149 | sentencepiece==0.1.99 150 | sentry-sdk==1.29.1 151 | setproctitle==1.3.2 152 | six==1.16.0 153 | smmap==5.0.0 154 | sniffio==1.3.0 155 | soundfile==0.12.1 156 | soupsieve==2.4.1 157 | speechbrain==0.5.15 158 | starlette==0.27.0 159 | starsessions==1.3.0 160 | sympy==1.12 161 | tabulate==0.9.0 162 | tensorboard==2.13.0 163 | tensorboard-data-server==0.7.1 164 | text2phonemesequence==0.1.4 165 | tokenizers==0.13.3 166 | toolz==0.12.0 167 | torch==2.0.1 168 | torchaudio==2.0.2 169 | torchmetrics==1.0.1 170 | tqdm==4.65.0 171 | traitlets==5.9.0 172 | transformers==4.29.2 173 | triton==2.0.0 174 | typing-extensions==4.7.1 175 | tzdata==2023.3 176 | unidic==1.1.0 177 | uritemplate==4.1.1 178 | urllib3==2.0.3 179 | uvicorn==0.23.1 180 | wandb==0.15.7 181 | wasabi==0.10.1 182 | wcwidth==0.2.6 183 | webdataset==0.2.48 184 | websocket-client==1.6.1 185 | websockets==11.0.3 186 | werkzeug==2.3.7 187 | wheel==0.40.0 188 | yarl==1.9.2 189 | # The following packages are considered to be unsafe in a requirements file: 190 | setuptools==68.0.0 191 | -------------------------------------------------------------------------------- /src/miipher/model/miipher.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .modules import FiLMLayer, PositionalEncoding, Postnet 3 | from torchaudio.models.conformer import ConformerLayer 4 | import torch 5 | 6 | 7 | class Miipher(nn.Module): 8 | def __init__( 9 | self, 10 | n_phone_feature, 11 | n_speaker_embedding, 12 | n_ssl_feature, 13 | n_hidden_dim, 14 | n_conformer_blocks, 15 | n_iters, 16 | ) -> None: 17 | super().__init__() 18 | self.phone_speaker_film = FiLMLayer(n_hidden_dim, n_hidden_dim) 19 | self.phone_linear = nn.Linear(n_phone_feature, n_hidden_dim) 20 | self.speaker_linear = nn.Linear(n_speaker_embedding, n_hidden_dim) 21 | 22 | self.ssl_linear = nn.Linear(n_ssl_feature, n_hidden_dim) 23 | 24 | self.positional_encoding = PositionalEncoding(n_hidden_dim) 25 | self.positional_encoding_film = FiLMLayer(n_hidden_dim, n_hidden_dim) 26 | self.conformer_blocks = nn.ModuleList() 27 | for i in range(n_conformer_blocks): 28 | self.conformer_blocks.append(FeatureCleanerBlock(n_hidden_dim, 8)) 29 | self.postnet = Postnet( 30 | n_hidden_dim, 31 | postnet_embedding_dim=512, 32 | postnet_kernel_size=5, 33 | postnet_n_convolutions=5, 34 | ) 35 | self.n_iters = n_iters 36 | self.n_conformer_blocks = n_conformer_blocks 37 | 38 | def forward( 39 | self, phone_feature, speaker_feature, ssl_feature, ssl_feature_lengths=None 40 | ): 41 | """ 42 | Args: 43 | phone_feature: (N, T, n_phone_feature) 44 | speaker_feature: (N, n_speaker_embedding) 45 | ssl_feature: (N, T, n_ssl_feature) 46 | """ 47 | N = phone_feature.size(0) 48 | assert speaker_feature.size(0) == N 49 | assert ssl_feature.size(0) == N 50 | phone_feature = self.phone_linear(phone_feature) 51 | speaker_feature = self.speaker_linear(speaker_feature) 52 | ssl_feature = self.ssl_linear(ssl_feature) 53 | intermediates = [] 54 | phone_speaker_feature = self.phone_speaker_film( 55 | phone_feature, speaker_feature.unsqueeze(1) 56 | ) 57 | for iteration_count in range(self.n_iters): 58 | pos_enc = self.positional_encoding( 59 | torch.tensor(iteration_count, device=self.device).unsqueeze(0).repeat(N) 60 | ) 61 | assert pos_enc.size(0) == N 62 | phone_speaker_feature = self.positional_encoding_film( 63 | phone_speaker_feature, pos_enc 64 | ) 65 | for i in range(self.n_conformer_blocks): 66 | ssl_feature = self.conformer_blocks[i]( 67 | ssl_feature.clone(), phone_speaker_feature, ssl_feature_lengths 68 | ) 69 | intermediates.append(ssl_feature.clone()) 70 | ssl_feature += self.postnet(ssl_feature.clone()) 71 | intermediates.append(ssl_feature.clone()) 72 | return ssl_feature, torch.stack(intermediates) 73 | 74 | @property 75 | def device(self): 76 | return next(iter(self.parameters())).device 77 | 78 | 79 | class FeatureCleanerBlock(nn.Module): 80 | def __init__(self, hidden_dim, num_heads) -> None: 81 | super().__init__() 82 | 83 | self.cross_attention = nn.MultiheadAttention( 84 | hidden_dim, num_heads, batch_first=True 85 | ) 86 | self.conformer_block = ConformerLayer(hidden_dim, hidden_dim * 4, num_heads, 31) 87 | self.layer_norm = nn.LayerNorm(1024) 88 | 89 | def forward( 90 | self, cleaning_feature, speaker_phone_feature, cleaning_feature_lengths=None 91 | ): 92 | if cleaning_feature_lengths is not None: 93 | mask = _lengths_to_padding_mask(cleaning_feature_lengths).T 94 | else: 95 | mask = None 96 | cleaning_feature += self.cross_attention( 97 | cleaning_feature.clone(), 98 | speaker_phone_feature, 99 | speaker_phone_feature, 100 | )[0] 101 | cleaning_feature = self.layer_norm(cleaning_feature.clone()) 102 | cleaning_feature += self.conformer_block( 103 | cleaning_feature.clone(), key_padding_mask=mask 104 | ) 105 | return cleaning_feature 106 | 107 | 108 | def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: 109 | batch_size = lengths.shape[0] 110 | max_length = int(torch.max(lengths).item()) 111 | padding_mask = torch.arange( 112 | max_length, device=lengths.device, dtype=lengths.dtype 113 | ).expand(batch_size, max_length) >= lengths.unsqueeze(1) 114 | return padding_mask 115 | -------------------------------------------------------------------------------- /src/miipher/preprocess/preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import torchaudio 4 | import pathlib 5 | from omegaconf import DictConfig 6 | import numpy as np 7 | import webdataset 8 | import tqdm 9 | from torch.utils.data import DataLoader 10 | from speechbrain.pretrained import EncoderClassifier 11 | from .noiseAugmentation import DegrationApplier 12 | import io 13 | import threading 14 | from concurrent.futures import ThreadPoolExecutor, as_completed 15 | 16 | 17 | class Preprocessor: 18 | """ 19 | Preprocess dataset 20 | """ 21 | 22 | def __init__(self, cfg: DictConfig): 23 | """ 24 | Args: 25 | cfg: hydra config 26 | """ 27 | self.cfg = cfg 28 | self.dataset = hydra.utils.instantiate(cfg.preprocess.preprocess_dataset) 29 | self.sampling_rate = self.cfg.sample_rate 30 | self.phoneme_tokenizer = hydra.utils.instantiate( 31 | cfg.preprocess.phoneme_tokenizer 32 | ) 33 | self.degration_model = DegrationApplier(cfg.preprocess.degration) 34 | self.text2phone_dict = dict() 35 | self.n_repeats = cfg.preprocess.n_repeats 36 | 37 | @torch.inference_mode() 38 | def process_utterance( 39 | self, 40 | basename: str, 41 | audio_file_path, 42 | word_segmented_text: str, 43 | lang_code: str, 44 | ): 45 | orig_waveform, sample_rate = torchaudio.load(audio_file_path) 46 | 47 | waveform = torchaudio.functional.resample( 48 | orig_waveform, sample_rate, new_freq=self.sampling_rate 49 | )[ 50 | 0 51 | ] # remove channel dimension only support mono 52 | 53 | with open(audio_file_path, mode="rb") as f: 54 | wav_bytes = f.read() 55 | 56 | input_ids, input_phonems = self.get_phonemes_input_ids( 57 | word_segmented_text, lang_code 58 | ) 59 | 60 | samples = [] 61 | for i in range(self.n_repeats): 62 | degraded_speech = self.apply_noise(waveform) 63 | buff = io.BytesIO() 64 | torchaudio.save( 65 | buff, 66 | src=degraded_speech.unsqueeze(0), 67 | sample_rate=self.sampling_rate, 68 | format="wav", 69 | ) 70 | buff.seek(0) 71 | 72 | sample = { 73 | "__key__": basename + f"_{i}", 74 | "speech.wav": wav_bytes, 75 | "degraded_speech.wav": buff.read(), 76 | "resampled_speech.pth": webdataset.torch_dumps(waveform), 77 | "word_segmented_text.txt": word_segmented_text, 78 | "phoneme_input_ids.pth": webdataset.torch_dumps(input_ids), 79 | "phoneme.txt": input_phonems, 80 | } 81 | samples.append(sample) 82 | return samples 83 | 84 | def apply_noise(self, waveform): 85 | waveform = self.degration_model.process(waveform, self.sampling_rate) 86 | return waveform 87 | 88 | @torch.inference_mode() 89 | def get_phonemes_input_ids(self, word_segmented_text, lang_code): 90 | if lang_code not in self.text2phone_dict.keys(): 91 | self.text2phone_dict[lang_code] = hydra.utils.instantiate( 92 | self.cfg.preprocess.text2phone_model, language=lang_code 93 | ) 94 | input_phonemes = self.text2phone_dict[lang_code].infer_sentence( 95 | word_segmented_text 96 | ) 97 | input_ids = self.phoneme_tokenizer(input_phonemes, return_tensors="pt") 98 | return input_ids, input_phonemes 99 | 100 | def build_from_path(self): 101 | pathlib.Path( 102 | "/".join(self.cfg.preprocess.train_tar_sink.pattern.split("/")[:-1]) 103 | ).mkdir(exist_ok=True) 104 | train_sink = hydra.utils.instantiate(self.cfg.preprocess.train_tar_sink) 105 | val_sink = hydra.utils.instantiate(self.cfg.preprocess.val_tar_sink) 106 | dataloader = DataLoader( 107 | self.dataset, batch_size=1, shuffle=True, num_workers=64 108 | ) 109 | for idx, data in enumerate(tqdm.tqdm(dataloader)): 110 | basename = data["basename"][0] 111 | wav_path = data["wav_path"][0] 112 | word_segmented_text = data["word_segmented_text"][0] 113 | lang_code = data["lang_code"][0] 114 | result = self.process_utterance( 115 | basename, wav_path, word_segmented_text, lang_code 116 | ) 117 | if idx >= self.cfg.preprocess.val_size: 118 | sink = train_sink 119 | else: 120 | sink = val_sink 121 | for sample in result: 122 | sink.write(sample) 123 | train_sink.close() 124 | val_sink.close() 125 | -------------------------------------------------------------------------------- /src/miipher/preprocess/noiseAugmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from torch import nn as nn 3 | import torchaudio 4 | import random 5 | import pyroomacoustics as pra 6 | import numpy as np 7 | import torch 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | 11 | 12 | def align_waveform(wav1, wav2): 13 | assert wav2.size(1) >= wav1.size(1) 14 | diff = wav2.size(1) - wav1.size(1) 15 | min_mse = float("inf") 16 | best_i = -1 17 | 18 | for i in range(diff): 19 | segment = wav2[:, i : i + wav1.size(1)] 20 | mse = torch.mean((wav1 - segment) ** 2).item() 21 | if mse < min_mse: 22 | min_mse = mse 23 | best_i = i 24 | 25 | return best_i, wav2[:, best_i : best_i + wav1.size(1)] 26 | 27 | 28 | class DegrationApplier: 29 | def __init__(self, cfg) -> None: 30 | self.format_encoding_pairs = cfg.format_encoding_pairs 31 | self.reverb_conditions = cfg.reverb_conditions 32 | self.background_noise = cfg.background_noise 33 | self.cfg = cfg 34 | self.rirs = [] 35 | self.prepare_rir(cfg.n_rirs) 36 | self.noise_audio_paths = [] 37 | for root, pattern in self.cfg.background_noise.patterns: 38 | self.noise_audio_paths.extend(list(Path(root).glob(pattern))) 39 | 40 | def applyCodec(self, waveform, sample_rate): 41 | if len(self.format_encoding_pairs) == 0: 42 | return waveform 43 | param = random.choice(self.format_encoding_pairs) 44 | augmented = torchaudio.functional.apply_codec( 45 | waveform=waveform.float(), sample_rate=sample_rate, **param 46 | ) 47 | # mp3 encoding may increase the length of the waveform by zero-padding 48 | if waveform.size(1) != augmented.size(1): 49 | best_idx, augmented = align_waveform(waveform, augmented) 50 | return augmented.float() 51 | 52 | def applyReverb(self, waveform): 53 | if len(self.rirs) == 0: 54 | raise RuntimeError 55 | rir = random.choice(self.rirs) 56 | augmented = torchaudio.functional.fftconvolve(waveform, rir) 57 | # rir convolution may increase the length of the waveform 58 | if waveform.size(1) != augmented.size(1): 59 | augmented = augmented[:, : waveform.size(1)] 60 | return augmented.float() 61 | 62 | def prepare_rir(self, n_rirs): 63 | for i in tqdm(range(n_rirs)): 64 | xy_minmax = self.reverb_conditions.room_xy 65 | z_minmax = self.reverb_conditions.room_z 66 | x = random.uniform(xy_minmax.min, xy_minmax.max) 67 | y = random.uniform(xy_minmax.min, xy_minmax.max) 68 | z = random.uniform(z_minmax.min, z_minmax.max) 69 | corners = np.array([[0, 0], [0, y], [x, y], [x, 0]]).T 70 | room = pra.Room.from_corners(corners, **self.reverb_conditions.room_params) 71 | room.extrude(z) 72 | room.add_source(self.cfg.reverb_conditions.source_pos) 73 | room.add_microphone(self.cfg.reverb_conditions.mic_pos) 74 | 75 | room.compute_rir() 76 | rir = torch.tensor(np.array(room.rir[0])) 77 | rir = rir / rir.norm(p=2) 78 | self.rirs.append(rir) 79 | 80 | def applyBackgroundNoise(self, waveform, sample_rate): 81 | snr_max, snr_min = self.background_noise.snr.max, self.background_noise.snr.min 82 | snr = random.uniform(snr_min, snr_max) 83 | 84 | noise_path = random.choice(self.noise_audio_paths) 85 | noise, noise_sr = torchaudio.load(noise_path) 86 | noise /= noise.norm(p=2) 87 | if noise.size(0) > 1: 88 | noise = noise[0].unsqueeze(0) 89 | noise = torchaudio.functional.resample(noise, noise_sr, sample_rate) 90 | if not noise.size(1) < waveform.size(1): 91 | start_idx = random.randint(0, noise.size(1) - waveform.size(1)) 92 | end_idx = start_idx + waveform.size(1) 93 | noise = noise[:, start_idx:end_idx] 94 | else: 95 | noise = noise.repeat(1, waveform.size(1) // noise.size(1) + 1)[ 96 | :, : waveform.size(1) 97 | ] 98 | if noise.abs().max() > 0: 99 | augmented = torchaudio.functional.add_noise( 100 | waveform=waveform, noise=noise, snr=torch.tensor([snr]) 101 | ) 102 | else: 103 | augmented = waveform 104 | return augmented 105 | 106 | def process(self, waveform, sample_rate): 107 | if len(waveform.shape) == 1: 108 | waveform = waveform.unsqueeze(0) 109 | org_len = waveform.size(1) 110 | waveform = self.applyBackgroundNoise(waveform, sample_rate) 111 | if random.random() > self.cfg.reverb_conditions.p: 112 | waveform = self.applyReverb(waveform) 113 | waveform = self.applyCodec(waveform, sample_rate) 114 | assert org_len == waveform.size(1), f"{org_len}, {waveform.size(1)}" 115 | return waveform.squeeze() 116 | 117 | def __call__(self, waveform, sample_rate): 118 | return self.process(waveform, sample_rate) 119 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | absl-py==1.4.0 11 | aiofiles==23.2.1 12 | aiohttp==3.8.4 13 | aiosignal==1.3.1 14 | altair==5.1.1 15 | antlr4-python3-runtime==4.9.3 16 | anyio==3.7.1 17 | appdirs==1.4.4 18 | arrow==1.2.3 19 | asttokens==2.2.1 20 | async-timeout==4.0.2 21 | attrs==23.1.0 22 | babel==2.12.1 23 | backcall==0.2.0 24 | backoff==2.2.1 25 | beautifulsoup4==4.12.2 26 | black==23.7.0 27 | blessed==1.20.0 28 | braceexpand==0.1.7 29 | cachetools==5.3.1 30 | certifi==2023.5.7 31 | cffi==1.15.1 32 | charset-normalizer==3.2.0 33 | click==8.1.6 34 | clldutils==3.19.0 35 | cmake==3.26.4 36 | colorama==0.4.6 37 | colorlog==6.7.0 38 | comm==0.1.3 39 | contourpy==1.1.0 40 | croniter==1.4.1 41 | csvw==3.1.3 42 | cycler==0.11.0 43 | cython==3.0.0 44 | dateutils==0.6.12 45 | debugpy==1.6.7 46 | decorator==5.1.1 47 | deepdiff==6.3.1 48 | dill==0.3.7 49 | docker-pycreds==0.4.0 50 | exceptiongroup==1.1.2 51 | executing==1.2.0 52 | fastapi==0.100.0 53 | ffmpy==0.3.1 54 | filelock==3.12.2 55 | fonttools==4.41.0 56 | frozenlist==1.4.0 57 | fsspec==2023.6.0 58 | gitdb==4.0.10 59 | gitpython==3.1.32 60 | google-auth==2.17.3 61 | google-auth-oauthlib==1.0.0 62 | gradio==3.45.2 63 | gradio-client==0.5.3 64 | grpcio==1.57.0 65 | h11==0.14.0 66 | httpcore==0.18.0 67 | httpx==0.25.0 68 | huggingface-hub==0.16.4 69 | hydra-core==1.3.2 70 | hyperpyyaml==1.2.1 71 | idna==3.4 72 | importlib-resources==6.1.0 73 | iniconfig==2.0.0 74 | inquirer==3.1.3 75 | ipykernel==6.24.0 76 | ipython==8.14.0 77 | isodate==0.6.1 78 | itsdangerous==2.1.2 79 | jedi==0.18.2 80 | jinja2==3.1.2 81 | joblib==1.3.1 82 | jsonschema==4.18.4 83 | jsonschema-specifications==2023.7.1 84 | jupyter-client==8.3.0 85 | jupyter-core==5.3.1 86 | kiwisolver==1.4.4 87 | language-tags==1.2.0 88 | lightning==2.0.5 89 | lightning-cloud==0.5.37 90 | lightning-utilities==0.9.0 91 | lightning-vocoders @ git+https://github.com/Wataru-Nakata/ssl-vocoders 92 | lit==16.0.6 93 | llvmlite==0.40.1 94 | lxml==4.9.3 95 | markdown==3.4.3 96 | markdown-it-py==3.0.0 97 | markupsafe==2.1.3 98 | matplotlib==3.7.2 99 | matplotlib-inline==0.1.6 100 | mdurl==0.1.2 101 | mecab-python3==1.0.6 102 | mpmath==1.3.0 103 | multidict==6.0.4 104 | mypy-extensions==1.0.0 105 | nest-asyncio==1.5.6 106 | networkx==3.1 107 | numpy==1.25.1 108 | nvidia-cublas-cu11==11.10.3.66 109 | nvidia-cuda-cupti-cu11==11.7.101 110 | nvidia-cuda-nvrtc-cu11==11.7.99 111 | nvidia-cuda-runtime-cu11==11.7.99 112 | nvidia-cudnn-cu11==8.5.0.96 113 | nvidia-cufft-cu11==10.9.0.58 114 | nvidia-curand-cu11==10.2.10.91 115 | nvidia-cusolver-cu11==11.4.0.1 116 | nvidia-cusparse-cu11==11.7.4.91 117 | nvidia-nccl-cu11==2.14.3 118 | nvidia-nvtx-cu11==11.7.91 119 | oauthlib==3.2.2 120 | omegaconf==2.3.0 121 | ordered-set==4.1.0 122 | orjson==3.9.7 123 | packaging==23.1 124 | pandarallel==1.6.5 125 | pandas==2.0.3 126 | parso==0.8.3 127 | pathspec==0.11.2 128 | pathtools==0.1.2 129 | pexpect==4.8.0 130 | pickleshare==0.7.5 131 | pillow==10.0.0 132 | plac==1.3.5 133 | platformdirs==3.9.1 134 | pluggy==1.2.0 135 | prompt-toolkit==3.0.39 136 | protobuf==4.23.4 137 | psutil==5.9.5 138 | ptyprocess==0.7.0 139 | pure-eval==0.2.2 140 | pyasn1==0.5.0 141 | pyasn1-modules==0.3.0 142 | pybind11==2.11.1 143 | pycparser==2.21 144 | pydantic==1.10.11 145 | pydub==0.25.1 146 | pygments==2.15.1 147 | pyjwt==2.8.0 148 | pylatexenc==2.10 149 | pyparsing==3.0.9 150 | pyroomacoustics==0.7.3 151 | pyrootutils==1.0.4 152 | pytest==7.4.0 153 | python-dateutil==2.8.2 154 | python-dotenv==1.0.0 155 | python-editor==1.0.4 156 | python-multipart==0.0.6 157 | pytorch-lightning==2.0.5 158 | pytz==2023.3 159 | pyyaml==6.0.1 160 | pyzmq==25.1.0 161 | rdflib==6.3.2 162 | readchar==4.0.5 163 | referencing==0.30.0 164 | regex==2023.6.3 165 | requests==2.31.0 166 | requests-oauthlib==1.3.1 167 | rfc3986==1.5.0 168 | rich==13.4.2 169 | rpds-py==0.9.2 170 | rsa==4.9 171 | ruamel-yaml==0.17.28 172 | ruamel-yaml-clib==0.2.7 173 | scipy==1.11.1 174 | segments==2.2.1 175 | semantic-version==2.10.0 176 | sentencepiece==0.1.99 177 | sentry-sdk==1.29.1 178 | setproctitle==1.3.2 179 | six==1.16.0 180 | smmap==5.0.0 181 | sniffio==1.3.0 182 | soundfile==0.12.1 183 | soupsieve==2.4.1 184 | speechbrain==0.5.15 185 | stack-data==0.6.2 186 | starlette==0.27.0 187 | starsessions==1.3.0 188 | sympy==1.12 189 | tabulate==0.9.0 190 | tensorboard==2.13.0 191 | tensorboard-data-server==0.7.1 192 | text2phonemesequence==0.1.4 193 | tokenizers==0.13.3 194 | tomli==2.0.1 195 | toolz==0.12.0 196 | torch==2.0.1 197 | torchaudio==2.0.2 198 | torchmetrics==1.0.1 199 | tornado==6.3.2 200 | tqdm==4.65.0 201 | traitlets==5.9.0 202 | transformers==4.29.2 203 | triton==2.0.0 204 | typing-extensions==4.7.1 205 | tzdata==2023.3 206 | unidic==1.1.0 207 | uritemplate==4.1.1 208 | urllib3==2.0.3 209 | uvicorn==0.23.1 210 | wandb==0.15.7 211 | wasabi==0.10.1 212 | wcwidth==0.2.6 213 | webdataset==0.2.48 214 | websocket-client==1.6.1 215 | websockets==11.0.3 216 | werkzeug==2.3.7 217 | wheel==0.40.0 218 | yarl==1.9.2 219 | # The following packages are considered to be unsafe in a requirements file: 220 | setuptools==68.0.0 221 | -------------------------------------------------------------------------------- /src/miipher/model/modules.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class FiLMLayer(nn.Module): 7 | def __init__(self, input_channels, intermediate_channels) -> None: 8 | super().__init__() 9 | self.conv1 = nn.Conv1d( 10 | input_channels, intermediate_channels, kernel_size=3, stride=1, padding=1 11 | ) 12 | self.conv2 = nn.Conv1d( 13 | intermediate_channels, input_channels, kernel_size=3, stride=1, padding=1 14 | ) 15 | self.leaky_relu = nn.LeakyReLU(0.1) 16 | 17 | def forward(self, a: torch.Tensor, b: torch.Tensor): 18 | batch_size, K, D = a.size() 19 | Q = b.size(1) 20 | a = a.transpose(1, 2) 21 | output = self.conv2( 22 | (self.leaky_relu(self.conv1(a)).transpose(1, 2) + b).transpose(1, 2) 23 | ) 24 | output = output.permute(0, 2, 1) 25 | assert output.size() == (batch_size, K, D) 26 | return output 27 | 28 | 29 | class PositionalEncoding(nn.Module): 30 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 31 | super().__init__() 32 | self.dropout = nn.Dropout(p=dropout) 33 | 34 | position = torch.arange(max_len).unsqueeze(1) 35 | div_term = torch.exp( 36 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 37 | ) 38 | pe = torch.zeros(max_len, 1, d_model) 39 | pe[:, 0, 0::2] = torch.sin(position * div_term) 40 | pe[:, 0, 1::2] = torch.cos(position * div_term) 41 | self.register_buffer("pe", pe) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Arguments: 46 | x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` 47 | """ 48 | return self.pe[x] 49 | 50 | 51 | class FeedForward(nn.Module): 52 | def __init__(self, dim, hidden_dim=2048, dropout=0.1): 53 | super().__init__() 54 | self.dropout = nn.Dropout(dropout) 55 | self.linear_1 = nn.Linear(dim, hidden_dim) 56 | self.relu = nn.ReLU() 57 | self.linear_2 = nn.Linear(hidden_dim, dim) 58 | 59 | def forward(self, x): 60 | x = self.linear_1(x) 61 | x = self.relu(x) 62 | x = self.dropout(x) 63 | x = self.linear_2(x) 64 | return x 65 | 66 | 67 | class ConvNorm(torch.nn.Module): 68 | def __init__( 69 | self, 70 | in_channels, 71 | out_channels, 72 | kernel_size=1, 73 | stride=1, 74 | padding=None, 75 | dilation=1, 76 | bias=True, 77 | w_init_gain="linear", 78 | ): 79 | super(ConvNorm, self).__init__() 80 | if padding is None: 81 | assert kernel_size % 2 == 1 82 | padding = int(dilation * (kernel_size - 1) / 2) 83 | 84 | self.conv = torch.nn.Conv1d( 85 | in_channels, 86 | out_channels, 87 | kernel_size=kernel_size, 88 | stride=stride, 89 | padding=padding, 90 | dilation=dilation, 91 | bias=bias, 92 | ) 93 | 94 | torch.nn.init.xavier_uniform_( 95 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain) 96 | ) 97 | 98 | def forward(self, signal): 99 | conv_signal = self.conv(signal) 100 | return conv_signal 101 | 102 | 103 | class Postnet(nn.Module): 104 | """Postnet 105 | - Five 1-d convolution with 512 channels and kernel size 5 106 | """ 107 | 108 | def __init__( 109 | self, 110 | n_mel_channels, 111 | postnet_embedding_dim, 112 | postnet_kernel_size, 113 | postnet_n_convolutions, 114 | ): 115 | super(Postnet, self).__init__() 116 | self.convolutions = nn.ModuleList() 117 | 118 | self.convolutions.append( 119 | nn.Sequential( 120 | ConvNorm( 121 | n_mel_channels, 122 | postnet_embedding_dim, 123 | kernel_size=postnet_kernel_size, 124 | stride=1, 125 | padding=int((postnet_kernel_size - 1) / 2), 126 | dilation=1, 127 | w_init_gain="tanh", 128 | ), 129 | nn.BatchNorm1d(postnet_embedding_dim), 130 | ) 131 | ) 132 | 133 | for i in range(1, postnet_n_convolutions - 1): 134 | self.convolutions.append( 135 | nn.Sequential( 136 | ConvNorm( 137 | postnet_embedding_dim, 138 | postnet_embedding_dim, 139 | kernel_size=postnet_kernel_size, 140 | stride=1, 141 | padding=int((postnet_kernel_size - 1) / 2), 142 | dilation=1, 143 | w_init_gain="tanh", 144 | ), 145 | nn.BatchNorm1d(postnet_embedding_dim), 146 | ) 147 | ) 148 | 149 | self.convolutions.append( 150 | nn.Sequential( 151 | ConvNorm( 152 | postnet_embedding_dim, 153 | n_mel_channels, 154 | kernel_size=postnet_kernel_size, 155 | stride=1, 156 | padding=int((postnet_kernel_size - 1) / 2), 157 | dilation=1, 158 | w_init_gain="linear", 159 | ), 160 | nn.BatchNorm1d(n_mel_channels), 161 | ) 162 | ) 163 | 164 | def forward(self, x): 165 | x = x.transpose(1, 2) 166 | for i in range(len(self.convolutions) - 1): 167 | x = torch.nn.functional.dropout( 168 | torch.tanh(self.convolutions[i](x)), 0.5, self.training 169 | ) 170 | x = torch.nn.functional.dropout(self.convolutions[-1](x), 0.5, self.training) 171 | x = x.transpose(1, 2) 172 | return x 173 | -------------------------------------------------------------------------------- /src/miipher/lightning_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | from lightning.pytorch import LightningModule 3 | from lightning.pytorch.utilities.types import STEP_OUTPUT 4 | from lightning_vocoders.models.hifigan.lightning_module import MultiPeriodDiscriminator, MultiScaleDiscriminator 5 | from .model.miipher import Miipher 6 | from omegaconf import DictConfig 7 | from lightning.pytorch import loggers 8 | from torch import nn 9 | from typing import List 10 | from lightning_vocoders.models.hifigan.xvector_lightning_module import HiFiGANXvectorLightningModule 11 | import torch 12 | import hydra 13 | 14 | 15 | class FeatureExtractor(): 16 | def __init__(self,cfg) -> None: 17 | self.speech_ssl_model = hydra.utils.instantiate(cfg.model.ssl_models.model) 18 | self.speech_ssl_model.eval() 19 | self.phoneme_model = hydra.utils.instantiate(cfg.model.phoneme_model) 20 | self.phoneme_model.eval() 21 | self.xvector_model = hydra.utils.instantiate(cfg.model.xvector_model) 22 | self.xvector_model.eval() 23 | self.cfg = cfg 24 | 25 | @torch.inference_mode() 26 | def __call__(self, inputs): 27 | wav_16k = inputs["degraded_wav_16k"] 28 | wav_16k_lens = inputs["degraded_wav_16k_lengths"] 29 | feats = self.xvector_model.mods.compute_features(wav_16k) 30 | feats = self.xvector_model.mods.mean_var_norm(feats, wav_16k_lens) 31 | xvector = self.xvector_model.mods.embedding_model(feats, wav_16k_lens).squeeze( 32 | 1 33 | ) 34 | phone_feature = self.phoneme_model( 35 | **inputs["phoneme_input_ids"] 36 | ).last_hidden_state 37 | if 'clean_ssl_input' in inputs.keys(): 38 | clean_ssl_feature = self.speech_ssl_model( 39 | **inputs["clean_ssl_input"], output_hidden_states=True 40 | ) 41 | clean_ssl_feature = clean_ssl_feature.hidden_states[ 42 | self.cfg.model.ssl_models.layer 43 | ] 44 | else: 45 | clean_ssl_feature = None 46 | 47 | degraded_ssl_feature = self.speech_ssl_model( 48 | **inputs["degraded_ssl_input"], output_hidden_states=True 49 | ) 50 | degraded_ssl_feature = degraded_ssl_feature.hidden_states[ 51 | self.cfg.model.ssl_models.layer 52 | ] 53 | 54 | return phone_feature, xvector, degraded_ssl_feature, clean_ssl_feature 55 | 56 | def to(self,device:torch.device): 57 | self.speech_ssl_model = self.speech_ssl_model.to(device) 58 | self.phoneme_model = self.phoneme_model.to(device) 59 | self.xvector_model = self.xvector_model.to(device) 60 | 61 | class MiipherLightningModule(LightningModule): 62 | def __init__(self, cfg: DictConfig) -> None: 63 | super().__init__() 64 | 65 | self.miipher = Miipher(**cfg.model.miipher) 66 | self.mse_loss = nn.MSELoss() 67 | self.mae_loss = nn.L1Loss() 68 | self.cfg = cfg 69 | self.feature_extractor = FeatureExtractor(cfg) 70 | # GANs 71 | self.save_hyperparameters() 72 | 73 | def on_fit_start(self): 74 | self.feature_extractor.to(self.device) 75 | 76 | def forward(self,phone_feature, speaker_feature, degraded_ssl_feature): 77 | cleaned_feature, intermediates = self.miipher.forward( 78 | phone_feature.clone(), speaker_feature.clone(), degraded_ssl_feature.clone() 79 | ) 80 | return cleaned_feature, intermediates 81 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 82 | ( 83 | phone_feature, 84 | speaker_feature, 85 | degraded_ssl_feature, 86 | clean_ssl_feature, 87 | ) = self.feature_extractor(batch) 88 | 89 | cleaned_feature, intermediates = self.miipher.forward( 90 | phone_feature.clone(), speaker_feature.clone(), degraded_ssl_feature.clone() 91 | ) 92 | with torch.cuda.amp.autocast(enabled=False): 93 | loss = self.criterion(intermediates.float(), clean_ssl_feature.float(),log=True,stage='train') 94 | self.log("train/loss", loss, batch_size=phone_feature.size(0),prog_bar=True) 95 | return loss 96 | 97 | def validation_step(self, batch, batch_idx) -> STEP_OUTPUT | None: 98 | ( 99 | phone_feature, 100 | speaker_feature, 101 | degraded_ssl_feature, 102 | clean_ssl_feature, 103 | ) = self.feature_extractor(batch) 104 | cleaned_feature, intermediates = self.miipher.forward( 105 | phone_feature, speaker_feature, degraded_ssl_feature 106 | ) 107 | with torch.cuda.amp.autocast(enabled=False): 108 | loss = self.criterion(intermediates.float(), clean_ssl_feature.float(),log=True,stage='val') 109 | self.log("val/loss", loss, batch_size=phone_feature.size(0)) 110 | if batch_idx < 10 and self.global_rank == 0 and self.local_rank==0: 111 | cleaned_wav = self.synthesis(cleaned_feature[0], batch["degraded_wav_16k"][0], batch["degraded_wav_16k_lengths"][0]) 112 | self.log_audio(cleaned_wav, f"val/cleaned_wav/{batch_idx}", 22050) 113 | input_wav = self.synthesis(degraded_ssl_feature[0], batch["degraded_wav_16k"][0], batch["degraded_wav_16k_lengths"][0]) 114 | self.log_audio(input_wav, f"val/input_wav/{batch_idx}", 22050) 115 | clean_wav = self.synthesis(clean_ssl_feature[0], batch["degraded_wav_16k"][0], batch["degraded_wav_16k_lengths"][0]) 116 | self.log_audio(clean_wav, f"val/target_wav/{batch_idx}", 22050) 117 | return loss 118 | 119 | def configure_optimizers(self): 120 | return hydra.utils.instantiate(self.cfg.optimizers, params=self.miipher.parameters()) 121 | 122 | def criterion(self, intermediates: List[torch.Tensor], target: torch.Tensor,log=False,stage='train'): 123 | loss = 0 124 | minimum_length = min(intermediates[0].size(1), target.size(1)) 125 | target = target[:, :minimum_length, :].clone() 126 | for idx, intermediate in enumerate(intermediates): 127 | intermediate = intermediate[:, :minimum_length, :].clone() 128 | loss = loss + self.mae_loss(intermediate, target).clone() 129 | mae_loss = self.mae_loss(intermediate, target) 130 | mse_loss = self.mse_loss(intermediate, target) 131 | spectoral_loss = ( (intermediate - target + 1e-7).norm(p=2, dim=(1, 2)).pow(2) / (target.norm(p=2, dim=(1, 2)).pow(2))).mean() 132 | loss += mae_loss + mse_loss + spectoral_loss 133 | if log: 134 | self.log(f'{stage}/{idx}/mae_loss', mae_loss) 135 | self.log(f'{stage}/{idx}/mse_loss', mse_loss) 136 | self.log(f'{stage}/{idx}/spectoral_loss', spectoral_loss) 137 | 138 | return loss 139 | @torch.inference_mode() 140 | def synthesis(self,features:torch.Tensor,wav16k,wav16k_lens): 141 | vocoder = HiFiGANXvectorLightningModule.load_from_checkpoint("https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wavlm-large-l8-xvector/wavlm-large-l8-xvector.ckpt",map_location='cpu') 142 | vocoder.eval() 143 | xvector_model = hydra.utils.instantiate(vocoder.cfg.data.xvector.model) 144 | xvector_model.eval() 145 | xvector = xvector_model.encode_batch(wav16k.unsqueeze(0).cpu()).squeeze(1) 146 | vocoder = vocoder.float() 147 | return vocoder.generator_forward({"input_feature": features.unsqueeze(0).cpu().float(), "xvector": xvector.cpu().float()})[0].T 148 | 149 | def log_audio(self, audio, name, sampling_rate): 150 | for logger in self.loggers: 151 | match type(logger): 152 | case loggers.WandbLogger: 153 | import wandb 154 | 155 | wandb.log( 156 | {name: wandb.Audio(audio, sample_rate=sampling_rate)}, 157 | step=self.global_step, 158 | ) 159 | case loggers.TensorBoardLogger: 160 | logger.experiment.add_audio( 161 | name, 162 | audio, 163 | self.global_step, 164 | sampling_rate, 165 | ) 166 | --------------------------------------------------------------------------------