├── .github └── workflows │ └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── aet.py ├── aet_demo.py ├── aet_sample ├── src.wav ├── tar.wav └── transfer.wav ├── configs ├── test │ ├── melspec │ │ ├── audio_effect_transfer.yaml │ │ ├── dual.yaml │ │ ├── pretrain_jvs.yaml │ │ ├── ssl_jsut.yaml │ │ └── ssl_tono.yaml │ └── vocfeats │ │ ├── audio_effect_transfer.yaml │ │ ├── dual.yaml │ │ ├── pretrain_jvs.yaml │ │ ├── ssl_jsut.yaml │ │ └── ssl_tono.yaml └── train │ ├── melspec │ ├── dual.yaml │ ├── pretrain_jvs.yaml │ ├── ssl_jsut.yaml │ └── ssl_tono.yaml │ └── vocfeats │ ├── dual.yaml │ ├── pretrain_jvs.yaml │ ├── ssl_jsut.yaml │ └── ssl_tono.yaml ├── dataset.py ├── eval.py ├── hifigan ├── LICENSE ├── __init__.py ├── config_melspec.json ├── config_vocfeats.json └── models.py ├── imgs └── method.jpg ├── lightning_module.py ├── model.py ├── preprocess.py ├── pretrained_models.md ├── requirements.txt ├── setup.sh ├── simulated_data.py ├── train.py └── utils.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '23 21 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v5 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' 28 | days-before-stale: 30 29 | days-before-close: 5 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # data, checkpoint, and models 111 | output*/ 112 | data/ 113 | *.ckpt 114 | preprocessed*/ 115 | hifigan/hifigan* 116 | ckpts_*/ 117 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Takaaki Saeki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelfRemaster: Self-Supervised Speech Restoration 2 | 3 | Official implementation of [SelfRemaster: Self-Supervised Speech Restoration with Analysis-by-Synthesis Approach Using Channel Modeling](https://arxiv.org/abs/2203.12937) (to appear in *INTERSPEECH 2022*). 4 | 5 | ## Note 6 | This repo contains an older version of the code, but is kept for compatibility. 7 | 8 | **The latest version is available [here](https://github.com/Takaaki-Saeki/ssl_speech_restoration_v2/tree/master)**. 9 | 10 | ## Demo 11 | - [Audio samples](https://takaaki-saeki.github.io/ssl_remaster_demo/) 12 | - Audio effect transfer with [Gradio + HuggingFace Spaces 🤗](https://huggingface.co/spaces/saefro991/aet_demo) 13 | 14 | ## Setup 15 | 1. Clone this repository: `git clone https://github.com/Takaaki-Saeki/ssl_speech_restoration.git`. 16 | 2. CD into this repository: `cd ssl_speech_restoration`. 17 | 3. Install python packages and download some pretrained models: `./setup.sh`. 18 | 19 | ## Getting started 20 | - If you use default Japanese corpora: 21 | - Download [JSUT Basic5000](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [JVS Corpus](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus) 22 | - Downsample them to 22.05 kHz and place them under `data/` as `jsut_22k` and `jvs_22k`. 23 | - JSUT is a single-speaker dataset and requires the structure as `jsut_22k/*.wav`. Note that this is the ground-truth clean speech data which correspond to the simulated data and is not used for training. You may want to use `jsut_22k` only to compare the restored speech and ground-truth speech. 24 | - JVS parallel100 includes 100-speaker data and requires the structure as `jvs_22k/${spkr_name}/*.wav`. This is a clean speech dataset used for the backward learning of the dual-learning method. 25 | - Place simulated low-quality data under `./data` as `jsut_22k-low`. 26 | - Or you can use arbitrary datasets by modifying config files. 27 | 28 | ## Training 29 | 30 | You can choose `MelSpec` or `SourFilter` models with `--config_path` option. 31 | As shown in the paper, `MelSpec` model is of higher-quality. 32 | 33 | Firstly you need to split the data to train/val/test and dump them by the following command. 34 | ```shell 35 | python preprocess.py --config_path configs/train/${feature}/ssl_jsut.yaml 36 | ``` 37 | 38 | To perform self-supervised learning with dual learning, run the following command. 39 | ```shell 40 | python train.py \ 41 | --config_path configs/train/${feature}/ssl_jsut.yaml \ 42 | --stage ssl-dual \ 43 | --run_name ssl_melspec_dual 44 | ``` 45 | For other options, refer to `train.py`. 46 | 47 | Note that you might need to tune some parameters for your own datasets. 48 | In our experiences, `learning_rate` and `beta` are cruicial parameters. 49 | For example, if the trianing is unstable, consider making `beta` smaller (e.b., `beta: 0.001`). 50 | 51 | ## Speech restoration 52 | To perform speech restoration of the test data, run the following command. 53 | ```shell 54 | python eval.py \ 55 | --config_path configs/test/${feature}/ssl_jsut.yaml \ 56 | --ckpt_path ${path to checkpoint} \ 57 | --stage ssl-dual \ 58 | --run_name ssl_melspec_dual 59 | ``` 60 | For other options, see `eval.py`. 61 | 62 | ## Audio effect transfer 63 | You can run a simple audio effect transfer demo using a model pretrained with real data. 64 | Run the following command. 65 | ```shell 66 | python aet_demo.py 67 | ``` 68 | 69 | Or you can customize the dataset or model. 70 | You need to edit `audio_effect_transfer.yaml` and run the following command. 71 | ```shell 72 | python aet.py \ 73 | --config_path configs/test/melspec/audio_effect_transfer.yaml \ 74 | --stage ssl-dual \ 75 | --run_name aet_melspec_dual 76 | ``` 77 | For other options, see `aet.py`. 78 | 79 | 80 | ## Pretrained models 81 | See [here](./pretrained_models.md). 82 | 83 | ## Reproducing results 84 | You can generate simulated low-quality data as in the paper with the following command. 85 | ```shell 86 | python simulated_data.py \ 87 | --in_dir ${input_directory (e.g., path to jsut_22k)} \ 88 | --output_dir ${output_directory (e.g., path to jsut_22k-low)} \ 89 | --corpus_type ${single-speaker corpus or multi-speaker corpus} \ 90 | --deg_type lowpass 91 | ``` 92 | 93 | Then download the pretrained model correspond to the deg_type and run the following command. 94 | ```shell 95 | python eval.py \ 96 | --config_path configs/train/${feature}/ssl_jsut.yaml \ 97 | --ckpt_path ${path to checkpoint} \ 98 | --stage ssl-dual \ 99 | --run_name ssl_melspec_dual 100 | ``` 101 | 102 | ## Citation 103 | ```bib 104 | @article{saeki22selfremaster, 105 | title={{SelfRemaster}: {S}elf-Supervised Speech Restoration with Analysis-by-Synthesis Approach Using Channel Modeling}, 106 | author={T. Saeki and S. Takamichi and T. Nakamura and N. Tanji and H. Saruwatari}, 107 | journal={arXiv preprint arXiv:2203.12937}, 108 | year={2022} 109 | } 110 | ``` 111 | 112 | ## Reference 113 | - [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) 114 | - [VoiceFixer: Toward General Speech Restoration with Neural Vocoder](https://arxiv.org/abs/2109.13731) 115 | -------------------------------------------------------------------------------- /aet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import yaml 4 | import torch 5 | import torchaudio 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | import random 9 | import librosa 10 | from dataset import Dataset 11 | import pickle 12 | from lightning_module import ( 13 | SSLStepLightningModule, 14 | SSLDualLightningModule, 15 | ) 16 | from utils import plot_and_save_mels 17 | import os 18 | import tqdm 19 | 20 | 21 | def get_arg(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--stage", required=True, type=str) 24 | parser.add_argument("--config_path", required=True, type=pathlib.Path) 25 | parser.add_argument("--exist_src_aux", action="store_true") 26 | parser.add_argument("--run_name", required=True, type=str) 27 | return parser.parse_args() 28 | 29 | 30 | class AETDataset(Dataset): 31 | def __init__(self, filetxt, src_config, tar_config): 32 | self.config = src_config 33 | 34 | self.preprocessed_dir_src = pathlib.Path( 35 | src_config["general"]["preprocessed_path"] 36 | ) 37 | self.preprocessed_dir_tar = pathlib.Path( 38 | tar_config["general"]["preprocessed_path"] 39 | ) 40 | for item in [ 41 | "sampling_rate", 42 | "fft_length", 43 | "frame_length", 44 | "frame_shift", 45 | "fmin", 46 | "fmax", 47 | "n_mels", 48 | ]: 49 | assert src_config["preprocess"][item] == tar_config["preprocess"][item] 50 | 51 | self.spec_module = torchaudio.transforms.MelSpectrogram( 52 | sample_rate=src_config["preprocess"]["sampling_rate"], 53 | n_fft=src_config["preprocess"]["fft_length"], 54 | win_length=src_config["preprocess"]["frame_length"], 55 | hop_length=src_config["preprocess"]["frame_shift"], 56 | f_min=src_config["preprocess"]["fmin"], 57 | f_max=src_config["preprocess"]["fmax"], 58 | n_mels=src_config["preprocess"]["n_mels"], 59 | power=1, 60 | center=True, 61 | norm="slaney", 62 | mel_scale="slaney", 63 | ) 64 | 65 | with open(self.preprocessed_dir_src / filetxt, "r") as fr: 66 | self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr] 67 | with open(self.preprocessed_dir_tar / filetxt, "r") as fr: 68 | self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr] 69 | 70 | self.d_out = {"src": {}, "tar": {}} 71 | for item in ["wavs", "wavsaux"]: 72 | self.d_out["src"][item] = [] 73 | self.d_out["tar"][item] = [] 74 | 75 | for swp in self.filelist_src: 76 | if src_config["general"]["corpus_type"] == "single": 77 | basename = str(swp.stem) 78 | else: 79 | basename = str(swp.parent.name) + "-" + str(swp.stem) 80 | with open( 81 | self.preprocessed_dir_src / "{}.pickle".format(basename), "rb" 82 | ) as fw: 83 | d_preprocessed = pickle.load(fw) 84 | for item in ["wavs", "wavsaux"]: 85 | try: 86 | self.d_out["src"][item].extend(d_preprocessed[item]) 87 | except: 88 | pass 89 | 90 | for twp in self.filelist_tar: 91 | if tar_config["general"]["corpus_type"] == "single": 92 | basename = str(twp.stem) 93 | else: 94 | basename = str(twp.parent.name) + "-" + str(twp.stem) 95 | with open( 96 | self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb" 97 | ) as fw: 98 | d_preprocessed = pickle.load(fw) 99 | for item in ["wavs", "wavsaux"]: 100 | try: 101 | self.d_out["tar"][item].extend(d_preprocessed[item]) 102 | except: 103 | pass 104 | 105 | min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"])) 106 | for spk in ["src", "tar"]: 107 | for item in ["wavs", "wavsaux"]: 108 | if self.d_out[spk][item] != None: 109 | self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len]) 110 | 111 | def __len__(self): 112 | return len(self.d_out["src"]["wavs"]) 113 | 114 | def __getitem__(self, idx): 115 | d_batch = {} 116 | 117 | for spk in ["src", "tar"]: 118 | for item in ["wavs", "wavsaux"]: 119 | if self.d_out[spk][item].size > 0: 120 | d_batch["{}_{}".format(item, spk)] = torch.from_numpy( 121 | self.d_out[spk][item][idx] 122 | ) 123 | d_batch["{}_{}".format(item, spk)] = self.normalize_waveform( 124 | d_batch["{}_{}".format(item, spk)], db=-3 125 | ) 126 | 127 | d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"]) 128 | return d_batch 129 | 130 | 131 | class AETModule(torch.nn.Module): 132 | """ 133 | src: Dataset from which we extract the channel features 134 | tar: Dataset to which the src channel features are added 135 | """ 136 | 137 | def __init__(self, args, chmatch_config, src_config, tar_config): 138 | super().__init__() 139 | if args.stage == "ssl-step": 140 | LModule = SSLStepLightningModule 141 | elif args.stage == "ssl-dual": 142 | LModule = SSLDualLightningModule 143 | else: 144 | raise NotImplementedError() 145 | 146 | src_model = LModule(src_config).load_from_checkpoint( 147 | checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"], 148 | config=src_config, 149 | ) 150 | self.src_config = src_config 151 | 152 | self.encoder_src = src_model.encoder 153 | if src_config["general"]["use_gst"]: 154 | self.gst_src = src_model.gst 155 | else: 156 | self.channelfeats_src = src_model.channelfeats 157 | self.channel_src = src_model.channel 158 | 159 | def forward(self, melspecs_src, wavsaux_tar): 160 | if self.src_config["general"]["use_gst"]: 161 | chfeats_src = self.gst_src(melspecs_src.transpose(1, 2)) 162 | else: 163 | _, enc_hidden_src = self.encoder_src( 164 | melspecs_src.unsqueeze(1).transpose(2, 3) 165 | ) 166 | chfeats_src = self.channelfeats_src(enc_hidden_src) 167 | wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src) 168 | return wavschmatch_tar 169 | 170 | 171 | def calc_deg_baseline(wav, char_vector, tar_config): 172 | wav = wav[0, ...].cpu().detach().numpy() 173 | spec = librosa.stft( 174 | wav, 175 | n_fft=tar_config["preprocess"]["fft_length"], 176 | hop_length=tar_config["preprocess"]["frame_shift"], 177 | win_length=tar_config["preprocess"]["frame_length"], 178 | ) 179 | spec_converted = spec * char_vector.reshape(-1, 1) 180 | wav_converted = librosa.istft( 181 | spec_converted, 182 | hop_length=tar_config["preprocess"]["frame_shift"], 183 | win_length=tar_config["preprocess"]["frame_length"], 184 | ) 185 | wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0) 186 | return wav_converted 187 | 188 | 189 | def calc_deg_charactaristics(chmatch_config): 190 | src_config = yaml.load( 191 | open(chmatch_config["general"]["source"]["config_path"], "r"), 192 | Loader=yaml.FullLoader, 193 | ) 194 | tar_config = yaml.load( 195 | open(chmatch_config["general"]["target"]["config_path"], "r"), 196 | Loader=yaml.FullLoader, 197 | ) 198 | # configs 199 | preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"]) 200 | n_train = src_config["preprocess"]["n_train"] 201 | SR = src_config["preprocess"]["sampling_rate"] 202 | 203 | os.makedirs(preprocessed_dir, exist_ok=True) 204 | 205 | sourcepath = pathlib.Path(src_config["general"]["source_path"]) 206 | 207 | if src_config["general"]["corpus_type"] == "single": 208 | fulllist = list(sourcepath.glob("*.wav")) 209 | random.seed(0) 210 | random.shuffle(fulllist) 211 | train_filelist = fulllist[:n_train] 212 | elif src_config["general"]["corpus_type"] == "multi-seen": 213 | fulllist = list(sourcepath.glob("*/*.wav")) 214 | random.seed(0) 215 | random.shuffle(fulllist) 216 | train_filelist = fulllist[:n_train] 217 | elif src_config["general"]["corpus_type"] == "multi-unseen": 218 | spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")])) 219 | train_filelist = [] 220 | random.seed(0) 221 | random.shuffle(spk_list) 222 | for i, spk in enumerate(spk_list): 223 | sourcespkpath = sourcepath / spk 224 | if i < n_train: 225 | train_filelist.extend(list(sourcespkpath.glob("*.wav"))) 226 | else: 227 | raise NotImplementedError( 228 | "corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}" 229 | ) 230 | 231 | specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1)) 232 | 233 | for wp in tqdm.tqdm(train_filelist): 234 | wav, _ = librosa.load(wp, sr=SR) 235 | spec = np.abs( 236 | librosa.stft( 237 | wav, 238 | n_fft=src_config["preprocess"]["fft_length"], 239 | hop_length=src_config["preprocess"]["frame_shift"], 240 | win_length=src_config["preprocess"]["frame_length"], 241 | ) 242 | ) 243 | 244 | auxpath = pathlib.Path(src_config["general"]["aux_path"]) 245 | if src_config["general"]["corpus_type"] == "single": 246 | wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR) 247 | else: 248 | wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR) 249 | spec_aux = np.abs( 250 | librosa.stft( 251 | wav_aux, 252 | n_fft=src_config["preprocess"]["fft_length"], 253 | hop_length=src_config["preprocess"]["frame_shift"], 254 | win_length=src_config["preprocess"]["frame_length"], 255 | ) 256 | ) 257 | min_len = min(spec.shape[1], spec_aux.shape[1]) 258 | spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10) 259 | specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)]) 260 | 261 | char_vector = np.mean(specs_all, axis=1) 262 | char_vector = char_vector / (np.sum(char_vector) + 1e-10) 263 | return char_vector 264 | 265 | 266 | def normalize_waveform(wav, tar_config, db=-3): 267 | wav, _ = torchaudio.sox_effects.apply_effects_tensor( 268 | wav, 269 | tar_config["preprocess"]["sampling_rate"], 270 | [["norm", "{}".format(db)]], 271 | ) 272 | return wav 273 | 274 | 275 | def main(args, chmatch_config, device): 276 | src_config = yaml.load( 277 | open(chmatch_config["general"]["source"]["config_path"], "r"), 278 | Loader=yaml.FullLoader, 279 | ) 280 | tar_config = yaml.load( 281 | open(chmatch_config["general"]["target"]["config_path"], "r"), 282 | Loader=yaml.FullLoader, 283 | ) 284 | output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name 285 | dataset = AETDataset("test.txt", src_config, tar_config) 286 | loader = DataLoader(dataset, batch_size=1, shuffle=False) 287 | chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device) 288 | 289 | if args.exist_src_aux: 290 | char_vector = calc_deg_charactaristics(chmatch_config) 291 | 292 | for idx, batch in enumerate(tqdm.tqdm(loader)): 293 | melspecs_src = batch["melspecs_src"].to(device) 294 | wavsdeg_src = batch["wavs_src"].to(device) 295 | wavsaux_tar = batch["wavsaux_tar"].to(device) 296 | if args.exist_src_aux: 297 | wavsdegbaseline_tar = calc_deg_baseline( 298 | batch["wavsaux_tar"], char_vector, tar_config 299 | ) 300 | wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config) 301 | wavsdeg_tar = batch["wavs_tar"].to(device) 302 | wavsmatch_tar = normalize_waveform( 303 | chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config 304 | ) 305 | torchaudio.save( 306 | output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx), 307 | wavsdeg_src.cpu(), 308 | src_config["preprocess"]["sampling_rate"], 309 | ) 310 | torchaudio.save( 311 | output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx), 312 | wavsaux_tar.cpu(), 313 | tar_config["preprocess"]["sampling_rate"], 314 | ) 315 | if args.exist_src_aux: 316 | torchaudio.save( 317 | output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx), 318 | wavsdegbaseline_tar.cpu(), 319 | tar_config["preprocess"]["sampling_rate"], 320 | ) 321 | torchaudio.save( 322 | output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx), 323 | wavsdeg_tar.cpu(), 324 | tar_config["preprocess"]["sampling_rate"], 325 | ) 326 | torchaudio.save( 327 | output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx), 328 | wavsmatch_tar.cpu(), 329 | tar_config["preprocess"]["sampling_rate"], 330 | ) 331 | plot_and_save_mels( 332 | wavsdeg_src[0, ...].cpu().detach(), 333 | output_path / "test_mels" / "{}-src_melsdeg.png".format(idx), 334 | src_config, 335 | ) 336 | plot_and_save_mels( 337 | wavsaux_tar[0, ...].cpu().detach(), 338 | output_path / "test_mels" / "{}-tar_melsaux.png".format(idx), 339 | tar_config, 340 | ) 341 | if args.exist_src_aux: 342 | plot_and_save_mels( 343 | wavsdegbaseline_tar[0, ...].cpu().detach(), 344 | output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx), 345 | tar_config, 346 | ) 347 | plot_and_save_mels( 348 | wavsdeg_tar[0, ...].cpu().detach(), 349 | output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx), 350 | tar_config, 351 | ) 352 | plot_and_save_mels( 353 | wavsmatch_tar[0, ...].cpu().detach(), 354 | output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx), 355 | tar_config, 356 | ) 357 | 358 | 359 | if __name__ == "__main__": 360 | args = get_arg() 361 | chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) 362 | output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name 363 | os.makedirs(output_path, exist_ok=True) 364 | os.makedirs(output_path / "test_wavs", exist_ok=True) 365 | os.makedirs(output_path / "test_mels", exist_ok=True) 366 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 367 | 368 | main(args, chmatch_config, device) 369 | -------------------------------------------------------------------------------- /aet_demo.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import yaml 3 | import torch 4 | import torchaudio 5 | import numpy as np 6 | from lightning_module import SSLDualLightningModule 7 | import gradio as gr 8 | 9 | 10 | def normalize_waveform(wav, sr, db=-3): 11 | wav, _ = torchaudio.sox_effects.apply_effects_tensor( 12 | wav.unsqueeze(0), 13 | sr, 14 | [["norm", "{}".format(db)]], 15 | ) 16 | return wav.squeeze(0) 17 | 18 | 19 | def calc_spectrogram(wav, config): 20 | spec_module = torchaudio.transforms.MelSpectrogram( 21 | sample_rate=config["preprocess"]["sampling_rate"], 22 | n_fft=config["preprocess"]["fft_length"], 23 | win_length=config["preprocess"]["frame_length"], 24 | hop_length=config["preprocess"]["frame_shift"], 25 | f_min=config["preprocess"]["fmin"], 26 | f_max=config["preprocess"]["fmax"], 27 | n_mels=config["preprocess"]["n_mels"], 28 | power=1, 29 | center=True, 30 | norm="slaney", 31 | mel_scale="slaney", 32 | ) 33 | specs = spec_module(wav) 34 | log_spec = torch.log( 35 | torch.clamp_min(specs, config["preprocess"]["min_magnitude"]) 36 | * config["preprocess"]["comp_factor"] 37 | ).to(torch.float32) 38 | return log_spec 39 | 40 | 41 | def transfer(audio): 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | 44 | wp_src = pathlib.Path("aet_sample/src.wav") 45 | wav_src, sr = torchaudio.load(wp_src) 46 | sr_inp, wav_tar = audio 47 | wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1) 48 | wav_tar = torch.from_numpy(wav_tar.astype(np.float32)) 49 | resampler = torchaudio.transforms.Resample( 50 | orig_freq=sr_inp, 51 | new_freq=sr, 52 | ) 53 | wav_tar = resampler(wav_tar) 54 | config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml") 55 | config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 56 | 57 | melspec_src = calc_spectrogram(normalize_waveform(wav_src.squeeze(0), sr), config) 58 | wav_tar = normalize_waveform(wav_tar.squeeze(0), sr) 59 | ckpt_path = pathlib.Path("aet_sample/tono_aet_melspec.ckpt") 60 | src_model = SSLDualLightningModule(config).load_from_checkpoint( 61 | checkpoint_path=ckpt_path, 62 | config=config, 63 | ) 64 | 65 | encoder_src = src_model.encoder.to(device) 66 | channelfeats_src = src_model.channelfeats.to(device) 67 | channel_src = src_model.channel.to(device) 68 | 69 | _, enc_hidden_src = encoder_src( 70 | melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device) 71 | ) 72 | chfeats_src = channelfeats_src(enc_hidden_src) 73 | wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src) 74 | wav_transfer = wav_transfer.cpu().detach().numpy()[0, :] 75 | return sr, wav_transfer 76 | 77 | 78 | if __name__ == "__main__": 79 | iface = gr.Interface( 80 | transfer, 81 | "audio", 82 | gr.outputs.Audio(type="numpy"), 83 | examples=[["aet_sample/tar.wav"]], 84 | title="Audio effect transfer demo", 85 | description="Add channel feature of Japanese old audio recording to any high-quality audio", 86 | ) 87 | 88 | iface.launch(share=True) 89 | -------------------------------------------------------------------------------- /aet_sample/src.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/ssl_speech_restoration/32c420a346c9169b710eb9bd33c0fe0462dd2ccb/aet_sample/src.wav -------------------------------------------------------------------------------- /aet_sample/tar.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/ssl_speech_restoration/32c420a346c9169b710eb9bd33c0fe0462dd2ccb/aet_sample/tar.wav -------------------------------------------------------------------------------- /aet_sample/transfer.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/ssl_speech_restoration/32c420a346c9169b710eb9bd33c0fe0462dd2ccb/aet_sample/transfer.wav -------------------------------------------------------------------------------- /configs/test/melspec/audio_effect_transfer.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | preprocessed_path: "./preprocessed/audio_effect_transfer" 3 | output_path: "./output/melspec/audio_effect_transfer" 4 | feature_type: "melspec" 5 | source: 6 | dataset_path: "./data/tono" 7 | config_path: "./configs/test/melspec/ssl_tono.yaml" 8 | ckpt_path: "./ckpts_tono/tono_melspec_multi_nopre_0217.ckpt" 9 | target: 10 | dataset_path: "./data/jvs_22k-low" 11 | config_path: "./configs/test/melspec/pretrain_jvs.yaml" 12 | use_gst: False 13 | 14 | preprocess: 15 | sampling_rate: 22050 16 | segment_length: -1 17 | frame_shift: 256 18 | 19 | model: null 20 | 21 | train: 22 | epoch: 100 23 | batchsize: 8 24 | multi_gpu_mode: False 25 | num_workers: 4 26 | learning_rate: 0.001 27 | grad_clip_thresh: 1.0 28 | logger_step: 1000 -------------------------------------------------------------------------------- /configs/test/melspec/dual.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" 4 | source_path: "./data/jvs_22k" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/dual" 7 | preprocess: 8 | n_train: 90 9 | n_val: 5 10 | n_test: 5 11 | sampling_rate: 22050 12 | segment_length: 2 -------------------------------------------------------------------------------- /configs/test/melspec/pretrain_jvs.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jvs_22k-low" 5 | aux_path: "./data/jvs_22k" 6 | preprocessed_path: "./preprocessed/jvs" 7 | output_path: "./output/melspec/pretrain" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 90 16 | n_val: 5 17 | n_test: 5 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | max_wav_value: 32768.0 28 | segment_length: -1 29 | 30 | train: 31 | batchsize: 8 32 | epoch: 50 33 | alpha: 0.1 34 | augment: True 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.005 38 | grad_clip_thresh: 1.0 39 | logger_step: 1000 40 | load_pretrained: False 41 | pretrained_path: null 42 | early_stopping: False 43 | multi_scale_loss: 44 | use_linear: False 45 | gamma: 1.0 46 | feature_loss: 47 | type: "mae" -------------------------------------------------------------------------------- /configs/test/melspec/ssl_jsut.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jsut_22k-low" 5 | aux_path: "./data/jsut_22k" 6 | preprocessed_path: "./preprocessed/jsut-low" 7 | output_path: "./output/melspec/jsut-low" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 4950 16 | n_val: 25 17 | n_test: 25 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | bitrate: "16k" 28 | max_wav_value: 32768.0 29 | segment_length: -1 30 | 31 | train: 32 | batchsize: 1 33 | epoch: 50 34 | epoch_channel: 25 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.001 38 | alpha: 0.1 39 | beta: 0.001 40 | augment: False 41 | grad_clip_thresh: 1.0 42 | logger_step: 1000 43 | load_pretrained: False 44 | pretrained_path: null 45 | fix_channel: False 46 | early_stopping: False 47 | multi_scale_loss: 48 | use_linear: True 49 | gamma: 1.0 50 | feature_loss: 51 | type: "mae" 52 | 53 | dual: 54 | enable: True 55 | config_path: ./configs/test/melspec/dual.yaml -------------------------------------------------------------------------------- /configs/test/melspec/ssl_tono.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/tono_22k" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/tono" 7 | output_path: "./output/melspec/tono" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 270 16 | n_val: 34 17 | n_test: 30 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | bitrate: "16k" 28 | max_wav_value: 32768.0 29 | segment_length: -1 30 | 31 | train: 32 | batchsize: 4 33 | epoch: 50 34 | epoch_channel: 25 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.001 38 | alpha: 0.1 39 | beta: 0.001 40 | grad_clip_thresh: 1.0 41 | logger_step: 1000 42 | load_pretrained: False 43 | pretrained_path: null 44 | fix_channel: False 45 | early_stopping: False 46 | multi_scale_loss: 47 | use_linear: True 48 | gamma: 1.0 49 | feature_loss: 50 | type: "mae" 51 | 52 | dual: 53 | enable: True 54 | config_path: ./configs/train/melspec/dual.yaml -------------------------------------------------------------------------------- /configs/test/vocfeats/audio_effect_transfer.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | preprocessed_path: "./preprocessed/audio_effect_transfer" 3 | output_path: "./output/vocfeats/audio_effect_transfer" 4 | feature_type: "vocfeats" 5 | source: 6 | dataset_path: "./data/tono" 7 | config_path: "./configs/test/melspec/ssl_tono.yaml" 8 | ckpt_path: "./ckpts_tono/tono_melspec_multi_nopre_0217.ckpt" 9 | target: 10 | dataset_path: "./data/jvs_22k-low" 11 | config_path: "./configs/test/vocfeats/pretrain_jvs.yaml" 12 | use_gst: False 13 | 14 | preprocess: 15 | sampling_rate: 22050 16 | segment_length: -1 17 | frame_shift: 256 18 | 19 | model: null 20 | 21 | train: 22 | epoch: 100 23 | batchsize: 8 24 | multi_gpu_mode: False 25 | num_workers: 4 26 | learning_rate: 0.001 27 | grad_clip_thresh: 1.0 28 | logger_step: 1000 -------------------------------------------------------------------------------- /configs/test/vocfeats/dual.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" 4 | source_path: "./data/jvs_22k" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/dual" 7 | preprocess: 8 | n_train: 90 9 | n_val: 5 10 | n_test: 5 11 | sampling_rate: 22050 12 | segment_length: 2 -------------------------------------------------------------------------------- /configs/test/vocfeats/pretrain_jvs.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jvs_22k-low" 5 | aux_path: "./data/jvs_22k" 6 | preprocessed_path: "./preprocessed/jvs" 7 | output_path: "./output/vocfeats/pretrain" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 90 16 | n_val: 5 17 | n_test: 5 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | f0_extractor: "dio" 27 | comp_factor: 1.0 28 | min_magnitude: 0.00001 29 | max_wav_value: 32768.0 30 | segment_length: -1 31 | 32 | train: 33 | batchsize: 8 34 | epoch: 50 35 | alpha: 0.1 36 | augment: True 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.005 40 | grad_clip_thresh: 1.0 41 | logger_step: 1000 42 | load_pretrained: False 43 | pretrained_path: null 44 | early_stopping: False 45 | multi_scale_loss: 46 | use_linear: True 47 | gamma: 1.0 48 | feature_loss: 49 | type: "mae" -------------------------------------------------------------------------------- /configs/test/vocfeats/ssl_jsut.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jsut_22k-low" 5 | aux_path: "./data/jsut_22k" 6 | preprocessed_path: "./preprocessed/jsut-low" 7 | output_path: "./output/vocfeats/jsut-low" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 4950 16 | n_val: 25 17 | n_test: 25 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | f0_extractor: "harvest" 27 | comp_factor: 1.0 28 | min_magnitude: 0.00001 29 | bitrate: "16k" 30 | max_wav_value: 32768.0 31 | segment_length: -1 32 | 33 | train: 34 | batchsize: 1 35 | epoch: 50 36 | epoch_channel: 25 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.001 40 | alpha: 0.1 41 | beta: 0.001 42 | augment: False 43 | grad_clip_thresh: 1.0 44 | logger_step: 1000 45 | load_pretrained: False 46 | pretrained_path: null 47 | fix_channel: False 48 | early_stopping: False 49 | multi_scale_loss: 50 | use_linear: True 51 | gamma: 1.0 52 | feature_loss: 53 | type: "mae" 54 | 55 | dual: 56 | enable: True 57 | config_path: ./configs/test/vocfeats/dual.yaml -------------------------------------------------------------------------------- /configs/test/vocfeats/ssl_tono.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/tono" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/tono-denoise" 7 | output_path: "./output/vocfeats/tono-denoise" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 270 16 | n_val: 34 17 | n_test: 30 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | comp_factor: 1.0 27 | min_magnitude: 0.00001 28 | bitrate: "16k" 29 | f0_extractor: "harvest" 30 | max_wav_value: 32768.0 31 | segment_length: -1 32 | 33 | train: 34 | batchsize: 4 35 | epoch: 50 36 | epoch_channel: 25 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.001 40 | alpha: 0.1 41 | beta: 0.001 42 | grad_clip_thresh: 1.0 43 | logger_step: 1000 44 | load_pretrained: False 45 | pretrained_path: null 46 | fix_channel: False 47 | early_stopping: False 48 | multi_scale_loss: 49 | use_linear: True 50 | gamma: 1.0 51 | feature_loss: 52 | type: "mae" 53 | 54 | dual: 55 | enable: True 56 | config_path: ./configs/train/vocfeats/dual.yaml -------------------------------------------------------------------------------- /configs/train/melspec/dual.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" 4 | source_path: "./data/jvs_22k" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/dual" 7 | preprocess: 8 | n_train: 90 9 | n_val: 5 10 | n_test: 5 11 | sampling_rate: 22050 12 | segment_length: 2 13 | -------------------------------------------------------------------------------- /configs/train/melspec/pretrain_jvs.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jvs_22k-low" 5 | aux_path: ./data/jvs_22k" 6 | preprocessed_path: "./preprocessed/jvs" 7 | output_path: "./output/melspec/pretrain" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 90 16 | n_val: 5 17 | n_test: 5 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | max_wav_value: 32768.0 28 | segment_length: 2 29 | 30 | train: 31 | batchsize: 8 32 | epoch: 50 33 | alpha: 0.1 34 | augment: True 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.005 38 | grad_clip_thresh: 1.0 39 | logger_step: 1000 40 | load_pretrained: False 41 | pretrained_path: null 42 | early_stopping: False 43 | multi_scale_loss: 44 | use_linear: False 45 | gamma: 1.0 46 | feature_loss: 47 | type: "mae" -------------------------------------------------------------------------------- /configs/train/melspec/ssl_jsut.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jsut_22k-low" 5 | aux_path: "./data/jsut_22k" 6 | preprocessed_path: "./preprocessed/jsut-low" 7 | output_path: "./output/melspec/jsut-low" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 4950 16 | n_val: 25 17 | n_test: 25 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | bitrate: "16k" 28 | max_wav_value: 32768.0 29 | segment_length: 2 30 | 31 | train: 32 | batchsize: 4 33 | epoch: 50 34 | epoch_channel: 25 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.001 38 | alpha: 0.1 39 | beta: 0.001 40 | grad_clip_thresh: 1.0 41 | logger_step: 1000 42 | load_pretrained: True 43 | pretrained_path: null 44 | fix_channel: False 45 | early_stopping: False 46 | multi_scale_loss: 47 | use_linear: True 48 | gamma: 1.0 49 | feature_loss: 50 | type: "mae" 51 | 52 | dual: 53 | enable: True 54 | config_path: ./configs/train/melspec/dual.yaml -------------------------------------------------------------------------------- /configs/train/melspec/ssl_tono.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/tono" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/tono" 7 | output_path: "./output/melspec/tono" 8 | test_wav_path: null 9 | feature_type: "melspec" 10 | hifigan_path: "./hifigan/hifigan_melspec_universal" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 270 16 | n_val: 34 17 | n_test: 30 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | comp_factor: 1.0 26 | min_magnitude: 0.00001 27 | bitrate: "16k" 28 | max_wav_value: 32768.0 29 | segment_length: 2 30 | 31 | train: 32 | batchsize: 4 33 | epoch: 50 34 | epoch_channel: 25 35 | multi_gpu_mode: False 36 | num_workers: 4 37 | learning_rate: 0.001 38 | alpha: 0.1 39 | beta: 0.001 40 | grad_clip_thresh: 1.0 41 | logger_step: 1000 42 | load_pretrained: False 43 | pretrained_path: null 44 | fix_channel: False 45 | early_stopping: False 46 | multi_scale_loss: 47 | use_linear: True 48 | gamma: 1.0 49 | feature_loss: 50 | type: "mae" 51 | 52 | dual: 53 | enable: True 54 | config_path: ./configs/train/melspec/dual.yaml -------------------------------------------------------------------------------- /configs/train/vocfeats/dual.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" 4 | source_path: "./data/jvs_22k" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/dual" 7 | preprocess: 8 | n_train: 90 9 | n_val: 5 10 | n_test: 5 11 | sampling_rate: 22050 12 | segment_length: 2 -------------------------------------------------------------------------------- /configs/train/vocfeats/pretrain_jvs.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "pretrain" 3 | corpus_type: "multi-unseen" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jvs_22k-low" 5 | aux_path: "./data/jvs_22k" 6 | preprocessed_path: "./preprocessed/jvs" 7 | output_path: "./output/vocfeats/pretrain" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 90 16 | n_val: 5 17 | n_test: 5 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | f0_extractor: "dio" 27 | comp_factor: 1.0 28 | min_magnitude: 0.00001 29 | max_wav_value: 32768.0 30 | segment_length: 2 31 | 32 | train: 33 | batchsize: 8 34 | epoch: 50 35 | alpha: 0.1 36 | augment: True 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.005 40 | grad_clip_thresh: 1.0 41 | logger_step: 1000 42 | load_pretrained: False 43 | pretrained_path: null 44 | early_stopping: False 45 | multi_scale_loss: 46 | use_linear: True 47 | gamma: 1.0 48 | feature_loss: 49 | type: "mae" -------------------------------------------------------------------------------- /configs/train/vocfeats/ssl_jsut.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/jsut_22k-low" 5 | aux_path: "./data/jsut_22k" 6 | preprocessed_path: "./preprocessed/jsut-low" 7 | output_path: "./output/vocfeats/jsut-low" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 4950 16 | n_val: 25 17 | n_test: 25 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | comp_factor: 1.0 27 | min_magnitude: 0.00001 28 | bitrate: "16k" 29 | f0_extractor: "harvest" 30 | max_wav_value: 32768.0 31 | segment_length: 2 32 | 33 | train: 34 | batchsize: 4 35 | epoch: 50 36 | epoch_channel: 25 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.001 40 | alpha: 0.1 41 | beta: 0.001 42 | grad_clip_thresh: 1.0 43 | logger_step: 1000 44 | load_pretrained: True 45 | pretrained_path: null 46 | fix_channel: False 47 | early_stopping: False 48 | multi_scale_loss: 49 | use_linear: True 50 | gamma: 1.0 51 | feature_loss: 52 | type: "mae" 53 | 54 | dual: 55 | enable: True 56 | config_path: ./configs/train/vocfeats/dual.yaml -------------------------------------------------------------------------------- /configs/train/vocfeats/ssl_tono.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | stage: "ssl" 3 | corpus_type: "single" # (single, multi-seen, multi-unseen) 4 | source_path: "./data/tono" 5 | aux_path: null 6 | preprocessed_path: "./preprocessed/tono" 7 | output_path: "./output/vocfeats/tono" 8 | test_wav_path: null 9 | feature_type: "vocfeats" 10 | hifigan_path: "./hifigan/hifigan_jvs_40d_600k" 11 | power_norm: True 12 | use_gst: False 13 | 14 | preprocess: 15 | n_train: 270 16 | n_val: 34 17 | n_test: 30 18 | sampling_rate: 22050 19 | frame_length: 1024 20 | frame_shift: 256 21 | fft_length: 1024 22 | fmin: 0 23 | fmax: 8000 24 | n_mels: 80 25 | cep_order: 40 26 | comp_factor: 1.0 27 | min_magnitude: 0.00001 28 | bitrate: "16k" 29 | f0_extractor: "harvest" 30 | max_wav_value: 32768.0 31 | segment_length: 2 32 | 33 | train: 34 | batchsize: 4 35 | epoch: 50 36 | epoch_channel: 25 37 | multi_gpu_mode: False 38 | num_workers: 4 39 | learning_rate: 0.001 40 | alpha: 0.1 41 | beta: 0.001 42 | grad_clip_thresh: 1.0 43 | logger_step: 1000 44 | load_pretrained: False 45 | pretrained_path: null 46 | fix_channel: False 47 | early_stopping: False 48 | multi_scale_loss: 49 | use_linear: True 50 | gamma: 1.0 51 | feature_loss: 52 | type: "mae" 53 | 54 | dual: 55 | enable: True 56 | config_path: ./configs/train/vocfeats/dual.yaml -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pathlib 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | import pytorch_lightning as pl 6 | import numpy as np 7 | import yaml 8 | import torchaudio 9 | import pyworld 10 | import pysptk 11 | import random 12 | 13 | 14 | class DataModule(pl.LightningDataModule): 15 | def __init__(self, config): 16 | super().__init__() 17 | self.config = config 18 | self.batchsize = config["train"]["batchsize"] 19 | self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"]) 20 | 21 | def setup(self, stage): 22 | 23 | if not self.preprocessed_dir.exists(): 24 | raise RuntimeError("Preprocessed directory was not be found") 25 | 26 | if "dual" in self.config: 27 | if self.config["dual"]["enable"]: 28 | task_config = yaml.load( 29 | open(self.config["dual"]["config_path"], "r"), 30 | Loader=yaml.FullLoader, 31 | ) 32 | task_preprocessed_dir = ( 33 | self.preprocessed_dir.parent 34 | / pathlib.Path(task_config["general"]["preprocessed_path"]).name 35 | ) 36 | if not task_preprocessed_dir.exists(): 37 | raise RuntimeError( 38 | "Preprocessed directory for multi-task learning was not be found" 39 | ) 40 | 41 | self.flnames = { 42 | "train": "train.txt", 43 | "val": "val.txt", 44 | "test": "test.txt", 45 | } 46 | 47 | def get_ds(self, phase): 48 | ds = Dataset(self.flnames[phase], self.config) 49 | return ds 50 | 51 | def get_loader(self, phase): 52 | ds = self.get_ds(phase) 53 | dl = DataLoader( 54 | ds, 55 | self.batchsize, 56 | shuffle=True if phase == "train" else False, 57 | num_workers=self.config["train"]["num_workers"], 58 | drop_last=True, 59 | ) 60 | return dl 61 | 62 | def train_dataloader(self): 63 | return self.get_loader(phase="train") 64 | 65 | def val_dataloader(self): 66 | return self.get_loader(phase="val") 67 | 68 | def test_dataloader(self): 69 | return self.get_loader(phase="test") 70 | 71 | 72 | class Dataset(torch.utils.data.Dataset): 73 | def __init__(self, filetxt, config): 74 | 75 | self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"]) 76 | self.config = config 77 | self.spec_module = torchaudio.transforms.MelSpectrogram( 78 | sample_rate=config["preprocess"]["sampling_rate"], 79 | n_fft=config["preprocess"]["fft_length"], 80 | win_length=config["preprocess"]["frame_length"], 81 | hop_length=config["preprocess"]["frame_shift"], 82 | f_min=config["preprocess"]["fmin"], 83 | f_max=config["preprocess"]["fmax"], 84 | n_mels=config["preprocess"]["n_mels"], 85 | power=1, 86 | center=True, 87 | norm="slaney", 88 | mel_scale="slaney", 89 | ) 90 | self.resample_candidate = [8000, 11025, 12000, 16000] 91 | self.quantization_candidate = range(2 ** 6, 2 ** 10 + 2, 2) 92 | self.segment_length = config["preprocess"]["segment_length"] 93 | 94 | with open(self.preprocessed_dir / filetxt, "r") as fr: 95 | self.filelist = [pathlib.Path(path.strip("\n")) for path in fr] 96 | 97 | self.d_out = dict() 98 | for item in ["wavs", "wavsaux"]: 99 | self.d_out[item] = [] 100 | 101 | for wp in self.filelist: 102 | 103 | if config["general"]["corpus_type"] == "single": 104 | basename = str(wp.stem) 105 | else: 106 | basename = str(wp.parent.name) + "-" + str(wp.stem) 107 | 108 | with open(self.preprocessed_dir / "{}.pickle".format(basename), "rb") as fw: 109 | d_preprocessed = pickle.load(fw) 110 | 111 | for item in ["wavs", "wavsaux"]: 112 | try: 113 | self.d_out[item].extend(d_preprocessed[item]) 114 | except: 115 | pass 116 | 117 | for item in ["wavs", "wavsaux"]: 118 | if self.d_out[item] != None: 119 | self.d_out[item] = np.asarray(self.d_out[item]) 120 | 121 | if "dual" in self.config: 122 | if self.config["dual"]["enable"]: 123 | task_config = yaml.load( 124 | open(config["dual"]["config_path"], "r"), 125 | Loader=yaml.FullLoader, 126 | ) 127 | task_preprocessed_dir = ( 128 | self.preprocessed_dir.parent 129 | / pathlib.Path(task_config["general"]["preprocessed_path"]).name 130 | ) 131 | with open(task_preprocessed_dir / filetxt, "r") as fr: 132 | task_filelist = [pathlib.Path(path.strip("\n")) for path in fr] 133 | self.d_out["wavstask"] = [] 134 | for wp in task_filelist: 135 | if task_config["general"]["corpus_type"] == "single": 136 | basename = str(wp.stem) 137 | else: 138 | basename = str(wp.parent.name) + "-" + str(wp.stem) 139 | with open( 140 | task_preprocessed_dir / "{}.pickle".format(basename), "rb" 141 | ) as fw: 142 | d_preprocessed = pickle.load(fw) 143 | self.d_out["wavstask"].extend(d_preprocessed["wavs"]) 144 | self.d_out["wavstask"] = np.asarray(self.d_out["wavstask"]) 145 | 146 | def __len__(self): 147 | return len(self.d_out["wavs"]) 148 | 149 | def __getitem__(self, idx): 150 | 151 | d_batch = {} 152 | 153 | if self.d_out["wavs"].size > 0: 154 | d_batch["wavs"] = torch.from_numpy(self.d_out["wavs"][idx]) 155 | 156 | if self.d_out["wavsaux"].size > 0: 157 | d_batch["wavsaux"] = torch.from_numpy(self.d_out["wavsaux"][idx]) 158 | 159 | if (self.d_out["wavs"].size > 0) & (self.segment_length > 0): 160 | if self.d_out["wavsaux"].size > 0: 161 | d_batch["wavs"], d_batch["wavsaux"] = self.get_segment( 162 | d_batch["wavs"], 163 | self.segment_length, 164 | d_batch["wavsaux"] 165 | ) 166 | else: 167 | d_batch["wavs"] = self.get_segment(d_batch["wavs"], self.segment_length) 168 | 169 | if self.config["general"]["stage"] == "pretrain": 170 | if self.config["train"]["augment"]: 171 | d_batch["wavs"] = self.augmentation(d_batch["wavsaux"]) 172 | d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3) 173 | d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3) 174 | if len(d_batch["wavs"]) != len(d_batch["wavsaux"]): 175 | min_seq_len = min(len(d_batch["wavs"]), len(d_batch["wavsaux"])) 176 | d_batch["wavs"] = d_batch["wavs"][:min_seq_len] 177 | d_batch["wavsaux"] = d_batch["wavsaux"][:min_seq_len] 178 | d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"]) 179 | if self.config["general"]["feature_type"] == "melspec": 180 | d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"]) 181 | elif self.config["general"]["feature_type"] == "vocfeats": 182 | d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"]) 183 | d_batch["f0s"] = self.calc_f0(d_batch["wavs"]) 184 | d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"]) 185 | else: 186 | raise NotImplementedError() 187 | 188 | elif self.config["general"]["stage"].startswith("ssl"): 189 | d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3) 190 | d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"]) 191 | if self.config["general"]["feature_type"] == "vocfeats": 192 | d_batch["f0s"] = self.calc_f0(d_batch["wavs"]) 193 | d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"]) 194 | if self.d_out["wavsaux"].size > 0: 195 | d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3) 196 | if self.config["general"]["feature_type"] == "melspec": 197 | d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"]) 198 | elif self.config["general"]["feature_type"] == "vocfeats": 199 | d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"]) 200 | if "dual" in self.config: 201 | if self.config["dual"]["enable"]: 202 | rand_idx = random.randint(0, len(self.d_out["wavstask"]) - 1) 203 | d_batch["wavstask"] = torch.from_numpy(self.d_out["wavstask"][rand_idx]) 204 | if self.segment_length > 0: 205 | d_batch["wavstask"] = self.get_segment( 206 | d_batch["wavstask"], self.segment_length 207 | ) 208 | d_batch["wavstask"] = self.normalize_waveform( 209 | d_batch["wavstask"], db=-3 210 | ) 211 | if self.config["general"]["feature_type"] == "melspec": 212 | d_batch["melspecstask"] = self.calc_spectrogram( 213 | d_batch["wavstask"] 214 | ) 215 | elif self.config["general"]["feature_type"] == "vocfeats": 216 | d_batch["melcepstask"] = self.calc_melcep(d_batch["wavstask"]) 217 | else: 218 | raise NotImplementedError() 219 | else: 220 | raise NotImplementedError() 221 | 222 | return d_batch 223 | 224 | def calc_spectrogram(self, wav): 225 | specs = self.spec_module(wav) 226 | log_spec = torch.log( 227 | torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"]) 228 | * self.config["preprocess"]["comp_factor"] 229 | ).to(torch.float32) 230 | return log_spec 231 | 232 | def calc_melcep(self, wav): 233 | wav = wav.numpy() 234 | _, sp, _ = pyworld.wav2world( 235 | wav.astype(np.float64), 236 | self.config["preprocess"]["sampling_rate"], 237 | fft_size=self.config["preprocess"]["fft_length"], 238 | frame_period=( 239 | self.config["preprocess"]["frame_shift"] 240 | / self.config["preprocess"]["sampling_rate"] 241 | * 1000 242 | ), 243 | ) 244 | melcep = pysptk.sp2mc( 245 | sp, 246 | order=self.config["preprocess"]["cep_order"], 247 | alpha=pysptk.util.mcepalpha(self.config["preprocess"]["sampling_rate"]), 248 | ).transpose(1, 0) 249 | melcep = torch.from_numpy(melcep).to(torch.float32) 250 | return melcep 251 | 252 | def calc_f0(self, wav): 253 | if self.config["preprocess"]["f0_extractor"] == "dio": 254 | return self.calc_f0_dio(wav) 255 | elif self.config["preprocess"]["f0_extractor"] == "harvest": 256 | return self.calc_f0_harvest(wav) 257 | elif self.config["preprocess"]["f0_extractor"] == "swipe": 258 | return self.calc_f0_swipe(wav) 259 | else: 260 | raise NotImplementedError() 261 | 262 | def calc_f0_dio(self, wav): 263 | wav = wav.numpy() 264 | _f0, _t = pyworld.dio( 265 | wav.astype(np.float64), 266 | self.config["preprocess"]["sampling_rate"], 267 | frame_period=( 268 | self.config["preprocess"]["frame_shift"] 269 | / self.config["preprocess"]["sampling_rate"] 270 | * 1000 271 | ), 272 | ) 273 | f0 = pyworld.stonemask( 274 | wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"] 275 | ) 276 | f0 = torch.from_numpy(f0).to(torch.float32) 277 | return f0 278 | 279 | def calc_f0_harvest(self, wav): 280 | wav = wav.numpy() 281 | _f0, _t = pyworld.harvest( 282 | wav.astype(np.float64), 283 | self.config["preprocess"]["sampling_rate"], 284 | frame_period=( 285 | self.config["preprocess"]["frame_shift"] 286 | / self.config["preprocess"]["sampling_rate"] 287 | * 1000 288 | ), 289 | ) 290 | f0 = pyworld.stonemask( 291 | wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"] 292 | ) 293 | f0 = torch.from_numpy(f0).to(torch.float32) 294 | return f0 295 | 296 | def calc_f0_swipe(self, wav): 297 | wav = wav.numpy() 298 | f0 = pysptk.sptk.swipe( 299 | wav.astype(np.float64), 300 | fs=self.config["preprocess"]["sampling_rate"], 301 | min=71, 302 | max=800, 303 | hopsize=self.config["preprocess"]["frame_shift"], 304 | otype="f0", 305 | ) 306 | f0 = torch.from_numpy(f0).to(torch.float32) 307 | return f0 308 | 309 | def augmentation(self, wav): 310 | wav /= torch.max(torch.abs(wav)) 311 | new_freq = random.choice(self.resample_candidate) 312 | new_quantization = random.choice(self.quantization_candidate) 313 | mulaw_encoder = torchaudio.transforms.MuLawEncoding( 314 | quantization_channels=new_quantization 315 | ) 316 | wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0 317 | downsampler = torchaudio.transforms.Resample( 318 | orig_freq=self.config["preprocess"]["sampling_rate"], 319 | new_freq=new_freq, 320 | resampling_method="sinc_interpolation", 321 | lowpass_filter_width=6, 322 | dtype=torch.float32, 323 | ) 324 | upsampler = torchaudio.transforms.Resample( 325 | orig_freq=new_freq, 326 | new_freq=self.config["preprocess"]["sampling_rate"], 327 | resampling_method="sinc_interpolation", 328 | lowpass_filter_width=6, 329 | dtype=torch.float32, 330 | ) 331 | wav_processed = upsampler(downsampler(wav_quantized)) 332 | return wav_processed 333 | 334 | def normalize_waveform(self, wav, db=-3): 335 | wav, _ = torchaudio.sox_effects.apply_effects_tensor( 336 | wav.unsqueeze(0), 337 | self.config["preprocess"]["sampling_rate"], 338 | [["norm", "{}".format(db)]], 339 | ) 340 | return wav.squeeze(0) 341 | 342 | def get_segment(self, wav, segment_length, wavaux=None): 343 | seg_size = self.config["preprocess"]["sampling_rate"] * segment_length 344 | if len(wav) >= seg_size: 345 | max_wav_start = len(wav) - seg_size 346 | wav_start = random.randint(0, max_wav_start) 347 | wav = wav[wav_start : wav_start + seg_size] 348 | if wavaux != None: 349 | wavaux = wavaux[wav_start : wav_start + seg_size] 350 | else: 351 | wav = torch.nn.functional.pad(wav, (0, seg_size - len(wav)), "constant") 352 | if wavaux != None: 353 | wavaux = torch.nn.functional.pad(wavaux, (0, seg_size - len(wavaux)), "constant") 354 | if wavaux != None: 355 | return wav, wavaux 356 | else: 357 | return wav 358 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | import yaml 5 | 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.loggers.csv_logs import CSVLogger 8 | 9 | from dataset import DataModule 10 | from lightning_module import ( 11 | PretrainLightningModule, 12 | SSLStepLightningModule, 13 | SSLDualLightningModule, 14 | ) 15 | 16 | 17 | def get_arg(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config_path", required=True, type=pathlib.Path) 20 | parser.add_argument("--ckpt_path", required=True, type=pathlib.Path) 21 | parser.add_argument( 22 | "--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"] 23 | ) 24 | parser.add_argument("--run_name", required=True, type=str) 25 | return parser.parse_args() 26 | 27 | 28 | def eval(args, config, output_path): 29 | 30 | csvlogger = CSVLogger(save_dir=output_path, name="test_log") 31 | trainer = Trainer( 32 | gpus=-1, 33 | deterministic=False, 34 | auto_select_gpus=True, 35 | benchmark=True, 36 | logger=[csvlogger], 37 | default_root_dir=os.getcwd(), 38 | ) 39 | 40 | if config["general"]["stage"] == "pretrain": 41 | model = PretrainLightningModule(config).load_from_checkpoint( 42 | checkpoint_path=args.ckpt_path, config=config 43 | ) 44 | elif config["general"]["stage"] == "ssl-step": 45 | model = SSLStepLightningModule(config).load_from_checkpoint( 46 | checkpoint_path=args.ckpt_path, config=config 47 | ) 48 | elif config["general"]["stage"] == "ssl-dual": 49 | model = SSLDualLightningModule(config).load_from_checkpoint( 50 | checkpoint_path=args.ckpt_path, config=config 51 | ) 52 | else: 53 | raise NotImplementedError() 54 | 55 | datamodule = DataModule(config) 56 | trainer.test(model=model, verbose=True, datamodule=datamodule) 57 | 58 | 59 | if __name__ == "__main__": 60 | args = get_arg() 61 | config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) 62 | output_path = str(pathlib.Path(config["general"]["output_path"]) / args.run_name) 63 | config["general"]["stage"] = str(getattr(args, "stage")) 64 | 65 | eval(args, config, output_path) 66 | -------------------------------------------------------------------------------- /hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self -------------------------------------------------------------------------------- /hifigan/config_melspec.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "feat_order": 80, 27 | 28 | "fmin": 0, 29 | "fmax": 8000, 30 | "fmax_for_loss": null, 31 | 32 | "num_workers": 4, 33 | 34 | "dist_config": { 35 | "dist_backend": "nccl", 36 | "dist_url": "tcp://localhost:54321", 37 | "world_size": 1 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /hifigan/config_vocfeats.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "feat_order": 42, 27 | 28 | "fmin": 0, 29 | "fmax": 8000, 30 | "fmax_for_loss": null, 31 | 32 | "num_workers": 4, 33 | 34 | "dist_config": { 35 | "dist_backend": "nccl", 36 | "dist_url": "tcp://localhost:54321", 37 | "world_size": 1 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | def init_weights(m, mean=0.0, std=0.01): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | m.weight.data.normal_(mean, std) 13 | 14 | 15 | def get_padding(kernel_size, dilation=1): 16 | return int((kernel_size * dilation - dilation) / 2) 17 | 18 | 19 | class ResBlock1(torch.nn.Module): 20 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 21 | super(ResBlock1, self).__init__() 22 | self.h = h 23 | self.convs1 = nn.ModuleList([ 24 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 25 | padding=get_padding(kernel_size, dilation[0]))), 26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 27 | padding=get_padding(kernel_size, dilation[1]))), 28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 29 | padding=get_padding(kernel_size, dilation[2]))) 30 | ]) 31 | self.convs1.apply(init_weights) 32 | 33 | self.convs2 = nn.ModuleList([ 34 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 35 | padding=get_padding(kernel_size, 1))), 36 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 37 | padding=get_padding(kernel_size, 1))), 38 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 39 | padding=get_padding(kernel_size, 1))) 40 | ]) 41 | self.convs2.apply(init_weights) 42 | 43 | def forward(self, x): 44 | for c1, c2 in zip(self.convs1, self.convs2): 45 | xt = F.leaky_relu(x, LRELU_SLOPE) 46 | xt = c1(xt) 47 | xt = F.leaky_relu(xt, LRELU_SLOPE) 48 | xt = c2(xt) 49 | x = xt + x 50 | return x 51 | 52 | def remove_weight_norm(self): 53 | for l in self.convs1: 54 | remove_weight_norm(l) 55 | for l in self.convs2: 56 | remove_weight_norm(l) 57 | 58 | 59 | class ResBlock2(torch.nn.Module): 60 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 61 | super(ResBlock2, self).__init__() 62 | self.h = h 63 | self.convs = nn.ModuleList([ 64 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 65 | padding=get_padding(kernel_size, dilation[0]))), 66 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 67 | padding=get_padding(kernel_size, dilation[1]))) 68 | ]) 69 | self.convs.apply(init_weights) 70 | 71 | def forward(self, x): 72 | for c in self.convs: 73 | xt = F.leaky_relu(x, LRELU_SLOPE) 74 | xt = c(xt) 75 | x = xt + x 76 | return x 77 | 78 | def remove_weight_norm(self): 79 | for l in self.convs: 80 | remove_weight_norm(l) 81 | 82 | 83 | class Generator(torch.nn.Module): 84 | def __init__(self, h): 85 | super(Generator, self).__init__() 86 | self.h = h 87 | self.num_kernels = len(h.resblock_kernel_sizes) 88 | self.num_upsamples = len(h.upsample_rates) 89 | self.conv_pre = weight_norm(Conv1d(h.feat_order, h.upsample_initial_channel, 7, 1, padding=3)) 90 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 91 | 92 | self.ups = nn.ModuleList() 93 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 94 | self.ups.append(weight_norm( 95 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 96 | k, u, padding=(k-u)//2))) 97 | 98 | self.resblocks = nn.ModuleList() 99 | for i in range(len(self.ups)): 100 | ch = h.upsample_initial_channel//(2**(i+1)) 101 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 102 | self.resblocks.append(resblock(h, ch, k, d)) 103 | 104 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 105 | self.ups.apply(init_weights) 106 | self.conv_post.apply(init_weights) 107 | 108 | def forward(self, x): 109 | x = self.conv_pre(x) 110 | for i in range(self.num_upsamples): 111 | x = F.leaky_relu(x, LRELU_SLOPE) 112 | x = self.ups[i](x) 113 | xs = None 114 | for j in range(self.num_kernels): 115 | if xs is None: 116 | xs = self.resblocks[i*self.num_kernels+j](x) 117 | else: 118 | xs += self.resblocks[i*self.num_kernels+j](x) 119 | x = xs / self.num_kernels 120 | x = F.leaky_relu(x) 121 | x = self.conv_post(x) 122 | x = torch.tanh(x) 123 | 124 | return x 125 | 126 | def remove_weight_norm(self): 127 | print('Removing weight norm...') 128 | for l in self.ups: 129 | remove_weight_norm(l) 130 | for l in self.resblocks: 131 | l.remove_weight_norm() 132 | remove_weight_norm(self.conv_pre) 133 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /imgs/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/ssl_speech_restoration/32c420a346c9169b710eb9bd33c0fe0462dd2ccb/imgs/method.jpg -------------------------------------------------------------------------------- /lightning_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torchaudio 4 | import os 5 | import pathlib 6 | import tqdm 7 | from model import ( 8 | EncoderModule, 9 | ChannelFeatureModule, 10 | ChannelModule, 11 | MultiScaleSpectralLoss, 12 | GSTModule, 13 | ) 14 | from utils import ( 15 | manual_logging, 16 | load_vocoder, 17 | plot_and_save_mels, 18 | plot_and_save_mels_all, 19 | ) 20 | 21 | 22 | class PretrainLightningModule(pl.LightningModule): 23 | """ 24 | Supervised pretraining for low-resource settings. 25 | This module provides supervised pretraining for the analysis and channel modules. 26 | """ 27 | 28 | def __init__(self, config): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | self.config = config 32 | if config["general"]["use_gst"]: 33 | self.encoder = EncoderModule(config) 34 | self.gst = GSTModule(config) 35 | else: 36 | self.encoder = EncoderModule(config, use_channel=True) 37 | self.channelfeats = ChannelFeatureModule(config) 38 | 39 | self.channel = ChannelModule(config) 40 | self.vocoder = load_vocoder(config) 41 | 42 | self.criteria_a = MultiScaleSpectralLoss(config) 43 | if "feature_loss" in config["train"]: 44 | if config["train"]["feature_loss"]["type"] == "mae": 45 | self.criteria_b = torch.nn.L1Loss() 46 | else: 47 | self.criteria_b = torch.nn.MSELoss() 48 | else: 49 | self.criteria = torch.nn.L1Loss() 50 | self.alpha = config["train"]["alpha"] 51 | 52 | def forward(self, melspecs, wavsaux): 53 | if self.config["general"]["use_gst"]: 54 | enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) 55 | chfeats = self.gst(melspecs.transpose(1, 2)) 56 | else: 57 | enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) 58 | chfeats = self.channelfeats(enc_hidden) 59 | enc_out = enc_out.squeeze(1).transpose(1, 2) 60 | wavsdeg = self.channel(wavsaux, chfeats) 61 | return enc_out, wavsdeg 62 | 63 | def training_step(self, batch, batch_idx): 64 | if self.config["general"]["use_gst"]: 65 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 66 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 67 | else: 68 | enc_out, enc_hidden = self.encoder( 69 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 70 | ) 71 | chfeats = self.channelfeats(enc_hidden) 72 | enc_out = enc_out.squeeze(1).transpose(1, 2) 73 | wavsdeg = self.channel(batch["wavsaux"], chfeats) 74 | loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) 75 | if self.config["general"]["feature_type"] == "melspec": 76 | loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"]) 77 | elif self.config["general"]["feature_type"] == "vocfeats": 78 | loss_encoder = self.criteria_b(enc_out, batch["melceps"]) 79 | loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder 80 | self.log( 81 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 82 | ) 83 | self.log( 84 | "train_loss_recons", 85 | loss_recons, 86 | on_step=True, 87 | on_epoch=True, 88 | prog_bar=True, 89 | logger=True, 90 | ) 91 | self.log( 92 | "train_loss_encoder", 93 | loss_encoder, 94 | on_step=True, 95 | on_epoch=True, 96 | prog_bar=True, 97 | logger=True, 98 | ) 99 | return loss 100 | 101 | def validation_step(self, batch, batch_idx): 102 | if self.config["general"]["use_gst"]: 103 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 104 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 105 | else: 106 | enc_out, enc_hidden = self.encoder( 107 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 108 | ) 109 | chfeats = self.channelfeats(enc_hidden) 110 | enc_out = enc_out.squeeze(1).transpose(1, 2) 111 | wavsdeg = self.channel(batch["wavsaux"], chfeats) 112 | loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) 113 | if self.config["general"]["feature_type"] == "melspec": 114 | val_aux_feats = batch["melspecsaux"] 115 | feats_name = "melspec" 116 | loss_encoder = self.criteria_b(enc_out, val_aux_feats) 117 | elif self.config["general"]["feature_type"] == "vocfeats": 118 | val_aux_feats = batch["melceps"] 119 | feats_name = "melcep" 120 | loss_encoder = self.criteria_b(enc_out, val_aux_feats) 121 | loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder 122 | logger_img_dict = { 123 | "val_src_melspec": batch["melspecs"], 124 | "val_pred_{}".format(feats_name): enc_out, 125 | "val_aux_{}".format(feats_name): val_aux_feats, 126 | } 127 | logger_wav_dict = { 128 | "val_src_wav": batch["wavs"], 129 | "val_pred_wav": wavsdeg, 130 | "val_aux_wav": batch["wavsaux"], 131 | } 132 | return { 133 | "val_loss": loss, 134 | "val_loss_recons": loss_recons, 135 | "val_loss_encoder": loss_encoder, 136 | "logger_dict": [logger_img_dict, logger_wav_dict], 137 | } 138 | 139 | def validation_epoch_end(self, outputs): 140 | val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item() 141 | val_loss_recons = ( 142 | torch.stack([out["val_loss_recons"] for out in outputs]).mean().item() 143 | ) 144 | val_loss_encoder = ( 145 | torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item() 146 | ) 147 | self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True) 148 | self.log( 149 | "val_loss_recons", 150 | val_loss_recons, 151 | on_epoch=True, 152 | prog_bar=True, 153 | logger=True, 154 | ) 155 | self.log( 156 | "val_loss_encoder", 157 | val_loss_encoder, 158 | on_epoch=True, 159 | prog_bar=True, 160 | logger=True, 161 | ) 162 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image") 163 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio") 164 | 165 | def test_step(self, batch, batch_idx): 166 | if self.config["general"]["use_gst"]: 167 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 168 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 169 | else: 170 | enc_out, enc_hidden = self.encoder( 171 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 172 | ) 173 | chfeats = self.channelfeats(enc_hidden) 174 | enc_out = enc_out.squeeze(1).transpose(1, 2) 175 | wavsdeg = self.channel(batch["wavsaux"], chfeats) 176 | if self.config["general"]["feature_type"] == "melspec": 177 | enc_feats = enc_out 178 | enc_feats_aux = batch["melspecsaux"] 179 | elif self.config["general"]["feature_type"] == "vocfeats": 180 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 181 | enc_feats_aux = torch.cat( 182 | (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 183 | ) 184 | recons_wav = self.vocoder(enc_feats_aux).squeeze(1) 185 | remas = self.vocoder(enc_feats).squeeze(1) 186 | if self.config["general"]["feature_type"] == "melspec": 187 | enc_feats_input = batch["melspecs"] 188 | elif self.config["general"]["feature_type"] == "vocfeats": 189 | enc_feats_input = torch.cat( 190 | (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 191 | ) 192 | input_recons = self.vocoder(enc_feats_input).squeeze(1) 193 | if "wavsaux" in batch: 194 | gt_wav = batch["wavsaux"] 195 | else: 196 | gt_wav = None 197 | return { 198 | "reconstructed": recons_wav, 199 | "remastered": remas, 200 | "channeled": wavsdeg, 201 | "groundtruth": gt_wav, 202 | "input": batch["wavs"], 203 | "input_recons": input_recons, 204 | } 205 | 206 | def test_epoch_end(self, outputs): 207 | wav_dir = ( 208 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" 209 | ) 210 | os.makedirs(wav_dir, exist_ok=True) 211 | mel_dir = ( 212 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" 213 | ) 214 | os.makedirs(mel_dir, exist_ok=True) 215 | print("Saving mel spectrogram plots ...") 216 | for idx, out in enumerate(tqdm.tqdm(outputs)): 217 | for key in [ 218 | "reconstructed", 219 | "remastered", 220 | "channeled", 221 | "input", 222 | "input_recons", 223 | "groundtruth", 224 | ]: 225 | if out[key] != None: 226 | torchaudio.save( 227 | wav_dir / "{}-{}.wav".format(idx, key), 228 | out[key][0, ...].unsqueeze(0).cpu(), 229 | sample_rate=self.config["preprocess"]["sampling_rate"], 230 | channels_first=True, 231 | ) 232 | plot_and_save_mels( 233 | out[key][0, ...].cpu(), 234 | mel_dir / "{}-{}.png".format(idx, key), 235 | self.config, 236 | ) 237 | plot_and_save_mels_all( 238 | out, 239 | [ 240 | "reconstructed", 241 | "remastered", 242 | "channeled", 243 | "input", 244 | "input_recons", 245 | "groundtruth", 246 | ], 247 | mel_dir / "{}-all.png".format(idx), 248 | self.config, 249 | ) 250 | 251 | def configure_optimizers(self): 252 | optimizer = torch.optim.Adam( 253 | self.parameters(), lr=self.config["train"]["learning_rate"] 254 | ) 255 | lr_scheduler_config = { 256 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 257 | optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True 258 | ), 259 | "interval": "epoch", 260 | "frequency": 3, 261 | "monitor": "val_loss", 262 | } 263 | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} 264 | 265 | def tflogger(self, logger_dict, data_type): 266 | for lg in self.logger.experiment: 267 | if type(lg).__name__ == "SummaryWriter": 268 | tensorboard = lg 269 | for key in logger_dict.keys(): 270 | manual_logging( 271 | logger=tensorboard, 272 | item=logger_dict[key], 273 | idx=0, 274 | tag=key, 275 | global_step=self.global_step, 276 | data_type=data_type, 277 | config=self.config, 278 | ) 279 | 280 | 281 | class SSLBaseModule(pl.LightningModule): 282 | def __init__(self, config): 283 | super().__init__() 284 | self.save_hyperparameters() 285 | self.config = config 286 | if config["general"]["use_gst"]: 287 | self.encoder = EncoderModule(config) 288 | self.gst = GSTModule(config) 289 | else: 290 | self.encoder = EncoderModule(config, use_channel=True) 291 | self.channelfeats = ChannelFeatureModule(config) 292 | self.channel = ChannelModule(config) 293 | 294 | if config["train"]["load_pretrained"]: 295 | pre_model = PretrainLightningModule.load_from_checkpoint( 296 | checkpoint_path=config["train"]["pretrained_path"] 297 | ) 298 | self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False) 299 | self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False) 300 | if config["general"]["use_gst"]: 301 | self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False) 302 | else: 303 | self.channelfeats.load_state_dict( 304 | pre_model.channelfeats.state_dict(), strict=False 305 | ) 306 | 307 | self.vocoder = load_vocoder(config) 308 | self.criteria = self.get_loss_function(config) 309 | 310 | def training_step(self, batch, batch_idx): 311 | raise NotImplementedError() 312 | 313 | def validation_step(self, batch, batch_idx): 314 | raise NotImplementedError() 315 | 316 | def validation_epoch_end(self, outputs): 317 | raise NotImplementedError() 318 | 319 | def configure_optimizers(self): 320 | raise NotImplementedError() 321 | 322 | def get_loss_function(self, config): 323 | raise NotImplementedError() 324 | 325 | def forward(self, melspecs, f0s=None): 326 | if self.config["general"]["use_gst"]: 327 | enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) 328 | chfeats = self.gst(melspecs.transpose(1, 2)) 329 | else: 330 | enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) 331 | chfeats = self.channelfeats(enc_hidden) 332 | enc_out = enc_out.squeeze(1).transpose(1, 2) 333 | if self.config["general"]["feature_type"] == "melspec": 334 | enc_feats = enc_out 335 | elif self.config["general"]["feature_type"] == "vocfeats": 336 | enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1) 337 | remas = self.vocoder(enc_feats).squeeze(1) 338 | wavsdeg = self.channel(remas, chfeats) 339 | return remas, wavsdeg 340 | 341 | def test_step(self, batch, batch_idx): 342 | if self.config["general"]["use_gst"]: 343 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 344 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 345 | else: 346 | enc_out, enc_hidden = self.encoder( 347 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 348 | ) 349 | chfeats = self.channelfeats(enc_hidden) 350 | enc_out = enc_out.squeeze(1).transpose(1, 2) 351 | if self.config["general"]["feature_type"] == "melspec": 352 | enc_feats = enc_out 353 | elif self.config["general"]["feature_type"] == "vocfeats": 354 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 355 | remas = self.vocoder(enc_feats).squeeze(1) 356 | wavsdeg = self.channel(remas, chfeats) 357 | if self.config["general"]["feature_type"] == "melspec": 358 | enc_feats_input = batch["melspecs"] 359 | elif self.config["general"]["feature_type"] == "vocfeats": 360 | enc_feats_input = torch.cat( 361 | (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 362 | ) 363 | input_recons = self.vocoder(enc_feats_input).squeeze(1) 364 | if "wavsaux" in batch: 365 | gt_wav = batch["wavsaux"] 366 | if self.config["general"]["feature_type"] == "melspec": 367 | enc_feats_aux = batch["melspecsaux"] 368 | elif self.config["general"]["feature_type"] == "vocfeats": 369 | enc_feats_aux = torch.cat( 370 | (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 371 | ) 372 | recons_wav = self.vocoder(enc_feats_aux).squeeze(1) 373 | else: 374 | gt_wav = None 375 | recons_wav = None 376 | return { 377 | "reconstructed": recons_wav, 378 | "remastered": remas, 379 | "channeled": wavsdeg, 380 | "input": batch["wavs"], 381 | "input_recons": input_recons, 382 | "groundtruth": gt_wav, 383 | } 384 | 385 | def test_epoch_end(self, outputs): 386 | wav_dir = ( 387 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" 388 | ) 389 | os.makedirs(wav_dir, exist_ok=True) 390 | mel_dir = ( 391 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" 392 | ) 393 | os.makedirs(mel_dir, exist_ok=True) 394 | print("Saving mel spectrogram plots ...") 395 | for idx, out in enumerate(tqdm.tqdm(outputs)): 396 | plot_keys = [] 397 | for key in [ 398 | "reconstructed", 399 | "remastered", 400 | "channeled", 401 | "input", 402 | "input_recons", 403 | "groundtruth", 404 | ]: 405 | if out[key] != None: 406 | plot_keys.append(key) 407 | torchaudio.save( 408 | wav_dir / "{}-{}.wav".format(idx, key), 409 | out[key][0, ...].unsqueeze(0).cpu(), 410 | sample_rate=self.config["preprocess"]["sampling_rate"], 411 | channels_first=True, 412 | ) 413 | plot_and_save_mels( 414 | out[key][0, ...].cpu(), 415 | mel_dir / "{}-{}.png".format(idx, key), 416 | self.config, 417 | ) 418 | plot_and_save_mels_all( 419 | out, 420 | plot_keys, 421 | mel_dir / "{}-all.png".format(idx), 422 | self.config, 423 | ) 424 | 425 | def tflogger(self, logger_dict, data_type): 426 | for lg in self.logger.experiment: 427 | if type(lg).__name__ == "SummaryWriter": 428 | tensorboard = lg 429 | for key in logger_dict.keys(): 430 | manual_logging( 431 | logger=tensorboard, 432 | item=logger_dict[key], 433 | idx=0, 434 | tag=key, 435 | global_step=self.global_step, 436 | data_type=data_type, 437 | config=self.config, 438 | ) 439 | 440 | 441 | class SSLStepLightningModule(SSLBaseModule): 442 | """ 443 | Self-supervised speech restoration model 444 | with fine-tuning supervisedly pretrained model, 445 | correspond to ``SSL-pre'' in the paper. 446 | 447 | This module provises step-wise learning, which only trains the channel module at early epochs 448 | and then only train the analysis module ar later epochs to stabilize the traininig. 449 | """ 450 | 451 | def __init__(self, config): 452 | super().__init__(config) 453 | if config["train"]["fix_channel"]: 454 | for param in self.channel.parameters(): 455 | param.requires_grad = False 456 | 457 | def training_step(self, batch, batch_idx, optimizer_idx): 458 | if self.config["general"]["use_gst"]: 459 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 460 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 461 | else: 462 | enc_out, enc_hidden = self.encoder( 463 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 464 | ) 465 | chfeats = self.channelfeats(enc_hidden) 466 | enc_out = enc_out.squeeze(1).transpose(1, 2) 467 | if self.config["general"]["feature_type"] == "melspec": 468 | enc_feats = enc_out 469 | elif self.config["general"]["feature_type"] == "vocfeats": 470 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 471 | remas = self.vocoder(enc_feats).squeeze(1) 472 | wavsdeg = self.channel(remas, chfeats) 473 | loss = self.criteria(wavsdeg, batch["wavs"]) 474 | self.log( 475 | "train_loss", 476 | loss, 477 | on_step=True, 478 | on_epoch=True, 479 | prog_bar=True, 480 | logger=True, 481 | ) 482 | return loss 483 | 484 | def validation_step(self, batch, batch_idx): 485 | if self.config["general"]["use_gst"]: 486 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 487 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 488 | else: 489 | enc_out, enc_hidden = self.encoder( 490 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 491 | ) 492 | chfeats = self.channelfeats(enc_hidden) 493 | enc_out = enc_out.squeeze(1).transpose(1, 2) 494 | if self.config["general"]["feature_type"] == "melspec": 495 | enc_feats = enc_out 496 | feats_name = "melspec" 497 | elif self.config["general"]["feature_type"] == "vocfeats": 498 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 499 | feats_name = "melcep" 500 | remas = self.vocoder(enc_feats).squeeze(1) 501 | wavsdeg = self.channel(remas, chfeats) 502 | loss = self.criteria(wavsdeg, batch["wavs"]) 503 | logger_img_dict = { 504 | "val_src_melspec": batch["melspecs"], 505 | "val_pred_{}".format(feats_name): enc_out, 506 | } 507 | for auxfeats in ["melceps", "melspecsaux"]: 508 | if auxfeats in batch: 509 | logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] 510 | logger_wav_dict = { 511 | "val_src_wav": batch["wavs"], 512 | "val_remastered_wav": remas, 513 | "val_pred_wav": wavsdeg, 514 | } 515 | if "wavsaux" in batch: 516 | logger_wav_dict["val_aux_wav"] = batch["wavsaux"] 517 | d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]} 518 | return d_out 519 | 520 | def validation_epoch_end(self, outputs): 521 | self.log( 522 | "val_loss", 523 | torch.stack([out["val_loss"] for out in outputs]).mean().item(), 524 | on_epoch=True, 525 | prog_bar=True, 526 | logger=True, 527 | ) 528 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image") 529 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio") 530 | 531 | def optimizer_step( 532 | self, 533 | epoch, 534 | batch_idx, 535 | optimizer, 536 | optimizer_idx, 537 | optimizer_closure, 538 | on_tpu=False, 539 | using_native_amp=False, 540 | using_lbfgs=False, 541 | ): 542 | if epoch < self.config["train"]["epoch_channel"]: 543 | if optimizer_idx == 0: 544 | optimizer.step(closure=optimizer_closure) 545 | elif optimizer_idx == 1: 546 | optimizer_closure() 547 | else: 548 | if optimizer_idx == 0: 549 | optimizer_closure() 550 | elif optimizer_idx == 1: 551 | optimizer.step(closure=optimizer_closure) 552 | 553 | def configure_optimizers(self): 554 | if self.config["train"]["fix_channel"]: 555 | if self.config["general"]["use_gst"]: 556 | optimizer_channel = torch.optim.Adam( 557 | self.gst.parameters(), lr=self.config["train"]["learning_rate"] 558 | ) 559 | else: 560 | optimizer_channel = torch.optim.Adam( 561 | self.channelfeats.parameters(), 562 | lr=self.config["train"]["learning_rate"], 563 | ) 564 | optimizer_encoder = torch.optim.Adam( 565 | self.encoder.parameters(), lr=self.config["train"]["learning_rate"] 566 | ) 567 | else: 568 | if self.config["general"]["use_gst"]: 569 | optimizer_channel = torch.optim.Adam( 570 | [ 571 | {"params": self.channel.parameters()}, 572 | {"params": self.gst.parameters()}, 573 | ], 574 | lr=self.config["train"]["learning_rate"], 575 | ) 576 | else: 577 | optimizer_channel = torch.optim.Adam( 578 | [ 579 | {"params": self.channel.parameters()}, 580 | {"params": self.channelfeats.parameters()}, 581 | ], 582 | lr=self.config["train"]["learning_rate"], 583 | ) 584 | optimizer_encoder = torch.optim.Adam( 585 | self.encoder.parameters(), lr=self.config["train"]["learning_rate"] 586 | ) 587 | optimizers = [optimizer_channel, optimizer_encoder] 588 | schedulers = [ 589 | { 590 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 591 | optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True 592 | ), 593 | "interval": "epoch", 594 | "frequency": 3, 595 | "monitor": "val_loss", 596 | }, 597 | { 598 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 599 | optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True 600 | ), 601 | "interval": "epoch", 602 | "frequency": 3, 603 | "monitor": "val_loss", 604 | }, 605 | ] 606 | return optimizers, schedulers 607 | 608 | def get_loss_function(self, config): 609 | return MultiScaleSpectralLoss(config) 610 | 611 | 612 | class SSLDualLightningModule(SSLBaseModule): 613 | """ 614 | Self-supervised speech restoration model with dual learning, 615 | correspond to ``SSL-dual'' or ``SSL-dual-pre'' in the paper. 616 | 617 | This module provises dual-learning. 618 | In addition to the basic training framework, we introduce a training task that 619 | propagates information in the reverse direction. 620 | """ 621 | 622 | def __init__(self, config): 623 | super().__init__(config) 624 | if config["train"]["fix_channel"]: 625 | for param in self.channel.parameters(): 626 | param.requires_grad = False 627 | self.spec_module = torchaudio.transforms.MelSpectrogram( 628 | sample_rate=config["preprocess"]["sampling_rate"], 629 | n_fft=config["preprocess"]["fft_length"], 630 | win_length=config["preprocess"]["frame_length"], 631 | hop_length=config["preprocess"]["frame_shift"], 632 | f_min=config["preprocess"]["fmin"], 633 | f_max=config["preprocess"]["fmax"], 634 | n_mels=config["preprocess"]["n_mels"], 635 | power=1, 636 | center=True, 637 | norm="slaney", 638 | mel_scale="slaney", 639 | ) 640 | self.beta = config["train"]["beta"] 641 | self.criteria_a, self.criteria_b = self.get_loss_function(config) 642 | 643 | def training_step(self, batch, batch_idx): 644 | if self.config["general"]["use_gst"]: 645 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 646 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 647 | else: 648 | enc_out, enc_hidden = self.encoder( 649 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 650 | ) 651 | chfeats = self.channelfeats(enc_hidden) 652 | enc_out = enc_out.squeeze(1).transpose(1, 2) 653 | if self.config["general"]["feature_type"] == "melspec": 654 | enc_feats = enc_out 655 | elif self.config["general"]["feature_type"] == "vocfeats": 656 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 657 | remas = self.vocoder(enc_feats).squeeze(1) 658 | wavsdeg = self.channel(remas, chfeats) 659 | loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) 660 | 661 | with torch.no_grad(): 662 | wavsdegtask = self.channel(batch["wavstask"], chfeats) 663 | melspecstask = self.calc_spectrogram(wavsdegtask) 664 | if self.config["general"]["use_gst"]: 665 | enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) 666 | else: 667 | enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) 668 | enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) 669 | if self.config["general"]["feature_type"] == "melspec": 670 | loss_task = self.criteria_b(enc_out_task, batch["melspecstask"]) 671 | elif self.config["general"]["feature_type"] == "vocfeats": 672 | loss_task = self.criteria_b(enc_out_task, batch["melcepstask"]) 673 | loss = self.beta * loss_recons + (1 - self.beta) * loss_task 674 | 675 | self.log( 676 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 677 | ) 678 | self.log( 679 | "train_loss_recons", 680 | loss_recons, 681 | on_step=True, 682 | on_epoch=True, 683 | prog_bar=True, 684 | logger=True, 685 | ) 686 | self.log( 687 | "train_loss_task", 688 | loss_task, 689 | on_step=True, 690 | on_epoch=True, 691 | prog_bar=True, 692 | logger=True, 693 | ) 694 | return loss 695 | 696 | def validation_step(self, batch, batch_idx): 697 | if self.config["general"]["use_gst"]: 698 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 699 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 700 | else: 701 | enc_out, enc_hidden = self.encoder( 702 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 703 | ) 704 | chfeats = self.channelfeats(enc_hidden) 705 | enc_out = enc_out.squeeze(1).transpose(1, 2) 706 | if self.config["general"]["feature_type"] == "melspec": 707 | enc_feats = enc_out 708 | feats_name = "melspec" 709 | elif self.config["general"]["feature_type"] == "vocfeats": 710 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 711 | feats_name = "melcep" 712 | remas = self.vocoder(enc_feats).squeeze(1) 713 | wavsdeg = self.channel(remas, chfeats) 714 | loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) 715 | 716 | wavsdegtask = self.channel(batch["wavstask"], chfeats) 717 | melspecstask = self.calc_spectrogram(wavsdegtask) 718 | if self.config["general"]["use_gst"]: 719 | enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) 720 | else: 721 | enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) 722 | enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) 723 | if self.config["general"]["feature_type"] == "melspec": 724 | enc_out_task_truth = batch["melspecstask"] 725 | loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) 726 | elif self.config["general"]["feature_type"] == "vocfeats": 727 | enc_out_task_truth = batch["melcepstask"] 728 | loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) 729 | loss = self.beta * loss_recons + (1 - self.beta) * loss_task 730 | 731 | logger_img_dict = { 732 | "val_src_melspec": batch["melspecs"], 733 | "val_pred_{}".format(feats_name): enc_out, 734 | "val_truth_{}_task".format(feats_name): enc_out_task_truth, 735 | "val_pred_{}_task".format(feats_name): enc_out_task, 736 | } 737 | for auxfeats in ["melceps", "melspecsaux"]: 738 | if auxfeats in batch: 739 | logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] 740 | logger_wav_dict = { 741 | "val_src_wav": batch["wavs"], 742 | "val_remastered_wav": remas, 743 | "val_pred_wav": wavsdeg, 744 | "val_truth_wavtask": batch["wavstask"], 745 | "val_deg_wavtask": wavsdegtask, 746 | } 747 | if "wavsaux" in batch: 748 | logger_wav_dict["val_aux_wav"] = batch["wavsaux"] 749 | 750 | d_out = { 751 | "val_loss": loss, 752 | "val_loss_recons": loss_recons, 753 | "val_loss_task": loss_task, 754 | "logger_dict": [logger_img_dict, logger_wav_dict], 755 | } 756 | return d_out 757 | 758 | def validation_epoch_end(self, outputs): 759 | self.log( 760 | "val_loss", 761 | torch.stack([out["val_loss"] for out in outputs]).mean().item(), 762 | on_epoch=True, 763 | prog_bar=True, 764 | logger=True, 765 | ) 766 | self.log( 767 | "val_loss_recons", 768 | torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(), 769 | on_epoch=True, 770 | prog_bar=True, 771 | logger=True, 772 | ) 773 | self.log( 774 | "val_loss_task", 775 | torch.stack([out["val_loss_task"] for out in outputs]).mean().item(), 776 | on_epoch=True, 777 | prog_bar=True, 778 | logger=True, 779 | ) 780 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image") 781 | self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio") 782 | 783 | def test_step(self, batch, batch_idx): 784 | if self.config["general"]["use_gst"]: 785 | enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) 786 | chfeats = self.gst(batch["melspecs"].transpose(1, 2)) 787 | else: 788 | enc_out, enc_hidden = self.encoder( 789 | batch["melspecs"].unsqueeze(1).transpose(2, 3) 790 | ) 791 | chfeats = self.channelfeats(enc_hidden) 792 | enc_out = enc_out.squeeze(1).transpose(1, 2) 793 | if self.config["general"]["feature_type"] == "melspec": 794 | enc_feats = enc_out 795 | elif self.config["general"]["feature_type"] == "vocfeats": 796 | enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) 797 | remas = self.vocoder(enc_feats).squeeze(1) 798 | wavsdeg = self.channel(remas, chfeats) 799 | if self.config["general"]["feature_type"] == "melspec": 800 | enc_feats_input = batch["melspecs"] 801 | elif self.config["general"]["feature_type"] == "vocfeats": 802 | enc_feats_input = torch.cat( 803 | (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 804 | ) 805 | input_recons = self.vocoder(enc_feats_input).squeeze(1) 806 | 807 | wavsdegtask = self.channel(batch["wavstask"], chfeats) 808 | if "wavsaux" in batch: 809 | gt_wav = batch["wavsaux"] 810 | if self.config["general"]["feature_type"] == "melspec": 811 | enc_feats_aux = batch["melspecsaux"] 812 | elif self.config["general"]["feature_type"] == "vocfeats": 813 | enc_feats_aux = torch.cat( 814 | (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 815 | ) 816 | recons_wav = self.vocoder(enc_feats_aux).squeeze(1) 817 | else: 818 | gt_wav = None 819 | recons_wav = None 820 | return { 821 | "reconstructed": recons_wav, 822 | "remastered": remas, 823 | "channeled": wavsdeg, 824 | "channeled_task": wavsdegtask, 825 | "input": batch["wavs"], 826 | "input_recons": input_recons, 827 | "groundtruth": gt_wav, 828 | } 829 | 830 | def test_epoch_end(self, outputs): 831 | wav_dir = ( 832 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" 833 | ) 834 | os.makedirs(wav_dir, exist_ok=True) 835 | mel_dir = ( 836 | pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" 837 | ) 838 | os.makedirs(mel_dir, exist_ok=True) 839 | print("Saving mel spectrogram plots ...") 840 | for idx, out in enumerate(tqdm.tqdm(outputs)): 841 | plot_keys = [] 842 | for key in [ 843 | "reconstructed", 844 | "remastered", 845 | "channeled", 846 | "channeled_task", 847 | "input", 848 | "input_recons", 849 | "groundtruth", 850 | ]: 851 | if out[key] != None: 852 | plot_keys.append(key) 853 | torchaudio.save( 854 | wav_dir / "{}-{}.wav".format(idx, key), 855 | out[key][0, ...].unsqueeze(0).cpu(), 856 | sample_rate=self.config["preprocess"]["sampling_rate"], 857 | channels_first=True, 858 | ) 859 | plot_and_save_mels( 860 | out[key][0, ...].cpu(), 861 | mel_dir / "{}-{}.png".format(idx, key), 862 | self.config, 863 | ) 864 | plot_and_save_mels_all( 865 | out, 866 | plot_keys, 867 | mel_dir / "{}-all.png".format(idx), 868 | self.config, 869 | ) 870 | 871 | def configure_optimizers(self): 872 | optimizer = torch.optim.Adam( 873 | self.parameters(), lr=self.config["train"]["learning_rate"] 874 | ) 875 | lr_scheduler_config = { 876 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 877 | optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True 878 | ), 879 | "interval": "epoch", 880 | "frequency": 3, 881 | "monitor": "val_loss", 882 | } 883 | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} 884 | 885 | def calc_spectrogram(self, wav): 886 | specs = self.spec_module(wav) 887 | log_spec = torch.log( 888 | torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"]) 889 | * self.config["preprocess"]["comp_factor"] 890 | ).to(torch.float32) 891 | return log_spec 892 | 893 | def get_loss_function(self, config): 894 | if config["train"]["feature_loss"]["type"] == "mae": 895 | feature_loss = torch.nn.L1Loss() 896 | else: 897 | feature_loss = torch.nn.MSELoss() 898 | return MultiScaleSpectralLoss(config), feature_loss 899 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import numpy as np 7 | 8 | 9 | class EncoderModule(nn.Module): 10 | """ 11 | Analysis module based on 2D conv U-Net 12 | Inspired by https://github.com/haoheliu/voicefixer 13 | 14 | Args: 15 | config (dict): config 16 | use_channel (bool): output channel feature or not 17 | """ 18 | 19 | def __init__(self, config, use_channel=False): 20 | super().__init__() 21 | 22 | self.channels = 1 23 | self.use_channel = use_channel 24 | self.downsample_ratio = 2 ** 4 25 | 26 | self.down_block1 = DownBlockRes2D( 27 | in_channels=self.channels, 28 | out_channels=32, 29 | downsample=(2, 2), 30 | activation="relu", 31 | momentum=0.01, 32 | ) 33 | self.down_block2 = DownBlockRes2D( 34 | in_channels=32, 35 | out_channels=64, 36 | downsample=(2, 2), 37 | activation="relu", 38 | momentum=0.01, 39 | ) 40 | self.down_block3 = DownBlockRes2D( 41 | in_channels=64, 42 | out_channels=128, 43 | downsample=(2, 2), 44 | activation="relu", 45 | momentum=0.01, 46 | ) 47 | self.down_block4 = DownBlockRes2D( 48 | in_channels=128, 49 | out_channels=256, 50 | downsample=(2, 2), 51 | activation="relu", 52 | momentum=0.01, 53 | ) 54 | self.conv_block5 = ConvBlockRes2D( 55 | in_channels=256, 56 | out_channels=256, 57 | size=3, 58 | activation="relu", 59 | momentum=0.01, 60 | ) 61 | self.up_block1 = UpBlockRes2D( 62 | in_channels=256, 63 | out_channels=256, 64 | stride=(2, 2), 65 | activation="relu", 66 | momentum=0.01, 67 | ) 68 | self.up_block2 = UpBlockRes2D( 69 | in_channels=256, 70 | out_channels=128, 71 | stride=(2, 2), 72 | activation="relu", 73 | momentum=0.01, 74 | ) 75 | self.up_block3 = UpBlockRes2D( 76 | in_channels=128, 77 | out_channels=64, 78 | stride=(2, 2), 79 | activation="relu", 80 | momentum=0.01, 81 | ) 82 | self.up_block4 = UpBlockRes2D( 83 | in_channels=64, 84 | out_channels=32, 85 | stride=(2, 2), 86 | activation="relu", 87 | momentum=0.01, 88 | ) 89 | 90 | self.after_conv_block1 = ConvBlockRes2D( 91 | in_channels=32, 92 | out_channels=32, 93 | size=3, 94 | activation="relu", 95 | momentum=0.01, 96 | ) 97 | 98 | self.after_conv2 = nn.Conv2d( 99 | in_channels=32, 100 | out_channels=1, 101 | kernel_size=(1, 1), 102 | stride=(1, 1), 103 | padding=(0, 0), 104 | bias=True, 105 | ) 106 | 107 | if config["general"]["feature_type"] == "melspec": 108 | out_dim = config["preprocess"]["n_mels"] 109 | elif config["general"]["feature_type"] == "vocfeats": 110 | out_dim = config["preprocess"]["cep_order"] + 1 111 | else: 112 | raise NotImplementedError() 113 | 114 | self.after_linear = nn.Linear( 115 | in_features=80, 116 | out_features=out_dim, 117 | bias=True, 118 | ) 119 | 120 | if self.use_channel: 121 | self.conv_channel = ConvBlockRes2D( 122 | in_channels=256, 123 | out_channels=256, 124 | size=3, 125 | activation="relu", 126 | momentum=0.01, 127 | ) 128 | 129 | def forward(self, x): 130 | """ 131 | Forward 132 | 133 | Args: 134 | mel spectrogram: (batch, 1, time, freq) 135 | 136 | Return: 137 | speech feature (mel spectrogram or mel cepstrum): (batch, 1, time, freq) 138 | input of channel feature module (batch, 256, time, freq) 139 | """ 140 | 141 | origin_len = x.shape[2] 142 | pad_len = ( 143 | int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio 144 | - origin_len 145 | ) 146 | x = F.pad(x, pad=(0, 0, 0, pad_len)) 147 | x = x[..., 0 : x.shape[-1] - 1] 148 | 149 | (x1_pool, x1) = self.down_block1(x) 150 | (x2_pool, x2) = self.down_block2(x1_pool) 151 | (x3_pool, x3) = self.down_block3(x2_pool) 152 | (x4_pool, x4) = self.down_block4(x3_pool) 153 | x_center = self.conv_block5(x4_pool) 154 | x5 = self.up_block1(x_center, x4) 155 | x6 = self.up_block2(x5, x3) 156 | x7 = self.up_block3(x6, x2) 157 | x8 = self.up_block4(x7, x1) 158 | x = self.after_conv_block1(x8) 159 | x = self.after_conv2(x) 160 | 161 | x = F.pad(x, pad=(0, 1)) 162 | x = x[:, :, 0:origin_len, :] 163 | 164 | x = self.after_linear(x) 165 | 166 | if self.use_channel: 167 | x_channel = self.conv_channel(x4_pool) 168 | return x, x_channel 169 | else: 170 | return x 171 | 172 | 173 | class ChannelModule(nn.Module): 174 | """ 175 | Channel module based on 1D conv U-Net 176 | 177 | Args: 178 | config (dict): config 179 | """ 180 | 181 | def __init__(self, config): 182 | super().__init__() 183 | 184 | self.channels = 1 185 | self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks} 186 | 187 | self.down_block1 = DownBlockRes1D( 188 | in_channels=self.channels, 189 | out_channels=32, 190 | downsample=2, 191 | activation="relu", 192 | momentum=0.01, 193 | ) 194 | self.down_block2 = DownBlockRes1D( 195 | in_channels=32, 196 | out_channels=64, 197 | downsample=2, 198 | activation="relu", 199 | momentum=0.01, 200 | ) 201 | self.down_block3 = DownBlockRes1D( 202 | in_channels=64, 203 | out_channels=128, 204 | downsample=2, 205 | activation="relu", 206 | momentum=0.01, 207 | ) 208 | self.down_block4 = DownBlockRes1D( 209 | in_channels=128, 210 | out_channels=256, 211 | downsample=2, 212 | activation="relu", 213 | momentum=0.01, 214 | ) 215 | self.down_block5 = DownBlockRes1D( 216 | in_channels=256, 217 | out_channels=512, 218 | downsample=2, 219 | activation="relu", 220 | momentum=0.01, 221 | ) 222 | self.conv_block6 = ConvBlockRes1D( 223 | in_channels=512, 224 | out_channels=384, 225 | size=3, 226 | activation="relu", 227 | momentum=0.01, 228 | ) 229 | self.up_block1 = UpBlockRes1D( 230 | in_channels=512, 231 | out_channels=512, 232 | stride=2, 233 | activation="relu", 234 | momentum=0.01, 235 | ) 236 | self.up_block2 = UpBlockRes1D( 237 | in_channels=512, 238 | out_channels=256, 239 | stride=2, 240 | activation="relu", 241 | momentum=0.01, 242 | ) 243 | self.up_block3 = UpBlockRes1D( 244 | in_channels=256, 245 | out_channels=128, 246 | stride=2, 247 | activation="relu", 248 | momentum=0.01, 249 | ) 250 | self.up_block4 = UpBlockRes1D( 251 | in_channels=128, 252 | out_channels=64, 253 | stride=2, 254 | activation="relu", 255 | momentum=0.01, 256 | ) 257 | self.up_block5 = UpBlockRes1D( 258 | in_channels=64, 259 | out_channels=32, 260 | stride=2, 261 | activation="relu", 262 | momentum=0.01, 263 | ) 264 | 265 | self.after_conv_block1 = ConvBlockRes1D( 266 | in_channels=32, 267 | out_channels=32, 268 | size=3, 269 | activation="relu", 270 | momentum=0.01, 271 | ) 272 | 273 | self.after_conv2 = nn.Conv1d( 274 | in_channels=32, 275 | out_channels=1, 276 | kernel_size=1, 277 | stride=1, 278 | padding=0, 279 | bias=True, 280 | ) 281 | 282 | def forward(self, x, h): 283 | """ 284 | Forward 285 | 286 | Args: 287 | clean waveform: (batch, n_channel (1), time) 288 | channel feature: (batch, feature_dim) 289 | Outputs: 290 | degraded waveform: (batch, n_channel (1), time) 291 | """ 292 | x = x.unsqueeze(1) 293 | 294 | origin_len = x.shape[2] 295 | pad_len = ( 296 | int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio 297 | - origin_len 298 | ) 299 | x = F.pad(x, pad=(0, pad_len)) 300 | x = x[..., 0 : x.shape[-1] - 1] 301 | 302 | (x1_pool, x1) = self.down_block1(x) 303 | (x2_pool, x2) = self.down_block2(x1_pool) 304 | (x3_pool, x3) = self.down_block3(x2_pool) 305 | (x4_pool, x4) = self.down_block4(x3_pool) 306 | (x5_pool, x5) = self.down_block5(x4_pool) 307 | x_center = self.conv_block6(x5_pool) 308 | x_concat = torch.cat( 309 | (x_center, h.unsqueeze(2).expand(-1, -1, x_center.size(2))), dim=1 310 | ) 311 | x6 = self.up_block1(x_concat, x5) 312 | x7 = self.up_block2(x6, x4) 313 | x8 = self.up_block3(x7, x3) 314 | x9 = self.up_block4(x8, x2) 315 | x10 = self.up_block5(x9, x1) 316 | x = self.after_conv_block1(x10) 317 | x = self.after_conv2(x) 318 | 319 | x = F.pad(x, pad=(0, 1)) 320 | x = x[..., 0:origin_len] 321 | 322 | return x.squeeze(1) 323 | 324 | 325 | class ChannelFeatureModule(nn.Module): 326 | """ 327 | Channel feature module based on 2D convolution layers 328 | 329 | Args: 330 | config (dict): config 331 | """ 332 | 333 | def __init__(self, config): 334 | super().__init__() 335 | self.conv_blocks_in = ConvBlockRes2D( 336 | in_channels=256, 337 | out_channels=512, 338 | size=3, 339 | activation="relu", 340 | momentum=0.01, 341 | ) 342 | self.down_block1 = DownBlockRes2D( 343 | in_channels=512, 344 | out_channels=256, 345 | downsample=(2, 2), 346 | activation="relu", 347 | momentum=0.01, 348 | ) 349 | self.down_block2 = DownBlockRes2D( 350 | in_channels=256, 351 | out_channels=256, 352 | downsample=(2, 2), 353 | activation="relu", 354 | momentum=0.01, 355 | ) 356 | self.conv_block_out = ConvBlockRes2D( 357 | in_channels=256, 358 | out_channels=128, 359 | size=3, 360 | activation="relu", 361 | momentum=0.01, 362 | ) 363 | self.avgpool2d = torch.nn.AdaptiveAvgPool2d(1) 364 | 365 | def forward(self, x): 366 | """ 367 | Forward 368 | 369 | Args: 370 | output of analysis module: (batch, 256, time, freq) 371 | 372 | Return: 373 | channel feature: (batch, feature_dim) 374 | """ 375 | x = self.conv_blocks_in(x) 376 | x, _ = self.down_block1(x) 377 | x, _ = self.down_block2(x) 378 | x = self.conv_block_out(x) 379 | x = self.avgpool2d(x) 380 | x = x.squeeze(3).squeeze(2) 381 | return x 382 | 383 | 384 | class ConvBlockRes2D(nn.Module): 385 | def __init__(self, in_channels, out_channels, size, activation, momentum): 386 | super().__init__() 387 | 388 | self.activation = activation 389 | if type(size) == type((3, 4)): 390 | pad = size[0] // 2 391 | size = size[0] 392 | else: 393 | pad = size // 2 394 | size = size 395 | 396 | self.conv1 = nn.Conv2d( 397 | in_channels=in_channels, 398 | out_channels=out_channels, 399 | kernel_size=(size, size), 400 | stride=(1, 1), 401 | dilation=(1, 1), 402 | padding=(pad, pad), 403 | bias=False, 404 | ) 405 | 406 | self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) 407 | 408 | self.conv2 = nn.Conv2d( 409 | in_channels=out_channels, 410 | out_channels=out_channels, 411 | kernel_size=(size, size), 412 | stride=(1, 1), 413 | dilation=(1, 1), 414 | padding=(pad, pad), 415 | bias=False, 416 | ) 417 | 418 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) 419 | 420 | if in_channels != out_channels: 421 | self.shortcut = nn.Conv2d( 422 | in_channels=in_channels, 423 | out_channels=out_channels, 424 | kernel_size=(1, 1), 425 | stride=(1, 1), 426 | padding=(0, 0), 427 | ) 428 | self.is_shortcut = True 429 | else: 430 | self.is_shortcut = False 431 | 432 | def forward(self, x): 433 | origin = x 434 | x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) 435 | x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01)) 436 | 437 | if self.is_shortcut: 438 | return self.shortcut(origin) + x 439 | else: 440 | return origin + x 441 | 442 | 443 | class ConvBlockRes1D(nn.Module): 444 | def __init__(self, in_channels, out_channels, size, activation, momentum): 445 | super().__init__() 446 | 447 | self.activation = activation 448 | pad = size // 2 449 | 450 | self.conv1 = nn.Conv1d( 451 | in_channels=in_channels, 452 | out_channels=out_channels, 453 | kernel_size=size, 454 | stride=1, 455 | dilation=1, 456 | padding=pad, 457 | bias=False, 458 | ) 459 | 460 | self.bn1 = nn.BatchNorm1d(in_channels, momentum=momentum) 461 | 462 | self.conv2 = nn.Conv1d( 463 | in_channels=out_channels, 464 | out_channels=out_channels, 465 | kernel_size=size, 466 | stride=1, 467 | dilation=1, 468 | padding=pad, 469 | bias=False, 470 | ) 471 | 472 | self.bn2 = nn.BatchNorm1d(out_channels, momentum=momentum) 473 | 474 | if in_channels != out_channels: 475 | self.shortcut = nn.Conv1d( 476 | in_channels=in_channels, 477 | out_channels=out_channels, 478 | kernel_size=1, 479 | stride=1, 480 | padding=0, 481 | ) 482 | self.is_shortcut = True 483 | else: 484 | self.is_shortcut = False 485 | 486 | def forward(self, x): 487 | origin = x 488 | x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) 489 | x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01)) 490 | 491 | if self.is_shortcut: 492 | return self.shortcut(origin) + x 493 | else: 494 | return origin + x 495 | 496 | 497 | class DownBlockRes2D(nn.Module): 498 | def __init__(self, in_channels, out_channels, downsample, activation, momentum): 499 | super().__init__() 500 | size = 3 501 | 502 | self.conv_block1 = ConvBlockRes2D( 503 | in_channels, out_channels, size, activation, momentum 504 | ) 505 | self.conv_block2 = ConvBlockRes2D( 506 | out_channels, out_channels, size, activation, momentum 507 | ) 508 | self.conv_block3 = ConvBlockRes2D( 509 | out_channels, out_channels, size, activation, momentum 510 | ) 511 | self.conv_block4 = ConvBlockRes2D( 512 | out_channels, out_channels, size, activation, momentum 513 | ) 514 | self.avg_pool2d = torch.nn.AvgPool2d(downsample) 515 | 516 | def forward(self, x): 517 | encoder = self.conv_block1(x) 518 | encoder = self.conv_block2(encoder) 519 | encoder = self.conv_block3(encoder) 520 | encoder = self.conv_block4(encoder) 521 | encoder_pool = self.avg_pool2d(encoder) 522 | return encoder_pool, encoder 523 | 524 | 525 | class DownBlockRes1D(nn.Module): 526 | def __init__(self, in_channels, out_channels, downsample, activation, momentum): 527 | super().__init__() 528 | size = 3 529 | 530 | self.conv_block1 = ConvBlockRes1D( 531 | in_channels, out_channels, size, activation, momentum 532 | ) 533 | self.conv_block2 = ConvBlockRes1D( 534 | out_channels, out_channels, size, activation, momentum 535 | ) 536 | self.conv_block3 = ConvBlockRes1D( 537 | out_channels, out_channels, size, activation, momentum 538 | ) 539 | self.conv_block4 = ConvBlockRes1D( 540 | out_channels, out_channels, size, activation, momentum 541 | ) 542 | self.avg_pool1d = torch.nn.AvgPool1d(downsample) 543 | 544 | def forward(self, x): 545 | encoder = self.conv_block1(x) 546 | encoder = self.conv_block2(encoder) 547 | encoder = self.conv_block3(encoder) 548 | encoder = self.conv_block4(encoder) 549 | encoder_pool = self.avg_pool1d(encoder) 550 | return encoder_pool, encoder 551 | 552 | 553 | class UpBlockRes2D(nn.Module): 554 | def __init__(self, in_channels, out_channels, stride, activation, momentum): 555 | super().__init__() 556 | size = 3 557 | self.activation = activation 558 | 559 | self.conv1 = torch.nn.ConvTranspose2d( 560 | in_channels=in_channels, 561 | out_channels=out_channels, 562 | kernel_size=(size, size), 563 | stride=stride, 564 | padding=(0, 0), 565 | output_padding=(0, 0), 566 | bias=False, 567 | dilation=(1, 1), 568 | ) 569 | 570 | self.bn1 = nn.BatchNorm2d(in_channels) 571 | self.conv_block2 = ConvBlockRes2D( 572 | out_channels * 2, out_channels, size, activation, momentum 573 | ) 574 | self.conv_block3 = ConvBlockRes2D( 575 | out_channels, out_channels, size, activation, momentum 576 | ) 577 | self.conv_block4 = ConvBlockRes2D( 578 | out_channels, out_channels, size, activation, momentum 579 | ) 580 | self.conv_block5 = ConvBlockRes2D( 581 | out_channels, out_channels, size, activation, momentum 582 | ) 583 | 584 | def prune(self, x, both=False): 585 | """Prune the shape of x after transpose convolution.""" 586 | if both: 587 | x = x[:, :, 0:-1, 0:-1] 588 | else: 589 | x = x[:, :, 0:-1, :] 590 | return x 591 | 592 | def forward(self, input_tensor, concat_tensor, both=False): 593 | x = self.conv1(F.relu_(self.bn1(input_tensor))) 594 | x = self.prune(x, both=both) 595 | x = torch.cat((x, concat_tensor), dim=1) 596 | x = self.conv_block2(x) 597 | x = self.conv_block3(x) 598 | x = self.conv_block4(x) 599 | x = self.conv_block5(x) 600 | return x 601 | 602 | 603 | class UpBlockRes1D(nn.Module): 604 | def __init__(self, in_channels, out_channels, stride, activation, momentum): 605 | super().__init__() 606 | size = 3 607 | self.activation = activation 608 | 609 | self.conv1 = torch.nn.ConvTranspose1d( 610 | in_channels=in_channels, 611 | out_channels=out_channels, 612 | kernel_size=size, 613 | stride=stride, 614 | padding=0, 615 | output_padding=0, 616 | bias=False, 617 | dilation=1, 618 | ) 619 | 620 | self.bn1 = nn.BatchNorm1d(in_channels) 621 | self.conv_block2 = ConvBlockRes1D( 622 | out_channels * 2, out_channels, size, activation, momentum 623 | ) 624 | self.conv_block3 = ConvBlockRes1D( 625 | out_channels, out_channels, size, activation, momentum 626 | ) 627 | self.conv_block4 = ConvBlockRes1D( 628 | out_channels, out_channels, size, activation, momentum 629 | ) 630 | self.conv_block5 = ConvBlockRes1D( 631 | out_channels, out_channels, size, activation, momentum 632 | ) 633 | 634 | def prune(self, x): 635 | """Prune the shape of x after transpose convolution.""" 636 | print(x.shape) 637 | x = x[:, 0:-1, :] 638 | print(x.shape) 639 | return x 640 | 641 | def forward(self, input_tensor, concat_tensor): 642 | x = self.conv1(F.relu_(self.bn1(input_tensor))) 643 | # x = self.prune(x) 644 | x = torch.cat((x, concat_tensor), dim=1) 645 | x = self.conv_block2(x) 646 | x = self.conv_block3(x) 647 | x = self.conv_block4(x) 648 | x = self.conv_block5(x) 649 | return x 650 | 651 | 652 | class MultiScaleSpectralLoss(nn.Module): 653 | """ 654 | Multi scale spectral loss 655 | https://openreview.net/forum?id=B1x1ma4tDr 656 | 657 | Args: 658 | config (dict): config 659 | """ 660 | 661 | def __init__(self, config): 662 | super().__init__() 663 | try: 664 | self.use_linear = config["train"]["multi_scale_loss"]["use_linear"] 665 | self.gamma = config["train"]["multi_scale_loss"]["gamma"] 666 | except KeyError: 667 | self.use_linear = False 668 | 669 | self.fft_sizes = [2048, 512, 256, 128, 64] 670 | self.spectrograms = [] 671 | for fftsize in self.fft_sizes: 672 | self.spectrograms.append( 673 | torchaudio.transforms.Spectrogram( 674 | n_fft=fftsize, hop_length=fftsize // 4, power=2 675 | ) 676 | ) 677 | self.spectrograms = nn.ModuleList(self.spectrograms) 678 | self.criteria = nn.L1Loss() 679 | self.eps = 1e-10 680 | 681 | def forward(self, wav_out, wav_target): 682 | """ 683 | Forward 684 | 685 | Args: 686 | wav_out: output of channel module (batch, time) 687 | wav_target: input degraded waveform (batch, time) 688 | 689 | Return: 690 | loss 691 | """ 692 | loss = 0.0 693 | length = min(wav_out.size(1), wav_target.size(1)) 694 | for spectrogram in self.spectrograms: 695 | S_out = spectrogram(wav_out[..., :length]) 696 | S_target = spectrogram(wav_target[..., :length]) 697 | log_S_out = torch.log(S_out + self.eps) 698 | log_S_target = torch.log(S_target + self.eps) 699 | if self.use_linear: 700 | loss += self.criteria(S_out, S_target) + self.gamma * self.criteria( 701 | log_S_out, log_S_target 702 | ) 703 | else: 704 | loss += self.criteria(log_S_out, log_S_target) 705 | return loss 706 | 707 | 708 | class ReferenceEncoder(nn.Module): 709 | def __init__( 710 | self, idim=80, ref_enc_filters=[32, 32, 64, 64, 128, 128], ref_dim=128 711 | ): 712 | super().__init__() 713 | K = len(ref_enc_filters) 714 | filters = [1] + ref_enc_filters 715 | 716 | convs = [ 717 | nn.Conv2d( 718 | in_channels=filters[i], 719 | out_channels=filters[i + 1], 720 | kernel_size=(3, 3), 721 | stride=(2, 2), 722 | padding=(1, 1), 723 | ) 724 | for i in range(K) 725 | ] 726 | self.convs = nn.ModuleList(convs) 727 | self.bns = nn.ModuleList( 728 | [nn.BatchNorm2d(num_features=ref_enc_filters[i]) for i in range(K)] 729 | ) 730 | 731 | out_channels = self.calculate_channels(idim, 3, 2, 1, K) 732 | 733 | self.gru = nn.GRU( 734 | input_size=ref_enc_filters[-1] * out_channels, 735 | hidden_size=ref_dim, 736 | batch_first=True, 737 | ) 738 | self.n_mel_channels = idim 739 | 740 | def forward(self, inputs): 741 | 742 | out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels) 743 | for conv, bn in zip(self.convs, self.bns): 744 | out = conv(out) 745 | out = bn(out) 746 | out = F.relu(out) 747 | 748 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] 749 | N, T = out.size(0), out.size(1) 750 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] 751 | 752 | self.gru.flatten_parameters() 753 | 754 | _, out = self.gru(out) 755 | 756 | return out.squeeze(0) 757 | 758 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs): 759 | for _ in range(n_convs): 760 | L = (L - kernel_size + 2 * pad) // stride + 1 761 | return L 762 | 763 | 764 | class STL(nn.Module): 765 | def __init__(self, ref_dim=128, num_heads=4, token_num=10, token_dim=128): 766 | super().__init__() 767 | self.embed = nn.Parameter(torch.FloatTensor(token_num, token_dim // num_heads)) 768 | d_q = ref_dim 769 | d_k = token_dim // num_heads 770 | self.attention = MultiHeadAttention( 771 | query_dim=d_q, key_dim=d_k, num_units=token_dim, num_heads=num_heads 772 | ) 773 | init.normal_(self.embed, mean=0, std=0.5) 774 | 775 | def forward(self, inputs): 776 | N = inputs.size(0) 777 | query = inputs.unsqueeze(1) 778 | keys = ( 779 | torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) 780 | ) # [N, token_num, token_embedding_size // num_heads] 781 | style_embed = self.attention(query, keys) 782 | return style_embed 783 | 784 | 785 | class MultiHeadAttention(nn.Module): 786 | """ 787 | Multi head attention 788 | https://github.com/KinglittleQ/GST-Tacotron 789 | 790 | """ 791 | 792 | def __init__(self, query_dim, key_dim, num_units, num_heads): 793 | super().__init__() 794 | self.num_units = num_units 795 | self.num_heads = num_heads 796 | self.key_dim = key_dim 797 | 798 | self.W_query = nn.Linear( 799 | in_features=query_dim, out_features=num_units, bias=False 800 | ) 801 | self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) 802 | self.W_value = nn.Linear( 803 | in_features=key_dim, out_features=num_units, bias=False 804 | ) 805 | 806 | def forward(self, query, key): 807 | """ 808 | Forward 809 | 810 | Args: 811 | query: (batch, T_q, query_dim) 812 | key: (batch, T_k, key_dim) 813 | 814 | Return: 815 | out: (N, T_q, num_units) 816 | """ 817 | querys = self.W_query(query) # [N, T_q, num_units] 818 | 819 | keys = self.W_key(key) # [N, T_k, num_units] 820 | values = self.W_value(key) 821 | 822 | split_size = self.num_units // self.num_heads 823 | querys = torch.stack( 824 | torch.split(querys, split_size, dim=2), dim=0 825 | ) # [h, N, T_q, num_units/h] 826 | keys = torch.stack( 827 | torch.split(keys, split_size, dim=2), dim=0 828 | ) # [h, N, T_k, num_units/h] 829 | values = torch.stack( 830 | torch.split(values, split_size, dim=2), dim=0 831 | ) # [h, N, T_k, num_units/h] 832 | 833 | # score = softmax(QK^T / (d_k ** 0.5)) 834 | scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] 835 | scores = scores / (self.key_dim ** 0.5) 836 | scores = F.softmax(scores, dim=3) 837 | 838 | # out = score * V 839 | out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] 840 | out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze( 841 | 0 842 | ) # [N, T_q, num_units] 843 | 844 | return out 845 | 846 | 847 | class GSTModule(nn.Module): 848 | def __init__(self, config): 849 | super().__init__() 850 | self.encoder_post = ReferenceEncoder( 851 | idim=config["preprocess"]["n_mels"], 852 | ref_dim=256, 853 | ) 854 | self.stl = STL(ref_dim=256, num_heads=8, token_num=10, token_dim=128) 855 | 856 | def forward(self, inputs): 857 | acoustic_embed = self.encoder_post(inputs) 858 | style_embed = self.stl(acoustic_embed) 859 | return style_embed.squeeze(1) 860 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import librosa 4 | import tqdm 5 | import pickle 6 | import random 7 | import argparse 8 | import yaml 9 | import pathlib 10 | 11 | 12 | def get_arg(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--config_path", required=True, type=pathlib.Path) 15 | parser.add_argument("--corpus_type", default=None, type=str) 16 | parser.add_argument("--source_path", default=None, type=pathlib.Path) 17 | parser.add_argument("--source_path_task", default=None, type=pathlib.Path) 18 | parser.add_argument("--aux_path", default=None, type=pathlib.Path) 19 | parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path) 20 | parser.add_argument("--n_train", default=None, type=int) 21 | parser.add_argument("--n_val", default=None, type=int) 22 | parser.add_argument("--n_test", default=None, type=int) 23 | return parser.parse_args() 24 | 25 | 26 | def preprocess(config): 27 | 28 | # configs 29 | preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"]) 30 | n_train = config["preprocess"]["n_train"] 31 | n_val = config["preprocess"]["n_val"] 32 | n_test = config["preprocess"]["n_test"] 33 | SR = config["preprocess"]["sampling_rate"] 34 | 35 | os.makedirs(preprocessed_dir, exist_ok=True) 36 | 37 | sourcepath = pathlib.Path(config["general"]["source_path"]).resolve() 38 | 39 | if config["general"]["corpus_type"] == "single": 40 | fulllist = list(sourcepath.glob("*.wav")) 41 | random.seed(0) 42 | random.shuffle(fulllist) 43 | train_filelist = fulllist[:n_train] 44 | val_filelist = fulllist[n_train : n_train + n_val] 45 | test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test] 46 | filelist = train_filelist + val_filelist + test_filelist 47 | elif config["general"]["corpus_type"] == "multi-seen": 48 | fulllist = list(sourcepath.glob("*/*.wav")) 49 | random.seed(0) 50 | random.shuffle(fulllist) 51 | train_filelist = fulllist[:n_train] 52 | val_filelist = fulllist[n_train : n_train + n_val] 53 | test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test] 54 | filelist = train_filelist + val_filelist + test_filelist 55 | elif config["general"]["corpus_type"] == "multi-unseen": 56 | spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")])) 57 | train_filelist = [] 58 | val_filelist = [] 59 | test_filelist = [] 60 | random.seed(0) 61 | random.shuffle(spk_list) 62 | for i, spk in enumerate(spk_list): 63 | sourcespkpath = sourcepath / spk 64 | if i < n_train: 65 | train_filelist.extend(list(sourcespkpath.glob("*.wav"))) 66 | elif i < n_train + n_val: 67 | val_filelist.extend(list(sourcespkpath.glob("*.wav"))) 68 | elif i < n_train + n_val + n_test: 69 | test_filelist.extend(list(sourcespkpath.glob("*.wav"))) 70 | filelist = train_filelist + val_filelist + test_filelist 71 | else: 72 | raise NotImplementedError( 73 | "corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}" 74 | ) 75 | 76 | with open(preprocessed_dir / "train.txt", "w", encoding="utf-8") as f: 77 | for m in train_filelist: 78 | f.write(str(m) + "\n") 79 | with open(preprocessed_dir / "val.txt", "w", encoding="utf-8") as f: 80 | for m in val_filelist: 81 | f.write(str(m) + "\n") 82 | with open(preprocessed_dir / "test.txt", "w", encoding="utf-8") as f: 83 | for m in test_filelist: 84 | f.write(str(m) + "\n") 85 | 86 | for wp in tqdm.tqdm(filelist): 87 | 88 | if config["general"]["corpus_type"] == "single": 89 | basename = str(wp.stem) 90 | else: 91 | basename = str(wp.parent.name) + "-" + str(wp.stem) 92 | 93 | wav, _ = librosa.load(wp, sr=SR) 94 | wavsegs = [] 95 | 96 | if config["general"]["aux_path"] != None: 97 | auxpath = pathlib.Path(config["general"]["aux_path"]) 98 | if config["general"]["corpus_type"] == "single": 99 | wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR) 100 | else: 101 | wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR) 102 | wavauxsegs = [] 103 | 104 | if config["general"]["aux_path"] == None: 105 | wavsegs.append(wav) 106 | else: 107 | min_seq_len = min(len(wav), len(wav_aux)) 108 | wav = wav[:min_seq_len] 109 | wav_aux = wav_aux[:min_seq_len] 110 | wavsegs.append(wav) 111 | wavauxsegs.append(wav_aux) 112 | 113 | wavsegs = np.asarray(wavsegs).astype(np.float32) 114 | if config["general"]["aux_path"] != None: 115 | wavauxsegs = np.asarray(wavauxsegs).astype(np.float32) 116 | else: 117 | wavauxsegs = None 118 | 119 | d_preprocessed = {"wavs": wavsegs, "wavsaux": wavauxsegs} 120 | 121 | with open(preprocessed_dir / "{}.pickle".format(basename), "wb") as fw: 122 | pickle.dump(d_preprocessed, fw) 123 | 124 | 125 | if __name__ == "__main__": 126 | args = get_arg() 127 | 128 | config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) 129 | for key in ["corpus_type", "source_path", "aux_path", "preprocessed_path"]: 130 | if getattr(args, key) != None: 131 | config["general"][key] = str(getattr(args, key)) 132 | for key in ["n_train", "n_val", "n_test"]: 133 | if getattr(args, key) != None: 134 | config["preprocess"][key] = getattr(args, key) 135 | 136 | print("Performing preprocessing ...") 137 | preprocess(config) 138 | 139 | if "dual" in config: 140 | if config["dual"]["enable"]: 141 | task_config = yaml.load( 142 | open(config["dual"]["config_path"], "r"), Loader=yaml.FullLoader 143 | ) 144 | task_preprocessed_dir = ( 145 | pathlib.Path(config["general"]["preprocessed_path"]).parent 146 | / pathlib.Path(task_config["general"]["preprocessed_path"]).name 147 | ) 148 | task_config["general"]["preprocessed_path"] = task_preprocessed_dir 149 | if args.source_path_task != None: 150 | task_config["general"]["source_path"] = args.source_path_task 151 | print("Performing preprocessing for dual learning ...") 152 | preprocess(task_config) 153 | -------------------------------------------------------------------------------- /pretrained_models.md: -------------------------------------------------------------------------------- 1 | # Pretrained models 2 | 3 | ### Pretrained HiFi-GAN with SourceFilter features 4 | 5 | HiFi-GAN-based synthethis modules to synthesize waveform from source-filter vocoder features trained on JVS or VCTK. 6 | Scripts for training are available in [another repo](https://github.com/Takaaki-Saeki/hifi-gan/tree/voc_feat). 7 | `hifigan_jvs_40d_600k` is used in the default configuration. 8 | 9 | |Name|Feature|Dataset|Iteration|Link| 10 | |------|---|---|---|---| 11 | |hifigan_jvs_40d_600k|40-D Melcep. + F0 (WORLD)|JVS|600K|[Download](https://drive.google.com/file/d/1lkvtAJ3xTny5qmxyVcNPWQRB9MjdPlAY/view?usp=sharing)| 12 | |hifigan_jvs_40d_1000k|40-D Melcep. + F0 (WORLD)|JVS|1000K|[Download](https://drive.google.com/file/d/1ZJbhWeAgs0RhoZ41puIKRFuioKyw0g8q/view?usp=sharing)| 13 | |hifigan_vctk_40d_600k|40-D Melcep. + F0 (WORLD)|VCTK|600K|[Download](https://drive.google.com/file/d/1SnzZNt25eOCrrcMzF9KUjZ2kqVKE_cWf/view?usp=sharing)| 14 | |hifigan_vctk-jvs_40d_400k|40-D Melcep. + F0 (WORLD)|JVS+VCTK|400K|[Download](https://drive.google.com/file/d/1I4HlZZZIXleKy7YtJP55MKpckwyvjZKm/view?usp=sharing)| 15 | |hifigan_vctk-jvs_60d_400k|60-D Melcep. + F0 (WORLD)|JVS+VCTK|400K|[Download](https://drive.google.com/file/d/1kBJoTGgVSpGRkuEccZyJYiTLIim9kc48/view?usp=sharing)| 16 | 17 | ### SSL pretarined models for speech restoration 18 | 19 | Speech restoration models trained on simulated data. 20 | 21 | |Name|Dataset|Distortion|Feature|Link| 22 | |------|---|---|---|---| 23 | |jsut-bandlimited_melspec.ckpt|JSUT Baseic5000|Bandlimited|MelSpec|[Download](https://drive.google.com/file/d/1KwCEZ7pmfP__MjlE1sINnKs0Njw311m0/view?usp=sharing)| 24 | |jsut-bandlimited_vocfeats.ckpt|JSUT Baseic5000|Bandlimited|SourceFilter|[Download](https://drive.google.com/file/d/1MB3FqxAHbDOWICib5tlEQ4DFM2e_7oJD/view?usp=sharing)| 25 | |jsut-clip_melspec.ckpt|JSUT Baseic5000|Clipping|MelSpec|[Download](https://drive.google.com/file/d/19IkXv3rOwOeJ6TFNRp-x-cM4UWzNt6Ud/view?usp=sharing)| 26 | |jsut-clip_vocfeats.ckpt|JSUT Baseic5000|Clipping|SourceFilter|[Download](https://drive.google.com/file/d/1_xfJqwJR-WhMSPZaTYE9xNqABQyFiu9m/view?usp=sharing)| 27 | |jsut-qr_melspec.ckpt|JSUT Baseic5000|Quantized & Resampled|MelSpec|[Download](https://drive.google.com/file/d/1hn_q_hPROZlo_l89b0S2yPY2zUZk-RJf/view?usp=sharing)| 28 | |jsut-qr_vocfeats.ckpt|JSUT Baseic5000|Quantized & Resampled|SourceFilter|[Download](https://drive.google.com/file/d/1_AdhP1KwdOKK_w6dZigiZ3yVe2vkCwyc/view?usp=sharing)| 29 | |jsut-overdrive_melspec.ckpt|JSUT Baseic5000|Overdrive|MelSpec|[Download](https://drive.google.com/file/d/1I1Rhz8GwaUROPX8NOyqBKrAeDOSkaJTA/view?usp=sharing)| 30 | |jsut-overdrive_vocfeats.ckpt|JSUT Baseic5000|Overdrive|SourceFilter|[Download](https://drive.google.com/file/d/1G_YjC8UZTTdDL93vCSQHu0lSiIF4_fmM/view?usp=sharing)| 31 | 32 | ### Supervisedly pretrained models 33 | 34 | Supervisedly pretrained model to apply our method to low-resource settings. 35 | There are two type of the analysis module; `Normal` and `GST`. 36 | `Normal` is to extract restored speech features and channel features simultaneously in the analysis module. 37 | `GST` extracts channel features using a separated GST encoder. 38 | We use the `Normal` method in our paper because we have confirmed that the Normal method is of slightly higher quality in our preliminary experiments. 39 | 40 | |Name|Analysis module type|Feature|Dataset|Link| 41 | |------|---|---|---|---| 42 | |pretrain_melspec_normal.ckpt|Normal|MelSpec|JVS|[Download](https://drive.google.com/file/d/11bqYcyF0OqKogr4pDr7qeS7QysTbeOWd/view?usp=sharing)| 43 | |pretrain_melspec_gst.ckpt|GST|MelSpec|JVS|[Download](https://drive.google.com/file/d/1vX9cTUnBFxjMfx_IP_RDtzEzi0_7z8ks/view?usp=sharing)| 44 | |pretrain_vocfeats_normal.ckpt|Normal|SourceFilter|JVS|[Download](https://drive.google.com/file/d/1d2Nh9bbMEAW6gfy8PP0g3mFJImj8Tes9/view?usp=sharing)| 45 | |pretrain_vocfeats_gst.ckpt|GST|SourceFilter|JVS|[Download](https://drive.google.com/file/d/1Qehs1sU0GSPX5VWqJs5tFaoxtzfPy11j/view?usp=sharing)| 46 | 47 | ### SSL pretarined models for audio effect transfer 48 | 49 | The following model was trained on the real data described in the paper and is intended to be used for audio effect transfer. 50 | This operation enables to give effects to arbitrary speech data as if it were an old recording. 51 | Note that the following model uses `MelSpec` features. 52 | 53 | |Name|Distortion|Link| 54 | |------|---|---| 55 | |tono.ckpt|Tono no mukashibanashi|[Download](https://drive.google.com/file/d/1xJzUNqwwf145YuSFQRZ4KjwxGtcL7rol/view?usp=sharing)| -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | torch==1.10.0 3 | torchaudio==0.10.0 4 | librosa==0.8.1 5 | pyworld==0.3.0 6 | pysptk==0.1.18 7 | matplotlib==3.4.3 8 | PyYAML==5.4.1 9 | SoundFile==0.10.3.post1 10 | pytorch-lightning==1.5.9 11 | tqdm==4.62.1 12 | pypesq==1.2.4 13 | gradio==2.8.12 -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | function download_gdrive () { 2 | FILE_ID=$1 3 | FILE_NAME=$2 4 | curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=${FILE_ID}" > /dev/null 5 | CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)" 6 | curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=${FILE_ID}" -o ${FILE_NAME} 7 | } 8 | 9 | echo "Installing packages ..." 10 | pip install -r requirements.txt 11 | 12 | echo "Downloading pretrained model for audio effect transfer ..." 13 | curl -OL https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/tono_aet_melspec.ckpt 14 | mv tono_aet_melspec.ckpt aet_sample/ 15 | 16 | echo "Downloading pretrained HiFi-GAN for MelSpec ..." 17 | download_gdrive 10OJ2iznutxzp8MEIS6lBVaIS_g5c_70V hifigan_melspec_universal 18 | mv hifigan_melspec_universal hifigan/ 19 | 20 | if [ -n "$1" ]; then 21 | exit 0 22 | fi 23 | 24 | echo "Downloading pretrained HiFi-GAN for SourceFilter ..." 25 | curl -OL https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/hifigan_jvs_40d_600k 26 | mv hifigan_jvs_40d_600k hifigan/ 27 | 28 | mkdir -p data 29 | 30 | echo "Done!" -------------------------------------------------------------------------------- /simulated_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import os 4 | import tqdm 5 | import soundfile as sf 6 | import torch 7 | import torchaudio 8 | import numpy as np 9 | 10 | 11 | def get_arg(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in_dir", required=True, type=pathlib.Path) 14 | parser.add_argument("--out_dir", required=True, type=pathlib.Path) 15 | parser.add_argument( 16 | "--corpus_type", required=True, type=str, choices=["single", "multi"] 17 | ) 18 | parser.add_argument( 19 | "--deg_type", 20 | required=True, 21 | type=str, 22 | choices=["lowpass", "clipping", "mulaw", "overdrive"], 23 | ) 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def lowpass(args): 29 | in_dir = args.in_dir 30 | out_dir = args.out_dir 31 | 32 | if args.corpus_type == "single": 33 | data_type = "single" 34 | else: 35 | data_type = "multi" 36 | 37 | os.makedirs(out_dir, exist_ok=True) 38 | 39 | if data_type == "single": 40 | wavslist = list(in_dir.glob("*.wav")) 41 | elif data_type == "multi": 42 | wavslist = list(in_dir.glob("*/*.wav")) 43 | else: 44 | raise NotImplementedError() 45 | 46 | for wp in tqdm.tqdm(wavslist): 47 | wav, sr = torchaudio.load(wp) 48 | if data_type == "multi": 49 | os.makedirs(out_dir / wp.parent.name, exist_ok=True) 50 | wav_processed = torchaudio.functional.lowpass_biquad( 51 | wav, sample_rate=sr, cutoff_freq=1000, Q=1.0 52 | ) 53 | wav_out = wav_norm(wav_processed, sr) 54 | wav_out = wav_out.squeeze(0).numpy() 55 | if data_type == "single": 56 | sf.write(out_dir / wp.name, wav_out, sr) 57 | else: 58 | sf.write(out_dir / wp.parent.name / wp.name, wav_out, sr) 59 | 60 | 61 | def clipping(args): 62 | in_dir = args.in_dir 63 | out_dir = args.out_dir 64 | 65 | if args.corpus_type == "single": 66 | data_type = "single" 67 | else: 68 | data_type = "multi" 69 | 70 | os.makedirs(out_dir, exist_ok=True) 71 | 72 | if data_type == "single": 73 | wavslist = list(in_dir.glob("*.wav")) 74 | elif data_type == "multi": 75 | wavslist = list(in_dir.glob("*/*.wav")) 76 | else: 77 | raise NotImplementedError() 78 | 79 | eta = 0.25 80 | 81 | for wp in tqdm.tqdm(wavslist): 82 | wav, sr = sf.read(wp) 83 | if data_type == "multi": 84 | os.makedirs(out_dir / wp.parent.name, exist_ok=True) 85 | amp = eta * np.max(wav) 86 | wav_processed = np.maximum(np.minimum(wav, amp), -amp) 87 | wav_processed = torch.from_numpy(wav_processed.astype(np.float32)).unsqueeze(0) 88 | wav_out = wav_norm(wav_processed, sr) 89 | wav_out = wav_out.squeeze(0).numpy() 90 | if data_type == "single": 91 | sf.write(out_dir / wp.name, wav_out, sr) 92 | else: 93 | sf.write(out_dir / wp.parent.name / wp.name, wav_out, sr) 94 | 95 | 96 | def mulaw(args): 97 | in_dir = args.in_dir 98 | out_dir = args.out_dir 99 | 100 | if args.corpus_type == "single": 101 | data_type = "single" 102 | else: 103 | data_type = "multi" 104 | 105 | os.makedirs(out_dir, exist_ok=True) 106 | 107 | if data_type == "single": 108 | wavslist = list(in_dir.glob("*.wav")) 109 | elif data_type == "multi": 110 | wavslist = list(in_dir.glob("*/*.wav")) 111 | else: 112 | raise NotImplementedError() 113 | 114 | for wp in tqdm.tqdm(wavslist): 115 | wav, sr = sf.read(wp) 116 | if data_type == "multi": 117 | os.makedirs(out_dir / wp.parent.name, exist_ok=True) 118 | wav /= torch.max(torch.abs(torch.from_numpy(wav))) 119 | new_freq = 8000 120 | new_quantization = 128 121 | mulaw_encoder = torchaudio.transforms.MuLawEncoding( 122 | quantization_channels=new_quantization 123 | ) 124 | wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0 125 | downsampler = torchaudio.transforms.Resample( 126 | orig_freq=sr, 127 | new_freq=new_freq, 128 | resampling_method="sinc_interpolation", 129 | lowpass_filter_width=6, 130 | dtype=torch.float32, 131 | ) 132 | upsampler = torchaudio.transforms.Resample( 133 | orig_freq=new_freq, 134 | new_freq=sr, 135 | resampling_method="sinc_interpolation", 136 | lowpass_filter_width=6, 137 | dtype=torch.float32, 138 | ) 139 | wav_processed = upsampler(downsampler(wav_quantized)) 140 | wav_out = wav_norm(wav_processed, sr) 141 | wav_out = wav_out.squeeze(0).numpy() 142 | if data_type == "single": 143 | sf.write(out_dir / wp.name, wav_out, sr) 144 | else: 145 | sf.write(out_dir / wp.parent.name / wp.name, wav_out, sr) 146 | 147 | 148 | def overdrive(args): 149 | in_dir = args.in_dir 150 | out_dir = args.out_dir 151 | 152 | if args.corpus_type == "single": 153 | data_type = "single" 154 | else: 155 | data_type = "multi" 156 | 157 | os.makedirs(out_dir, exist_ok=True) 158 | 159 | if data_type == "single": 160 | wavslist = list(in_dir.glob("*.wav")) 161 | elif data_type == "multi": 162 | wavslist = list(in_dir.glob("*/*.wav")) 163 | else: 164 | raise NotImplementedError() 165 | 166 | for wp in tqdm.tqdm(wavslist): 167 | wav, sr = sf.read(wp) 168 | if data_type == "multi": 169 | os.makedirs(out_dir / wp.parent.name, exist_ok=True) 170 | wav_processed = torchaudio.functional.overdrive( 171 | torch.from_numpy(wav.astype(np.float32)).unsqueeze(0), gain=40, colour=20 172 | ) 173 | wav_out = wav_norm(wav_processed, sr) 174 | wav_out = wav_out.squeeze(0).numpy() 175 | if data_type == "single": 176 | sf.write(out_dir / wp.name, wav_out, sr) 177 | else: 178 | sf.write(out_dir / wp.parent.name / wp.name, wav_out, sr) 179 | 180 | 181 | def wav_norm(wav_processed, sr): 182 | wav_out, _ = torchaudio.sox_effects.apply_effects_tensor( 183 | wav_processed, 184 | sr, 185 | [["norm", "{}".format(-3)]], 186 | ) 187 | return wav_out 188 | 189 | 190 | if __name__ == "__main__": 191 | args = get_arg() 192 | if args.deg_type == "lowpass": 193 | lowpass(args) 194 | elif args.deg_type == "clipping": 195 | clipping(args) 196 | elif args.deg_type == "mulaw": 197 | mulaw(args) 198 | elif args.deg_type == "overdrive": 199 | overdrive(args) 200 | else: 201 | raise NotImplementedError() 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | import yaml 5 | from dataset import DataModule 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers.csv_logs import CSVLogger 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 11 | from lightning_module import ( 12 | PretrainLightningModule, 13 | SSLStepLightningModule, 14 | SSLDualLightningModule, 15 | ) 16 | from utils import configure_args 17 | 18 | 19 | def get_arg(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--config_path", required=True, type=pathlib.Path) 22 | parser.add_argument( 23 | "--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"] 24 | ) 25 | parser.add_argument("--run_name", required=True, type=str) 26 | parser.add_argument("--corpus_type", default=None, type=str) 27 | parser.add_argument("--source_path", default=None, type=pathlib.Path) 28 | parser.add_argument("--aux_path", default=None, type=pathlib.Path) 29 | parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path) 30 | parser.add_argument("--n_train", default=None, type=int) 31 | parser.add_argument("--n_val", default=None, type=int) 32 | parser.add_argument("--n_test", default=None, type=int) 33 | parser.add_argument("--epoch", default=None, type=int) 34 | parser.add_argument("--load_pretrained", action="store_true") 35 | parser.add_argument("--pretrained_path", default=None, type=pathlib.Path) 36 | parser.add_argument("--early_stopping", action="store_true") 37 | parser.add_argument("--alpha", default=None, type=float) 38 | parser.add_argument("--beta", default=None, type=float) 39 | parser.add_argument("--learning_rate", default=None, type=float) 40 | parser.add_argument( 41 | "--feature_loss_type", default=None, type=str, choices=["mae", "mse"] 42 | ) 43 | parser.add_argument("--debug", action="store_true") 44 | return parser.parse_args() 45 | 46 | 47 | def train(args, config, output_path): 48 | debug = args.debug 49 | 50 | csvlogger = CSVLogger(save_dir=str(output_path), name="train_log") 51 | tblogger = TensorBoardLogger(save_dir=str(output_path), name="tf_log") 52 | 53 | checkpoint_callback = ModelCheckpoint( 54 | dirpath=str(output_path), 55 | save_weights_only=True, 56 | save_top_k=-1, 57 | every_n_epochs=1, 58 | monitor="val_loss", 59 | ) 60 | callbacks = [checkpoint_callback] 61 | if config["train"]["early_stopping"]: 62 | earlystop_callback = EarlyStopping( 63 | monitor="val_loss", min_delta=0.0, patience=15, mode="min" 64 | ) 65 | callbacks.append(earlystop_callback) 66 | 67 | trainer = Trainer( 68 | max_epochs=1 if debug else config["train"]["epoch"], 69 | gpus=-1, 70 | deterministic=False, 71 | auto_select_gpus=True, 72 | benchmark=True, 73 | default_root_dir=os.getcwd(), 74 | limit_train_batches=0.01 if debug else 1.0, 75 | limit_val_batches=0.5 if debug else 1.0, 76 | callbacks=callbacks, 77 | logger=[csvlogger, tblogger], 78 | gradient_clip_val=config["train"]["grad_clip_thresh"], 79 | flush_logs_every_n_steps=config["train"]["logger_step"], 80 | val_check_interval=0.5, 81 | ) 82 | 83 | if config["general"]["stage"] == "pretrain": 84 | model = PretrainLightningModule(config) 85 | elif config["general"]["stage"] == "ssl-step": 86 | model = SSLStepLightningModule(config) 87 | elif config["general"]["stage"] == "ssl-dual": 88 | model = SSLDualLightningModule(config) 89 | else: 90 | raise NotImplementedError() 91 | 92 | datamodule = DataModule(config) 93 | trainer.fit(model, datamodule=datamodule) 94 | 95 | 96 | if __name__ == "__main__": 97 | 98 | args = get_arg() 99 | config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) 100 | 101 | output_path = pathlib.Path(config["general"]["output_path"]) / args.run_name 102 | os.makedirs(output_path, exist_ok=True) 103 | 104 | config, args = configure_args(config, args) 105 | 106 | train(args=args, config=config, output_path=output_path) 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import librosa.display 2 | import matplotlib.pyplot as plt 3 | import json 4 | import torch 5 | import torchaudio 6 | import hifigan 7 | 8 | 9 | def manual_logging(logger, item, idx, tag, global_step, data_type, config): 10 | 11 | if data_type == "audio": 12 | audio = item[idx, ...].detach().cpu().numpy() 13 | logger.add_audio( 14 | tag, 15 | audio, 16 | global_step, 17 | sample_rate=config["preprocess"]["sampling_rate"], 18 | ) 19 | elif data_type == "image": 20 | image = item[idx, ...].detach().cpu().numpy() 21 | fig, ax = plt.subplots() 22 | _ = librosa.display.specshow( 23 | image, 24 | x_axis="time", 25 | y_axis="linear", 26 | sr=config["preprocess"]["sampling_rate"], 27 | hop_length=config["preprocess"]["frame_shift"], 28 | fmax=config["preprocess"]["sampling_rate"] // 2, 29 | ax=ax, 30 | ) 31 | logger.add_figure(tag, fig, global_step) 32 | else: 33 | raise NotImplementedError( 34 | "Data type given to logger should be [audio] or [image]" 35 | ) 36 | 37 | 38 | def load_vocoder(config): 39 | with open( 40 | "hifigan/config_{}.json".format(config["general"]["feature_type"]), "r" 41 | ) as f: 42 | config_hifigan = hifigan.AttrDict(json.load(f)) 43 | vocoder = hifigan.Generator(config_hifigan) 44 | vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"]) 45 | vocoder.remove_weight_norm() 46 | for param in vocoder.parameters(): 47 | param.requires_grad = False 48 | return vocoder 49 | 50 | 51 | def get_conv_padding(kernel_size, dilation=1): 52 | return int((kernel_size * dilation - dilation) / 2) 53 | 54 | 55 | def plot_and_save_mels(wav, save_path, config): 56 | spec_module = torchaudio.transforms.MelSpectrogram( 57 | sample_rate=config["preprocess"]["sampling_rate"], 58 | n_fft=config["preprocess"]["fft_length"], 59 | win_length=config["preprocess"]["frame_length"], 60 | hop_length=config["preprocess"]["frame_shift"], 61 | f_min=config["preprocess"]["fmin"], 62 | f_max=config["preprocess"]["fmax"], 63 | n_mels=config["preprocess"]["n_mels"], 64 | power=1, 65 | center=True, 66 | norm="slaney", 67 | mel_scale="slaney", 68 | ) 69 | spec = spec_module(wav.unsqueeze(0)) 70 | log_spec = torch.log( 71 | torch.clamp_min(spec, config["preprocess"]["min_magnitude"]) 72 | * config["preprocess"]["comp_factor"] 73 | ) 74 | fig, ax = plt.subplots() 75 | _ = librosa.display.specshow( 76 | log_spec.squeeze(0).numpy(), 77 | x_axis="time", 78 | y_axis="linear", 79 | sr=config["preprocess"]["sampling_rate"], 80 | hop_length=config["preprocess"]["frame_shift"], 81 | fmax=config["preprocess"]["sampling_rate"] // 2, 82 | ax=ax, 83 | cmap="viridis", 84 | ) 85 | fig.savefig(save_path, bbox_inches="tight", pad_inches=0) 86 | 87 | 88 | def plot_and_save_mels_all(wavs, keys, save_path, config): 89 | spec_module = torchaudio.transforms.MelSpectrogram( 90 | sample_rate=config["preprocess"]["sampling_rate"], 91 | n_fft=config["preprocess"]["fft_length"], 92 | win_length=config["preprocess"]["frame_length"], 93 | hop_length=config["preprocess"]["frame_shift"], 94 | f_min=config["preprocess"]["fmin"], 95 | f_max=config["preprocess"]["fmax"], 96 | n_mels=config["preprocess"]["n_mels"], 97 | power=1, 98 | center=True, 99 | norm="slaney", 100 | mel_scale="slaney", 101 | ) 102 | fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18)) 103 | for i, key in enumerate(keys): 104 | wav = wavs[key][0, ...].cpu() 105 | spec = spec_module(wav.unsqueeze(0)) 106 | log_spec = torch.log( 107 | torch.clamp_min(spec, config["preprocess"]["min_magnitude"]) 108 | * config["preprocess"]["comp_factor"] 109 | ) 110 | ax[i // 3, i % 3].set(title=key) 111 | _ = librosa.display.specshow( 112 | log_spec.squeeze(0).numpy(), 113 | x_axis="time", 114 | y_axis="linear", 115 | sr=config["preprocess"]["sampling_rate"], 116 | hop_length=config["preprocess"]["frame_shift"], 117 | fmax=config["preprocess"]["sampling_rate"] // 2, 118 | ax=ax[i // 3, i % 3], 119 | cmap="viridis", 120 | ) 121 | fig.savefig(save_path, bbox_inches="tight", pad_inches=0) 122 | 123 | 124 | def configure_args(config, args): 125 | for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]: 126 | if getattr(args, key) != None: 127 | config["general"][key] = str(getattr(args, key)) 128 | 129 | for key in ["n_train", "n_val", "n_test"]: 130 | if getattr(args, key) != None: 131 | config["preprocess"][key] = getattr(args, key) 132 | 133 | for key in ["alpha", "beta", "learning_rate", "epoch"]: 134 | if getattr(args, key) != None: 135 | config["train"][key] = getattr(args, key) 136 | 137 | for key in ["load_pretrained", "early_stopping"]: 138 | config["train"][key] = getattr(args, key) 139 | 140 | if args.feature_loss_type != None: 141 | config["train"]["feature_loss"]["type"] = args.feature_loss_type 142 | 143 | for key in ["pretrained_path"]: 144 | if getattr(args, key) != None: 145 | config["train"][key] = str(getattr(args, key)) 146 | 147 | return config, args 148 | --------------------------------------------------------------------------------