├── .gitignore ├── LICENSE ├── README.md ├── config └── config.json ├── conversion.py ├── data_loader.py ├── main.py ├── make_metadata.py ├── models ├── ecapa_tdnn.py ├── hifi_gan.py └── model_vc.py ├── samples ├── zs_f2f_conversion.wav ├── zs_f2f_source.wav ├── zs_f2f_target.wav ├── zs_f2m_conversion.wav ├── zs_f2m_source.wav ├── zs_f2m_target.wav ├── zs_m2f_conversion.wav ├── zs_m2f_source.wav ├── zs_m2f_target.wav ├── zs_m2m_conversion.wav ├── zs_m2m_source.wav └── zs_m2m_target.wav ├── solver_encoder.py └── utils ├── function_f.py ├── mel.py └── perturbation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 cjchun3616 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perturbation AUTOVC: Voice Conversion from Perturbation and Autoencoder Loss 2 | 3 | This repository provides a PyTorch implementation of Perturbation AUTOVC. 4 | 5 | ### Audio Samples 6 | 7 | The audio samples for Perturbation AUTOVC can ben found in [link](https://www.amclab.kr/demo/perturbation_autovc/). 8 | 9 | ### Dependencies 10 | - Python3 11 | - Numpy 12 | - PyTorch 13 | - librosa 14 | - tqdm 15 | - parselmouth `pip install praat-parselmouth` 16 | - torchaudio 17 | - omegaconf 18 | - HiFi-GAN vocoder 19 | - ECAPA-TDNN 20 | 21 | 22 | ## Pre-trained models 23 | 24 | You can also use pretrained models we provide. 25 | 26 | [Download pretrained models](https://drive.google.com/drive/folders/1N3Uo4nM8vtWBqNmoYsqTlRxayM-3owbU?usp=sharing) 27 | 28 | Place pre-trained models at `./checkpoint` 29 | 30 | 31 | ### Speaker Encoder 32 | 33 | We use the ECAPA-TDNN as a speaker encoder. 34 | 35 | For more information, please refer to [ECAPA-TDNN](https://github.com/taoruijie/ecapa-tdnn) 36 | 37 | ### Vocoder 38 | 39 | We use the HiFi-GAN as a vocoder. 40 | 41 | Download pretrained HiFi-GAN config and checkpoint from [HiFi-GAN](http://github.com/jik876/hifi-gan) `pretrained/UNIVERSAL_V1` 42 | 43 | Place checkpoint at `./checkpoint` and config at `./configs` 44 | 45 | 46 | ## Datasets 47 | 48 | Datasets used when training are: 49 | - VCTK: 50 | - CSTR VCTK Corpus: English Multi speaker Corpus for CSTR Voice Coloning Toolkit 51 | - https://datashare.ed.ac.uk/handle/10283/2651 52 | 53 | Place datasets at `datasets/wavs/` 54 | 55 | 56 | ## Preprocess dataset. 57 | 58 | If you prefer `praat-parselmouth`, run `python make_metadata.py` 59 | 60 | ```python 61 | parser = argparse.ArgumentParser() 62 | 63 | parser.add_argument('--wav_dir', type=str, default='./datasets/wavs', help='path of wav directory') 64 | parser.add_argument('--real_dir', type=str, default='./datasets/real', help='save path of original mel-spectrogram') 65 | parser.add_argument('--perturb_dir', type=str, default='./datasets/perturb', help='save path of perturbation mel-spectrogram') 66 | parser.add_argument('--save_path', type=str, default='./datasets/', help='save path of metadata') 67 | ``` 68 | 69 | If the data is processed in advance, please pause this line. 70 | ```python 71 | make_data(wav_dir, real_dir, perturb_dir) 72 | ``` 73 | 74 | When this is done, `metadata.pkl` is created at `--save_path`. 75 | 76 | 77 | ## Training 78 | 79 | Prefer `metadata.pkl`, including the speaker embedding, output of ECAPA-TDNN. 80 | 81 | If you prefer `metadata.pkl`, run `python main.py` 82 | 83 | ```python 84 | parser = argparse.ArgumentParser() 85 | 86 | # Model configuration. 87 | parser.add_argument('--lambda_cd', type=float, default=1, help='weight for hidden code loss') 88 | parser.add_argument('--dim_emb', type=int, default=192, help='speaker embedding dimensions') 89 | parser.add_argument('--dim_pre', type=int, default=512) 90 | parser.add_argument('--freq', type=int, default=1, help='downsampling factor') 91 | 92 | # Save configuration. 93 | parser.add_argument('--resume', type=str, default=None, help='path to load model') 94 | parser.add_argument('--save_dir', type=str, default='./model/test', help='path to save model') 95 | parser.add_argument('--pt_name', type=str, default='test_model', help='model name') 96 | 97 | # Data configuration. 98 | parser.add_argument('--data_dir', type=str, default='./datasets/metadata.pkl', help='path to metatdata') 99 | 100 | # Training configuration. 101 | parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size') 102 | parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations') 103 | parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length') 104 | parser.add_argument('--log_step', type=int, default=10000) 105 | ``` 106 | 107 | Converges when the reconstruction loss is around 0.01. 108 | 109 | ## Inference 110 | 111 | Run the `python conversion.py --source_path={} --target_path={}` 112 | 113 | You may want to edit `conversion.py` for custom manipulation. 114 | 115 | ```python 116 | parser = argparse.ArgumentParser() 117 | 118 | # Conversion configurations. 119 | parser.add_argument('--source_path', type=str, required=True, help='path to source audio file, sr=22050') 120 | parser.add_argument('--target_path', type=str, required=True, help='path to target audio file, sr=16000') 121 | parser.add_argument('--save_path', type=str, default='./result', help='path to save conversion audio') 122 | 123 | # Model configurations. 124 | parser.add_argument('--ckpt_path', type=str, default='./checkpoint/model.pt', help='path to model checkpoint') 125 | parser.add_argument('--dim_emb', type=int, default=192, help='speaker embedding dimensions.') 126 | parser.add_argument('--dim_pre', type=int, default=512) 127 | parser.add_argument('--freq', type=int, default=1, help='downsampling factor') 128 | ``` 129 | 130 | ## Acknowledgment 131 | This work was partly supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2022-0-00963, Localization Technology Development on Spoken Language Synthesis and Translation of OTT Media Contents). 132 | 133 | 134 | ## License 135 | This code is MIT-licensed. The license applies to our pre-trained models as well. 136 | -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import librosa 5 | import argparse 6 | import numpy as np 7 | import soundfile as sf 8 | 9 | from math import ceil 10 | from models.model_vc import Generator 11 | from omegaconf import OmegaConf 12 | from models.hifi_gan import Generator as hifigan_vocoder 13 | 14 | from utils.mel import mel_spectrogram 15 | from utils.perturbation import load_wav 16 | from make_metadata import Speaker_Encoder 17 | 18 | class Conversion(object): 19 | def __init__(self, config): 20 | # Inference configurations. 21 | self.source_path = config.source_path 22 | self.target_path = config.target_path 23 | self.save_path = config.save_path 24 | 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | # Voice Conversion model configuration. 28 | self.ckpt_path = config.ckpt_path 29 | self.dim_emb = config.dim_emb 30 | self.dim_pre = config.dim_pre 31 | self.freq = config.freq 32 | self.load_model() 33 | 34 | # Vocoder configurations. 35 | self.hifigan_config = './config/config.json' 36 | self.hifigan_ckpt = './checkpoint/g_02500000' 37 | self.vocoder() 38 | 39 | # Speaker Encoder configuration. 40 | self.speaker_encoder = Speaker_Encoder() 41 | 42 | def load_model(self): 43 | '''Load voice conversion model.''' 44 | self.G = Generator(self.dim_emb, self.dim_pre, self.freq) 45 | g_checkpoint = torch.load(self.ckpt_path) 46 | self.G.load_state_dict(g_checkpoint['model_state_dict']) 47 | self.G.to(self.device) 48 | self.G.eval() 49 | 50 | def vocoder(self): 51 | '''Load vocoder.''' 52 | hifigan_config = OmegaConf.load(self.hifigan_config) 53 | self.vocoder = hifigan_vocoder(hifigan_config) 54 | 55 | state_dict_g = torch.load(self.hifigan_ckpt) 56 | self.vocoder.to(self.device) 57 | self.vocoder.load_state_dict(state_dict_g['generator']) 58 | self.vocoder.eval() 59 | 60 | def pad_seq(self, x, base=32): 61 | len_out = int(base * ceil(float(x.shape[-1])/base)) 62 | len_pad = len_out - x.shape[-1] 63 | assert len_pad >= 0 64 | x_org = np.pad(x, ((0,0),(0,len_pad)), 'constant') 65 | uttr_org = torch.from_numpy(x_org[np.newaxis, :, :]).to(self.device) 66 | return uttr_org, len_pad 67 | 68 | def extract_energy(self, uttr_org): 69 | energy = torch.mean(uttr_org, dim=1, keepdim=True) 70 | return energy 71 | 72 | def conversion(self): 73 | '''Conversion process''' 74 | # Preprocess input data. 75 | source_wav, _ = load_wav(self.source_path, fs=22050) 76 | source = mel_spectrogram(source_wav) # Source mel-spectrogram. 77 | 78 | uttr_org, len_pad = self.pad_seq(source.squeeze()) # Energy. 79 | emb_trg = self.speaker_encoder.extract_speaker(self.target_path) # Target speaker embedding. 80 | emb_trg = torch.from_numpy(emb_trg).to(self.device) 81 | 82 | energy = self.extract_energy(uttr_org) 83 | emb_trg = emb_trg.unsqueeze(-1).expand(-1, -1, energy.shape[-1]) 84 | model_input = torch.cat((emb_trg, energy), dim=1) 85 | 86 | # Conversion. 87 | with torch.no_grad(): 88 | _, x_identic_psnt, a = self.G(uttr_org, model_input) 89 | 90 | if len_pad == 0: 91 | uttr_trg = x_identic_psnt[:, :, :] 92 | else: 93 | uttr_trg = x_identic_psnt[:, :, :-len_pad] 94 | self.save_wav(uttr_trg) 95 | 96 | def save_wav(self, uttr_trg): 97 | '''Save conversion waveform and mel-spectrogram at save path.''' 98 | # Make save directory. 99 | if not os.path.exists(self.save_path): 100 | os.makedirs(self.save_path) 101 | # Reconstruction to waveform from mel-spectrogram. 102 | self.vocoder.eval() 103 | with torch.no_grad(): 104 | conv_wav = self.vocoder(uttr_trg.squeeze().to(self.device)) 105 | conv_wav = conv_wav.squeeze().detach().cpu().numpy() 106 | sf.write(os.path.join(self.save_path, 'conversion.wav'), conv_wav, samplerate=22050) 107 | np.save(os.path.join(self.save_path, 'conversion.npy'), uttr_trg.cpu().detach().numpy().squeeze(), allow_pickle=False) 108 | print("Conversion Successful") 109 | 110 | if __name__ == '__main__': 111 | 112 | parser = argparse.ArgumentParser() 113 | 114 | # Conversion configurations. 115 | parser.add_argument('--source_path', type=str, required=True, help='path to source audio file, sr=22050') 116 | parser.add_argument('--target_path', type=str, required=True, help='path to target audio file, sr=16000') 117 | parser.add_argument('--save_path', type=str, default='./result', help='path to save conversion audio') 118 | 119 | # Model configurations. 120 | parser.add_argument('--ckpt_path', type=str, default='./checkpoint/model.pt', help='path to model checkpoint') 121 | parser.add_argument('--dim_emb', type=int, default=192, help='speaker embedding dimensions.') 122 | parser.add_argument('--dim_pre', type=int, default=512) 123 | parser.add_argument('--freq', type=int, default=1, help='downsampling factor') 124 | 125 | config = parser.parse_args() 126 | print(config) 127 | conversion = Conversion(config) 128 | conversion.conversion() 129 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from torch.utils import data 7 | from multiprocessing import Process, Manager 8 | 9 | class Utterances(data.Dataset): 10 | 11 | def __init__(self, root_dir, spk_path, len_crop): 12 | 13 | self.root_dir = root_dir 14 | self.len_crop = len_crop 15 | self.spk_path = spk_path 16 | self.step = 10 17 | 18 | data_path = os.path.join(root_dir, self.spk_path) 19 | meta = pickle.load(open(data_path, "rb")) 20 | # meta = meta[:-20] # unseen speaker 21 | 22 | '''Load data using multiprocessing''' 23 | manager = Manager() 24 | meta = manager.list(meta) 25 | dataset = manager.list(len(meta)*[None]) 26 | processes = [] 27 | for i in range(0, len(meta), self.step): 28 | p = Process(target=self.load_data, 29 | args=(meta[i:i+self.step], dataset, i)) 30 | p.start() 31 | processes.append(p) 32 | for p in processes: 33 | p.join() 34 | 35 | self.train_dataset = dataset 36 | self.num_tokens = len(self.train_dataset) 37 | 38 | print('Finished loading the dataset...') 39 | 40 | 41 | def load_data(self, submeta, dataset, idx_offset): 42 | for k, sbmt in enumerate(submeta): 43 | uttrs = len(sbmt)*[None] 44 | for j, tmp in enumerate(sbmt): 45 | if j < 2: # name, speaker embedding 46 | uttrs[j] = tmp 47 | else: # Data path 48 | tmp_uttr = tmp 49 | x_real = tmp.replace('perturb', 'real') 50 | uttrs[j] = [tmp_uttr, x_real] 51 | dataset[idx_offset+k] = uttrs 52 | 53 | 54 | def __getitem__(self, index): 55 | 56 | dataset = self.train_dataset 57 | 58 | list_uttrs = dataset[index] 59 | emb_org = list_uttrs[1] 60 | 61 | # Pick random utterence with random crop. 62 | a = np.random.randint(2, len(list_uttrs)) 63 | tmp = list_uttrs[a] 64 | 65 | # Load the mel-spectrogram. 66 | uttr = np.load(tmp[0]) # mel-spectrogram with perturbation. 67 | x_real = np.load(tmp[1]) # original mel-spectrogram 68 | 69 | if uttr.shape[-1] < self.len_crop: 70 | len_pad = self.len_crop - uttr.shape[-1] 71 | uttr_x = np.pad(uttr, ((0,0), (0,0),(0,len_pad)), 'constant') 72 | real_x = np.pad(x_real, ((0,0), (0,0),(0,len_pad)), 'constant') 73 | 74 | elif uttr.shape[-1] > self.len_crop: 75 | left = np.random.randint(uttr.shape[-1]-self.len_crop) 76 | uttr_x = uttr[:, :, left:left+self.len_crop] 77 | real_x = x_real[:, :, left:left+self.len_crop] 78 | else: 79 | uttr_x = uttr 80 | real_x = x_real 81 | 82 | return torch.from_numpy(real_x).squeeze(), torch.from_numpy(uttr_x).squeeze(), emb_org 83 | 84 | def __len__(self): 85 | '''Return the number of speakers''' 86 | return self.num_tokens 87 | 88 | 89 | def get_loader(root_dir, spk_path, batch_size=16, len_crop=128, num_workers=0): 90 | '''Build and return a data loader''' 91 | dataset = Utterances(root_dir, spk_path, len_crop) 92 | worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32)) 93 | data_loader = data.DataLoader(dataset=dataset, 94 | batch_size=batch_size, 95 | shuffle=True, 96 | num_workers=num_workers, 97 | drop_last=True, 98 | worker_init_fn=worker_init_fn) 99 | return data_loader 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from solver_encoder import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | 7 | 8 | def str2bool(v): 9 | return v.lower() in ('true') 10 | 11 | def main(config): 12 | # For fast training. 13 | cudnn.benchmark = True 14 | 15 | # Data loader. 16 | vcc_loader = get_loader(config.data_dir, config.data_path, config.batch_size, config.len_crop) 17 | 18 | solver = Solver(vcc_loader, config) 19 | 20 | solver.train() 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | # Model configuration. 26 | parser.add_argument('--lambda_cd', type=float, default=1, help='weight for hidden code loss') 27 | parser.add_argument('--dim_emb', type=int, default=192, help='speaker embedding dimensions') 28 | parser.add_argument('--dim_pre', type=int, default=512) 29 | parser.add_argument('--freq', type=int, default=1, help='downsampling factor') 30 | 31 | # Save configuration. 32 | parser.add_argument('--resume', type=str, default=None, help='path to load model') 33 | parser.add_argument('--save_dir', type=str, default='./model/test', help='path to save model') 34 | parser.add_argument('--pt_name', type=str, default='test_model', help='model name') 35 | 36 | # Data configuration. 37 | parser.add_argument('--data_dir', type=str, default='./datasets', help='path to dataset') 38 | parser.add_argument('--data_path', type=str, default='metadata.pkl', help='name to metadata') 39 | 40 | # Training configuration. 41 | parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size') 42 | parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations') 43 | parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length') 44 | parser.add_argument('--log_step', type=int, default=10000) 45 | 46 | config = parser.parse_args() 47 | print(config) 48 | main(config) -------------------------------------------------------------------------------- /make_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import librosa 5 | import argparse 6 | import numpy as np 7 | from glob import glob 8 | from tqdm.auto import tqdm 9 | import torch.nn.functional as F 10 | 11 | from models import ecapa_tdnn 12 | from utils.perturbation import make_data 13 | from utils.perturbation import load_wav 14 | 15 | class Speaker_Encoder(object): 16 | def __init__(self, config=None): 17 | if config is not None: # For make metadata. 18 | self.wav_dir = config.wav_dir # Audio file directory. 19 | self.speaker = sorted(os.listdir(self.wav_dir)) # Speaker name list. 20 | self.perturb_dir = config.perturb_dir 21 | self.save_path = config.save_path # Metadata save path. 22 | self.speakers = [] 23 | 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | '''Load speaker encoder model.''' 27 | self.model_load() 28 | 29 | def model_load(self): 30 | self.model = ecapa_tdnn.ECAPA_TDNN(1024).to(self.device) 31 | self.model.load_state_dict(torch.load('./checkpoint/tdnn.pt')) 32 | 33 | def extract_speaker(self, filepath): 34 | '''Extract speaker information with ecapa-tdnn model.''' 35 | audio, _ = load_wav(filepath, fs=16000) # Load wav 36 | with torch.no_grad(): 37 | self.model.eval() 38 | feature1, _ = self.model(audio.float().to(self.device), False) 39 | feature1 = F.normalize(feature1, p=2, dim=1) 40 | emb = feature1.cpu().detach().numpy() 41 | return emb 42 | 43 | def make_metadata(self): 44 | '''Extract speaker embedding of each speaker''' 45 | for spk_name in tqdm(self.speaker, total=len(self.speaker)): 46 | test_dir = glob(os.path.join(self.wav_dir, spk_name, '*.wav')) # wav file list. 47 | data_list = glob(os.path.join(self.perturb_dir, spk_name, '*npy')) # mel-spectrogram with perturbation list. 48 | Data1 = [] 49 | # Extract speaker embedding per utterances of each speaker. 50 | for i in test_dir: 51 | utterances = [] 52 | emb = self.extract_speaker(i) 53 | Data1.extend(emb.reshape(1, 192)) 54 | Data1 = np.array(Data1) 55 | spk_emb = np.mean(Data1, axis=0) # Average of each utterance embedding. 56 | 57 | utterances.append(spk_name) 58 | utterances.append(spk_emb) # Speaker embedding with 192 dimensions. 59 | utterances.extend(data_list) # Data path. 60 | self.speakers.append(utterances) 61 | 62 | '''Save metadata at save path.''' 63 | with open(os.path.join(self.save_path, 'metadata.pkl'), 'wb') as handle: 64 | pickle.dump(self.speakers, handle) 65 | 66 | if __name__ == '__main__': 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument('--wav_dir', type=str, default='./datasets/wavs', help='path to wav directory') 71 | parser.add_argument('--real_dir', type=str, default='./datasets/real', help='path to save mel-spectrogram') 72 | parser.add_argument('--perturb_dir', type=str, default='./datasets/perturb', help='path to save perturbation mel-spectrogram') 73 | parser.add_argument('--save_path', type=str, default='./datasets/', help='path to save metadata') 74 | 75 | config = parser.parse_args() 76 | print(config) 77 | make_data(config.wav_dir, config.real_dir, config.perturb_dir) # If the data is processed in advance, please comment. 78 | speaker_encoder = Speaker_Encoder(config) 79 | speaker_encoder.make_metadata() -------------------------------------------------------------------------------- /models/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torchaudio 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # We use a speaker recognition network, as a speaker encoder. 8 | # This is an implementation of the ecapa-tdnn model. 9 | # For more information about the model, visit "https://github.com/TaoRuijie/ECAPA-TDNN" 10 | # Paper: "Neural Analysis and Synthesis: Reconstructing Speech from Self-Supervised Representations" 11 | 12 | class SEModule(nn.Module): 13 | def __init__(self, channels, bottleneck=128): 14 | super(SEModule, self).__init__() 15 | self.se = nn.Sequential( 16 | nn.AdaptiveAvgPool1d(1), 17 | nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), 18 | nn.ReLU(), 19 | nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), 20 | nn.Sigmoid(), 21 | ) 22 | 23 | def forward(self, input): 24 | x = self.se(input) 25 | return input * x 26 | 27 | class Bottle2neck(nn.Module): 28 | 29 | def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale = 8): 30 | super(Bottle2neck, self).__init__() 31 | width = int(math.floor(planes / scale)) 32 | self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1) 33 | self.bn1 = nn.BatchNorm1d(width*scale) 34 | self.nums = scale -1 35 | convs = [] 36 | bns = [] 37 | num_pad = math.floor(kernel_size/2)*dilation 38 | for i in range(self.nums): 39 | convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad)) 40 | bns.append(nn.BatchNorm1d(width)) 41 | self.convs = nn.ModuleList(convs) 42 | self.bns = nn.ModuleList(bns) 43 | self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1) 44 | self.bn3 = nn.BatchNorm1d(planes) 45 | self.relu = nn.ReLU() 46 | self.width = width 47 | self.se = SEModule(planes) 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.conv1(x) 52 | out = self.relu(out) 53 | out = self.bn1(out) 54 | 55 | spx = torch.split(out, self.width, 1) 56 | for i in range(self.nums): 57 | if i==0: 58 | sp = spx[i] 59 | else: 60 | sp = sp + spx[i] 61 | sp = self.convs[i](sp) 62 | sp = self.relu(sp) 63 | sp = self.bns[i](sp) 64 | if i==0: 65 | out = sp 66 | else: 67 | out = torch.cat((out, sp), 1) 68 | out = torch.cat((out, spx[self.nums]),1) 69 | 70 | out = self.conv3(out) 71 | out = self.relu(out) 72 | out = self.bn3(out) 73 | 74 | out = self.se(out) 75 | out += residual 76 | return out 77 | 78 | class PreEmphasis(torch.nn.Module): 79 | 80 | def __init__(self, coef: float = 0.97): 81 | super().__init__() 82 | self.coef = coef 83 | self.register_buffer( 84 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 85 | ) 86 | 87 | def forward(self, input: torch.tensor) -> torch.tensor: 88 | input = input.unsqueeze(1) 89 | input = F.pad(input, (1, 0), 'reflect') 90 | return F.conv1d(input, self.flipped_filter).squeeze(1) 91 | 92 | class FbankAug(nn.Module): 93 | 94 | def __init__(self, freq_mask_width = (0, 8), time_mask_width = (0, 10)): 95 | self.time_mask_width = time_mask_width 96 | self.freq_mask_width = freq_mask_width 97 | super().__init__() 98 | 99 | def mask_along_axis(self, x, dim): 100 | original_size = x.shape 101 | batch, fea, time = x.shape 102 | if dim == 1: 103 | D = fea 104 | width_range = self.freq_mask_width 105 | else: 106 | D = time 107 | width_range = self.time_mask_width 108 | 109 | mask_len = torch.randint(width_range[0], width_range[1], (batch, 1), device=x.device).unsqueeze(2) 110 | mask_pos = torch.randint(0, max(1, D - mask_len.max()), (batch, 1), device=x.device).unsqueeze(2) 111 | arange = torch.arange(D, device=x.device).view(1, 1, -1) 112 | mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) 113 | mask = mask.any(dim=1) 114 | 115 | if dim == 1: 116 | mask = mask.unsqueeze(2) 117 | else: 118 | mask = mask.unsqueeze(1) 119 | 120 | x = x.masked_fill_(mask, 0.0) 121 | return x.view(*original_size) 122 | 123 | def forward(self, x): 124 | x = self.mask_along_axis(x, dim=2) 125 | x = self.mask_along_axis(x, dim=1) 126 | return x 127 | 128 | class ECAPA_TDNN(nn.Module): 129 | 130 | def __init__(self, C): 131 | 132 | super(ECAPA_TDNN, self).__init__() 133 | 134 | self.torchfbank = torch.nn.Sequential( 135 | PreEmphasis(), 136 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, \ 137 | f_min = 20, f_max = 7000, window_fn=torch.hamming_window, n_mels=80), 138 | ) 139 | 140 | self.specaug = FbankAug() # Spec augmentation 141 | 142 | self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2) 143 | self.relu = nn.ReLU() 144 | self.bn1 = nn.BatchNorm1d(C) 145 | self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8) 146 | self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8) 147 | self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8) 148 | # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper. 149 | self.layer4 = nn.Conv1d(3*C, 1536, kernel_size=1) 150 | self.attention = nn.Sequential( 151 | nn.Conv1d(4608, 256, kernel_size=1), 152 | nn.ReLU(), 153 | nn.BatchNorm1d(256), 154 | nn.Tanh(), # I add this layer 155 | nn.Conv1d(256, 1536, kernel_size=1), 156 | nn.Softmax(dim=2), 157 | ) 158 | self.bn5 = nn.BatchNorm1d(3072) 159 | self.fc6 = nn.Linear(3072, 192) 160 | self.bn6 = nn.BatchNorm1d(192) 161 | self.fc7 = nn.Linear(192,5994) 162 | 163 | 164 | def forward(self, x, aug): 165 | with torch.no_grad(): 166 | x = self.torchfbank(x)+1e-6 167 | x = x.log() 168 | x = x - torch.mean(x, dim=-1, keepdim=True) 169 | if aug == True: 170 | x = self.specaug(x) 171 | 172 | x = self.conv1(x) 173 | x = self.relu(x) 174 | x = self.bn1(x) 175 | 176 | x1 = self.layer1(x) 177 | x2 = self.layer2(x+x1) 178 | x3 = self.layer3(x+x1+x2) 179 | 180 | x = self.layer4(torch.cat((x1,x2,x3),dim=1)) 181 | x = self.relu(x) 182 | 183 | t = x.size()[-1] 184 | 185 | global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1) 186 | 187 | w = self.attention(global_x) 188 | 189 | mu = torch.sum(x * w, dim=2) 190 | sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) ) 191 | 192 | x = torch.cat((mu,sg),1) 193 | x = self.bn5(x) 194 | emb = self.fc6(x) 195 | out = self.bn6(emb) 196 | #out = nn.ReLU()(out) 197 | #out = self.fc7(out) 198 | #emb = torch.nn.functional.dropout(emb,p=0.5,training=True) 199 | return emb,out -------------------------------------------------------------------------------- /models/hifi_gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | # We use a HiFi-GAN, as a vocoder. 10 | # This is an implementation of the HiFi-GAN model. 11 | # For more information about the model, visit "https://github.com/jik876/hifi-gan" 12 | # Paper: "HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis" 13 | 14 | def init_weights(m, mean=0.0, std=0.01): 15 | classname = m.__class__.__name__ 16 | if classname.find("Conv") != -1: 17 | m.weight.data.normal_(mean, std) 18 | 19 | 20 | def get_padding(kernel_size, dilation=1): 21 | return int((kernel_size * dilation - dilation) / 2) 22 | 23 | 24 | class ResBlock1(torch.nn.Module): 25 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 26 | super(ResBlock1, self).__init__() 27 | self.h = h 28 | self.convs1 = nn.ModuleList([ 29 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 30 | padding=get_padding(kernel_size, dilation[0]))), 31 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 32 | padding=get_padding(kernel_size, dilation[1]))), 33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 34 | padding=get_padding(kernel_size, dilation[2]))) 35 | ]) 36 | self.convs1.apply(init_weights) 37 | 38 | self.convs2 = nn.ModuleList([ 39 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 40 | padding=get_padding(kernel_size, 1))), 41 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 42 | padding=get_padding(kernel_size, 1))), 43 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 44 | padding=get_padding(kernel_size, 1))) 45 | ]) 46 | self.convs2.apply(init_weights) 47 | 48 | def forward(self, x): 49 | for c1, c2 in zip(self.convs1, self.convs2): 50 | xt = F.leaky_relu(x, LRELU_SLOPE) 51 | xt = c1(xt) 52 | xt = F.leaky_relu(xt, LRELU_SLOPE) 53 | xt = c2(xt) 54 | x = xt + x 55 | return x 56 | 57 | def remove_weight_norm(self): 58 | for l in self.convs1: 59 | remove_weight_norm(l) 60 | for l in self.convs2: 61 | remove_weight_norm(l) 62 | 63 | 64 | class ResBlock2(torch.nn.Module): 65 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 66 | super(ResBlock2, self).__init__() 67 | self.h = h 68 | self.convs = nn.ModuleList([ 69 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 70 | padding=get_padding(kernel_size, dilation[0]))), 71 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 72 | padding=get_padding(kernel_size, dilation[1]))) 73 | ]) 74 | self.convs.apply(init_weights) 75 | 76 | def forward(self, x): 77 | for c in self.convs: 78 | xt = F.leaky_relu(x, LRELU_SLOPE) 79 | xt = c(xt) 80 | x = xt + x 81 | return x 82 | 83 | def remove_weight_norm(self): 84 | for l in self.convs: 85 | remove_weight_norm(l) 86 | 87 | 88 | class Generator(torch.nn.Module): 89 | def __init__(self, h): 90 | super(Generator, self).__init__() 91 | self.h = h 92 | self.num_kernels = len(h.resblock_kernel_sizes) 93 | self.num_upsamples = len(h.upsample_rates) 94 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 95 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 96 | 97 | self.ups = nn.ModuleList() 98 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 99 | self.ups.append(weight_norm( 100 | ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), 101 | k, u, padding=(k - u) // 2))) 102 | 103 | self.resblocks = nn.ModuleList() 104 | for i in range(len(self.ups)): 105 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 106 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 107 | self.resblocks.append(resblock(h, ch, k, d)) 108 | 109 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 110 | self.ups.apply(init_weights) 111 | self.conv_post.apply(init_weights) 112 | 113 | def forward(self, x): 114 | x = self.conv_pre(x) 115 | for i in range(self.num_upsamples): 116 | x = F.leaky_relu(x, LRELU_SLOPE) 117 | x = self.ups[i](x) 118 | xs = None 119 | for j in range(self.num_kernels): 120 | if xs is None: 121 | xs = self.resblocks[i * self.num_kernels + j](x) 122 | else: 123 | xs += self.resblocks[i * self.num_kernels + j](x) 124 | x = xs / self.num_kernels 125 | x = F.leaky_relu(x) 126 | x = self.conv_post(x) 127 | x = torch.tanh(x) 128 | 129 | return x 130 | 131 | def remove_weight_norm(self): 132 | print('Removing weight norm...') 133 | for l in self.ups: 134 | remove_weight_norm(l) 135 | for l in self.resblocks: 136 | l.remove_weight_norm() 137 | remove_weight_norm(self.conv_pre) 138 | remove_weight_norm(self.conv_post) 139 | 140 | 141 | class DiscriminatorP(torch.nn.Module): 142 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 143 | super(DiscriminatorP, self).__init__() 144 | self.period = period 145 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 146 | self.convs = nn.ModuleList([ 147 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 148 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 149 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 150 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 151 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 152 | ]) 153 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 154 | 155 | def forward(self, x): 156 | fmap = [] 157 | 158 | # 1d to 2d 159 | b, c, t = x.shape 160 | if t % self.period != 0: # pad first 161 | n_pad = self.period - (t % self.period) 162 | x = F.pad(x, (0, n_pad), "reflect") 163 | t = t + n_pad 164 | x = x.view(b, c, t // self.period, self.period) 165 | 166 | for l in self.convs: 167 | x = l(x) 168 | x = F.leaky_relu(x, LRELU_SLOPE) 169 | fmap.append(x) 170 | x = self.conv_post(x) 171 | fmap.append(x) 172 | x = torch.flatten(x, 1, -1) 173 | 174 | return x, fmap 175 | 176 | 177 | class MultiPeriodDiscriminator(torch.nn.Module): 178 | def __init__(self): 179 | super(MultiPeriodDiscriminator, self).__init__() 180 | self.discriminators = nn.ModuleList([ 181 | DiscriminatorP(2), 182 | DiscriminatorP(3), 183 | DiscriminatorP(5), 184 | DiscriminatorP(7), 185 | DiscriminatorP(11), 186 | ]) 187 | 188 | def forward(self, y, y_hat): 189 | y_d_rs = [] 190 | y_d_gs = [] 191 | fmap_rs = [] 192 | fmap_gs = [] 193 | for i, d in enumerate(self.discriminators): 194 | y_d_r, fmap_r = d(y) 195 | y_d_g, fmap_g = d(y_hat) 196 | y_d_rs.append(y_d_r) 197 | fmap_rs.append(fmap_r) 198 | y_d_gs.append(y_d_g) 199 | fmap_gs.append(fmap_g) 200 | 201 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 202 | 203 | 204 | class DiscriminatorS(torch.nn.Module): 205 | def __init__(self, use_spectral_norm=False): 206 | super(DiscriminatorS, self).__init__() 207 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 208 | self.convs = nn.ModuleList([ 209 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 210 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 211 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 212 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 213 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 214 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 215 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 216 | ]) 217 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 218 | 219 | def forward(self, x): 220 | fmap = [] 221 | for l in self.convs: 222 | x = l(x) 223 | x = F.leaky_relu(x, LRELU_SLOPE) 224 | fmap.append(x) 225 | x = self.conv_post(x) 226 | fmap.append(x) 227 | x = torch.flatten(x, 1, -1) 228 | 229 | return x, fmap 230 | 231 | 232 | class MultiScaleDiscriminator(torch.nn.Module): 233 | def __init__(self): 234 | super(MultiScaleDiscriminator, self).__init__() 235 | self.discriminators = nn.ModuleList([ 236 | DiscriminatorS(use_spectral_norm=True), 237 | DiscriminatorS(), 238 | DiscriminatorS(), 239 | ]) 240 | self.meanpools = nn.ModuleList([ 241 | AvgPool1d(4, 2, padding=2), 242 | AvgPool1d(4, 2, padding=2) 243 | ]) 244 | 245 | def forward(self, y, y_hat): 246 | y_d_rs = [] 247 | y_d_gs = [] 248 | fmap_rs = [] 249 | fmap_gs = [] 250 | for i, d in enumerate(self.discriminators): 251 | if i != 0: 252 | y = self.meanpools[i - 1](y) 253 | y_hat = self.meanpools[i - 1](y_hat) 254 | y_d_r, fmap_r = d(y) 255 | y_d_g, fmap_g = d(y_hat) 256 | y_d_rs.append(y_d_r) 257 | fmap_rs.append(fmap_r) 258 | y_d_gs.append(y_d_g) 259 | fmap_gs.append(fmap_g) 260 | 261 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 262 | 263 | 264 | def feature_loss(fmap_r, fmap_g): 265 | loss = 0 266 | for dr, dg in zip(fmap_r, fmap_g): 267 | for rl, gl in zip(dr, dg): 268 | loss += torch.mean(torch.abs(rl - gl)) 269 | 270 | return loss * 2 271 | 272 | 273 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 274 | loss = 0 275 | r_losses = [] 276 | g_losses = [] 277 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 278 | r_loss = torch.mean((1 - dr) ** 2) 279 | g_loss = torch.mean(dg ** 2) 280 | loss += (r_loss + g_loss) 281 | r_losses.append(r_loss.item()) 282 | g_losses.append(g_loss.item()) 283 | 284 | return loss, r_losses, g_losses 285 | 286 | 287 | def generator_loss(disc_outputs): 288 | loss = 0 289 | gen_losses = [] 290 | for dg in disc_outputs: 291 | l = torch.mean((1 - dg) ** 2) 292 | gen_losses.append(l) 293 | loss += l 294 | 295 | return loss, gen_losses 296 | -------------------------------------------------------------------------------- /models/model_vc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class LinearNorm(torch.nn.Module): 8 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 9 | super(LinearNorm, self).__init__() 10 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 11 | 12 | torch.nn.init.xavier_uniform_( 13 | self.linear_layer.weight, 14 | gain=torch.nn.init.calculate_gain(w_init_gain)) 15 | 16 | def forward(self, x): 17 | return self.linear_layer(x) 18 | 19 | 20 | class ConvNorm(torch.nn.Module): 21 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 22 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 23 | super(ConvNorm, self).__init__() 24 | if padding is None: 25 | assert(kernel_size % 2 == 1) 26 | padding = int(dilation * (kernel_size - 1) / 2) 27 | 28 | 29 | 30 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 31 | kernel_size=kernel_size, stride=stride, 32 | padding=padding, dilation=dilation, 33 | bias=bias) 34 | 35 | torch.nn.init.xavier_uniform_( 36 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 37 | 38 | def forward(self, signal): 39 | conv_signal = self.conv(signal) 40 | return conv_signal 41 | 42 | 43 | class Encoder(nn.Module): 44 | '''Extract content embedding containing only linguistic information''' 45 | def __init__(self, freq): 46 | super(Encoder, self).__init__() 47 | self.freq = freq 48 | 49 | convolutions = [] 50 | for i in range(3): 51 | conv_layer = nn.Sequential( 52 | ConvNorm(80 if i==0 else 512, 53 | 512, 54 | kernel_size=5, stride=1, 55 | padding=2, 56 | dilation=1, w_init_gain='relu'), 57 | nn.BatchNorm1d(512)) 58 | convolutions.append(conv_layer) 59 | self.convolutions = nn.ModuleList(convolutions) 60 | 61 | self.lstm = nn.LSTM(512, 256, 2, batch_first=True, bidirectional=True) 62 | 63 | def forward(self, x): 64 | x = x.squeeze(1) 65 | 66 | for conv in self.convolutions: 67 | x = F.relu(conv(x)) 68 | x = x.transpose(1, 2) 69 | 70 | self.lstm.flatten_parameters() 71 | outputs, _ = self.lstm(x) 72 | out_forward = outputs[:, :, :256] # 256 is content embedding dimension. 73 | out_backward = outputs[:, :, 256:] # It can be changed. 74 | 75 | codes = [] 76 | for i in range(0, outputs.size(1), self.freq): 77 | codes.append(torch.cat((out_forward[:,i+self.freq-1,:],out_backward[:,i,:]), dim=-1)) 78 | 79 | return codes 80 | 81 | class Decoder(nn.Module): 82 | '''Reconstruction mel-spectrogram using speaker embedding, content embedding and energy''' 83 | def __init__(self, dim_emb, dim_pre): 84 | super(Decoder, self).__init__() 85 | 86 | self.lstm1 = nn.LSTM(256*2+dim_emb+1, dim_pre, 1, batch_first=True) 87 | 88 | convolutions = [] 89 | for i in range(3): 90 | conv_layer = nn.Sequential( 91 | ConvNorm(dim_pre, 92 | dim_pre, 93 | kernel_size=5, stride=1, 94 | padding=2, 95 | dilation=1, w_init_gain='relu'), 96 | nn.BatchNorm1d(dim_pre)) 97 | convolutions.append(conv_layer) 98 | self.convolutions = nn.ModuleList(convolutions) 99 | 100 | self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True) 101 | 102 | self.linear_projection = LinearNorm(1024, 80) 103 | 104 | def forward(self, x): 105 | x, _ = self.lstm1(x) 106 | x = x.transpose(1, 2) 107 | 108 | for conv in self.convolutions: 109 | x = F.relu(conv(x)) 110 | x = x.transpose(1, 2) 111 | 112 | outputs, _ = self.lstm2(x) 113 | 114 | decoder_output = self.linear_projection(outputs) 115 | 116 | return decoder_output 117 | 118 | class Postnet(nn.Module): 119 | '''Five 1-d convolution with 512 channels and kernel size 5''' 120 | def __init__(self): 121 | super(Postnet, self).__init__() 122 | 123 | self.convolutions = nn.ModuleList() 124 | self.convolutions.append( 125 | nn.Sequential( 126 | ConvNorm(80, 512, 127 | kernel_size=5, stride=1, 128 | padding=2, 129 | dilation=1, w_init_gain='tanh'), 130 | nn.BatchNorm1d(512)) 131 | ) 132 | 133 | for i in range(1, 5 - 1): 134 | self.convolutions.append( 135 | nn.Sequential( 136 | ConvNorm(512, 137 | 512, 138 | kernel_size=5, stride=1, 139 | padding=2, 140 | dilation=1, w_init_gain='tanh'), 141 | nn.BatchNorm1d(512)) 142 | ) 143 | 144 | self.convolutions.append( 145 | nn.Sequential( 146 | ConvNorm(512, 80, 147 | kernel_size=5, stride=1, 148 | padding=2, 149 | dilation=1, w_init_gain='linear'), 150 | nn.BatchNorm1d(80)) 151 | ) 152 | 153 | def forward(self, x): 154 | for i in range(len(self.convolutions) - 1): 155 | x = torch.tanh(self.convolutions[i](x)) 156 | x = self.convolutions[-1](x) 157 | 158 | return x 159 | 160 | 161 | class Generator(nn.Module): 162 | '''Generate network''' 163 | def __init__(self, dim_emb, dim_pre, freq): 164 | super(Generator, self).__init__() 165 | 166 | self.encoder = Encoder(freq) 167 | self.decoder = Decoder(dim_emb, dim_pre) 168 | self.postnet = Postnet() 169 | 170 | def forward(self, x, c_trg): # c_trg includes speaker embedding and energy. 171 | codes = self.encoder(x) 172 | 173 | if c_trg is None: 174 | return torch.cat(codes, dim=-1) # Use for content loss calculation. 175 | 176 | tmp = [] 177 | for code in codes: 178 | tmp.append(code.unsqueeze(-1)) 179 | 180 | code_exp = torch.cat(tmp, dim=-1) # Content embedding. 181 | encoder_outputs = torch.cat((code_exp, c_trg), dim=1) # Concatenate content embedding, speaker embedding, energy. 182 | 183 | encoder_outputs = encoder_outputs.transpose(2, 1) 184 | mel_outputs = self.decoder(encoder_outputs) 185 | 186 | mel_outputs_postnet = self.postnet(mel_outputs.transpose(2,1)) 187 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1) 188 | 189 | return mel_outputs.permute(0, 2, 1), mel_outputs_postnet.permute(0, 2, 1), torch.cat(codes, dim=-1) -------------------------------------------------------------------------------- /samples/zs_f2f_conversion.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2f_conversion.wav -------------------------------------------------------------------------------- /samples/zs_f2f_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2f_source.wav -------------------------------------------------------------------------------- /samples/zs_f2f_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2f_target.wav -------------------------------------------------------------------------------- /samples/zs_f2m_conversion.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2m_conversion.wav -------------------------------------------------------------------------------- /samples/zs_f2m_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2m_source.wav -------------------------------------------------------------------------------- /samples/zs_f2m_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_f2m_target.wav -------------------------------------------------------------------------------- /samples/zs_m2f_conversion.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2f_conversion.wav -------------------------------------------------------------------------------- /samples/zs_m2f_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2f_source.wav -------------------------------------------------------------------------------- /samples/zs_m2f_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2f_target.wav -------------------------------------------------------------------------------- /samples/zs_m2m_conversion.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2m_conversion.wav -------------------------------------------------------------------------------- /samples/zs_m2m_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2m_source.wav -------------------------------------------------------------------------------- /samples/zs_m2m_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjchun3616/perturbation_autovc/e19ac3fa9d59fc4faebf9822ac6784b0e1ec6c0c/samples/zs_m2m_target.wav -------------------------------------------------------------------------------- /solver_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import datetime 5 | import torch.nn.functional as F 6 | from models.model_vc import Generator 7 | 8 | class Solver(object): 9 | 10 | def __init__(self, vcc_loader, config): 11 | # Data loader. 12 | self.vcc_loader = vcc_loader 13 | 14 | # Model configurations. 15 | self.lambda_cd = config.lambda_cd 16 | self.dim_emb = config.dim_emb 17 | self.dim_pre = config.dim_pre 18 | self.freq = config.freq 19 | 20 | # model checkpoint and save path. 21 | self.resume = config.resume 22 | self.save_dir = config.save_dir 23 | 24 | # Training configrations. 25 | self.batch_size = config.batch_size 26 | self.num_iters = config.num_iters 27 | 28 | self.use_cuda = torch.cuda.is_available() 29 | self.device = torch.device('cuda:0' if self.use_cuda else 'cpu') 30 | self.log_step = config.log_step 31 | 32 | # build the model. 33 | self.build_model() 34 | 35 | def build_model(self): 36 | 37 | self.G = Generator(self.dim_emb, self.dim_pre, self.freq) 38 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), 1e-4) 39 | self.G.to(self.device) 40 | 41 | if self.resume is not None: 42 | print("checkpoint load at ", self.resume) 43 | g_checkpoint = torch.load(self.resume, map_location='cuda:0') 44 | self.G.load_state_dict(g_checkpoint['model_state_dict']) 45 | self.g_optimizer.load_state_dict(g_checkpoint['optimizer_state_dict']) 46 | 47 | def reset_grad(self): 48 | # Reset the gradient. 49 | self.g_optimizer.zero_grad() 50 | 51 | def train(self): 52 | # Set the loader. 53 | data_loader = self.vcc_loader 54 | 55 | # Print logs in specified order. 56 | keys = ['G/loss_id','G/loss_id_psnt','G/loss_cd'] 57 | 58 | print('Start training...') 59 | start_time = time.time() 60 | self.G = self.G.train() 61 | 62 | for i in range(self.num_iters): 63 | '''Preprocess input data.''' 64 | data_iter = iter(data_loader) 65 | x_real, uttr, emb_org = next(data_iter) 66 | 67 | x_real = x_real.to(self.device) # Clean data. 68 | uttr = uttr.to(self.device) # Perturbed data. 69 | emb_org = emb_org.to(self.device) # Speaker information. 70 | 71 | energy = torch.mean(x_real, dim=1, keepdim=True) # Extract energy. 72 | emb_de = emb_org.unsqueeze(-1).expand(-1, -1, energy.shape[-1]) 73 | emb_de = torch.cat((emb_de, energy), dim=1) 74 | 75 | x_identic, x_identic_psnt, code_real = self.G(uttr, emb_de) 76 | 77 | x_identic = x_identic.squeeze() 78 | x_identic_psnt = x_identic_psnt.squeeze() 79 | 80 | # Calculate reconstruction loss. 81 | g_loss_id = F.mse_loss(x_real, x_identic) 82 | g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt) 83 | 84 | # Calculate cotent loss. 85 | code_reconst = self.G(x_identic_psnt, None) 86 | g_loss_cd = F.l1_loss(code_real, code_reconst) 87 | 88 | # Backward and optimize. 89 | g_loss = 2*(g_loss_id + g_loss_id_psnt) + self.lambda_cd * g_loss_cd # sigma=2, mu=1. 90 | self.reset_grad() 91 | g_loss.backward() 92 | self.g_optimizer.step() 93 | 94 | # Print out training losses. 95 | if i % 200 == 0: 96 | print("loss_id: ", g_loss_id.item(), "loss_id_psnt: ", g_loss_id_psnt.item(), "g_loss_cd: ", g_loss_cd.item()) 97 | 98 | if i % 100000 == 0: 99 | self.lambda_cd = self.lambda_cd * 0.9 100 | 101 | # Logging. 102 | loss = {} 103 | loss['G/loss_id'] = g_loss_id.item() 104 | loss['G/loss_id_psnt'] = g_loss_id_psnt.item() 105 | loss['G/loss_cd'] = g_loss_cd.item() 106 | 107 | # Print out training information. 108 | if (i+1) % self.log_step == 0: 109 | et = time.time() - start_time 110 | et = str(datetime.timedelta(seconds=et))[:-7] 111 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) 112 | # Save the model at checkpoint. 113 | torch.save({'model_state_dict': self.G.state_dict(), 114 | 'optimizer_state_dict':self.g_optimizer.state_dict()}, 115 | os.path.join(self.save_dir, f'{self.pt_name}_{i//10000+1}.pt')) 116 | for tag in keys: 117 | log += ", {}: {:.4f}".format(tag, loss[tag]) 118 | print(log) 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils/function_f.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import librosa 4 | import numpy as np 5 | import parselmouth 6 | import scipy.signal 7 | import torch 8 | import torchaudio.functional as AF 9 | 10 | PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT = 0.0 11 | PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT = 1.0 12 | PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT = 1.0 13 | PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT = 1.0 14 | PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT = 1.0 15 | 16 | 17 | def wav_to_Sound(wav, sampling_frequency: int = 22050) -> parselmouth.Sound: 18 | if isinstance(wav, parselmouth.Sound): 19 | sound = wav 20 | elif isinstance(wav, np.ndarray): 21 | sound = parselmouth.Sound(wav, sampling_frequency=sampling_frequency) 22 | elif isinstance(wav, list): 23 | wav_np = np.asarray(wav) 24 | sound = parselmouth.Sound(np.asarray(wav_np), sampling_frequency=sampling_frequency) 25 | else: 26 | raise NotImplementedError 27 | return sound 28 | 29 | 30 | def wav_to_Tensor(wav) -> torch.Tensor: 31 | if isinstance(wav, np.ndarray): 32 | wav_tensor = torch.from_numpy(wav) 33 | elif isinstance(wav, torch.Tensor): 34 | wav_tensor = wav 35 | elif isinstance(wav, parselmouth.Sound): 36 | wav_np = wav.values 37 | wav_tensor = torch.from_numpy(wav_np) 38 | else: 39 | raise NotImplementedError 40 | return wav_tensor 41 | 42 | 43 | def get_pitch_median(wav, sr: int = None): 44 | sound = wav_to_Sound(wav, sr) 45 | pitch = None 46 | pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT 47 | 48 | try: 49 | pitch = parselmouth.praat.call(sound, "To Pitch", 0.8 / 75, 75, 600) 50 | pitch_median = parselmouth.praat.call(pitch, "Get quantile", 0.0, 0.0, 0.5, "Hertz") 51 | except Exception as e: 52 | raise e 53 | pass 54 | 55 | return pitch, pitch_median 56 | 57 | 58 | def change_gender( 59 | sound, pitch=None, 60 | formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT, 61 | new_pitch_median: float = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT, 62 | pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT, 63 | duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT, ) -> parselmouth.Sound: 64 | try: 65 | if pitch is None: 66 | new_sound = parselmouth.praat.call( 67 | sound, "Change gender", 75, 600, 68 | formant_shift_ratio, 69 | new_pitch_median, 70 | pitch_range_ratio, 71 | duration_factor 72 | ) 73 | else: 74 | new_sound = parselmouth.praat.call( 75 | (sound, pitch), "Change gender", 76 | formant_shift_ratio, 77 | new_pitch_median, 78 | pitch_range_ratio, 79 | duration_factor 80 | ) 81 | except Exception as e: 82 | raise e 83 | 84 | return new_sound 85 | 86 | 87 | def apply_formant_and_pitch_shift( 88 | sound: parselmouth.Sound, 89 | formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT, 90 | pitch_shift_ratio: float = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT, 91 | pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT, 92 | duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT) -> parselmouth.Sound: 93 | 94 | pitch = None 95 | new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT 96 | if pitch_shift_ratio != 1.: 97 | try: 98 | pitch, pitch_median = get_pitch_median(sound, None) 99 | new_pitch_median = pitch_median * pitch_shift_ratio 100 | 101 | # https://github.com/praat/praat/issues/1926#issuecomment-974909408 102 | pitch_minimum = parselmouth.praat.call(pitch, "Get minimum", 0.0, 0.0, "Hertz", "Parabolic") 103 | newMedian = pitch_median * pitch_shift_ratio 104 | scaledMinimum = pitch_minimum * pitch_shift_ratio 105 | resultingMinimum = newMedian + (scaledMinimum - newMedian) * pitch_range_ratio 106 | 107 | if resultingMinimum < 0: 108 | new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT # 0.0 109 | pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT # 1.0 110 | 111 | if math.isnan(new_pitch_median): 112 | new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT # 0.0 113 | pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT # 1.0 114 | 115 | except Exception as e: 116 | raise e 117 | 118 | new_sound = change_gender( 119 | sound, pitch, 120 | formant_shift_ratio, new_pitch_median, 121 | pitch_range_ratio, duration_factor) 122 | 123 | return new_sound 124 | 125 | 126 | # fs & pr 127 | def formant_and_pitch_shift(sound: parselmouth.Sound) -> parselmouth.Sound: 128 | r"""calculate random factors and apply formant and pitch shift 129 | 130 | designed for formant shifting(fs) and pitch randomization(pr) in the paper 131 | """ 132 | formant_shifting_ratio = random.uniform(1.2, 1.5) # ratio > 1, female->fmale 133 | use_reciprocal = random.uniform(-1, 1) > 0 # ratio < 1, male->female 134 | if use_reciprocal: # 135 | formant_shifting_ratio = 1 / formant_shifting_ratio 136 | 137 | pitch_shift_ratio = random.uniform(1.2, 1.5) 138 | use_reciprocal = random.uniform(-1, 1) > 0 139 | if use_reciprocal: 140 | pitch_shift_ratio = 1 / pitch_shift_ratio 141 | 142 | pitch_range_ratio = random.uniform(1.1, 1.5) 143 | use_reciprocal = random.uniform(-1, 1) > 0 144 | if use_reciprocal: 145 | pitch_range_ratio = 1 / pitch_range_ratio 146 | 147 | sound_new = apply_formant_and_pitch_shift( 148 | sound, 149 | formant_shift_ratio=formant_shifting_ratio, 150 | pitch_shift_ratio=pitch_shift_ratio, 151 | pitch_range_ratio=pitch_range_ratio, 152 | duration_factor=1. 153 | ) 154 | return sound_new 155 | 156 | 157 | # fs 158 | def formant_shift(sound: parselmouth.Sound) -> parselmouth.Sound: 159 | formant_shifting_ratio = random.uniform(1.0, 1.3) 160 | use_reciprocal = random.uniform(-1, 1) > 0 161 | if use_reciprocal: 162 | formant_shifting_ratio = 1 / formant_shifting_ratio 163 | 164 | sound_new = apply_formant_and_pitch_shift( 165 | sound, 166 | formant_shift_ratio=formant_shifting_ratio, 167 | ) 168 | return sound_new 169 | 170 | 171 | def power_ratio(r: float, a: float, b: float): 172 | return a * math.pow((b / a), r) 173 | 174 | 175 | # peq 176 | def parametric_equalizer(wav: torch.Tensor, sr: int) -> torch.Tensor: 177 | cutoff_low_freq = 60. 178 | cutoff_high_freq = 10000. 179 | 180 | q_min = 2 181 | q_max = 5 182 | 183 | num_filters = 8 + 2 # 8 for peak, 2 for high/low 184 | key_freqs = [ 185 | power_ratio(float(z) / (num_filters), cutoff_low_freq, cutoff_high_freq) 186 | for z in range(num_filters) 187 | ] 188 | Qs = [ 189 | power_ratio(random.uniform(0, 1), q_min, q_max) 190 | for _ in range(num_filters) 191 | ] 192 | gains = [random.uniform(-12, 12) for _ in range(num_filters)] 193 | 194 | # peak filters 195 | for i in range(1, 9): 196 | wav = apply_iir_filter( 197 | wav, 198 | ftype='peak', 199 | dBgain=gains[i], 200 | cutoff_freq=key_freqs[i], 201 | sample_rate=sr, 202 | Q=Qs[i] 203 | ) 204 | 205 | # high-shelving filter 206 | wav = apply_iir_filter( 207 | wav, 208 | ftype='high', 209 | dBgain=gains[-1], 210 | cutoff_freq=key_freqs[-1], 211 | sample_rate=sr, 212 | Q=Qs[-1] 213 | ) 214 | 215 | # low-shelving filter 216 | wav = apply_iir_filter( 217 | wav, 218 | ftype='low', 219 | dBgain=gains[0], 220 | cutoff_freq=key_freqs[0], 221 | sample_rate=sr, 222 | Q=Qs[0] 223 | ) 224 | 225 | return wav 226 | 227 | 228 | # implemented using the cookbook https://webaudio.github.io/Audio-EQ-Cookbook/audio-eq-cookbook.html 229 | def lowShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q): 230 | A = math.pow(10, dBgain / 40.) 231 | 232 | w0 = 2 * math.pi * cutoff_freq / sample_rate 233 | alpha = math.sin(w0) / 2 / Q 234 | # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A) 235 | 236 | b0 = A * ((A + 1) - (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha) 237 | b1 = 2 * A * ((A - 1) - (A + 1) * math.cos(w0)) 238 | b2 = A * ((A + 1) - (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha) 239 | 240 | a0 = (A + 1) + (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha 241 | a1 = -2 * ((A - 1) + (A + 1) * math.cos(w0)) 242 | a2 = (A + 1) + (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha 243 | return b0, b1, b2, a0, a1, a2 244 | 245 | 246 | def highShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q): 247 | A = math.pow(10, dBgain / 40.) 248 | 249 | w0 = 2 * math.pi * cutoff_freq / sample_rate 250 | alpha = math.sin(w0) / 2 / Q 251 | # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A) 252 | 253 | b0 = A * ((A + 1) + (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha) 254 | b1 = -2 * A * ((A - 1) + (A + 1) * math.cos(w0)) 255 | b2 = A * ((A + 1) + (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha) 256 | 257 | a0 = (A + 1) - (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha 258 | a1 = 2 * ((A - 1) - (A + 1) * math.cos(w0)) 259 | a2 = (A + 1) - (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha 260 | return b0, b1, b2, a0, a1, a2 261 | 262 | 263 | def peaking_coeffs(dBgain, cutoff_freq, sample_rate, Q): 264 | A = math.pow(10, dBgain / 40.) 265 | 266 | w0 = 2 * math.pi * cutoff_freq / sample_rate 267 | alpha = math.sin(w0) / 2 / Q 268 | # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A) 269 | 270 | b0 = 1 + alpha * A 271 | b1 = -2 * math.cos(w0) 272 | b2 = 1 - alpha * A 273 | 274 | a0 = 1 + alpha / A 275 | a1 = -2 * math.cos(w0) 276 | a2 = 1 - alpha / A 277 | return b0, b1, b2, a0, a1, a2 278 | 279 | 280 | def apply_iir_filter(wav: torch.Tensor, ftype, dBgain, cutoff_freq, sample_rate, Q, torch_backend=True): 281 | if ftype == 'low': 282 | b0, b1, b2, a0, a1, a2 = lowShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q) 283 | elif ftype == 'high': 284 | b0, b1, b2, a0, a1, a2 = highShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q) 285 | elif ftype == 'peak': 286 | b0, b1, b2, a0, a1, a2 = peaking_coeffs(dBgain, cutoff_freq, sample_rate, Q) 287 | else: 288 | raise NotImplementedError 289 | if torch_backend: 290 | return_wav = AF.biquad(wav, b0, b1, b2, a0, a1, a2) 291 | else: 292 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter_zi.html 293 | wav_numpy = wav.numpy() 294 | b = np.asarray([b0, b1, b2]) 295 | a = np.asarray([a0, a1, a2]) 296 | zi = scipy.signal.lfilter_zi(b, a) * wav_numpy[0] 297 | return_wav, _ = scipy.signal.lfilter(b, a, wav_numpy, zi=zi) 298 | return_wav = torch.from_numpy(return_wav) 299 | return return_wav 300 | 301 | 302 | peq = parametric_equalizer 303 | fs = formant_shift 304 | 305 | 306 | def g(wav: torch.Tensor, sr: int) -> torch.Tensor: 307 | 308 | wav = peq(wav, sr) 309 | wav_numpy = wav.numpy() 310 | 311 | sound = wav_to_Sound(wav_numpy, sampling_frequency=sr) 312 | sound = formant_shift(sound) 313 | 314 | wav = torch.from_numpy(sound.values).float().squeeze(0) 315 | return wav 316 | 317 | 318 | def f(wav: torch.Tensor, sr: int) -> torch.Tensor: 319 | 320 | wav = peq(wav, sr) 321 | wav_numpy = wav.numpy() 322 | 323 | sound = wav_to_Sound(wav_numpy, sampling_frequency=sr) 324 | sound = formant_and_pitch_shift(sound) 325 | 326 | wav = torch.from_numpy(sound.values).float().squeeze(0) 327 | return wav 328 | -------------------------------------------------------------------------------- /utils/mel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import numpy as np 5 | from librosa.util import normalize 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | MAX_WAV_VALUE = 32768.0 9 | 10 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 11 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 12 | 13 | 14 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 15 | return torch.log(torch.clamp(x, min=clip_val) * C) 16 | 17 | 18 | def spectral_normalize_torch(magnitudes): 19 | output = dynamic_range_compression_torch(magnitudes) 20 | return output 21 | 22 | mel_basis = {} 23 | hann_window = {} 24 | 25 | '''Make mel-spectrogram''' 26 | def mel_spectrogram(y, n_fft=1024, num_mels=80, sampling_rate=22050, 27 | hop_size=256, win_size=1024, fmin=0, fmax=7500, center=False): 28 | global mel_basis, hann_window 29 | if fmax not in mel_basis: 30 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 31 | mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 32 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 33 | 34 | y = torch.nn.functional.pad(y.unsqueeze(0), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 35 | mode='reflect') 36 | y = y.squeeze(1) 37 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 38 | center=center, pad_mode='reflect', normalized=False, onesided=True) 39 | 40 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 41 | 42 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) 43 | spec = spectral_normalize_torch(spec) 44 | 45 | return spec -------------------------------------------------------------------------------- /utils/perturbation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import torchaudio 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | 7 | from utils import function_f 8 | from utils.mel import mel_spectrogram 9 | 10 | '''Generate mel-spectrogram data for training''' 11 | def load_wav(filepath, fs): 12 | # Load wav. 13 | audio, sr = torchaudio.load(filepath) 14 | resampler = torchaudio.transforms.Resample(sr, fs) 15 | audio = resampler(audio) 16 | return audio, fs 17 | 18 | def make_data(wav_dir, real_dir, perturb_dir): 19 | 20 | dirName, subdirList, _ = next(os.walk(wav_dir)) 21 | 22 | for subdir in tqdm(sorted(subdirList), total=len(subdirList)): 23 | # make directory. 24 | if not os.path.exists(os.path.join(perturb_dir, subdir)): 25 | os.makedirs(os.path.join(perturb_dir, subdir)) 26 | os.makedirs(os.path.join(real_dir, subdir)) 27 | 28 | _,_, fileList = next(os.walk(os.path.join(dirName,subdir))) 29 | for fileName in sorted(fileList): 30 | 31 | x, fs = load_wav(os.path.join(dirName,subdir,fileName), fs=22050) 32 | x_perturb = function_f.f(x, sr=fs) # Perturb audio. 33 | 34 | mel_perturb = mel_spectrogram(x_perturb) # extract mel-spectrogram from perturbation audio. 35 | mel_real = mel_spectrogram(x) # extract mel-spectrogram. 36 | 37 | np.save(os.path.join(perturb_dir, subdir, fileName[:-4]), mel_perturb, allow_pickle=False) 38 | np.save(os.path.join(real_dir, subdir, fileName[:-4]), mel_real, allow_pickle=False) 39 | --------------------------------------------------------------------------------