├── LICENSE ├── README.md ├── convert.py ├── convert_batch.py ├── data ├── __init__.py ├── intra_speaker_dataset.py ├── preprocess_dataset.py └── utils.py ├── docs ├── favicons │ ├── android-icon-144x144.png │ ├── android-icon-192x192.png │ ├── android-icon-36x36.png │ ├── android-icon-48x48.png │ ├── android-icon-72x72.png │ ├── android-icon-96x96.png │ ├── apple-icon-114x114.png │ ├── apple-icon-120x120.png │ ├── apple-icon-144x144.png │ ├── apple-icon-152x152.png │ ├── apple-icon-180x180.png │ ├── apple-icon-57x57.png │ ├── apple-icon-60x60.png │ ├── apple-icon-72x72.png │ ├── apple-icon-76x76.png │ ├── apple-icon-precomposed.png │ ├── apple-icon.png │ ├── browserconfig.xml │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ ├── favicon-96x96.png │ ├── favicon.ico │ ├── manifest.json │ ├── ms-icon-144x144.png │ ├── ms-icon-150x150.png │ ├── ms-icon-310x310.png │ └── ms-icon-70x70.png ├── imgs │ ├── model_arch.png │ ├── p225_p227 │ │ ├── s2s-f2m-p225_002-p227_001_010_025_052_090.attn.png │ │ ├── s2s-f2m-p225_002-p227_002.attn.png │ │ ├── s2s-f2m-p225_020-p227_001_010_027_030_038.attn.png │ │ └── s2s-f2m-p225_020-p227_020.attn.png │ ├── p227_p225 │ │ ├── s2s-m2f-p227_002-p225_001_010_025_052_090.attn.png │ │ ├── s2s-m2f-p227_002-p225_002.attn.png │ │ ├── s2s-m2f-p227_020-p225_001_010_027_030_038.attn.png │ │ └── s2s-m2f-p227_020-p225_020.attn.png │ ├── p228_p232 │ │ ├── s2s-f2m-p228_002-p232_001_010_025_052_090.attn.png │ │ ├── s2s-f2m-p228_002-p232_002.attn.png │ │ ├── s2s-f2m-p228_020-p232_001_010_027_030_038.attn.png │ │ └── s2s-f2m-p228_020-p232_020.attn.png │ ├── p232_p228 │ │ ├── s2s-m2f-p232_002-p228_001_010_025_052_090.attn.png │ │ ├── s2s-m2f-p232_002-p228_002.attn.png │ │ ├── s2s-m2f-p232_020-p228_001_010_027_030_038.attn.png │ │ └── s2s-m2f-p232_020-p228_020.attn.png │ └── smoother_extractor.png ├── index.html ├── style.css └── wavs │ ├── AdaIN_u2u-f2f-slt_a0541-lnh.wav │ ├── AdaIN_u2u-f2m-clb_b0503-rms.wav │ ├── AdaIN_u2u-m2f-bdl_a0283-ljm.wav │ ├── AdaIN_u2u-m2m-rms_a0296-bdl.wav │ ├── AutoVC_u2u-f2f-slt_a0541-lnh.wav │ ├── AutoVC_u2u-f2m-clb_b0503-rms.wav │ ├── AutoVC_u2u-m2f-bdl_a0283-ljm.wav │ ├── AutoVC_u2u-m2m-rms_a0296-bdl.wav │ ├── bdl_a0283.wav │ ├── bdl_a0296.wav │ ├── clb_b0503.wav │ ├── ljm_a0283.wav │ ├── lnh_a0541.wav │ ├── p225_002.wav │ ├── p225_020.wav │ ├── p227_002.wav │ ├── p227_020.wav │ ├── p228_002.wav │ ├── p228_020.wav │ ├── p232_002.wav │ ├── p232_020.wav │ ├── rms_a0296.wav │ ├── rms_b0503.wav │ ├── s2s-f2m-p225_002-p227_001_010_025_052_090.wav │ ├── s2s-f2m-p225_002-p227_002.wav │ ├── s2s-f2m-p225_020-p227_001_010_027_030_038.wav │ ├── s2s-f2m-p225_020-p227_020.wav │ ├── s2s-f2m-p228_002-p232_001_010_025_052_090.wav │ ├── s2s-f2m-p228_002-p232_002.wav │ ├── s2s-f2m-p228_020-p232_001_010_027_030_038.wav │ ├── s2s-f2m-p228_020-p232_020.wav │ ├── s2s-m2f-p227_002-p225_001_010_025_052_090.wav │ ├── s2s-m2f-p227_002-p225_002.wav │ ├── s2s-m2f-p227_020-p225_001_010_027_030_038.wav │ ├── s2s-m2f-p227_020-p225_020.wav │ ├── s2s-m2f-p232_002-p228_001_010_025_052_090.wav │ ├── s2s-m2f-p232_002-p228_002.wav │ ├── s2s-m2f-p232_020-p228_001_010_027_030_038.wav │ ├── s2s-m2f-p232_020-p228_020.wav │ ├── slt_a0541.wav │ ├── u2u-f2f-slt_a0541-lnh.wav │ ├── u2u-f2m-clb_b0503-rms.wav │ ├── u2u-m2f-bdl_a0283-ljm.wav │ └── u2u-m2m-rms_a0296-bdl.wav ├── models ├── __init__.py ├── convolutional_transformer.py ├── model.py └── utils.py ├── preprocess.py ├── requirements.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Yist Lin 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FragmentVC 2 | 3 | Here is the official implementation of the paper, [FragmentVC: Any-to-Any Voice Conversion by End-to-End Extracting and Fusing Fine-Grained Voice Fragments With Attention](https://arxiv.org/abs/2010.14150). 4 | In this paper we proposed FragmentVC, in which the latent phonetic structure of the utterance from the source speaker is obtained from Wav2Vec 2.0, while the spectral features of the utterance(s) from the target speaker are obtained from log mel-spectrograms. 5 | By aligning the hidden structures of the two different feature spaces with a two-stage training process, FragmentVC is able to extract fine-grained voice fragments from the target speaker utterance(s) and fuse them into the desired utterance, all based on the attention mechanism of Transformer as verified with analysis on attention maps, and is accomplished end-to-end. 6 | 7 | The following are the overall model architecture and the conceptual illustration. 8 | 9 | ![Model architecture](docs/imgs/model_arch.png) 10 | 11 | And the architecture of smoother blocks and extractor blocks. 12 | 13 | ![Smoother and extractor blocks](docs/imgs/smoother_extractor.png) 14 | 15 | For the audio samples and attention map analyses, please refer to our [demo page](https://yistlin.github.io/FragmentVC/). 16 | 17 | ## Usage 18 | 19 | You can download the pretrained model as well as the vocoder following the link under **Releases** section on the sidebar. 20 | 21 | The whole project was developed using Python 3.8, torch 1.6, and the pretrained model as well as the vocoder were turned to [TorchScript](https://pytorch.org/docs/stable/jit.html), so it's not guaranteed to be backward compatible. 22 | You can install the dependencies with 23 | 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | If you encounter any problems while installing *fairseq*, please refer to [pytorch/fairseq](https://github.com/pytorch/fairseq) for the installation instruction. 29 | 30 | ### Wav2Vec 31 | 32 | In our implementation, we're using Wav2Vec 2.0 Base w/o finetuning which is trained on LibriSpeech. 33 | You can download the checkpoint [wav2vec_small.pt](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) from [pytorch/fairseq](https://github.com/pytorch/fairseq). 34 | 35 | ### Vocoder 36 | 37 | The WaveRNN-based neural vocoder is from [yistLin/universal-vocoder](https://github.com/yistLin/universal-vocoder) which is based on the paper, [Towards achieving robust universal neural vocoding](https://arxiv.org/abs/1811.06292). 38 | 39 | ## Voice conversion with pretrained models 40 | 41 | You can convert an utterance from source speaker with multiple utterances from target speaker, e.g. 42 | ```bash 43 | python convert.py \ 44 | -w \ 45 | -v \ 46 | -c \ 47 | VCTK-Corpus/wav48/p225/p225_001.wav \ # source utterance 48 | VCTK-Corpus/wav48/p227/p227_002.wav \ # target utterance 1/3 49 | VCTK-Corpus/wav48/p227/p227_003.wav \ # target utterance 2/3 50 | VCTK-Corpus/wav48/p227/p227_004.wav \ # target utterance 3/3 51 | output.wav 52 | ``` 53 | 54 | Or you can prepare a conversion pairs information file in YAML format, like 55 | ```YAML 56 | # pairs_info.yaml 57 | pair1: 58 | source: VCTK-Corpus/wav48/p225/p225_001.wav 59 | target: 60 | - VCTK-Corpus/wav48/p227/p227_001.wav 61 | pair2: 62 | source: VCTK-Corpus/wav48/p225/p225_001.wav 63 | target: 64 | - VCTK-Corpus/wav48/p227/p227_002.wav 65 | - VCTK-Corpus/wav48/p227/p227_003.wav 66 | - VCTK-Corpus/wav48/p227/p227_004.wav 67 | ``` 68 | 69 | And convert multiple pairs at the same time, e.g. 70 | ```bash 71 | python convert_batch.py \ 72 | -w \ 73 | -v \ 74 | -c \ 75 | pairs_info.yaml \ 76 | outputs # the output directory of conversion results 77 | ``` 78 | 79 | After the conversion, the output directory, `outputs`, will be containing 80 | ```text 81 | pair1.wav 82 | pair1.mel.png 83 | pair1.attn.png 84 | pair2.wav 85 | pair2.mel.png 86 | pair2.attn.png 87 | ``` 88 | where `*.wav` are the converted utterances, `*.mel.png` are the plotted mel-spectrograms of the formers, and `*.attn.png` are the attention map between *Conv1d 1* and *Extractor 3* (please refer to the model architecture above). 89 | 90 | ## Train from scratch 91 | 92 | Emperically, if you train the model on the CSTR VCTK Corpus, it would take 1 hr to preprocess the data and around 12 hr to train to 200K steps (on an RTX 2080 Ti). 93 | 94 | ### Preprocessing 95 | 96 | You can preprocess multiple corpora by passing multiple paths. 97 | But each path should be the directory that directly contains the speaker directories, 98 | i.e. 99 | ```bash 100 | python preprocess.py \ 101 | VCTK-Corpus/wav48 \ 102 | LibriTTS/train-clean-360 \ 103 | \ 104 | features # the output directory of preprocessed features 105 | ``` 106 | 107 | After preprocessing, the output directory will be containing: 108 | ```text 109 | metadata.json 110 | utterance-000x7gsj.tar 111 | utterance-00wq7b0f.tar 112 | utterance-01lpqlnr.tar 113 | ... 114 | ``` 115 | 116 | ### Training 117 | 118 | ```bash 119 | python train.py features --save_dir ./ckpts 120 | ``` 121 | 122 | You can further specify `--preload` for preloading all training data into RAM to boost training speed. 123 | If `--comment ` is specified, e.g. `--comment vctk`, the training logs will be placed under a newly created directory like, `logs/2020-02-02_12:34:56_vctk`, otherwise there won't be any logging. 124 | For more details, you can refer to the usage by `python train.py -h`. 125 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Convert using one source utterance and multiple target utterances.""" 3 | 4 | import warnings 5 | from datetime import datetime 6 | from pathlib import Path 7 | from copy import deepcopy 8 | 9 | import torch 10 | import numpy as np 11 | import soundfile as sf 12 | from jsonargparse import ArgumentParser, ActionConfigFile 13 | 14 | import sox 15 | 16 | from data import load_wav, log_mel_spectrogram, plot_mel, plot_attn 17 | from models import load_pretrained_wav2vec 18 | 19 | 20 | def parse_args(): 21 | """Parse command-line arguments.""" 22 | parser = ArgumentParser() 23 | parser.add_argument("source_path", type=str) 24 | parser.add_argument("target_paths", type=str, nargs="+") 25 | parser.add_argument("-w", "--wav2vec_path", type=str, required=True) 26 | parser.add_argument("-c", "--ckpt_path", type=str, required=True) 27 | parser.add_argument("-v", "--vocoder_path", type=str, required=True) 28 | parser.add_argument("-o", "--output_path", type=str, default="output.wav") 29 | 30 | parser.add_argument("--sample_rate", type=int, default=16000) 31 | parser.add_argument("--preemph", type=float, default=0.97) 32 | parser.add_argument("--hop_len", type=int, default=326) 33 | parser.add_argument("--win_len", type=int, default=1304) 34 | parser.add_argument("--n_fft", type=int, default=1304) 35 | parser.add_argument("--n_mels", type=int, default=80) 36 | parser.add_argument("--f_min", type=int, default=80) 37 | parser.add_argument("--audio_config", action=ActionConfigFile) 38 | 39 | return vars(parser.parse_args()) 40 | 41 | 42 | def main( 43 | source_path, 44 | target_paths, 45 | wav2vec_path, 46 | ckpt_path, 47 | vocoder_path, 48 | output_path, 49 | sample_rate, 50 | preemph, 51 | hop_len, 52 | win_len, 53 | n_fft, 54 | n_mels, 55 | f_min, 56 | **kwargs, 57 | ): 58 | """Main function.""" 59 | 60 | begin_time = step_moment = datetime.now() 61 | 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | 64 | wav2vec = load_pretrained_wav2vec(wav2vec_path).to(device) 65 | print("[INFO] Wav2Vec is loaded from", wav2vec_path) 66 | 67 | model = torch.jit.load(ckpt_path).to(device).eval() 68 | print("[INFO] FragmentVC is loaded from", ckpt_path) 69 | 70 | vocoder = torch.jit.load(vocoder_path).to(device).eval() 71 | print("[INFO] Vocoder is loaded from", vocoder_path) 72 | 73 | elaspe_time = datetime.now() - step_moment 74 | step_moment = datetime.now() 75 | print("[INFO] elasped time", elaspe_time.total_seconds()) 76 | 77 | tfm = sox.Transformer() 78 | tfm.vad(location=1) 79 | tfm.vad(location=-1) 80 | 81 | src_wav = load_wav(source_path, sample_rate) 82 | src_wav = deepcopy(tfm.build_array(input_array=src_wav, sample_rate_in=sample_rate)) 83 | src_wav = torch.FloatTensor(src_wav).unsqueeze(0).to(device) 84 | print("[INFO] source waveform shape:", src_wav.shape) 85 | 86 | tgt_mels = [] 87 | for tgt_path in target_paths: 88 | tgt_wav = load_wav(tgt_path, sample_rate) 89 | tgt_wav = tfm.build_array(input_array=tgt_wav, sample_rate_in=sample_rate) 90 | tgt_wav = deepcopy(tgt_wav) 91 | tgt_mel = log_mel_spectrogram( 92 | tgt_wav, preemph, sample_rate, n_mels, n_fft, hop_len, win_len, f_min 93 | ) 94 | tgt_mels.append(tgt_mel) 95 | 96 | tgt_mel = np.concatenate(tgt_mels, axis=0) 97 | tgt_mel = torch.FloatTensor(tgt_mel.T).unsqueeze(0).to(device) 98 | print("[INFO] target spectrograms shape:", tgt_mel.shape) 99 | 100 | with torch.no_grad(): 101 | src_feat = wav2vec.extract_features(src_wav, None)[0] 102 | print("[INFO] source Wav2Vec feature shape:", src_feat.shape) 103 | 104 | elaspe_time = datetime.now() - step_moment 105 | step_moment = datetime.now() 106 | print("[INFO] elasped time", elaspe_time.total_seconds()) 107 | 108 | out_mel, attns = model(src_feat, tgt_mel) 109 | out_mel = out_mel.transpose(1, 2).squeeze(0) 110 | print("[INFO] converted spectrogram shape:", out_mel.shape) 111 | 112 | elaspe_time = datetime.now() - step_moment 113 | step_moment = datetime.now() 114 | print("[INFO] elasped time", elaspe_time.total_seconds()) 115 | 116 | out_wav = vocoder.generate([out_mel])[0] 117 | out_wav = out_wav.cpu().numpy() 118 | print("[INFO] generated waveform shape:", out_wav.shape) 119 | 120 | elaspe_time = datetime.now() - step_moment 121 | step_moment = datetime.now() 122 | print("[INFO] elasped time", elaspe_time.total_seconds()) 123 | 124 | wav_path = Path(output_path) 125 | sf.write(wav_path, out_wav, sample_rate) 126 | print("[INFO] generated waveform is saved to", wav_path) 127 | 128 | mel_path = wav_path.with_suffix(".mel.png") 129 | plot_mel(out_mel, filename=mel_path) 130 | print("[INFO] mel-spectrogram plot is saved to", mel_path) 131 | 132 | attn_path = wav_path.with_suffix(".attn.png") 133 | plot_attn(attns, filename=attn_path) 134 | print("[INFO] attention plot is saved to", attn_path) 135 | 136 | elaspe_time = datetime.now() - begin_time 137 | print("[INFO] Overall elasped time", elaspe_time.total_seconds()) 138 | 139 | 140 | if __name__ == "__main__": 141 | warnings.filterwarnings("ignore") 142 | main(**parse_args()) 143 | -------------------------------------------------------------------------------- /convert_batch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Convert multiple pairs.""" 3 | 4 | import warnings 5 | from pathlib import Path 6 | from functools import partial 7 | from multiprocessing import Pool, cpu_count 8 | 9 | import yaml 10 | import torch 11 | import numpy as np 12 | import soundfile as sf 13 | from jsonargparse import ArgumentParser, ActionConfigFile 14 | 15 | from data import load_wav, log_mel_spectrogram, plot_mel, plot_attn 16 | from models import load_pretrained_wav2vec 17 | 18 | 19 | def parse_args(): 20 | """Parse command-line arguments.""" 21 | parser = ArgumentParser() 22 | parser.add_argument("info_path", type=str) 23 | parser.add_argument("output_dir", type=str, default=".") 24 | parser.add_argument("-c", "--ckpt_path", default="checkpoints/fragmentvc.pt") 25 | parser.add_argument("-w", "--wav2vec_path", default="checkpoints/wav2vec_small.pt") 26 | parser.add_argument("-v", "--vocoder_path", default="checkpoints/vocoder.pt") 27 | 28 | parser.add_argument("--sample_rate", type=int, default=16000) 29 | parser.add_argument("--preemph", type=float, default=0.97) 30 | parser.add_argument("--hop_len", type=int, default=326) 31 | parser.add_argument("--win_len", type=int, default=1304) 32 | parser.add_argument("--n_fft", type=int, default=1304) 33 | parser.add_argument("--n_mels", type=int, default=80) 34 | parser.add_argument("--f_min", type=int, default=80) 35 | parser.add_argument("--audio_config", action=ActionConfigFile) 36 | 37 | return vars(parser.parse_args()) 38 | 39 | 40 | def main( 41 | info_path, 42 | output_dir, 43 | ckpt_path, 44 | wav2vec_path, 45 | vocoder_path, 46 | sample_rate, 47 | preemph, 48 | hop_len, 49 | win_len, 50 | n_fft, 51 | n_mels, 52 | f_min, 53 | **kwargs, 54 | ): 55 | """Main function.""" 56 | 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | 59 | wav2vec = load_pretrained_wav2vec(wav2vec_path).to(device) 60 | print("[INFO] Wav2Vec is loaded from", wav2vec_path) 61 | 62 | model = torch.jit.load(ckpt_path).to(device).eval() 63 | print("[INFO] FragmentVC is loaded from", ckpt_path) 64 | 65 | vocoder = torch.jit.load(vocoder_path).to(device).eval() 66 | print("[INFO] Vocoder is loaded from", vocoder_path) 67 | 68 | path2wav = partial(load_wav, sample_rate=sample_rate) 69 | wav2mel = partial( 70 | log_mel_spectrogram, 71 | preemph=preemph, 72 | sample_rate=sample_rate, 73 | n_mels=n_mels, 74 | n_fft=n_fft, 75 | hop_length=hop_len, 76 | win_length=win_len, 77 | f_min=f_min, 78 | ) 79 | 80 | with open(info_path) as f: 81 | infos = yaml.load(f, Loader=yaml.FullLoader) 82 | 83 | out_mels = [] 84 | attns = [] 85 | 86 | for pair_name, pair in infos.items(): 87 | src_wav = load_wav(pair["source"], sample_rate, trim=True) 88 | src_wav = torch.FloatTensor(src_wav).unsqueeze(0).to(device) 89 | 90 | with Pool(cpu_count()) as pool: 91 | tgt_wavs = pool.map(path2wav, pair["target"]) 92 | tgt_mels = pool.map(wav2mel, tgt_wavs) 93 | 94 | tgt_mel = np.concatenate(tgt_mels, axis=0) 95 | tgt_mel = torch.FloatTensor(tgt_mel.T).unsqueeze(0).to(device) 96 | 97 | with torch.no_grad(): 98 | src_feat = wav2vec.extract_features(src_wav, None)[0] 99 | 100 | out_mel, attn = model(src_feat, tgt_mel) 101 | out_mel = out_mel.transpose(1, 2).squeeze(0) 102 | 103 | out_mels.append(out_mel) 104 | attns.append(attn) 105 | 106 | print(f"[INFO] Pair {pair_name} converted") 107 | 108 | print("[INFO] Generating waveforms...") 109 | 110 | with torch.no_grad(): 111 | out_wavs = vocoder.generate(out_mels) 112 | 113 | print("[INFO] Waveforms generated") 114 | 115 | out_dir = Path(output_dir) 116 | out_dir.mkdir(parents=True, exist_ok=True) 117 | 118 | for pair_name, out_mel, out_wav, attn in zip( 119 | infos.keys(), out_mels, out_wavs, attns 120 | ): 121 | out_wav = out_wav.cpu().numpy() 122 | out_path = Path(out_dir, pair_name) 123 | 124 | plot_mel(out_mel, filename=out_path.with_suffix(".mel.png")) 125 | plot_attn(attn, filename=out_path.with_suffix(".attn.png")) 126 | sf.write(out_path.with_suffix(".wav"), out_wav, sample_rate) 127 | 128 | 129 | if __name__ == "__main__": 130 | warnings.filterwarnings("ignore") 131 | main(**parse_args()) 132 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess_dataset import PreprocessDataset 2 | from .intra_speaker_dataset import IntraSpeakerDataset, collate_batch 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /data/intra_speaker_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset for reconstruction scheme.""" 2 | 3 | import json 4 | import random 5 | from pathlib import Path 6 | from concurrent.futures import ThreadPoolExecutor 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from torch.utils.data import Dataset 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | 14 | class IntraSpeakerDataset(Dataset): 15 | """Dataset for reconstruction scheme. 16 | 17 | Returns: 18 | speaker_id: speaker id number. 19 | feat: Wav2Vec feature tensor. 20 | mel: log mel spectrogram tensor. 21 | """ 22 | 23 | def __init__(self, data_dir, metadata_path, n_samples=5, pre_load=False): 24 | with open(metadata_path, "r") as f: 25 | metadata = json.load(f) 26 | 27 | executor = ThreadPoolExecutor(max_workers=4) 28 | futures = [] 29 | 30 | for speaker_name, utterances in metadata.items(): 31 | for utterance in utterances: 32 | futures.append( 33 | executor.submit( 34 | _process_data, 35 | speaker_name, 36 | data_dir, 37 | utterance["feature_path"], 38 | pre_load, 39 | ) 40 | ) 41 | 42 | self.data = [] 43 | self.speaker_to_indices = {} 44 | for i, future in enumerate(tqdm(futures, ncols=0)): 45 | result = future.result() 46 | speaker_name = result[0] 47 | self.data.append(result) 48 | if speaker_name not in self.speaker_to_indices: 49 | self.speaker_to_indices[speaker_name] = [i] 50 | else: 51 | self.speaker_to_indices[speaker_name].append(i) 52 | 53 | self.data_dir = Path(data_dir) 54 | self.n_samples = n_samples 55 | self.pre_load = pre_load 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | def _get_data(self, index): 61 | if self.pre_load: 62 | speaker_name, content_emb, target_mel = self.data[index] 63 | else: 64 | speaker_name, content_emb, target_mel = _load_data(*self.data[index]) 65 | return speaker_name, content_emb, target_mel 66 | 67 | def __getitem__(self, index): 68 | speaker_name, content_emb, target_mel = self._get_data(index) 69 | utterance_indices = self.speaker_to_indices[speaker_name].copy() 70 | utterance_indices.remove(index) 71 | 72 | sampled_mels = [] 73 | for sampled_id in random.sample(utterance_indices, self.n_samples): 74 | sampled_mel = self._get_data(sampled_id)[2] 75 | sampled_mels.append(sampled_mel) 76 | 77 | reference_mels = torch.cat(sampled_mels, dim=0) 78 | 79 | return content_emb, reference_mels, target_mel 80 | 81 | 82 | def _process_data(speaker_name, data_dir, feature_path, load): 83 | if load: 84 | return _load_data(speaker_name, data_dir, feature_path) 85 | else: 86 | return speaker_name, data_dir, feature_path 87 | 88 | 89 | def _load_data(speaker_name, data_dir, feature_path): 90 | feature = torch.load(Path(data_dir, feature_path)) 91 | content_emb = feature["feat"] 92 | target_mel = feature["mel"] 93 | 94 | return speaker_name, content_emb, target_mel 95 | 96 | 97 | def collate_batch(batch): 98 | """Collate a batch of data.""" 99 | srcs, refs, tgts = zip(*batch) 100 | 101 | src_lens = [len(src) for src in srcs] 102 | ref_lens = [len(ref) for ref in refs] 103 | tgt_lens = [len(tgt) for tgt in tgts] 104 | overlap_lens = [ 105 | min(src_len, tgt_len) for src_len, tgt_len in zip(src_lens, tgt_lens) 106 | ] 107 | 108 | srcs = pad_sequence(srcs, batch_first=True) # (batch, max_src_len, wav2vec_dim) 109 | 110 | src_masks = [torch.arange(srcs.size(1)) >= src_len for src_len in src_lens] 111 | src_masks = torch.stack(src_masks) # (batch, max_src_len) 112 | 113 | refs = pad_sequence(refs, batch_first=True, padding_value=-20) 114 | refs = refs.transpose(1, 2) # (batch, mel_dim, max_ref_len) 115 | 116 | ref_masks = [torch.arange(refs.size(2)) >= ref_len for ref_len in ref_lens] 117 | ref_masks = torch.stack(ref_masks) # (batch, max_ref_len) 118 | 119 | tgts = pad_sequence(tgts, batch_first=True, padding_value=-20) 120 | tgts = tgts.transpose(1, 2) # (batch, mel_dim, max_tgt_len) 121 | 122 | tgt_masks = [torch.arange(tgts.size(2)) >= tgt_len for tgt_len in tgt_lens] 123 | tgt_masks = torch.stack(tgt_masks) # (batch, max_tgt_len) 124 | 125 | return srcs, src_masks, refs, ref_masks, tgts, tgt_masks, overlap_lens 126 | -------------------------------------------------------------------------------- /data/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | """Precompute Wav2Vec features and spectrograms.""" 2 | 3 | from copy import deepcopy 4 | from pathlib import Path 5 | 6 | import torch 7 | from librosa.util import find_files 8 | 9 | import sox 10 | 11 | from .utils import load_wav, log_mel_spectrogram 12 | 13 | 14 | class PreprocessDataset(torch.utils.data.Dataset): 15 | """Prefetch audio data for preprocessing.""" 16 | 17 | def __init__( 18 | self, 19 | data_dirs, 20 | trim_method, 21 | sample_rate, 22 | preemph, 23 | hop_len, 24 | win_len, 25 | n_fft, 26 | n_mels, 27 | f_min, 28 | ): 29 | 30 | data = [] 31 | 32 | for data_dir in data_dirs: 33 | data_dir_path = Path(data_dir) 34 | speaker_dirs = [x for x in data_dir_path.iterdir() if x.is_dir()] 35 | 36 | for speaker_dir in speaker_dirs: 37 | audio_paths = find_files(speaker_dir) 38 | if len(audio_paths) == 0: 39 | continue 40 | 41 | speaker_name = speaker_dir.name 42 | for audio_path in audio_paths: 43 | data.append((speaker_name, audio_path)) 44 | 45 | self.trim_method = trim_method 46 | self.sample_rate = sample_rate 47 | self.preemph = preemph 48 | self.hop_len = hop_len 49 | self.win_len = win_len 50 | self.n_fft = n_fft 51 | self.n_mels = n_mels 52 | self.f_min = f_min 53 | self.data = data 54 | 55 | if trim_method == "vad": 56 | tfm = sox.Transformer() 57 | tfm.vad(location=1) 58 | tfm.vad(location=-1) 59 | self.sox_transform = tfm 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, index): 65 | speaker_name, audio_path = self.data[index] 66 | 67 | if self.trim_method == "librosa": 68 | wav = load_wav(audio_path, self.sample_rate, trim=True) 69 | elif self.trim_method == "vad": 70 | wav = load_wav(audio_path, self.sample_rate) 71 | trim_wav = self.sox_transform.build_array( 72 | input_array=wav, sample_rate_in=self.sample_rate 73 | ) 74 | wav = deepcopy(trim_wav if len(trim_wav) > 10 else wav) 75 | 76 | mel = log_mel_spectrogram( 77 | wav, 78 | self.preemph, 79 | self.sample_rate, 80 | self.n_mels, 81 | self.n_fft, 82 | self.hop_len, 83 | self.win_len, 84 | self.f_min, 85 | ) 86 | return speaker_name, audio_path, torch.FloatTensor(wav), torch.FloatTensor(mel) 87 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for data manipulation.""" 2 | 3 | from typing import Union 4 | from pathlib import Path 5 | 6 | import librosa 7 | import numpy as np 8 | import matplotlib 9 | from matplotlib import pyplot as plt 10 | from scipy.signal import lfilter 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | def load_wav( 16 | audio_path: Union[str, Path], sample_rate: int, trim: bool = False 17 | ) -> np.ndarray: 18 | """Load and preprocess waveform.""" 19 | wav = librosa.load(audio_path, sr=sample_rate)[0] 20 | wav = wav / (np.abs(wav).max() + 1e-6) 21 | if trim: 22 | _, (start_frame, end_frame) = librosa.effects.trim( 23 | wav, top_db=25, frame_length=512, hop_length=128 24 | ) 25 | start_frame = max(0, start_frame - 0.1 * sample_rate) 26 | end_frame = min(len(wav), end_frame + 0.1 * sample_rate) 27 | 28 | start = int(start_frame) 29 | end = int(end_frame) 30 | if end - start > 1000: # prevent empty slice 31 | wav = wav[start:end] 32 | 33 | return wav 34 | 35 | 36 | def log_mel_spectrogram( 37 | x: np.ndarray, 38 | preemph: float, 39 | sample_rate: int, 40 | n_mels: int, 41 | n_fft: int, 42 | hop_length: int, 43 | win_length: int, 44 | f_min: int, 45 | ) -> np.ndarray: 46 | """Create a log Mel spectrogram from a raw audio signal.""" 47 | x = lfilter([1, -preemph], [1], x) 48 | magnitude = np.abs( 49 | librosa.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 50 | ) 51 | mel_fb = librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=f_min) 52 | mel_spec = np.dot(mel_fb, magnitude) 53 | log_mel_spec = np.log(mel_spec + 1e-9) 54 | return log_mel_spec.T 55 | 56 | 57 | def plot_mel(gt_mel, predicted_mel=None, filename="mel.png"): 58 | if predicted_mel is not None: 59 | fig, axes = plt.subplots(2, 1, squeeze=False, figsize=(10, 10)) 60 | else: 61 | fig, axes = plt.subplots(1, 1, squeeze=False, figsize=(10, 10)) 62 | 63 | axes[0][0].imshow(gt_mel.detach().cpu().numpy().T, origin="lower") 64 | axes[0][0].set_aspect(1, adjustable="box") 65 | axes[0][0].set_ylim(1.0, 80) 66 | axes[0][0].set_title("ground-truth mel-spectrogram", fontsize="medium") 67 | axes[0][0].tick_params(labelsize="x-small", left=False, labelleft=False) 68 | 69 | if predicted_mel is not None: 70 | axes[1][0].imshow(predicted_mel.detach().cpu().numpy(), origin="lower") 71 | axes[1][0].set_aspect(1.0, adjustable="box") 72 | axes[1][0].set_ylim(0, 80) 73 | axes[1][0].set_title("predicted mel-spectrogram", fontsize="medium") 74 | axes[1][0].tick_params(labelsize="x-small", left=False, labelleft=False) 75 | 76 | plt.tight_layout() 77 | plt.savefig(filename) 78 | plt.close() 79 | 80 | 81 | def plot_attn(attn, filename="attn.png"): 82 | fig, axes = plt.subplots(len(attn), 1, squeeze=False, figsize=(10, 10)) 83 | 84 | for i, layer_attn in enumerate(attn): 85 | axes[i][0].imshow(attn[i][0].detach().cpu().numpy(), origin="lower") 86 | axes[i][0].set_title("layer {}".format(i), fontsize="medium") 87 | axes[i][0].tick_params(labelsize="x-small") 88 | axes[i][0].set_xlabel("target") 89 | axes[i][0].set_ylabel("source") 90 | 91 | plt.tight_layout() 92 | plt.savefig(filename) 93 | plt.close() 94 | -------------------------------------------------------------------------------- /docs/favicons/android-icon-144x144.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-144x144.png -------------------------------------------------------------------------------- /docs/favicons/android-icon-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-192x192.png -------------------------------------------------------------------------------- /docs/favicons/android-icon-36x36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-36x36.png -------------------------------------------------------------------------------- /docs/favicons/android-icon-48x48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-48x48.png -------------------------------------------------------------------------------- /docs/favicons/android-icon-72x72.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-72x72.png -------------------------------------------------------------------------------- /docs/favicons/android-icon-96x96.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/android-icon-96x96.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-114x114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-114x114.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-120x120.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-144x144.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-144x144.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-152x152.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-180x180.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-57x57.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-57x57.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-60x60.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-72x72.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-72x72.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-76x76.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon-precomposed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon-precomposed.png -------------------------------------------------------------------------------- /docs/favicons/apple-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/apple-icon.png -------------------------------------------------------------------------------- /docs/favicons/browserconfig.xml: -------------------------------------------------------------------------------- 1 | 2 | #ffffff -------------------------------------------------------------------------------- /docs/favicons/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/favicon-16x16.png -------------------------------------------------------------------------------- /docs/favicons/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/favicon-32x32.png -------------------------------------------------------------------------------- /docs/favicons/favicon-96x96.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/favicon-96x96.png -------------------------------------------------------------------------------- /docs/favicons/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/favicon.ico -------------------------------------------------------------------------------- /docs/favicons/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "App", 3 | "icons": [ 4 | { 5 | "src": "\/android-icon-36x36.png", 6 | "sizes": "36x36", 7 | "type": "image\/png", 8 | "density": "0.75" 9 | }, 10 | { 11 | "src": "\/android-icon-48x48.png", 12 | "sizes": "48x48", 13 | "type": "image\/png", 14 | "density": "1.0" 15 | }, 16 | { 17 | "src": "\/android-icon-72x72.png", 18 | "sizes": "72x72", 19 | "type": "image\/png", 20 | "density": "1.5" 21 | }, 22 | { 23 | "src": "\/android-icon-96x96.png", 24 | "sizes": "96x96", 25 | "type": "image\/png", 26 | "density": "2.0" 27 | }, 28 | { 29 | "src": "\/android-icon-144x144.png", 30 | "sizes": "144x144", 31 | "type": "image\/png", 32 | "density": "3.0" 33 | }, 34 | { 35 | "src": "\/android-icon-192x192.png", 36 | "sizes": "192x192", 37 | "type": "image\/png", 38 | "density": "4.0" 39 | } 40 | ] 41 | } -------------------------------------------------------------------------------- /docs/favicons/ms-icon-144x144.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/ms-icon-144x144.png -------------------------------------------------------------------------------- /docs/favicons/ms-icon-150x150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/ms-icon-150x150.png -------------------------------------------------------------------------------- /docs/favicons/ms-icon-310x310.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/ms-icon-310x310.png -------------------------------------------------------------------------------- /docs/favicons/ms-icon-70x70.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/favicons/ms-icon-70x70.png -------------------------------------------------------------------------------- /docs/imgs/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/model_arch.png -------------------------------------------------------------------------------- /docs/imgs/p225_p227/s2s-f2m-p225_002-p227_001_010_025_052_090.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p225_p227/s2s-f2m-p225_002-p227_001_010_025_052_090.attn.png -------------------------------------------------------------------------------- /docs/imgs/p225_p227/s2s-f2m-p225_002-p227_002.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p225_p227/s2s-f2m-p225_002-p227_002.attn.png -------------------------------------------------------------------------------- /docs/imgs/p225_p227/s2s-f2m-p225_020-p227_001_010_027_030_038.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p225_p227/s2s-f2m-p225_020-p227_001_010_027_030_038.attn.png -------------------------------------------------------------------------------- /docs/imgs/p225_p227/s2s-f2m-p225_020-p227_020.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p225_p227/s2s-f2m-p225_020-p227_020.attn.png -------------------------------------------------------------------------------- /docs/imgs/p227_p225/s2s-m2f-p227_002-p225_001_010_025_052_090.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p227_p225/s2s-m2f-p227_002-p225_001_010_025_052_090.attn.png -------------------------------------------------------------------------------- /docs/imgs/p227_p225/s2s-m2f-p227_002-p225_002.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p227_p225/s2s-m2f-p227_002-p225_002.attn.png -------------------------------------------------------------------------------- /docs/imgs/p227_p225/s2s-m2f-p227_020-p225_001_010_027_030_038.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p227_p225/s2s-m2f-p227_020-p225_001_010_027_030_038.attn.png -------------------------------------------------------------------------------- /docs/imgs/p227_p225/s2s-m2f-p227_020-p225_020.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p227_p225/s2s-m2f-p227_020-p225_020.attn.png -------------------------------------------------------------------------------- /docs/imgs/p228_p232/s2s-f2m-p228_002-p232_001_010_025_052_090.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p228_p232/s2s-f2m-p228_002-p232_001_010_025_052_090.attn.png -------------------------------------------------------------------------------- /docs/imgs/p228_p232/s2s-f2m-p228_002-p232_002.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p228_p232/s2s-f2m-p228_002-p232_002.attn.png -------------------------------------------------------------------------------- /docs/imgs/p228_p232/s2s-f2m-p228_020-p232_001_010_027_030_038.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p228_p232/s2s-f2m-p228_020-p232_001_010_027_030_038.attn.png -------------------------------------------------------------------------------- /docs/imgs/p228_p232/s2s-f2m-p228_020-p232_020.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p228_p232/s2s-f2m-p228_020-p232_020.attn.png -------------------------------------------------------------------------------- /docs/imgs/p232_p228/s2s-m2f-p232_002-p228_001_010_025_052_090.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p232_p228/s2s-m2f-p232_002-p228_001_010_025_052_090.attn.png -------------------------------------------------------------------------------- /docs/imgs/p232_p228/s2s-m2f-p232_002-p228_002.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p232_p228/s2s-m2f-p232_002-p228_002.attn.png -------------------------------------------------------------------------------- /docs/imgs/p232_p228/s2s-m2f-p232_020-p228_001_010_027_030_038.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p232_p228/s2s-m2f-p232_020-p228_001_010_027_030_038.attn.png -------------------------------------------------------------------------------- /docs/imgs/p232_p228/s2s-m2f-p232_020-p228_020.attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/p232_p228/s2s-m2f-p232_020-p228_020.attn.png -------------------------------------------------------------------------------- /docs/imgs/smoother_extractor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/imgs/smoother_extractor.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | FragmentVC Demo 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
31 |
32 |

Audio & Attention Map Demo

33 |

FRAGMENTVC: ANY-TO-ANY VOICE CONVERSION BY END-TO-END EXTRACTING AND FUSING FINE-GRAINED VOICE FRAGMENTS WITH ATTENTION

34 |
35 |

36 | Abstract: 37 | Any-to-any voice conversion aims to convert the voice from and to any speakers even unseen during training, which is much more challenging compared to one-to-one or many-to-many tasks, but much more attractive in real-world scenarios. 38 | In this paper we proposed FragmentVC, in which the latent phonetic structure of the utterance from the source speaker is obtained from Wav2Vec 2.0, while the spectral features of the utterance(s) from the target speaker are obtained from log mel-spectrograms. 39 | By aligning the hidden structures of the two different feature spaces with a two-stage training process, FragmentVC is able to extract fine-grained voice fragments from the target speaker utterance(s) and fuse them into the desired utterance, all based on the attention mechanism of Transformer as verified with analysis on attention maps, and is accomplished end-to-end. 40 | This approach is trained with reconstruction loss only without any disentanglement considerations between content and speaker information and doesn't require parallel data. 41 | Objective evaluation based on speaker verification and subjective evaluation with MOS both showed that this approach outperformed SOTA approaches, such as AdaIN-VC and AutoVC. 42 |

43 |

44 | GitHub (Source Code) 45 | arXiv (Preprint) 46 |

47 |
48 |
49 | 50 |
51 |

Seen-to-seen conversion

52 | 53 | In the following sections, there are 4 conversion pairs, each containing 4 speech utterances. 54 | The first 2 utterances are drawn from the CSTR VCTK Corpus: 55 |
    56 |
  1. 57 | An utterance from the source speaker, termed as source utterance 58 |
  2. 59 |
  3. 60 | An utterance from the target speaker, which is of the same word transcription as the source utterance, termed as authentic utterance 61 |
  4. 62 |
63 | 64 | And the rest 2 utterances are the conversion results using the source utterance as source: 65 |
    66 |
  1. 67 | A synthetic utterance generated with the authentic utterance as target 68 |
  2. 69 |
  3. 70 | A synthetic utterance generated with 5 randomly sampled utterances from the target speaker as target 71 |
  4. 72 |
73 | 74 |
75 |
Pair 1
76 |
77 |
78 |
79 | Source speaker 80 |
81 |
82 | p225 83 |
84 |
85 | Target speaker 86 |
87 |
88 | p227 89 |
90 |
91 | Transcription 92 |
93 |
94 | “ 95 | Ask her to bring these things with her from the store. 96 | ” 97 |
98 |
99 | Source utterance 100 |
101 |
102 | 106 |
107 |
108 | Authentic utterance 109 |
110 | 111 | from the target speaker 112 | 113 |
114 |
115 | 119 |
120 |
121 | Conversion result 122 |
123 | 124 | with the authentic utterance as target 125 | 126 |
127 |
128 | 132 |
133 | 134 |
135 |
136 | Conversion result 137 |
138 | 139 | with 5 randomly sampled target utterances 140 | 141 |
142 |
143 | 147 |
148 | 149 |
150 |
151 |
152 |
153 | 154 |
155 |
Pair 2
156 |
157 |
158 |
159 | Source speaker 160 |
161 |
162 | p227 163 |
164 |
165 | Target speaker 166 |
167 |
168 | p225 169 |
170 |
171 | Transcription 172 |
173 |
174 | “ 175 | Many complicated ideas about the rainbow have been formed. 176 | ” 177 |
178 |
179 | Source utterance 180 |
181 |
182 | 186 |
187 |
188 | Authentic utterance 189 |
190 | 191 | from the target speaker 192 | 193 |
194 |
195 | 199 |
200 |
201 | Conversion result 202 |
203 | 204 | with the authentic utterance as target 205 | 206 |
207 |
208 | 212 |
213 | 214 |
215 |
216 | Conversion result 217 |
218 | 219 | with 5 randomly sampled target utterances 220 | 221 |
222 |
223 | 227 |
228 | 229 |
230 |
231 |
232 |
233 | 234 |
235 |
Pair 3
236 |
237 |
238 |
239 | Source speaker 240 |
241 |
242 | p228 243 |
244 |
245 | Target speaker 246 |
247 |
248 | p232 249 |
250 |
251 | Transcription 252 |
253 |
254 | “ 255 | Many complicated ideas about the rainbow have been formed. 256 | ” 257 |
258 |
259 | Source utterance 260 |
261 |
262 | 266 |
267 |
268 | Authentic utterance 269 |
270 | 271 | from the target speaker 272 | 273 |
274 |
275 | 279 |
280 |
281 | Conversion result 282 |
283 | 284 | with the authentic utterance as target 285 | 286 |
287 |
288 | 292 |
293 | 294 |
295 |
296 | Conversion result 297 |
298 | 299 | with 5 randomly sampled target utterances 300 | 301 |
302 |
303 | 307 |
308 | 309 |
310 |
311 |
312 |
313 | 314 |
315 |
Pair 4
316 |
317 |
318 |
319 | Source speaker 320 |
321 |
322 | p232 323 |
324 |
325 | Target speaker 326 |
327 |
328 | p228 329 |
330 |
331 | Transcription 332 |
333 |
334 | “ 335 | Ask her to bring these things with her from the store. 336 | ” 337 |
338 |
339 | Source utterance 340 |
341 |
342 | 346 |
347 |
348 | Authentic utterance 349 |
350 | 351 | from the target speaker 352 | 353 |
354 |
355 | 359 |
360 |
361 | Conversion result 362 |
363 | 364 | with the authentic utterance as target 365 | 366 |
367 |
368 | 372 |
373 | 374 |
375 |
376 | Conversion result 377 |
378 | 379 | with 5 randomly sampled target utterances 380 | 381 |
382 |
383 | 387 |
388 | 389 |
390 |
391 |
392 |
393 | 394 |

Unseen-to-unseen conversion

395 | 396 | In the following sections, there are 4 conversion pairs, each containing 4 speech utterances. 397 | The first 2 utterances are drawn from the CMU Arctic dataset: 398 |
    399 |
  1. 400 | An utterance from the source speaker, termed as source utterance 401 |
  2. 402 |
  3. 403 | An utterance from the target speaker, which is of the same word transcription as the source utterance, termed as authentic utterance 404 |
  4. 405 |
406 | 407 |

408 | And the last one is the conversion result generated with the source utterance as source and 10 randomly sampled utterances from the target speaker as target. 409 |

410 | 411 |
412 |
Pair 1
413 |
414 |
415 |
416 | Source speaker 417 |
418 |
419 | slt 420 |
421 |
422 | Target speaker 423 |
424 |
425 | lnh 426 |
427 |
428 | Transcription 429 |
430 |
431 | “ 432 | The Warden with a quart of champagne. 433 | ” 434 |
435 |
436 | Source utterance 437 |
438 |
439 | 443 |
444 |
445 | Authentic utterance 446 |
447 | 448 | from the target speaker 449 | 450 |
451 |
452 | 456 |
457 |
458 | 459 | Conversion results 460 | 461 |
462 |
463 |
464 |
465 |
466 | FragmentVC 467 |
468 | 469 | with 10 randomly sampled target utterances 470 | 471 |
472 |
473 | 477 |
478 |
479 | AdaIN 480 |
481 |
482 | 486 |
487 |
488 | AutoVC 489 |
490 |
491 | 495 |
496 |
497 |
498 |
499 | 500 |
501 |
Pair 2
502 |
503 |
504 |
505 | Source speaker 506 |
507 |
508 | clb 509 |
510 |
511 | Target speaker 512 |
513 |
514 | rms 515 |
516 |
517 | Transcription 518 |
519 |
520 | “ 521 | The scents of strange vegetation blew off the tropic land. 522 | ” 523 |
524 |
525 | Source utterance 526 |
527 |
528 | 532 |
533 |
534 | Authentic utterance 535 |
536 | 537 | from the target speaker 538 | 539 |
540 |
541 | 545 |
546 |
547 | 548 | Conversion results 549 | 550 |
551 |
552 |
553 |
554 |
555 | FragmentVC 556 |
557 | 558 | with 10 randomly sampled target utterances 559 | 560 |
561 |
562 | 566 |
567 |
568 | AdaIN 569 |
570 |
571 | 575 |
576 |
577 | AutoVC 578 |
579 |
580 | 584 |
585 |
586 |
587 |
588 | 589 | 590 |
591 |
Pair 3
592 |
593 |
594 |
595 | Source speaker 596 |
597 |
598 | bdl 599 |
600 |
601 | Target speaker 602 |
603 |
604 | ljm 605 |
606 |
607 | Transcription 608 |
609 |
610 | “ 611 | The woman in you is only incidental, accidental, and irrelevant. 612 | ” 613 |
614 |
615 | Source utterance 616 |
617 |
618 | 622 |
623 |
624 | Authentic utterance 625 |
626 | 627 | from the target speaker 628 | 629 |
630 |
631 | 635 |
636 |
637 | 638 | Conversion results 639 | 640 |
641 |
642 |
643 |
644 |
645 | FragmentVC 646 |
647 | 648 | with 10 randomly sampled target utterances 649 | 650 |
651 |
652 | 656 |
657 |
658 | AdaIN 659 |
660 |
661 | 665 |
666 |
667 | AutoVC 668 |
669 |
670 | 674 |
675 |
676 |
677 |
678 | 679 |
680 |
Pair 4
681 |
682 |
683 |
684 | Source speaker 685 |
686 |
687 | rms 688 |
689 |
690 | Target speaker 691 |
692 |
693 | bdl 694 |
695 |
696 | Transcription 697 |
698 |
699 | “ 700 | Bassett was a fastidious man. 701 | ” 702 |
703 |
704 | Source utterance 705 |
706 |
707 | 711 |
712 |
713 | Authentic utterance 714 |
715 | 716 | from the target speaker 717 | 718 |
719 |
720 | 724 |
725 |
726 | 727 | Conversion results 728 | 729 |
730 |
731 |
732 |
733 |
734 | FragmentVC 735 |
736 | 737 | with 10 randomly sampled target utterances 738 | 739 |
740 |
741 | 745 |
746 |
747 | AdaIN 748 |
749 |
750 | 754 |
755 |
756 | AutoVC 757 |
758 |
759 | 763 |
764 |
765 |
766 |
767 | 768 |
769 | 770 |
771 |

© 台大語音實驗室 NTU Speech Lab

772 |
773 | 774 | 775 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- 1 | p { 2 | text-align: justify; 3 | text-justify: inter-word; 4 | } 5 | audio { 6 | border-radius: 0%; 7 | } 8 | img { 9 | max-height:500px; 10 | max-width: 100%; 11 | height:auto; 12 | width:auto; 13 | } -------------------------------------------------------------------------------- /docs/wavs/AdaIN_u2u-f2f-slt_a0541-lnh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AdaIN_u2u-f2f-slt_a0541-lnh.wav -------------------------------------------------------------------------------- /docs/wavs/AdaIN_u2u-f2m-clb_b0503-rms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AdaIN_u2u-f2m-clb_b0503-rms.wav -------------------------------------------------------------------------------- /docs/wavs/AdaIN_u2u-m2f-bdl_a0283-ljm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AdaIN_u2u-m2f-bdl_a0283-ljm.wav -------------------------------------------------------------------------------- /docs/wavs/AdaIN_u2u-m2m-rms_a0296-bdl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AdaIN_u2u-m2m-rms_a0296-bdl.wav -------------------------------------------------------------------------------- /docs/wavs/AutoVC_u2u-f2f-slt_a0541-lnh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AutoVC_u2u-f2f-slt_a0541-lnh.wav -------------------------------------------------------------------------------- /docs/wavs/AutoVC_u2u-f2m-clb_b0503-rms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AutoVC_u2u-f2m-clb_b0503-rms.wav -------------------------------------------------------------------------------- /docs/wavs/AutoVC_u2u-m2f-bdl_a0283-ljm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AutoVC_u2u-m2f-bdl_a0283-ljm.wav -------------------------------------------------------------------------------- /docs/wavs/AutoVC_u2u-m2m-rms_a0296-bdl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/AutoVC_u2u-m2m-rms_a0296-bdl.wav -------------------------------------------------------------------------------- /docs/wavs/bdl_a0283.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/bdl_a0283.wav -------------------------------------------------------------------------------- /docs/wavs/bdl_a0296.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/bdl_a0296.wav -------------------------------------------------------------------------------- /docs/wavs/clb_b0503.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/clb_b0503.wav -------------------------------------------------------------------------------- /docs/wavs/ljm_a0283.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/ljm_a0283.wav -------------------------------------------------------------------------------- /docs/wavs/lnh_a0541.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/lnh_a0541.wav -------------------------------------------------------------------------------- /docs/wavs/p225_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p225_002.wav -------------------------------------------------------------------------------- /docs/wavs/p225_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p225_020.wav -------------------------------------------------------------------------------- /docs/wavs/p227_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p227_002.wav -------------------------------------------------------------------------------- /docs/wavs/p227_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p227_020.wav -------------------------------------------------------------------------------- /docs/wavs/p228_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p228_002.wav -------------------------------------------------------------------------------- /docs/wavs/p228_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p228_020.wav -------------------------------------------------------------------------------- /docs/wavs/p232_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p232_002.wav -------------------------------------------------------------------------------- /docs/wavs/p232_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/p232_020.wav -------------------------------------------------------------------------------- /docs/wavs/rms_a0296.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/rms_a0296.wav -------------------------------------------------------------------------------- /docs/wavs/rms_b0503.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/rms_b0503.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p225_002-p227_001_010_025_052_090.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p225_002-p227_001_010_025_052_090.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p225_002-p227_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p225_002-p227_002.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p225_020-p227_001_010_027_030_038.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p225_020-p227_001_010_027_030_038.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p225_020-p227_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p225_020-p227_020.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p228_002-p232_001_010_025_052_090.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p228_002-p232_001_010_025_052_090.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p228_002-p232_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p228_002-p232_002.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p228_020-p232_001_010_027_030_038.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p228_020-p232_001_010_027_030_038.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-f2m-p228_020-p232_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-f2m-p228_020-p232_020.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p227_002-p225_001_010_025_052_090.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p227_002-p225_001_010_025_052_090.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p227_002-p225_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p227_002-p225_002.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p227_020-p225_001_010_027_030_038.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p227_020-p225_001_010_027_030_038.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p227_020-p225_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p227_020-p225_020.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p232_002-p228_001_010_025_052_090.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p232_002-p228_001_010_025_052_090.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p232_002-p228_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p232_002-p228_002.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p232_020-p228_001_010_027_030_038.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p232_020-p228_001_010_027_030_038.wav -------------------------------------------------------------------------------- /docs/wavs/s2s-m2f-p232_020-p228_020.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/s2s-m2f-p232_020-p228_020.wav -------------------------------------------------------------------------------- /docs/wavs/slt_a0541.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/slt_a0541.wav -------------------------------------------------------------------------------- /docs/wavs/u2u-f2f-slt_a0541-lnh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/u2u-f2f-slt_a0541-lnh.wav -------------------------------------------------------------------------------- /docs/wavs/u2u-f2m-clb_b0503-rms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/u2u-f2m-clb_b0503-rms.wav -------------------------------------------------------------------------------- /docs/wavs/u2u-m2f-bdl_a0283-ljm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/u2u-m2f-bdl_a0283-ljm.wav -------------------------------------------------------------------------------- /docs/wavs/u2u-m2m-rms_a0296-bdl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/FragmentVC/79cc2ce3a585a5c31904a8dd36ad49120a601d8c/docs/wavs/u2u-m2m-rms_a0296-bdl.wav -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FragmentVC 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /models/convolutional_transformer.py: -------------------------------------------------------------------------------- 1 | """Convolutional transsformer""" 2 | 3 | from typing import Optional, Tuple 4 | 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from torch.nn import Module, Dropout, LayerNorm, Conv1d, MultiheadAttention 8 | 9 | 10 | class Smoother(Module): 11 | """Convolutional Transformer Encoder Layer""" 12 | 13 | def __init__(self, d_model: int, nhead: int, d_hid: int, dropout=0.1): 14 | super(Smoother, self).__init__() 15 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 16 | 17 | self.conv1 = Conv1d(d_model, d_hid, 9, padding=4) 18 | self.conv2 = Conv1d(d_hid, d_model, 1, padding=0) 19 | 20 | self.norm1 = LayerNorm(d_model) 21 | self.norm2 = LayerNorm(d_model) 22 | self.dropout1 = Dropout(dropout) 23 | self.dropout2 = Dropout(dropout) 24 | 25 | def forward( 26 | self, 27 | src: Tensor, 28 | src_mask: Optional[Tensor] = None, 29 | src_key_padding_mask: Optional[Tensor] = None, 30 | ) -> Tensor: 31 | # multi-head self attention 32 | src2 = self.self_attn( 33 | src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 34 | )[0] 35 | 36 | # add & norm 37 | src = src + self.dropout1(src2) 38 | src = self.norm1(src) 39 | 40 | # conv1d 41 | src2 = src.transpose(0, 1).transpose(1, 2) 42 | src2 = self.conv2(F.relu(self.conv1(src2))) 43 | src2 = src2.transpose(1, 2).transpose(0, 1) 44 | 45 | # add & norm 46 | src = src + self.dropout2(src2) 47 | src = self.norm2(src) 48 | return src 49 | 50 | 51 | class Extractor(Module): 52 | """Convolutional Transformer Decoder Layer""" 53 | 54 | def __init__( 55 | self, d_model: int, nhead: int, d_hid: int, dropout=0.1, no_residual=False, 56 | ): 57 | super(Extractor, self).__init__() 58 | 59 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 60 | self.cross_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 61 | 62 | self.conv1 = Conv1d(d_model, d_hid, 9, padding=4) 63 | self.conv2 = Conv1d(d_hid, d_model, 1, padding=0) 64 | 65 | self.norm1 = LayerNorm(d_model) 66 | self.norm2 = LayerNorm(d_model) 67 | self.norm3 = LayerNorm(d_model) 68 | self.dropout1 = Dropout(dropout) 69 | self.dropout2 = Dropout(dropout) 70 | self.dropout3 = Dropout(dropout) 71 | 72 | self.no_residual = no_residual 73 | 74 | def forward( 75 | self, 76 | tgt: Tensor, 77 | memory: Tensor, 78 | tgt_mask: Optional[Tensor] = None, 79 | memory_mask: Optional[Tensor] = None, 80 | tgt_key_padding_mask: Optional[Tensor] = None, 81 | memory_key_padding_mask: Optional[Tensor] = None, 82 | ) -> Tuple[Tensor, Optional[Tensor]]: 83 | # multi-head self attention 84 | tgt2 = self.self_attn( 85 | tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 86 | )[0] 87 | 88 | # add & norm 89 | tgt = tgt + self.dropout1(tgt2) 90 | tgt = self.norm1(tgt) 91 | 92 | # multi-head cross attention 93 | tgt2, attn = self.cross_attn( 94 | tgt, 95 | memory, 96 | memory, 97 | attn_mask=memory_mask, 98 | key_padding_mask=memory_key_padding_mask, 99 | ) 100 | 101 | # add & norm 102 | if self.no_residual: 103 | tgt = self.dropout2(tgt2) 104 | else: 105 | tgt = tgt + self.dropout2(tgt2) 106 | tgt = self.norm2(tgt) 107 | 108 | # conv1d 109 | tgt2 = tgt.transpose(0, 1).transpose(1, 2) 110 | tgt2 = self.conv2(F.relu(self.conv1(tgt2))) 111 | tgt2 = tgt2.transpose(1, 2).transpose(0, 1) 112 | 113 | # add & norm 114 | tgt = tgt + self.dropout3(tgt2) 115 | tgt = self.norm3(tgt) 116 | 117 | return tgt, attn 118 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | """FragmentVC model architecture.""" 2 | 3 | from typing import Tuple, List, Optional 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from .convolutional_transformer import Smoother, Extractor 10 | 11 | 12 | class FragmentVC(nn.Module): 13 | """ 14 | FragmentVC uses Wav2Vec feature of the source speaker to query and attend 15 | on mel spectrogram of the target speaker. 16 | """ 17 | 18 | def __init__(self, d_model=512): 19 | super().__init__() 20 | 21 | self.unet = UnetBlock(d_model) 22 | 23 | self.smoothers = nn.TransformerEncoder(Smoother(d_model, 2, 1024), num_layers=3) 24 | 25 | self.mel_linear = nn.Linear(d_model, 80) 26 | 27 | self.post_net = nn.Sequential( 28 | nn.Conv1d(80, 512, kernel_size=5, padding=2), 29 | nn.BatchNorm1d(512), 30 | nn.Tanh(), 31 | nn.Dropout(0.5), 32 | nn.Conv1d(512, 512, kernel_size=5, padding=2), 33 | nn.BatchNorm1d(512), 34 | nn.Tanh(), 35 | nn.Dropout(0.5), 36 | nn.Conv1d(512, 512, kernel_size=5, padding=2), 37 | nn.BatchNorm1d(512), 38 | nn.Tanh(), 39 | nn.Dropout(0.5), 40 | nn.Conv1d(512, 512, kernel_size=5, padding=2), 41 | nn.BatchNorm1d(512), 42 | nn.Tanh(), 43 | nn.Dropout(0.5), 44 | nn.Conv1d(512, 80, kernel_size=5, padding=2), 45 | nn.BatchNorm1d(80), 46 | nn.Dropout(0.5), 47 | ) 48 | 49 | def forward( 50 | self, 51 | srcs: Tensor, 52 | refs: Tensor, 53 | src_masks: Optional[Tensor] = None, 54 | ref_masks: Optional[Tensor] = None, 55 | ) -> Tuple[Tensor, List[Optional[Tensor]]]: 56 | """Forward function. 57 | 58 | Args: 59 | srcs: (batch, src_len, 768) 60 | src_masks: (batch, src_len) 61 | refs: (batch, 80, ref_len) 62 | ref_masks: (batch, ref_len) 63 | """ 64 | 65 | # out: (src_len, batch, d_model) 66 | out, attns = self.unet(srcs, refs, src_masks=src_masks, ref_masks=ref_masks) 67 | 68 | # out: (src_len, batch, d_model) 69 | out = self.smoothers(out, src_key_padding_mask=src_masks) 70 | 71 | # out: (src_len, batch, 80) 72 | out = self.mel_linear(out) 73 | 74 | # out: (batch, 80, src_len) 75 | out = out.transpose(1, 0).transpose(2, 1) 76 | refined = self.post_net(out) 77 | out = out + refined 78 | 79 | # out: (batch, 80, src_len) 80 | return out, attns 81 | 82 | 83 | class UnetBlock(nn.Module): 84 | """Hierarchically attend on references.""" 85 | 86 | def __init__(self, d_model: int): 87 | super(UnetBlock, self).__init__() 88 | 89 | self.conv1 = nn.Conv1d(80, d_model, 3, padding=1, padding_mode="replicate") 90 | self.conv2 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate") 91 | self.conv3 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate") 92 | 93 | self.prenet = nn.Sequential( 94 | nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, d_model), 95 | ) 96 | 97 | self.extractor1 = Extractor(d_model, 2, 1024, no_residual=True) 98 | self.extractor2 = Extractor(d_model, 2, 1024) 99 | self.extractor3 = Extractor(d_model, 2, 1024) 100 | 101 | def forward( 102 | self, 103 | srcs: Tensor, 104 | refs: Tensor, 105 | src_masks: Optional[Tensor] = None, 106 | ref_masks: Optional[Tensor] = None, 107 | ) -> Tuple[Tensor, List[Optional[Tensor]]]: 108 | """Forward function. 109 | 110 | Args: 111 | srcs: (batch, src_len, 768) 112 | src_masks: (batch, src_len) 113 | refs: (batch, 80, ref_len) 114 | ref_masks: (batch, ref_len) 115 | """ 116 | 117 | # tgt: (batch, tgt_len, d_model) 118 | tgt = self.prenet(srcs) 119 | # tgt: (tgt_len, batch, d_model) 120 | tgt = tgt.transpose(0, 1) 121 | 122 | # ref*: (batch, d_model, mel_len) 123 | ref1 = self.conv1(refs) 124 | ref2 = self.conv2(F.relu(ref1)) 125 | ref3 = self.conv3(F.relu(ref2)) 126 | 127 | # out*: (tgt_len, batch, d_model) 128 | out, attn1 = self.extractor1( 129 | tgt, 130 | ref3.transpose(1, 2).transpose(0, 1), 131 | tgt_key_padding_mask=src_masks, 132 | memory_key_padding_mask=ref_masks, 133 | ) 134 | out, attn2 = self.extractor2( 135 | out, 136 | ref2.transpose(1, 2).transpose(0, 1), 137 | tgt_key_padding_mask=src_masks, 138 | memory_key_padding_mask=ref_masks, 139 | ) 140 | out, attn3 = self.extractor3( 141 | out, 142 | ref1.transpose(1, 2).transpose(0, 1), 143 | tgt_key_padding_mask=src_masks, 144 | memory_key_padding_mask=ref_masks, 145 | ) 146 | 147 | # out: (tgt_len, batch, d_model) 148 | return out, [attn1, attn2, attn3] 149 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """Useful utilities.""" 2 | 3 | import math 4 | 5 | import torch 6 | from torch.optim import Optimizer 7 | from torch.optim.lr_scheduler import LambdaLR 8 | 9 | from fairseq.models.wav2vec import Wav2Vec2Model 10 | 11 | 12 | def load_pretrained_wav2vec(ckpt_path): 13 | """Load pretrained Wav2Vec model.""" 14 | ckpt = torch.load(ckpt_path) 15 | model = Wav2Vec2Model.build_model(ckpt["args"], task=None) 16 | model.load_state_dict(ckpt["model"]) 17 | model.remove_pretraining_modules() 18 | model.eval() 19 | return model 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, 24 | num_warmup_steps: int, 25 | num_training_steps: int, 26 | num_cycles: float = 0.5, 27 | last_epoch: int = -1, 28 | ): 29 | """ 30 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 31 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 32 | initial lr set in the optimizer. 33 | 34 | Args: 35 | optimizer (:class:`~torch.optim.Optimizer`): 36 | The optimizer for which to schedule the learning rate. 37 | num_warmup_steps (:obj:`int`): 38 | The number of steps for the warmup phase. 39 | num_training_steps (:obj:`int`): 40 | The total number of training steps. 41 | num_cycles (:obj:`float`, `optional`, defaults to 0.5): 42 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 43 | following a half-cosine). 44 | last_epoch (:obj:`int`, `optional`, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | 47 | Return: 48 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 49 | """ 50 | 51 | def lr_lambda(current_step): 52 | if current_step < num_warmup_steps: 53 | return float(current_step) / float(max(1, num_warmup_steps)) 54 | progress = float(current_step - num_warmup_steps) / float( 55 | max(1, num_training_steps - num_warmup_steps) 56 | ) 57 | return max( 58 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) 59 | ) 60 | 61 | return LambdaLR(optimizer, lr_lambda, last_epoch) 62 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Precompute Wav2Vec features.""" 3 | 4 | import os 5 | import json 6 | from pathlib import Path 7 | from tempfile import mkstemp 8 | from multiprocessing import cpu_count 9 | 10 | import tqdm 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from jsonargparse import ArgumentParser, ActionConfigFile 14 | 15 | from models import load_pretrained_wav2vec 16 | from data import PreprocessDataset 17 | 18 | 19 | def parse_args(): 20 | """Parse command-line arguments.""" 21 | parser = ArgumentParser() 22 | parser.add_argument("data_dirs", type=str, nargs="+") 23 | parser.add_argument("wav2vec_path", type=str) 24 | parser.add_argument("out_dir", type=str) 25 | parser.add_argument("--trim_method", choices=["librosa", "vad"], default="vad") 26 | parser.add_argument("--n_workers", type=int, default=cpu_count()) 27 | 28 | parser.add_argument("--sample_rate", type=int, default=16000) 29 | parser.add_argument("--preemph", type=float, default=0.97) 30 | parser.add_argument("--hop_len", type=int, default=326) 31 | parser.add_argument("--win_len", type=int, default=1304) 32 | parser.add_argument("--n_fft", type=int, default=1304) 33 | parser.add_argument("--n_mels", type=int, default=80) 34 | parser.add_argument("--f_min", type=int, default=80) 35 | parser.add_argument("--audio_config", action=ActionConfigFile) 36 | 37 | return vars(parser.parse_args()) 38 | 39 | 40 | def main( 41 | data_dirs, 42 | wav2vec_path, 43 | out_dir, 44 | trim_method, 45 | n_workers, 46 | sample_rate, 47 | preemph, 48 | hop_len, 49 | win_len, 50 | n_fft, 51 | n_mels, 52 | f_min, 53 | **kwargs, 54 | ): 55 | """Main function.""" 56 | 57 | out_dir_path = Path(out_dir) 58 | 59 | if out_dir_path.exists(): 60 | assert out_dir_path.is_dir() 61 | else: 62 | out_dir_path.mkdir(parents=True) 63 | 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | 66 | dataset = PreprocessDataset( 67 | data_dirs, 68 | trim_method, 69 | sample_rate, 70 | preemph, 71 | hop_len, 72 | win_len, 73 | n_fft, 74 | n_mels, 75 | f_min, 76 | ) 77 | dataloader = DataLoader( 78 | dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=n_workers 79 | ) 80 | 81 | wav2vec = load_pretrained_wav2vec(wav2vec_path).to(device) 82 | 83 | speaker_infos = {} 84 | 85 | pbar = tqdm.tqdm(total=len(dataset), ncols=0) 86 | 87 | for speaker_name, audio_path, wav, mel in dataloader: 88 | if wav.size(-1) < 10: 89 | continue 90 | 91 | wav = wav.to(device) 92 | speaker_name = speaker_name[0] 93 | audio_path = audio_path[0] 94 | 95 | with torch.no_grad(): 96 | feat = wav2vec.extract_features(wav, None)[0] 97 | feat = feat.detach().cpu().squeeze(0) 98 | mel = mel.squeeze(0) 99 | 100 | fd, temp_file = mkstemp(suffix=".tar", prefix="utterance-", dir=out_dir_path) 101 | torch.save({"feat": feat, "mel": mel}, temp_file) 102 | os.close(fd) 103 | 104 | if speaker_name not in speaker_infos.keys(): 105 | speaker_infos[speaker_name] = [] 106 | 107 | speaker_infos[speaker_name].append( 108 | { 109 | "feature_path": Path(temp_file).name, 110 | "audio_path": audio_path, 111 | "feat_len": len(feat), 112 | "mel_len": len(mel), 113 | } 114 | ) 115 | 116 | pbar.update(dataloader.batch_size) 117 | 118 | with open(out_dir_path / "metadata.json", "w") as f: 119 | json.dump(speaker_infos, f, indent=2) 120 | 121 | 122 | if __name__ == "__main__": 123 | main(**parse_args()) 124 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e git://github.com/pytorch/fairseq.git@1a709b2a401ac8bd6d805c8a6a5f4d7f03b923ff#egg=fairseq 2 | editdistance==0.5.3 3 | librosa==0.8.0 4 | jsonargparse==2.32.2 5 | matplotlib==3.3.2 6 | PyYAML==5.3.1 7 | sox==1.4.1 8 | tensorboard==2.3.0 9 | torch==1.6.0 10 | tqdm==4.51.0 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Train FragmentVC model.""" 3 | 4 | import argparse 5 | import datetime 6 | import random 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.optim import AdamW 12 | from torch.utils.data import DataLoader, random_split 13 | from torch.utils.tensorboard import SummaryWriter 14 | from tqdm import tqdm 15 | 16 | from data import IntraSpeakerDataset, collate_batch 17 | from models import FragmentVC, get_cosine_schedule_with_warmup 18 | 19 | 20 | def parse_args(): 21 | """Parse command-line arguments.""" 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("data_dir", type=str) 24 | parser.add_argument("--save_dir", type=str, default=".") 25 | parser.add_argument("--total_steps", type=int, default=250000) 26 | parser.add_argument("--warmup_steps", type=int, default=500) 27 | parser.add_argument("--valid_steps", type=int, default=1000) 28 | parser.add_argument("--log_steps", type=int, default=100) 29 | parser.add_argument("--save_steps", type=int, default=10000) 30 | parser.add_argument("--milestones", type=int, nargs=2, default=[50000, 150000]) 31 | parser.add_argument("--exclusive_rate", type=float, default=1.0) 32 | parser.add_argument("--n_samples", type=int, default=10) 33 | parser.add_argument("--accu_steps", type=int, default=2) 34 | parser.add_argument("--batch_size", type=int, default=8) 35 | parser.add_argument("--n_workers", type=int, default=8) 36 | parser.add_argument("--preload", action="store_true") 37 | parser.add_argument("--comment", type=str) 38 | return vars(parser.parse_args()) 39 | 40 | 41 | def model_fn(batch, model, criterion, self_exclude, ref_included, device): 42 | """Forward a batch through model.""" 43 | 44 | srcs, src_masks, refs, ref_masks, tgts, tgt_masks, overlap_lens = batch 45 | 46 | srcs = srcs.to(device) 47 | src_masks = src_masks.to(device) 48 | refs = refs.to(device) 49 | ref_masks = ref_masks.to(device) 50 | tgts = tgts.to(device) 51 | tgt_masks = tgt_masks.to(device) 52 | 53 | if ref_included: 54 | if random.random() >= self_exclude: 55 | refs = torch.cat((refs, tgts), dim=2) 56 | ref_masks = torch.cat((ref_masks, tgt_masks), dim=1) 57 | else: 58 | refs = tgts 59 | ref_masks = tgt_masks 60 | 61 | outs, _ = model(srcs, refs, src_masks=src_masks, ref_masks=ref_masks) 62 | 63 | losses = [] 64 | for out, tgt, overlap_len in zip(outs.unbind(), tgts.unbind(), overlap_lens): 65 | loss = criterion(out[:, :overlap_len], tgt[:, :overlap_len]) 66 | losses.append(loss) 67 | 68 | return sum(losses) / len(losses) 69 | 70 | 71 | def valid(dataloader, model, criterion, device): 72 | """Validate on validation set.""" 73 | 74 | model.eval() 75 | running_loss = 0.0 76 | pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc="Valid", unit=" uttr") 77 | 78 | for i, batch in enumerate(dataloader): 79 | with torch.no_grad(): 80 | loss = model_fn(batch, model, criterion, 1.0, True, device) 81 | running_loss += loss.item() 82 | 83 | pbar.update(dataloader.batch_size) 84 | pbar.set_postfix(loss=f"{running_loss / (i+1):.2f}") 85 | 86 | pbar.close() 87 | model.train() 88 | 89 | return running_loss / len(dataloader) 90 | 91 | 92 | def main( 93 | data_dir, 94 | save_dir, 95 | total_steps, 96 | warmup_steps, 97 | valid_steps, 98 | log_steps, 99 | save_steps, 100 | milestones, 101 | exclusive_rate, 102 | n_samples, 103 | accu_steps, 104 | batch_size, 105 | n_workers, 106 | preload, 107 | comment, 108 | ): 109 | """Main function.""" 110 | 111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 112 | 113 | metadata_path = Path(data_dir) / "metadata.json" 114 | 115 | dataset = IntraSpeakerDataset(data_dir, metadata_path, n_samples, preload) 116 | lengths = [trainlen := int(0.9 * len(dataset)), len(dataset) - trainlen] 117 | trainset, validset = random_split(dataset, lengths) 118 | train_loader = DataLoader( 119 | trainset, 120 | batch_size=batch_size, 121 | shuffle=True, 122 | drop_last=True, 123 | num_workers=n_workers, 124 | pin_memory=True, 125 | collate_fn=collate_batch, 126 | ) 127 | valid_loader = DataLoader( 128 | validset, 129 | batch_size=batch_size * accu_steps, 130 | num_workers=n_workers, 131 | drop_last=True, 132 | pin_memory=True, 133 | collate_fn=collate_batch, 134 | ) 135 | train_iterator = iter(train_loader) 136 | 137 | if comment is not None: 138 | log_dir = "logs/" 139 | log_dir += datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 140 | log_dir += "_" + comment 141 | writer = SummaryWriter(log_dir) 142 | 143 | save_dir_path = Path(save_dir) 144 | save_dir_path.mkdir(parents=True, exist_ok=True) 145 | 146 | model = FragmentVC().to(device) 147 | model = torch.jit.script(model) 148 | 149 | criterion = nn.L1Loss() 150 | optimizer = AdamW(model.parameters(), lr=1e-4) 151 | scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) 152 | 153 | best_loss = float("inf") 154 | best_state_dict = None 155 | 156 | self_exclude = 0.0 157 | ref_included = False 158 | 159 | pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") 160 | 161 | for step in range(total_steps): 162 | batch_loss = 0.0 163 | 164 | for _ in range(accu_steps): 165 | try: 166 | batch = next(train_iterator) 167 | except StopIteration: 168 | train_iterator = iter(train_loader) 169 | batch = next(train_iterator) 170 | 171 | loss = model_fn(batch, model, criterion, self_exclude, ref_included, device) 172 | loss = loss / accu_steps 173 | batch_loss += loss.item() 174 | loss.backward() 175 | 176 | optimizer.step() 177 | scheduler.step() 178 | optimizer.zero_grad() 179 | 180 | pbar.update() 181 | pbar.set_postfix(loss=f"{batch_loss:.2f}", excl=self_exclude, step=step + 1) 182 | 183 | if step % log_steps == 0 and comment is not None: 184 | writer.add_scalar("Loss/train", batch_loss, step) 185 | writer.add_scalar("Self-exclusive Rate", self_exclude, step) 186 | 187 | if (step + 1) % valid_steps == 0: 188 | pbar.close() 189 | 190 | valid_loss = valid(valid_loader, model, criterion, device) 191 | 192 | if comment is not None: 193 | writer.add_scalar("Loss/valid", valid_loss, step + 1) 194 | 195 | if valid_loss < best_loss: 196 | best_loss = valid_loss 197 | best_state_dict = model.state_dict() 198 | 199 | pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") 200 | 201 | if (step + 1) % save_steps == 0 and best_state_dict is not None: 202 | loss_str = f"{best_loss:.4f}".replace(".", "dot") 203 | best_ckpt_name = f"retriever-best-loss{loss_str}.pt" 204 | 205 | loss_str = f"{valid_loss:.4f}".replace(".", "dot") 206 | curr_ckpt_name = f"retriever-step{step+1}-loss{loss_str}.pt" 207 | 208 | current_state_dict = model.state_dict() 209 | model.cpu() 210 | 211 | model.load_state_dict(best_state_dict) 212 | model.save(str(save_dir_path / best_ckpt_name)) 213 | 214 | model.load_state_dict(current_state_dict) 215 | model.save(str(save_dir_path / curr_ckpt_name)) 216 | 217 | model.to(device) 218 | pbar.write(f"Step {step + 1}, best model saved. (loss={best_loss:.4f})") 219 | 220 | if (step + 1) >= milestones[1]: 221 | self_exclude = exclusive_rate 222 | 223 | elif (step + 1) == milestones[0]: 224 | ref_included = True 225 | optimizer = AdamW( 226 | [ 227 | {"params": model.unet.parameters(), "lr": 1e-6}, 228 | {"params": model.smoothers.parameters()}, 229 | {"params": model.mel_linear.parameters()}, 230 | {"params": model.post_net.parameters()}, 231 | ], 232 | lr=1e-4, 233 | ) 234 | scheduler = get_cosine_schedule_with_warmup( 235 | optimizer, warmup_steps, total_steps - milestones[0] 236 | ) 237 | pbar.write("Optimizer and scheduler restarted.") 238 | 239 | elif (step + 1) > milestones[0]: 240 | self_exclude = (step + 1 - milestones[0]) / (milestones[1] - milestones[0]) 241 | self_exclude *= exclusive_rate 242 | 243 | pbar.close() 244 | 245 | 246 | if __name__ == "__main__": 247 | main(**parse_args()) 248 | --------------------------------------------------------------------------------