├── LICENSE ├── README.md ├── assets ├── cover.png └── demo.pkl ├── data_loader.py ├── demo.ipynb ├── hparams.py ├── main.py ├── make_metadata.py ├── make_spect_f0.py ├── model.py ├── solver.py ├── synthesis.py ├── tfcompat ├── __init__.py └── hparam.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Speech Decomposition Via Triple Information Bottleneck 2 | 3 | This repository provides a PyTorch implementation of [SpeechSplit](https://arxiv.org/abs/2004.11284), which enables more detailed speaking style conversion by disentangling speech into content, timbre, rhythm and pitch. 4 | 5 | This is a short video that explains the main concepts of our work. If you find this work useful and use it in your research, please consider citing our paper. 6 | 7 | [![SpeechSplit](./assets/cover.png)](https://youtu.be/sIlQ3GcslD8) 8 | 9 | ``` 10 | @article{qian2020unsupervised, 11 | title={Unsupervised speech decomposition via triple information bottleneck}, 12 | author={Qian, Kaizhi and Zhang, Yang and Chang, Shiyu and Cox, David and Hasegawa-Johnson, Mark}, 13 | journal={arXiv preprint arXiv:2004.11284}, 14 | year={2020} 15 | } 16 | ``` 17 | 18 | 19 | ## Audio Demo 20 | 21 | The audio demo for SpeechSplit can be found [here](https://auspicious3000.github.io/SpeechSplit-Demo/) 22 | 23 | ## Dependencies 24 | - Python 3.6 25 | - Numpy 26 | - Scipy 27 | - PyTorch >= v1.2.0 28 | - librosa 29 | - pysptk 30 | - soundfile 31 | - matplotlib 32 | - wavenet_vocoder ```pip install wavenet_vocoder==0.1.1``` 33 | for more information, please refer to https://github.com/r9y9/wavenet_vocoder 34 | 35 | 36 | ## To Run Demo 37 | 38 | Download [pre-trained models](https://ibm.box.com/s/kgomuly35meo8xsh5mfxxklol8bulrry) to ```assets``` 39 | 40 | Download the same WaveNet vocoder model as in [AutoVC](https://github.com/auspicious3000/autovc) to ```assets``` 41 | 42 | The fast and high-quality hifi-gan v1 (https://github.com/jik876/hifi-gan) pre-trained model is now available [here.](https://ibm.box.com/s/asvv554v0zd09yipl2qadz49i7jpdhta) 43 | 44 | Run ```demo.ipynb``` 45 | 46 | Please refer to [AutoVC](https://github.com/auspicious3000/autovc) if you have any problems with the vocoder part, because they share the same vocoder scripts. 47 | 48 | 49 | ## To Train 50 | 51 | Download [training data](https://ibm.box.com/s/ahaj5zbuwu7jox47zxnsls2syf12g4c5) to ```assets```. 52 | The provided training data is very small for code verification purpose only. 53 | Please use the scripts to prepare your own data for training. 54 | 55 | 1. Extract spectrogram and f0: ```python make_spect_f0.py``` 56 | 57 | 2. Generate training metadata: ```python make_metadata.py ``` 58 | 59 | 3. Run the training scripts: ```python main.py``` 60 | 61 | Please refer to Appendix B.4 for training guidance. 62 | 63 | 64 | ## Final Words 65 | 66 | This project is part of an ongoing research. We hope this repo is useful for your research. If you need any help or have any suggestions on improving the framework, please raise an issue and we will do our best to get back to you as soon as possible. 67 | 68 | 69 | -------------------------------------------------------------------------------- /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/auspicious3000/SpeechSplit/c39330d84cdb7f2fd452058d46a17a4c578a0359/assets/cover.png -------------------------------------------------------------------------------- /assets/demo.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/auspicious3000/SpeechSplit/c39330d84cdb7f2fd452058d46a17a4c578a0359/assets/demo.pkl -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | 6 | from functools import partial 7 | from numpy.random import uniform 8 | from multiprocessing import Process, Manager 9 | 10 | from torch.utils import data 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | class Utterances(data.Dataset): 15 | """Dataset class for the Utterances dataset.""" 16 | 17 | def __init__(self, root_dir, feat_dir, mode): 18 | """Initialize and preprocess the Utterances dataset.""" 19 | self.root_dir = root_dir 20 | self.feat_dir = feat_dir 21 | self.mode = mode 22 | self.step = 20 23 | self.split = 0 24 | 25 | metaname = os.path.join(self.root_dir, "train.pkl") 26 | meta = pickle.load(open(metaname, "rb")) 27 | 28 | manager = Manager() 29 | meta = manager.list(meta) 30 | dataset = manager.list(len(meta)*[None]) # <-- can be shared between processes. 31 | processes = [] 32 | for i in range(0, len(meta), self.step): 33 | p = Process(target=self.load_data, 34 | args=(meta[i:i+self.step],dataset,i,mode)) 35 | p.start() 36 | processes.append(p) 37 | for p in processes: 38 | p.join() 39 | 40 | 41 | # very importtant to do dataset = list(dataset) 42 | if mode == 'train': 43 | self.train_dataset = list(dataset) 44 | self.num_tokens = len(self.train_dataset) 45 | elif mode == 'test': 46 | self.test_dataset = list(dataset) 47 | self.num_tokens = len(self.test_dataset) 48 | else: 49 | raise ValueError 50 | 51 | print('Finished loading {} dataset...'.format(mode)) 52 | 53 | 54 | 55 | def load_data(self, submeta, dataset, idx_offset, mode): 56 | for k, sbmt in enumerate(submeta): 57 | uttrs = len(sbmt)*[None] 58 | # fill in speaker id and embedding 59 | uttrs[0] = sbmt[0] 60 | uttrs[1] = sbmt[1] 61 | # fill in data 62 | sp_tmp = np.load(os.path.join(self.root_dir, sbmt[2])) 63 | f0_tmp = np.load(os.path.join(self.feat_dir, sbmt[2])) 64 | if self.mode == 'train': 65 | sp_tmp = sp_tmp[self.split:, :] 66 | f0_tmp = f0_tmp[self.split:] 67 | elif self.mode == 'test': 68 | sp_tmp = sp_tmp[:self.split, :] 69 | f0_tmp = f0_tmp[:self.split] 70 | else: 71 | raise ValueError 72 | uttrs[2] = ( sp_tmp, f0_tmp ) 73 | dataset[idx_offset+k] = uttrs 74 | 75 | 76 | 77 | def __getitem__(self, index): 78 | dataset = self.train_dataset if self.mode == 'train' else self.test_dataset 79 | 80 | list_uttrs = dataset[index] 81 | spk_id_org = list_uttrs[0] 82 | emb_org = list_uttrs[1] 83 | 84 | melsp, f0_org = list_uttrs[2] 85 | 86 | return melsp, emb_org, f0_org 87 | 88 | 89 | def __len__(self): 90 | """Return the number of spkrs.""" 91 | return self.num_tokens 92 | 93 | 94 | 95 | class MyCollator(object): 96 | def __init__(self, hparams): 97 | self.min_len_seq = hparams.min_len_seq 98 | self.max_len_seq = hparams.max_len_seq 99 | self.max_len_pad = hparams.max_len_pad 100 | 101 | def __call__(self, batch): 102 | # batch[i] is a tuple of __getitem__ outputs 103 | new_batch = [] 104 | for token in batch: 105 | aa, b, c = token 106 | len_crop = np.random.randint(self.min_len_seq, self.max_len_seq+1, size=2) # 1.5s ~ 3s 107 | left = np.random.randint(0, len(aa)-len_crop[0], size=2) 108 | pdb.set_trace() 109 | 110 | a = aa[left[0]:left[0]+len_crop[0], :] 111 | c = c[left[0]:left[0]+len_crop[0]] 112 | 113 | a = np.clip(a, 0, 1) 114 | 115 | a_pad = np.pad(a, ((0,self.max_len_pad-a.shape[0]),(0,0)), 'constant') 116 | c_pad = np.pad(c[:,np.newaxis], ((0,self.max_len_pad-c.shape[0]),(0,0)), 'constant', constant_values=-1e10) 117 | 118 | new_batch.append( (a_pad, b, c_pad, len_crop[0]) ) 119 | 120 | batch = new_batch 121 | 122 | a, b, c, d = zip(*batch) 123 | melsp = torch.from_numpy(np.stack(a, axis=0)) 124 | spk_emb = torch.from_numpy(np.stack(b, axis=0)) 125 | pitch = torch.from_numpy(np.stack(c, axis=0)) 126 | len_org = torch.from_numpy(np.stack(d, axis=0)) 127 | 128 | return melsp, spk_emb, pitch, len_org 129 | 130 | 131 | 132 | 133 | class MultiSampler(Sampler): 134 | """Samples elements more than once in a single pass through the data. 135 | """ 136 | def __init__(self, num_samples, n_repeats, shuffle=False): 137 | self.num_samples = num_samples 138 | self.n_repeats = n_repeats 139 | self.shuffle = shuffle 140 | 141 | def gen_sample_array(self): 142 | self.sample_idx_array = torch.arange(self.num_samples, dtype=torch.int64).repeat(self.n_repeats) 143 | if self.shuffle: 144 | self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array))] 145 | return self.sample_idx_array 146 | 147 | def __iter__(self): 148 | return iter(self.gen_sample_array()) 149 | 150 | def __len__(self): 151 | return len(self.sample_idx_array) 152 | 153 | 154 | 155 | 156 | def get_loader(hparams): 157 | """Build and return a data loader.""" 158 | 159 | dataset = Utterances(hparams.root_dir, hparams.feat_dir, hparams.mode) 160 | 161 | my_collator = MyCollator(hparams) 162 | 163 | sampler = MultiSampler(len(dataset), hparams.samplier, shuffle=hparams.shuffle) 164 | 165 | worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32)) 166 | 167 | data_loader = data.DataLoader(dataset=dataset, 168 | batch_size=hparams.batch_size, 169 | sampler=sampler, 170 | num_workers=hparams.num_workers, 171 | drop_last=True, 172 | pin_memory=True, 173 | worker_init_fn=worker_init_fn, 174 | collate_fn=my_collator) 175 | return data_loader -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# demo conversion\n", 10 | "import torch\n", 11 | "import pickle\n", 12 | "import numpy as np\n", 13 | "from hparams import hparams\n", 14 | "from utils import pad_seq_to_2\n", 15 | "from utils import quantize_f0_numpy\n", 16 | "from model import Generator_3 as Generator\n", 17 | "from model import Generator_6 as F0_Converter\n", 18 | "\n", 19 | "\n", 20 | "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", 21 | "G = Generator(hparams).eval().to(device)\n", 22 | "g_checkpoint = torch.load('assets/660000-G.ckpt', map_location=lambda storage, loc: storage)\n", 23 | "G.load_state_dict(g_checkpoint['model'])\n", 24 | "\n", 25 | "P = F0_Converter(hparams).eval().to(device)\n", 26 | "p_checkpoint = torch.load('assets/640000-P.ckpt', map_location=lambda storage, loc: storage)\n", 27 | "P.load_state_dict(p_checkpoint['model'])\n", 28 | "\n", 29 | "\n", 30 | "metadata = pickle.load(open('assets/demo.pkl', \"rb\"))\n", 31 | "\n", 32 | "\n", 33 | "sbmt_i = metadata[0]\n", 34 | "emb_org = torch.from_numpy(sbmt_i[1]).to(device)\n", 35 | "x_org, f0_org, len_org, uid_org = sbmt_i[2] \n", 36 | "uttr_org_pad, len_org_pad = pad_seq_to_2(x_org[np.newaxis,:,:], 192)\n", 37 | "uttr_org_pad = torch.from_numpy(uttr_org_pad).to(device)\n", 38 | "f0_org_pad = np.pad(f0_org, (0, 192-len_org), 'constant', constant_values=(0, 0))\n", 39 | "f0_org_quantized = quantize_f0_numpy(f0_org_pad)[0]\n", 40 | "f0_org_onehot = f0_org_quantized[np.newaxis, :, :]\n", 41 | "f0_org_onehot = torch.from_numpy(f0_org_onehot).to(device)\n", 42 | "uttr_f0_org = torch.cat((uttr_org_pad, f0_org_onehot), dim=-1)\n", 43 | "\n", 44 | "sbmt_j = metadata[1]\n", 45 | "emb_trg = torch.from_numpy(sbmt_j[1]).to(device)\n", 46 | "x_trg, f0_trg, len_trg, uid_trg = sbmt_j[2] \n", 47 | "uttr_trg_pad, len_trg_pad = pad_seq_to_2(x_trg[np.newaxis,:,:], 192)\n", 48 | "uttr_trg_pad = torch.from_numpy(uttr_trg_pad).to(device)\n", 49 | "f0_trg_pad = np.pad(f0_trg, (0, 192-len_trg), 'constant', constant_values=(0, 0))\n", 50 | "f0_trg_quantized = quantize_f0_numpy(f0_trg_pad)[0]\n", 51 | "f0_trg_onehot = f0_trg_quantized[np.newaxis, :, :]\n", 52 | "f0_trg_onehot = torch.from_numpy(f0_trg_onehot).to(device)\n", 53 | "\n", 54 | "with torch.no_grad():\n", 55 | " f0_pred = P(uttr_org_pad, f0_trg_onehot)[0]\n", 56 | " f0_pred_quantized = f0_pred.argmax(dim=-1).squeeze(0)\n", 57 | " f0_con_onehot = torch.zeros((1, 192, 257), device=device)\n", 58 | " f0_con_onehot[0, torch.arange(192), f0_pred_quantized] = 1\n", 59 | "uttr_f0_trg = torch.cat((uttr_org_pad, f0_con_onehot), dim=-1) \n", 60 | "\n", 61 | "\n", 62 | "conditions = ['R', 'F', 'U', 'RF', 'RU', 'FU', 'RFU']\n", 63 | "spect_vc = []\n", 64 | "with torch.no_grad():\n", 65 | " for condition in conditions:\n", 66 | " if condition == 'R':\n", 67 | " x_identic_val = G(uttr_f0_org, uttr_trg_pad, emb_org)\n", 68 | " if condition == 'F':\n", 69 | " x_identic_val = G(uttr_f0_trg, uttr_org_pad, emb_org)\n", 70 | " if condition == 'U':\n", 71 | " x_identic_val = G(uttr_f0_org, uttr_org_pad, emb_trg)\n", 72 | " if condition == 'RF':\n", 73 | " x_identic_val = G(uttr_f0_trg, uttr_trg_pad, emb_org)\n", 74 | " if condition == 'RU':\n", 75 | " x_identic_val = G(uttr_f0_org, uttr_trg_pad, emb_trg)\n", 76 | " if condition == 'FU':\n", 77 | " x_identic_val = G(uttr_f0_trg, uttr_org_pad, emb_trg)\n", 78 | " if condition == 'RFU':\n", 79 | " x_identic_val = G(uttr_f0_trg, uttr_trg_pad, emb_trg)\n", 80 | " \n", 81 | " if 'R' in condition:\n", 82 | " uttr_trg = x_identic_val[0, :len_trg, :].cpu().numpy()\n", 83 | " else:\n", 84 | " uttr_trg = x_identic_val[0, :len_org, :].cpu().numpy()\n", 85 | " \n", 86 | " spect_vc.append( ('{}_{}_{}_{}'.format(sbmt_i[0], sbmt_j[0], uid_org, condition), uttr_trg ) ) " 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# spectrogram to waveform\n", 96 | "import torch\n", 97 | "import soundfile\n", 98 | "import pickle\n", 99 | "import os\n", 100 | "from synthesis import build_model\n", 101 | "from synthesis import wavegen\n", 102 | "\n", 103 | "if not os.path.exists('results'):\n", 104 | " os.makedirs('results')\n", 105 | "\n", 106 | "model = build_model().to(device)\n", 107 | "checkpoint = torch.load(\"assets/checkpoint_step001000000_ema.pth\", map_location=torch.device(device))\n", 108 | "model.load_state_dict(checkpoint[\"state_dict\"])\n", 109 | "\n", 110 | "for spect in spect_vc:\n", 111 | " name = spect[0]\n", 112 | " c = spect[1]\n", 113 | " print(name)\n", 114 | " waveform = wavegen(model, c=c) \n", 115 | " soundfile.write('results/'+name+'.wav', waveform, samplerate=16000)" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.9" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 4 140 | } 141 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from tfcompat.hparam import HParams 2 | 3 | # NOTE: If you want full control for model architecture. please take a look 4 | # at the code and change whatever you want. Some hyper parameters are hardcoded. 5 | 6 | # Default hyperparameters: 7 | hparams = HParams( 8 | # model 9 | freq = 8, 10 | dim_neck = 8, 11 | freq_2 = 8, 12 | dim_neck_2 = 1, 13 | freq_3 = 8, 14 | dim_neck_3 = 32, 15 | out_channels = 10 * 3, 16 | layers = 24, 17 | stacks = 4, 18 | residual_channels = 512, 19 | gate_channels = 512, # split into 2 groups internally for gated activation 20 | skip_out_channels = 256, 21 | cin_channels = 80, 22 | gin_channels = -1, # i.e., speaker embedding dim 23 | weight_normalization = True, 24 | n_speakers = -1, 25 | dropout = 1 - 0.95, 26 | kernel_size = 3, 27 | upsample_conditional_features = True, 28 | upsample_scales = [4, 4, 4, 4], 29 | freq_axis_kernel_size = 3, 30 | legacy = True, 31 | 32 | dim_enc = 512, 33 | dim_enc_2 = 128, 34 | dim_enc_3 = 256, 35 | 36 | dim_freq = 80, 37 | dim_spk_emb = 82, 38 | dim_f0 = 257, 39 | dim_dec = 512, 40 | len_raw = 128, 41 | chs_grp = 16, 42 | 43 | # interp 44 | min_len_seg = 19, 45 | max_len_seg = 32, 46 | min_len_seq = 64, 47 | max_len_seq = 128, 48 | max_len_pad = 192, 49 | 50 | # data loader 51 | root_dir = 'assets/spmel', 52 | feat_dir = 'assets/raptf0', 53 | batch_size = 16, 54 | mode = 'train', 55 | shuffle = True, 56 | num_workers = 0, 57 | samplier = 8, 58 | 59 | # Convenient model builder 60 | builder = "wavenet", 61 | 62 | hop_size = 256, 63 | log_scale_min = float(-32.23619130191664), 64 | ) 65 | 66 | 67 | def hparams_debug_string(): 68 | values = hparams.values() 69 | hp = [' %s: %s' % (name, values[name]) for name in values] 70 | return 'Hyperparameters:\n' + '\n'.join(hp) 71 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.backends import cudnn 5 | 6 | from solver import Solver 7 | from data_loader import get_loader 8 | from hparams import hparams, hparams_debug_string 9 | 10 | 11 | 12 | def str2bool(v): 13 | return v.lower() in ('true') 14 | 15 | def main(config): 16 | # For fast training. 17 | cudnn.benchmark = True 18 | 19 | # Create directories if not exist. 20 | if not os.path.exists(config.log_dir): 21 | os.makedirs(config.log_dir) 22 | if not os.path.exists(config.model_save_dir): 23 | os.makedirs(config.model_save_dir) 24 | if not os.path.exists(config.sample_dir): 25 | os.makedirs(config.sample_dir) 26 | 27 | # Data loader. 28 | vcc_loader = get_loader(hparams) 29 | 30 | # Solver for training 31 | solver = Solver(vcc_loader, config, hparams) 32 | 33 | solver.train() 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | 40 | # Training configuration. 41 | parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations') 42 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') 43 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer') 44 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 45 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') 46 | 47 | # Miscellaneous. 48 | parser.add_argument('--use_tensorboard', type=str2bool, default=False) 49 | parser.add_argument('--device_id', type=int, default=0) 50 | 51 | # Directories. 52 | parser.add_argument('--log_dir', type=str, default='run/logs') 53 | parser.add_argument('--model_save_dir', type=str, default='run/models') 54 | parser.add_argument('--sample_dir', type=str, default='run/samples') 55 | 56 | # Step size. 57 | parser.add_argument('--log_step', type=int, default=10) 58 | parser.add_argument('--sample_step', type=int, default=1000) 59 | parser.add_argument('--model_save_step', type=int, default=1000) 60 | 61 | config = parser.parse_args() 62 | print(config) 63 | print(hparams_debug_string()) 64 | main(config) -------------------------------------------------------------------------------- /make_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | rootDir = 'assets/spmel' 6 | dirName, subdirList, _ = next(os.walk(rootDir)) 7 | print('Found directory: %s' % dirName) 8 | 9 | 10 | speakers = [] 11 | for speaker in sorted(subdirList): 12 | print('Processing speaker: %s' % speaker) 13 | utterances = [] 14 | utterances.append(speaker) 15 | _, _, fileList = next(os.walk(os.path.join(dirName,speaker))) 16 | 17 | # use hardcoded onehot embeddings in order to be cosistent with the test speakers 18 | # modify as needed 19 | # may use generalized speaker embedding for zero-shot conversion 20 | spkid = np.zeros((82,), dtype=np.float32) 21 | if speaker == 'p226': 22 | spkid[1] = 1.0 23 | else: 24 | spkid[7] = 1.0 25 | utterances.append(spkid) 26 | 27 | # create file list 28 | for fileName in sorted(fileList): 29 | utterances.append(os.path.join(speaker,fileName)) 30 | speakers.append(utterances) 31 | 32 | with open(os.path.join(rootDir, 'train.pkl'), 'wb') as handle: 33 | pickle.dump(speakers, handle) -------------------------------------------------------------------------------- /make_spect_f0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import numpy as np 5 | import soundfile as sf 6 | from scipy import signal 7 | from librosa.filters import mel 8 | from numpy.random import RandomState 9 | from pysptk import sptk 10 | from utils import butter_highpass 11 | from utils import speaker_normalization 12 | from utils import pySTFT 13 | 14 | 15 | mel_basis = mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T 16 | min_level = np.exp(-100 / 20 * np.log(10)) 17 | b, a = butter_highpass(30, 16000, order=5) 18 | 19 | spk2gen = pickle.load(open('assets/spk2gen.pkl', "rb")) 20 | 21 | 22 | # Modify as needed 23 | rootDir = 'assets/wavs' 24 | targetDir_f0 = 'assets/raptf0' 25 | targetDir = 'assets/spmel' 26 | 27 | 28 | dirName, subdirList, _ = next(os.walk(rootDir)) 29 | print('Found directory: %s' % dirName) 30 | 31 | for subdir in sorted(subdirList): 32 | print(subdir) 33 | 34 | if not os.path.exists(os.path.join(targetDir, subdir)): 35 | os.makedirs(os.path.join(targetDir, subdir)) 36 | if not os.path.exists(os.path.join(targetDir_f0, subdir)): 37 | os.makedirs(os.path.join(targetDir_f0, subdir)) 38 | _,_, fileList = next(os.walk(os.path.join(dirName,subdir))) 39 | 40 | if spk2gen[subdir] == 'M': 41 | lo, hi = 50, 250 42 | elif spk2gen[subdir] == 'F': 43 | lo, hi = 100, 600 44 | else: 45 | raise ValueError 46 | 47 | prng = RandomState(int(subdir[1:])) 48 | for fileName in sorted(fileList): 49 | # read audio file 50 | x, fs = sf.read(os.path.join(dirName,subdir,fileName)) 51 | assert fs == 16000 52 | if x.shape[0] % 256 == 0: 53 | x = np.concatenate((x, np.array([1e-06])), axis=0) 54 | y = signal.filtfilt(b, a, x) 55 | wav = y * 0.96 + (prng.rand(y.shape[0])-0.5)*1e-06 56 | 57 | # compute spectrogram 58 | D = pySTFT(wav).T 59 | D_mel = np.dot(D, mel_basis) 60 | D_db = 20 * np.log10(np.maximum(min_level, D_mel)) - 16 61 | S = (D_db + 100) / 100 62 | 63 | # extract f0 64 | f0_rapt = sptk.rapt(wav.astype(np.float32)*32768, fs, 256, min=lo, max=hi, otype=2) 65 | index_nonzero = (f0_rapt != -1e10) 66 | mean_f0, std_f0 = np.mean(f0_rapt[index_nonzero]), np.std(f0_rapt[index_nonzero]) 67 | f0_norm = speaker_normalization(f0_rapt, index_nonzero, mean_f0, std_f0) 68 | 69 | assert len(S) == len(f0_rapt) 70 | 71 | np.save(os.path.join(targetDir, subdir, fileName[:-4]), 72 | S.astype(np.float32), allow_pickle=False) 73 | np.save(os.path.join(targetDir_f0, subdir, fileName[:-4]), 74 | f0_norm.astype(np.float32), allow_pickle=False) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from math import ceil 7 | from utils import get_mask_from_lengths 8 | 9 | 10 | class LinearNorm(torch.nn.Module): 11 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 12 | super(LinearNorm, self).__init__() 13 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 14 | 15 | torch.nn.init.xavier_uniform_( 16 | self.linear_layer.weight, 17 | gain=torch.nn.init.calculate_gain(w_init_gain)) 18 | 19 | def forward(self, x): 20 | return self.linear_layer(x) 21 | 22 | 23 | 24 | class ConvNorm(torch.nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 26 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 27 | super(ConvNorm, self).__init__() 28 | if padding is None: 29 | assert(kernel_size % 2 == 1) 30 | padding = int(dilation * (kernel_size - 1) / 2) 31 | 32 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 33 | kernel_size=kernel_size, stride=stride, 34 | padding=padding, dilation=dilation, 35 | bias=bias) 36 | 37 | torch.nn.init.xavier_uniform_( 38 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 39 | 40 | def forward(self, signal): 41 | conv_signal = self.conv(signal) 42 | return conv_signal 43 | 44 | 45 | 46 | class Encoder_t(nn.Module): 47 | """Rhythm Encoder 48 | """ 49 | def __init__(self, hparams): 50 | super().__init__() 51 | 52 | self.dim_neck_2 = hparams.dim_neck_2 53 | self.freq_2 = hparams.freq_2 54 | self.dim_freq = hparams.dim_freq 55 | self.dim_enc_2 = hparams.dim_enc_2 56 | self.dim_emb = hparams.dim_spk_emb 57 | self.chs_grp = hparams.chs_grp 58 | 59 | convolutions = [] 60 | for i in range(1): 61 | conv_layer = nn.Sequential( 62 | ConvNorm(self.dim_freq if i==0 else self.dim_enc_2, 63 | self.dim_enc_2, 64 | kernel_size=5, stride=1, 65 | padding=2, 66 | dilation=1, w_init_gain='relu'), 67 | nn.GroupNorm(self.dim_enc_2//self.chs_grp, self.dim_enc_2)) 68 | convolutions.append(conv_layer) 69 | self.convolutions = nn.ModuleList(convolutions) 70 | 71 | self.lstm = nn.LSTM(self.dim_enc_2, self.dim_neck_2, 1, batch_first=True, bidirectional=True) 72 | 73 | 74 | def forward(self, x, mask): 75 | 76 | for conv in self.convolutions: 77 | x = F.relu(conv(x)) 78 | x = x.transpose(1, 2) 79 | 80 | self.lstm.flatten_parameters() 81 | outputs, _ = self.lstm(x) 82 | if mask is not None: 83 | outputs = outputs * mask 84 | out_forward = outputs[:, :, :self.dim_neck_2] 85 | out_backward = outputs[:, :, self.dim_neck_2:] 86 | 87 | codes = torch.cat((out_forward[:,self.freq_2-1::self.freq_2,:], out_backward[:,::self.freq_2,:]), dim=-1) 88 | 89 | return codes 90 | 91 | 92 | 93 | class Encoder_6(nn.Module): 94 | """F0 encoder 95 | """ 96 | def __init__(self, hparams): 97 | super().__init__() 98 | 99 | self.dim_neck_3 = hparams.dim_neck_3 100 | self.freq_3 = hparams.freq_3 101 | self.dim_f0 = hparams.dim_f0 102 | self.dim_enc_3 = hparams.dim_enc_3 103 | self.dim_emb = hparams.dim_spk_emb 104 | self.chs_grp = hparams.chs_grp 105 | self.register_buffer('len_org', torch.tensor(hparams.max_len_pad)) 106 | 107 | convolutions = [] 108 | for i in range(3): 109 | conv_layer = nn.Sequential( 110 | ConvNorm(self.dim_f0 if i==0 else self.dim_enc_3, 111 | self.dim_enc_3, 112 | kernel_size=5, stride=1, 113 | padding=2, 114 | dilation=1, w_init_gain='relu'), 115 | nn.GroupNorm(self.dim_enc_3//self.chs_grp, self.dim_enc_3)) 116 | convolutions.append(conv_layer) 117 | self.convolutions = nn.ModuleList(convolutions) 118 | 119 | self.lstm = nn.LSTM(self.dim_enc_3, self.dim_neck_3, 1, batch_first=True, bidirectional=True) 120 | 121 | self.interp = InterpLnr(hparams) 122 | 123 | def forward(self, x): 124 | 125 | for conv in self.convolutions: 126 | x = F.relu(conv(x)) 127 | x = x.transpose(1, 2) 128 | x = self.interp(x, self.len_org.expand(x.size(0))) 129 | x = x.transpose(1, 2) 130 | x = x.transpose(1, 2) 131 | 132 | self.lstm.flatten_parameters() 133 | outputs, _ = self.lstm(x) 134 | out_forward = outputs[:, :, :self.dim_neck_3] 135 | out_backward = outputs[:, :, self.dim_neck_3:] 136 | 137 | codes = torch.cat((out_forward[:,self.freq_3-1::self.freq_3,:], 138 | out_backward[:,::self.freq_3,:]), dim=-1) 139 | 140 | return codes 141 | 142 | 143 | 144 | class Encoder_7(nn.Module): 145 | """Sync Encoder module 146 | """ 147 | def __init__(self, hparams): 148 | super().__init__() 149 | 150 | self.dim_neck = hparams.dim_neck 151 | self.freq = hparams.freq 152 | self.freq_3 = hparams.freq_3 153 | self.dim_enc = hparams.dim_enc 154 | self.dim_enc_3 = hparams.dim_enc_3 155 | self.dim_freq = hparams.dim_freq 156 | self.chs_grp = hparams.chs_grp 157 | self.register_buffer('len_org', torch.tensor(hparams.max_len_pad)) 158 | self.dim_neck_3 = hparams.dim_neck_3 159 | self.dim_f0 = hparams.dim_f0 160 | 161 | # convolutions for code 1 162 | convolutions = [] 163 | for i in range(3): 164 | conv_layer = nn.Sequential( 165 | ConvNorm(self.dim_freq if i==0 else self.dim_enc, 166 | self.dim_enc, 167 | kernel_size=5, stride=1, 168 | padding=2, 169 | dilation=1, w_init_gain='relu'), 170 | nn.GroupNorm(self.dim_enc//self.chs_grp, self.dim_enc)) 171 | convolutions.append(conv_layer) 172 | self.convolutions_1 = nn.ModuleList(convolutions) 173 | 174 | self.lstm_1 = nn.LSTM(self.dim_enc, self.dim_neck, 2, batch_first=True, bidirectional=True) 175 | 176 | # convolutions for f0 177 | convolutions = [] 178 | for i in range(3): 179 | conv_layer = nn.Sequential( 180 | ConvNorm(self.dim_f0 if i==0 else self.dim_enc_3, 181 | self.dim_enc_3, 182 | kernel_size=5, stride=1, 183 | padding=2, 184 | dilation=1, w_init_gain='relu'), 185 | nn.GroupNorm(self.dim_enc_3//self.chs_grp, self.dim_enc_3)) 186 | convolutions.append(conv_layer) 187 | self.convolutions_2 = nn.ModuleList(convolutions) 188 | 189 | self.lstm_2 = nn.LSTM(self.dim_enc_3, self.dim_neck_3, 1, batch_first=True, bidirectional=True) 190 | 191 | self.interp = InterpLnr(hparams) 192 | 193 | 194 | def forward(self, x_f0): 195 | 196 | x = x_f0[:, :self.dim_freq, :] 197 | f0 = x_f0[:, self.dim_freq:, :] 198 | 199 | for conv_1, conv_2 in zip(self.convolutions_1, self.convolutions_2): 200 | x = F.relu(conv_1(x)) 201 | f0 = F.relu(conv_2(f0)) 202 | x_f0 = torch.cat((x, f0), dim=1).transpose(1, 2) 203 | x_f0 = self.interp(x_f0, self.len_org.expand(x.size(0))) 204 | x_f0 = x_f0.transpose(1, 2) 205 | x = x_f0[:, :self.dim_enc, :] 206 | f0 = x_f0[:, self.dim_enc:, :] 207 | 208 | 209 | x_f0 = x_f0.transpose(1, 2) 210 | x = x_f0[:, :, :self.dim_enc] 211 | f0 = x_f0[:, :, self.dim_enc:] 212 | 213 | # code 1 214 | x = self.lstm_1(x)[0] 215 | f0 = self.lstm_2(f0)[0] 216 | 217 | x_forward = x[:, :, :self.dim_neck] 218 | x_backward = x[:, :, self.dim_neck:] 219 | 220 | f0_forward = f0[:, :, :self.dim_neck_3] 221 | f0_backward = f0[:, :, self.dim_neck_3:] 222 | 223 | codes_x = torch.cat((x_forward[:,self.freq-1::self.freq,:], 224 | x_backward[:,::self.freq,:]), dim=-1) 225 | 226 | codes_f0 = torch.cat((f0_forward[:,self.freq_3-1::self.freq_3,:], 227 | f0_backward[:,::self.freq_3,:]), dim=-1) 228 | 229 | return codes_x, codes_f0 230 | 231 | 232 | 233 | class Decoder_3(nn.Module): 234 | """Decoder module 235 | """ 236 | def __init__(self, hparams): 237 | super().__init__() 238 | self.dim_neck = hparams.dim_neck 239 | self.dim_neck_2 = hparams.dim_neck_2 240 | self.dim_emb = hparams.dim_spk_emb 241 | self.dim_freq = hparams.dim_freq 242 | self.dim_neck_3 = hparams.dim_neck_3 243 | 244 | self.lstm = nn.LSTM(self.dim_neck*2+self.dim_neck_2*2+self.dim_neck_3*2+self.dim_emb, 245 | 512, 3, batch_first=True, bidirectional=True) 246 | 247 | self.linear_projection = LinearNorm(1024, self.dim_freq) 248 | 249 | def forward(self, x): 250 | 251 | outputs, _ = self.lstm(x) 252 | 253 | decoder_output = self.linear_projection(outputs) 254 | 255 | return decoder_output 256 | 257 | 258 | 259 | class Decoder_4(nn.Module): 260 | """For F0 converter 261 | """ 262 | def __init__(self, hparams): 263 | super().__init__() 264 | self.dim_neck_2 = hparams.dim_neck_2 265 | self.dim_f0 = hparams.dim_f0 266 | self.dim_neck_3 = hparams.dim_neck_3 267 | 268 | self.lstm = nn.LSTM(self.dim_neck_2*2+self.dim_neck_3*2, 269 | 256, 2, batch_first=True, bidirectional=True) 270 | 271 | self.linear_projection = LinearNorm(512, self.dim_f0) 272 | 273 | def forward(self, x): 274 | 275 | outputs, _ = self.lstm(x) 276 | 277 | decoder_output = self.linear_projection(outputs) 278 | 279 | return decoder_output 280 | 281 | 282 | 283 | class Generator_3(nn.Module): 284 | """SpeechSplit model""" 285 | def __init__(self, hparams): 286 | super().__init__() 287 | 288 | self.encoder_1 = Encoder_7(hparams) 289 | self.encoder_2 = Encoder_t(hparams) 290 | self.decoder = Decoder_3(hparams) 291 | 292 | self.freq = hparams.freq 293 | self.freq_2 = hparams.freq_2 294 | self.freq_3 = hparams.freq_3 295 | 296 | 297 | def forward(self, x_f0, x_org, c_trg): 298 | 299 | x_1 = x_f0.transpose(2,1) 300 | codes_x, codes_f0 = self.encoder_1(x_1) 301 | code_exp_1 = codes_x.repeat_interleave(self.freq, dim=1) 302 | code_exp_3 = codes_f0.repeat_interleave(self.freq_3, dim=1) 303 | 304 | x_2 = x_org.transpose(2,1) 305 | codes_2 = self.encoder_2(x_2, None) 306 | code_exp_2 = codes_2.repeat_interleave(self.freq_2, dim=1) 307 | 308 | encoder_outputs = torch.cat((code_exp_1, code_exp_2, code_exp_3, 309 | c_trg.unsqueeze(1).expand(-1,x_1.size(-1),-1)), dim=-1) 310 | 311 | mel_outputs = self.decoder(encoder_outputs) 312 | 313 | return mel_outputs 314 | 315 | 316 | def rhythm(self, x_org): 317 | x_2 = x_org.transpose(2,1) 318 | codes_2 = self.encoder_2(x_2, None) 319 | 320 | return codes_2 321 | 322 | 323 | 324 | class Generator_6(nn.Module): 325 | """F0 converter 326 | """ 327 | def __init__(self, hparams): 328 | super().__init__() 329 | 330 | self.encoder_2 = Encoder_t(hparams) 331 | self.encoder_3 = Encoder_6(hparams) 332 | self.decoder = Decoder_4(hparams) 333 | self.freq_2 = hparams.freq_2 334 | self.freq_3 = hparams.freq_3 335 | 336 | 337 | def forward(self, x_org, f0_trg): 338 | 339 | x_2 = x_org.transpose(2,1) 340 | codes_2 = self.encoder_2(x_2, None) 341 | code_exp_2 = codes_2.repeat_interleave(self.freq_2, dim=1) 342 | 343 | x_3 = f0_trg.transpose(2,1) 344 | codes_3 = self.encoder_3(x_3) 345 | code_exp_3 = codes_3.repeat_interleave(self.freq_3, dim=1) 346 | 347 | encoder_outputs = torch.cat((code_exp_2, code_exp_3), dim=-1) 348 | 349 | mel_outputs = self.decoder(encoder_outputs) 350 | 351 | return mel_outputs 352 | 353 | 354 | 355 | class InterpLnr(nn.Module): 356 | 357 | def __init__(self, hparams): 358 | super().__init__() 359 | self.max_len_seq = hparams.max_len_seq 360 | self.max_len_pad = hparams.max_len_pad 361 | 362 | self.min_len_seg = hparams.min_len_seg 363 | self.max_len_seg = hparams.max_len_seg 364 | 365 | self.max_num_seg = self.max_len_seq // self.min_len_seg + 1 366 | 367 | 368 | def pad_sequences(self, sequences): 369 | channel_dim = sequences[0].size()[-1] 370 | out_dims = (len(sequences), self.max_len_pad, channel_dim) 371 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 372 | 373 | for i, tensor in enumerate(sequences): 374 | length = tensor.size(0) 375 | out_tensor[i, :length, :] = tensor[:self.max_len_pad] 376 | 377 | return out_tensor 378 | 379 | 380 | def forward(self, x, len_seq): 381 | 382 | if not self.training: 383 | return x 384 | 385 | device = x.device 386 | batch_size = x.size(0) 387 | 388 | # indices of each sub segment 389 | indices = torch.arange(self.max_len_seg*2, device=device)\ 390 | .unsqueeze(0).expand(batch_size*self.max_num_seg, -1) 391 | # scales of each sub segment 392 | scales = torch.rand(batch_size*self.max_num_seg, 393 | device=device) + 0.5 394 | 395 | idx_scaled = indices / scales.unsqueeze(-1) 396 | idx_scaled_fl = torch.floor(idx_scaled) 397 | lambda_ = idx_scaled - idx_scaled_fl 398 | 399 | len_seg = torch.randint(low=self.min_len_seg, 400 | high=self.max_len_seg, 401 | size=(batch_size*self.max_num_seg,1), 402 | device=device) 403 | 404 | # end point of each segment 405 | idx_mask = idx_scaled_fl < (len_seg - 1) 406 | 407 | offset = len_seg.view(batch_size, -1).cumsum(dim=-1) 408 | # offset starts from the 2nd segment 409 | offset = F.pad(offset[:, :-1], (1,0), value=0).view(-1, 1) 410 | 411 | idx_scaled_org = idx_scaled_fl + offset 412 | 413 | len_seq_rp = torch.repeat_interleave(len_seq, self.max_num_seg) 414 | idx_mask_org = idx_scaled_org < (len_seq_rp - 1).unsqueeze(-1) 415 | 416 | idx_mask_final = idx_mask & idx_mask_org 417 | 418 | counts = idx_mask_final.sum(dim=-1).view(batch_size, -1).sum(dim=-1) 419 | 420 | index_1 = torch.repeat_interleave(torch.arange(batch_size, 421 | device=device), counts) 422 | 423 | index_2_fl = idx_scaled_org[idx_mask_final].long() 424 | index_2_cl = index_2_fl + 1 425 | 426 | y_fl = x[index_1, index_2_fl, :] 427 | y_cl = x[index_1, index_2_cl, :] 428 | lambda_f = lambda_[idx_mask_final].unsqueeze(-1) 429 | 430 | y = (1-lambda_f)*y_fl + lambda_f*y_cl 431 | 432 | sequences = torch.split(y, counts.tolist(), dim=0) 433 | 434 | seq_padded = self.pad_sequences(sequences) 435 | 436 | return seq_padded 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | from model import Generator_3 as Generator 2 | from model import InterpLnr 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import os 8 | import time 9 | import datetime 10 | import pickle 11 | 12 | from utils import pad_seq_to_2, quantize_f0_torch, quantize_f0_numpy 13 | 14 | # use demo data for simplicity 15 | # make your own validation set as needed 16 | validation_pt = pickle.load(open('assets/demo.pkl', "rb")) 17 | 18 | class Solver(object): 19 | """Solver for training""" 20 | 21 | def __init__(self, vcc_loader, config, hparams): 22 | """Initialize configurations.""" 23 | 24 | # Data loader. 25 | self.vcc_loader = vcc_loader 26 | self.hparams = hparams 27 | 28 | # Training configurations. 29 | self.num_iters = config.num_iters 30 | self.g_lr = config.g_lr 31 | self.beta1 = config.beta1 32 | self.beta2 = config.beta2 33 | self.resume_iters = config.resume_iters 34 | 35 | # Miscellaneous. 36 | self.use_tensorboard = config.use_tensorboard 37 | self.use_cuda = torch.cuda.is_available() 38 | self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu') 39 | 40 | # Directories. 41 | self.log_dir = config.log_dir 42 | self.sample_dir = config.sample_dir 43 | self.model_save_dir = config.model_save_dir 44 | 45 | # Step size. 46 | self.log_step = config.log_step 47 | self.sample_step = config.sample_step 48 | self.model_save_step = config.model_save_step 49 | 50 | 51 | # Build the model and tensorboard. 52 | self.build_model() 53 | if self.use_tensorboard: 54 | self.build_tensorboard() 55 | 56 | 57 | def build_model(self): 58 | self.G = Generator(self.hparams) 59 | 60 | self.Interp = InterpLnr(self.hparams) 61 | 62 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) 63 | self.print_network(self.G, 'G') 64 | 65 | self.G.to(self.device) 66 | self.Interp.to(self.device) 67 | 68 | 69 | def print_network(self, model, name): 70 | """Print out the network information.""" 71 | num_params = 0 72 | for p in model.parameters(): 73 | num_params += p.numel() 74 | print(model) 75 | print(name) 76 | print("The number of parameters: {}".format(num_params)) 77 | 78 | 79 | def print_optimizer(self, opt, name): 80 | print(opt) 81 | print(name) 82 | 83 | 84 | def restore_model(self, resume_iters): 85 | print('Loading the trained models from step {}...'.format(resume_iters)) 86 | G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) 87 | g_checkpoint = torch.load(G_path, map_location=lambda storage, loc: storage) 88 | self.G.load_state_dict(g_checkpoint['model']) 89 | self.g_optimizer.load_state_dict(g_checkpoint['optimizer']) 90 | self.g_lr = self.g_optimizer.param_groups[0]['lr'] 91 | 92 | 93 | def build_tensorboard(self): 94 | """Build a tensorboard logger.""" 95 | from torch.utils.tensorboard import SummaryWriter 96 | self.writer = SummaryWriter(self.log_dir) 97 | 98 | 99 | def reset_grad(self): 100 | """Reset the gradient buffers.""" 101 | self.g_optimizer.zero_grad() 102 | 103 | 104 | #===================================================================================================================== 105 | 106 | 107 | 108 | def train(self): 109 | # Set data loader. 110 | data_loader = self.vcc_loader 111 | 112 | # Fetch fixed inputs for debugging. 113 | data_iter = iter(data_loader) 114 | 115 | # Start training from scratch or resume training. 116 | start_iters = 0 117 | if self.resume_iters: 118 | print('Resuming ...') 119 | start_iters = self.resume_iters 120 | self.num_iters += self.resume_iters 121 | self.restore_model(self.resume_iters) 122 | self.print_optimizer(self.g_optimizer, 'G_optimizer') 123 | 124 | # Learning rate cache for decaying. 125 | g_lr = self.g_lr 126 | print ('Current learning rates, g_lr: {}.'.format(g_lr)) 127 | 128 | # Print logs in specified order 129 | keys = ['G/loss_id'] 130 | 131 | # Start training. 132 | print('Start training...') 133 | start_time = time.time() 134 | for i in range(start_iters, self.num_iters): 135 | 136 | # =================================================================================== # 137 | # 1. Preprocess input data # 138 | # =================================================================================== # 139 | 140 | # Fetch real images and labels. 141 | try: 142 | x_real_org, emb_org, f0_org, len_org = next(data_iter) 143 | except: 144 | data_iter = iter(data_loader) 145 | x_real_org, emb_org, f0_org, len_org = next(data_iter) 146 | 147 | x_real_org = x_real_org.to(self.device) 148 | emb_org = emb_org.to(self.device) 149 | len_org = len_org.to(self.device) 150 | f0_org = f0_org.to(self.device) 151 | 152 | 153 | # =================================================================================== # 154 | # 2. Train the generator # 155 | # =================================================================================== # 156 | 157 | self.G = self.G.train() 158 | 159 | # Identity mapping loss 160 | x_f0 = torch.cat((x_real_org, f0_org), dim=-1) 161 | x_f0_intrp = self.Interp(x_f0, len_org) 162 | f0_org_intrp = quantize_f0_torch(x_f0_intrp[:,:,-1])[0] 163 | x_f0_intrp_org = torch.cat((x_f0_intrp[:,:,:-1], f0_org_intrp), dim=-1) 164 | 165 | x_identic = self.G(x_f0_intrp_org, x_real_org, emb_org) 166 | g_loss_id = F.mse_loss(x_real_org, x_identic, reduction='mean') 167 | 168 | # Backward and optimize. 169 | g_loss = g_loss_id 170 | self.reset_grad() 171 | g_loss.backward() 172 | self.g_optimizer.step() 173 | 174 | # Logging. 175 | loss = {} 176 | loss['G/loss_id'] = g_loss_id.item() 177 | 178 | 179 | # =================================================================================== # 180 | # 4. Miscellaneous # 181 | # =================================================================================== # 182 | 183 | # Print out training information. 184 | if (i+1) % self.log_step == 0: 185 | et = time.time() - start_time 186 | et = str(datetime.timedelta(seconds=et))[:-7] 187 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) 188 | for tag in keys: 189 | log += ", {}: {:.8f}".format(tag, loss[tag]) 190 | print(log) 191 | 192 | if self.use_tensorboard: 193 | for tag, value in loss.items(): 194 | self.writer.add_scalar(tag, value, i+1) 195 | 196 | 197 | # Save model checkpoints. 198 | if (i+1) % self.model_save_step == 0: 199 | G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) 200 | torch.save({'model': self.G.state_dict(), 201 | 'optimizer': self.g_optimizer.state_dict()}, G_path) 202 | print('Saved model checkpoints into {}...'.format(self.model_save_dir)) 203 | 204 | 205 | # Validation. 206 | if (i+1) % self.sample_step == 0: 207 | self.G = self.G.eval() 208 | with torch.no_grad(): 209 | loss_val = [] 210 | for val_sub in validation_pt: 211 | emb_org_val = torch.from_numpy(val_sub[1]).to(self.device) 212 | for k in range(2, 3): 213 | x_real_pad, _ = pad_seq_to_2(val_sub[k][0][np.newaxis,:,:], 192) 214 | len_org = torch.tensor([val_sub[k][2]]).to(self.device) 215 | f0_org = np.pad(val_sub[k][1], (0, 192-val_sub[k][2]), 'constant', constant_values=(0, 0)) 216 | f0_quantized = quantize_f0_numpy(f0_org)[0] 217 | f0_onehot = f0_quantized[np.newaxis, :, :] 218 | f0_org_val = torch.from_numpy(f0_onehot).to(self.device) 219 | x_real_pad = torch.from_numpy(x_real_pad).to(self.device) 220 | x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1) 221 | x_identic_val = self.G(x_f0, x_real_pad, emb_org_val) 222 | g_loss_val = F.mse_loss(x_real_pad, x_identic_val, reduction='sum') 223 | loss_val.append(g_loss_val.item()) 224 | val_loss = np.mean(loss_val) 225 | print('Validation loss: {}'.format(val_loss)) 226 | if self.use_tensorboard: 227 | self.writer.add_scalar('Validation_loss', val_loss, i+1) 228 | 229 | 230 | # plot test samples 231 | if (i+1) % self.sample_step == 0: 232 | self.G = self.G.eval() 233 | with torch.no_grad(): 234 | for val_sub in validation_pt: 235 | emb_org_val = torch.from_numpy(val_sub[1]).to(self.device) 236 | for k in range(2, 3): 237 | x_real_pad, _ = pad_seq_to_2(val_sub[k][0][np.newaxis,:,:], 192) 238 | len_org = torch.tensor([val_sub[k][2]]).to(self.device) 239 | f0_org = np.pad(val_sub[k][1], (0, 192-val_sub[k][2]), 'constant', constant_values=(0, 0)) 240 | f0_quantized = quantize_f0_numpy(f0_org)[0] 241 | f0_onehot = f0_quantized[np.newaxis, :, :] 242 | f0_org_val = torch.from_numpy(f0_onehot).to(self.device) 243 | x_real_pad = torch.from_numpy(x_real_pad).to(self.device) 244 | x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1) 245 | x_f0_F = torch.cat((x_real_pad, torch.zeros_like(f0_org_val)), dim=-1) 246 | x_f0_C = torch.cat((torch.zeros_like(x_real_pad), f0_org_val), dim=-1) 247 | 248 | x_identic_val = self.G(x_f0, x_real_pad, emb_org_val) 249 | x_identic_woF = self.G(x_f0_F, x_real_pad, emb_org_val) 250 | x_identic_woR = self.G(x_f0, torch.zeros_like(x_real_pad), emb_org_val) 251 | x_identic_woC = self.G(x_f0_C, x_real_pad, emb_org_val) 252 | 253 | melsp_gd_pad = x_real_pad[0].cpu().numpy().T 254 | melsp_out = x_identic_val[0].cpu().numpy().T 255 | melsp_woF = x_identic_woF[0].cpu().numpy().T 256 | melsp_woR = x_identic_woR[0].cpu().numpy().T 257 | melsp_woC = x_identic_woC[0].cpu().numpy().T 258 | 259 | min_value = np.min(np.hstack([melsp_gd_pad, melsp_out, melsp_woF, melsp_woR, melsp_woC])) 260 | max_value = np.max(np.hstack([melsp_gd_pad, melsp_out, melsp_woF, melsp_woR, melsp_woC])) 261 | 262 | fig, (ax1,ax2,ax3,ax4,ax5) = plt.subplots(5, 1, sharex=True) 263 | im1 = ax1.imshow(melsp_gd_pad, aspect='auto', vmin=min_value, vmax=max_value) 264 | im2 = ax2.imshow(melsp_out, aspect='auto', vmin=min_value, vmax=max_value) 265 | im3 = ax3.imshow(melsp_woC, aspect='auto', vmin=min_value, vmax=max_value) 266 | im4 = ax4.imshow(melsp_woR, aspect='auto', vmin=min_value, vmax=max_value) 267 | im5 = ax5.imshow(melsp_woF, aspect='auto', vmin=min_value, vmax=max_value) 268 | plt.savefig(f'{self.sample_dir}/{i+1}_{val_sub[0]}_{k}.png', dpi=150) 269 | plt.close(fig) -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Synthesis waveform from trained WaveNet. 4 | 5 | Modified from https://github.com/r9y9/wavenet_vocoder 6 | """ 7 | 8 | import torch 9 | from tqdm import tqdm 10 | import librosa 11 | from hparams import hparams 12 | from wavenet_vocoder import builder 13 | 14 | torch.set_num_threads(4) 15 | use_cuda = torch.cuda.is_available() 16 | device = torch.device("cuda" if use_cuda else "cpu") 17 | 18 | 19 | def build_model(): 20 | 21 | model = getattr(builder, hparams.builder)( 22 | out_channels=hparams.out_channels, 23 | layers=hparams.layers, 24 | stacks=hparams.stacks, 25 | residual_channels=hparams.residual_channels, 26 | gate_channels=hparams.gate_channels, 27 | skip_out_channels=hparams.skip_out_channels, 28 | cin_channels=hparams.cin_channels, 29 | gin_channels=hparams.gin_channels, 30 | weight_normalization=hparams.weight_normalization, 31 | n_speakers=hparams.n_speakers, 32 | dropout=hparams.dropout, 33 | kernel_size=hparams.kernel_size, 34 | upsample_conditional_features=hparams.upsample_conditional_features, 35 | upsample_scales=hparams.upsample_scales, 36 | freq_axis_kernel_size=hparams.freq_axis_kernel_size, 37 | scalar_input=True, 38 | legacy=hparams.legacy, 39 | ) 40 | return model 41 | 42 | 43 | 44 | def wavegen(model, c=None, tqdm=tqdm): 45 | """Generate waveform samples by WaveNet. 46 | 47 | """ 48 | 49 | model.eval() 50 | model.make_generation_fast_() 51 | 52 | Tc = c.shape[0] 53 | upsample_factor = hparams.hop_size 54 | # Overwrite length according to feature size 55 | length = Tc * upsample_factor 56 | 57 | # B x C x T 58 | c = torch.FloatTensor(c.T).unsqueeze(0) 59 | 60 | initial_input = torch.zeros(1, 1, 1).fill_(0.0) 61 | 62 | # Transform data to GPU 63 | initial_input = initial_input.to(device) 64 | c = None if c is None else c.to(device) 65 | 66 | with torch.no_grad(): 67 | y_hat = model.incremental_forward( 68 | initial_input, c=c, g=None, T=length, tqdm=tqdm, softmax=True, quantize=True, 69 | log_scale_min=hparams.log_scale_min) 70 | 71 | y_hat = y_hat.view(-1).cpu().data.numpy() 72 | 73 | return y_hat -------------------------------------------------------------------------------- /tfcompat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/auspicious3000/SpeechSplit/c39330d84cdb7f2fd452058d46a17a4c578a0359/tfcompat/__init__.py -------------------------------------------------------------------------------- /tfcompat/hparam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hyperparameter values.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import json 21 | import numbers 22 | import re 23 | 24 | import six 25 | 26 | ## from tensorflow.contrib.training.python.training import hparam_pb2 27 | ## from tensorflow.python.framework import ops 28 | ## from tensorflow.python.util import compat 29 | ## from tensorflow.python.util import deprecation 30 | 31 | # Define the regular expression for parsing a single clause of the input 32 | # (delimited by commas). A legal clause looks like: 33 | # []? = 34 | # where is either a single token or [] enclosed list of tokens. 35 | # For example: "var[1] = a" or "x = [1,2,3]" 36 | PARAM_RE = re.compile(r""" 37 | (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" 38 | (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None 39 | \s*=\s* 40 | ((?P[^,\[]*) # single value: "a" or None 41 | | 42 | \[(?P[^\]]*)\]) # list of values: None or "1,2,3" 43 | ($|,\s*)""", re.VERBOSE) 44 | 45 | 46 | def _parse_fail(name, var_type, value, values): 47 | """Helper function for raising a value error for bad assignment.""" 48 | raise ValueError( 49 | 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % 50 | (name, var_type.__name__, value, values)) 51 | 52 | 53 | def _reuse_fail(name, values): 54 | """Helper function for raising a value error for reuse of name.""" 55 | raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, 56 | values)) 57 | 58 | 59 | def _process_scalar_value(name, parse_fn, var_type, m_dict, values, 60 | results_dictionary): 61 | """Update results_dictionary with a scalar value. 62 | 63 | Used to update the results_dictionary to be returned by parse_values when 64 | encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) 65 | 66 | Mutates results_dictionary. 67 | 68 | Args: 69 | name: Name of variable in assignment ("s" or "arr"). 70 | parse_fn: Function for parsing the actual value. 71 | var_type: Type of named variable. 72 | m_dict: Dictionary constructed from regex parsing. 73 | m_dict['val']: RHS value (scalar) 74 | m_dict['index']: List index value (or None) 75 | values: Full expression being parsed 76 | results_dictionary: The dictionary being updated for return by the parsing 77 | function. 78 | 79 | Raises: 80 | ValueError: If the name has already been used. 81 | """ 82 | try: 83 | parsed_value = parse_fn(m_dict['val']) 84 | except ValueError: 85 | _parse_fail(name, var_type, m_dict['val'], values) 86 | 87 | # If no index is provided 88 | if not m_dict['index']: 89 | if name in results_dictionary: 90 | _reuse_fail(name, values) 91 | results_dictionary[name] = parsed_value 92 | else: 93 | if name in results_dictionary: 94 | # The name has already been used as a scalar, then it 95 | # will be in this dictionary and map to a non-dictionary. 96 | if not isinstance(results_dictionary.get(name), dict): 97 | _reuse_fail(name, values) 98 | else: 99 | results_dictionary[name] = {} 100 | 101 | index = int(m_dict['index']) 102 | # Make sure the index position hasn't already been assigned a value. 103 | if index in results_dictionary[name]: 104 | _reuse_fail('{}[{}]'.format(name, index), values) 105 | results_dictionary[name][index] = parsed_value 106 | 107 | 108 | def _process_list_value(name, parse_fn, var_type, m_dict, values, 109 | results_dictionary): 110 | """Update results_dictionary from a list of values. 111 | 112 | Used to update results_dictionary to be returned by parse_values when 113 | encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) 114 | 115 | Mutates results_dictionary. 116 | 117 | Args: 118 | name: Name of variable in assignment ("arr"). 119 | parse_fn: Function for parsing individual values. 120 | var_type: Type of named variable. 121 | m_dict: Dictionary constructed from regex parsing. 122 | m_dict['val']: RHS value (scalar) 123 | values: Full expression being parsed 124 | results_dictionary: The dictionary being updated for return by the parsing 125 | function. 126 | 127 | Raises: 128 | ValueError: If the name has an index or the values cannot be parsed. 129 | """ 130 | if m_dict['index'] is not None: 131 | raise ValueError('Assignment of a list to a list index.') 132 | elements = filter(None, re.split('[ ,]', m_dict['vals'])) 133 | # Make sure the name hasn't already been assigned a value 134 | if name in results_dictionary: 135 | raise _reuse_fail(name, values) 136 | try: 137 | results_dictionary[name] = [parse_fn(e) for e in elements] 138 | except ValueError: 139 | _parse_fail(name, var_type, m_dict['vals'], values) 140 | 141 | 142 | def _cast_to_type_if_compatible(name, param_type, value): 143 | """Cast hparam to the provided type, if compatible. 144 | 145 | Args: 146 | name: Name of the hparam to be cast. 147 | param_type: The type of the hparam. 148 | value: The value to be cast, if compatible. 149 | 150 | Returns: 151 | The result of casting `value` to `param_type`. 152 | 153 | Raises: 154 | ValueError: If the type of `value` is not compatible with param_type. 155 | * If `param_type` is a string type, but `value` is not. 156 | * If `param_type` is a boolean, but `value` is not, or vice versa. 157 | * If `param_type` is an integer type, but `value` is not. 158 | * If `param_type` is a float type, but `value` is not a numeric type. 159 | """ 160 | fail_msg = ( 161 | "Could not cast hparam '%s' of type '%s' from value %r" % 162 | (name, param_type, value)) 163 | 164 | # Some callers use None, for which we can't do any casting/checking. :( 165 | if issubclass(param_type, type(None)): 166 | return value 167 | 168 | # Avoid converting a non-string type to a string. 169 | if (issubclass(param_type, (six.string_types, six.binary_type)) and 170 | not isinstance(value, (six.string_types, six.binary_type))): 171 | raise ValueError(fail_msg) 172 | 173 | # Avoid converting a number or string type to a boolean or vice versa. 174 | if issubclass(param_type, bool) != isinstance(value, bool): 175 | raise ValueError(fail_msg) 176 | 177 | # Avoid converting float to an integer (the reverse is fine). 178 | if (issubclass(param_type, numbers.Integral) and 179 | not isinstance(value, numbers.Integral)): 180 | raise ValueError(fail_msg) 181 | 182 | # Avoid converting a non-numeric type to a numeric type. 183 | if (issubclass(param_type, numbers.Number) and 184 | not isinstance(value, numbers.Number)): 185 | raise ValueError(fail_msg) 186 | 187 | return param_type(value) 188 | 189 | 190 | def parse_values(values, type_map): 191 | """Parses hyperparameter values from a string into a python map. 192 | 193 | `values` is a string containing comma-separated `name=value` pairs. 194 | For each pair, the value of the hyperparameter named `name` is set to 195 | `value`. 196 | 197 | If a hyperparameter name appears multiple times in `values`, a ValueError 198 | is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). 199 | 200 | If a hyperparameter name in both an index assignment and scalar assignment, 201 | a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). 202 | 203 | The hyperparameter name may contain '.' symbols, which will result in an 204 | attribute name that is only accessible through the getattr and setattr 205 | functions. (And must be first explicit added through add_hparam.) 206 | 207 | WARNING: Use of '.' in your variable names is allowed, but is not well 208 | supported and not recommended. 209 | 210 | The `value` in `name=value` must follows the syntax according to the 211 | type of the parameter: 212 | 213 | * Scalar integer: A Python-parsable integer point value. E.g.: 1, 214 | 100, -12. 215 | * Scalar float: A Python-parsable floating point value. E.g.: 1.0, 216 | -.54e89. 217 | * Boolean: Either true or false. 218 | * Scalar string: A non-empty sequence of characters, excluding comma, 219 | spaces, and square brackets. E.g.: foo, bar_1. 220 | * List: A comma separated list of scalar values of the parameter type 221 | enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. 222 | 223 | When index assignment is used, the corresponding type_map key should be the 224 | list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not 225 | "arr[1]"). 226 | 227 | Args: 228 | values: String. Comma separated list of `name=value` pairs where 229 | 'value' must follow the syntax described above. 230 | type_map: A dictionary mapping hyperparameter names to types. Note every 231 | parameter name in values must be a key in type_map. The values must 232 | conform to the types indicated, where a value V is said to conform to a 233 | type T if either V has type T, or V is a list of elements of type T. 234 | Hence, for a multidimensional parameter 'x' taking float values, 235 | 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. 236 | 237 | Returns: 238 | A python map mapping each name to either: 239 | * A scalar value. 240 | * A list of scalar values. 241 | * A dictionary mapping index numbers to scalar values. 242 | (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") 243 | 244 | Raises: 245 | ValueError: If there is a problem with input. 246 | * If `values` cannot be parsed. 247 | * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). 248 | * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', 249 | 'a[1]=1,a[1]=2', or 'a=1,a=[1]') 250 | """ 251 | results_dictionary = {} 252 | pos = 0 253 | while pos < len(values): 254 | m = PARAM_RE.match(values, pos) 255 | if not m: 256 | raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) 257 | # Check that there is a comma between parameters and move past it. 258 | pos = m.end() 259 | # Parse the values. 260 | m_dict = m.groupdict() 261 | name = m_dict['name'] 262 | if name not in type_map: 263 | raise ValueError('Unknown hyperparameter type for %s' % name) 264 | type_ = type_map[name] 265 | 266 | # Set up correct parsing function (depending on whether type_ is a bool) 267 | if type_ == bool: 268 | 269 | def parse_bool(value): 270 | if value in ['true', 'True']: 271 | return True 272 | elif value in ['false', 'False']: 273 | return False 274 | else: 275 | try: 276 | return bool(int(value)) 277 | except ValueError: 278 | _parse_fail(name, type_, value, values) 279 | 280 | parse = parse_bool 281 | else: 282 | parse = type_ 283 | 284 | # If a singe value is provided 285 | if m_dict['val'] is not None: 286 | _process_scalar_value(name, parse, type_, m_dict, values, 287 | results_dictionary) 288 | 289 | # If the assigned value is a list: 290 | elif m_dict['vals'] is not None: 291 | _process_list_value(name, parse, type_, m_dict, values, 292 | results_dictionary) 293 | 294 | else: # Not assigned a list or value 295 | _parse_fail(name, type_, '', values) 296 | 297 | return results_dictionary 298 | 299 | 300 | class HParams(object): 301 | """Class to hold a set of hyperparameters as name-value pairs. 302 | 303 | A `HParams` object holds hyperparameters used to build and train a model, 304 | such as the number of hidden units in a neural net layer or the learning rate 305 | to use when training. 306 | 307 | You first create a `HParams` object by specifying the names and values of the 308 | hyperparameters. 309 | 310 | To make them easily accessible the parameter names are added as direct 311 | attributes of the class. A typical usage is as follows: 312 | 313 | ```python 314 | # Create a HParams object specifying names and values of the model 315 | # hyperparameters: 316 | hparams = HParams(learning_rate=0.1, num_hidden_units=100) 317 | 318 | # The hyperparameter are available as attributes of the HParams object: 319 | hparams.learning_rate ==> 0.1 320 | hparams.num_hidden_units ==> 100 321 | ``` 322 | 323 | Hyperparameters have type, which is inferred from the type of their value 324 | passed at construction type. The currently supported types are: integer, 325 | float, boolean, string, and list of integer, float, boolean, or string. 326 | 327 | You can override hyperparameter values by calling the 328 | [`parse()`](#HParams.parse) method, passing a string of comma separated 329 | `name=value` pairs. This is intended to make it possible to override 330 | any hyperparameter values from a single command-line flag to which 331 | the user passes 'hyper-param=value' pairs. It avoids having to define 332 | one flag for each hyperparameter. 333 | 334 | The syntax expected for each value depends on the type of the parameter. 335 | See `parse()` for a description of the syntax. 336 | 337 | Example: 338 | 339 | ```python 340 | # Define a command line flag to pass name=value pairs. 341 | # For example using argparse: 342 | import argparse 343 | parser = argparse.ArgumentParser(description='Train my model.') 344 | parser.add_argument('--hparams', type=str, 345 | help='Comma separated list of "name=value" pairs.') 346 | args = parser.parse_args() 347 | ... 348 | def my_program(): 349 | # Create a HParams object specifying the names and values of the 350 | # model hyperparameters: 351 | hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, 352 | activations=['relu', 'tanh']) 353 | 354 | # Override hyperparameters values by parsing the command line 355 | hparams.parse(args.hparams) 356 | 357 | # If the user passed `--hparams=learning_rate=0.3` on the command line 358 | # then 'hparams' has the following attributes: 359 | hparams.learning_rate ==> 0.3 360 | hparams.num_hidden_units ==> 100 361 | hparams.activations ==> ['relu', 'tanh'] 362 | 363 | # If the hyperparameters are in json format use parse_json: 364 | hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') 365 | ``` 366 | """ 367 | 368 | _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. 369 | 370 | def __init__(self, hparam_def=None, model_structure=None, **kwargs): 371 | """Create an instance of `HParams` from keyword arguments. 372 | 373 | The keyword arguments specify name-values pairs for the hyperparameters. 374 | The parameter types are inferred from the type of the values passed. 375 | 376 | The parameter names are added as attributes of `HParams` object, so they 377 | can be accessed directly with the dot notation `hparams._name_`. 378 | 379 | Example: 380 | 381 | ```python 382 | # Define 3 hyperparameters: 'learning_rate' is a float parameter, 383 | # 'num_hidden_units' an integer parameter, and 'activation' a string 384 | # parameter. 385 | hparams = tf.HParams( 386 | learning_rate=0.1, num_hidden_units=100, activation='relu') 387 | 388 | hparams.activation ==> 'relu' 389 | ``` 390 | 391 | Note that a few names are reserved and cannot be used as hyperparameter 392 | names. If you use one of the reserved name the constructor raises a 393 | `ValueError`. 394 | 395 | Args: 396 | hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef 397 | protocol buffer. If provided, this object is initialized by 398 | deserializing hparam_def. Otherwise **kwargs is used. 399 | model_structure: An instance of ModelStructure, defining the feature 400 | crosses to be used in the Trial. 401 | **kwargs: Key-value pairs where the key is the hyperparameter name and 402 | the value is the value for the parameter. 403 | 404 | Raises: 405 | ValueError: If both `hparam_def` and initialization values are provided, 406 | or if one of the arguments is invalid. 407 | 408 | """ 409 | # Register the hyperparameters and their type in _hparam_types. 410 | # This simplifies the implementation of parse(). 411 | # _hparam_types maps the parameter name to a tuple (type, bool). 412 | # The type value is the type of the parameter for scalar hyperparameters, 413 | # or the type of the list elements for multidimensional hyperparameters. 414 | # The bool value is True if the value is a list, False otherwise. 415 | self._hparam_types = {} 416 | self._model_structure = model_structure 417 | if hparam_def: 418 | ## self._init_from_proto(hparam_def) 419 | ## if kwargs: 420 | ## raise ValueError('hparam_def and initialization values are ' 421 | ## 'mutually exclusive') 422 | raise ValueError('hparam_def has been disabled in this version') 423 | else: 424 | for name, value in six.iteritems(kwargs): 425 | self.add_hparam(name, value) 426 | 427 | ## def _init_from_proto(self, hparam_def): 428 | ## """Creates a new HParams from `HParamDef` protocol buffer. 429 | ## 430 | ## Args: 431 | ## hparam_def: `HParamDef` protocol buffer. 432 | ## """ 433 | ## assert isinstance(hparam_def, hparam_pb2.HParamDef) 434 | ## for name, value in hparam_def.hparam.items(): 435 | ## kind = value.WhichOneof('kind') 436 | ## if kind.endswith('_value'): 437 | ## # Single value. 438 | ## if kind.startswith('int64'): 439 | ## # Setting attribute value to be 'int' to ensure the type is compatible 440 | ## # with both Python2 and Python3. 441 | ## self.add_hparam(name, int(getattr(value, kind))) 442 | ## elif kind.startswith('bytes'): 443 | ## # Setting attribute value to be 'str' to ensure the type is compatible 444 | ## # with both Python2 and Python3. UTF-8 encoding is assumed. 445 | ## self.add_hparam(name, compat.as_str(getattr(value, kind))) 446 | ## else: 447 | ## self.add_hparam(name, getattr(value, kind)) 448 | ## else: 449 | ## # List of values. 450 | ## if kind.startswith('int64'): 451 | ## # Setting attribute value to be 'int' to ensure the type is compatible 452 | ## # with both Python2 and Python3. 453 | ## self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) 454 | ## elif kind.startswith('bytes'): 455 | ## # Setting attribute value to be 'str' to ensure the type is compatible 456 | ## # with both Python2 and Python3. UTF-8 encoding is assumed. 457 | ## self.add_hparam( 458 | ## name, [compat.as_str(v) for v in getattr(value, kind).value]) 459 | ## else: 460 | ## self.add_hparam(name, [v for v in getattr(value, kind).value]) 461 | 462 | def add_hparam(self, name, value): 463 | """Adds {name, value} pair to hyperparameters. 464 | 465 | Args: 466 | name: Name of the hyperparameter. 467 | value: Value of the hyperparameter. Can be one of the following types: 468 | int, float, string, int list, float list, or string list. 469 | 470 | Raises: 471 | ValueError: if one of the arguments is invalid. 472 | """ 473 | # Keys in kwargs are unique, but 'name' could the name of a pre-existing 474 | # attribute of this object. In that case we refuse to use it as a 475 | # hyperparameter name. 476 | if getattr(self, name, None) is not None: 477 | raise ValueError('Hyperparameter name is reserved: %s' % name) 478 | if isinstance(value, (list, tuple)): 479 | if not value: 480 | raise ValueError( 481 | 'Multi-valued hyperparameters cannot be empty: %s' % name) 482 | self._hparam_types[name] = (type(value[0]), True) 483 | else: 484 | self._hparam_types[name] = (type(value), False) 485 | setattr(self, name, value) 486 | 487 | def set_hparam(self, name, value): 488 | """Set the value of an existing hyperparameter. 489 | 490 | This function verifies that the type of the value matches the type of the 491 | existing hyperparameter. 492 | 493 | Args: 494 | name: Name of the hyperparameter. 495 | value: New value of the hyperparameter. 496 | 497 | Raises: 498 | ValueError: If there is a type mismatch. 499 | """ 500 | param_type, is_list = self._hparam_types[name] 501 | if isinstance(value, list): 502 | if not is_list: 503 | raise ValueError( 504 | 'Must not pass a list for single-valued parameter: %s' % name) 505 | setattr(self, name, [ 506 | _cast_to_type_if_compatible(name, param_type, v) for v in value]) 507 | else: 508 | if is_list: 509 | raise ValueError( 510 | 'Must pass a list for multi-valued parameter: %s.' % name) 511 | setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) 512 | 513 | def del_hparam(self, name): 514 | """Removes the hyperparameter with key 'name'. 515 | 516 | Args: 517 | name: Name of the hyperparameter. 518 | """ 519 | if hasattr(self, name): 520 | delattr(self, name) 521 | del self._hparam_types[name] 522 | 523 | def parse(self, values): 524 | """Override hyperparameter values, parsing new values from a string. 525 | 526 | See parse_values for more detail on the allowed format for values. 527 | 528 | Args: 529 | values: String. Comma separated list of `name=value` pairs where 530 | 'value' must follow the syntax described above. 531 | 532 | Returns: 533 | The `HParams` instance. 534 | 535 | Raises: 536 | ValueError: If `values` cannot be parsed. 537 | """ 538 | type_map = dict() 539 | for name, t in self._hparam_types.items(): 540 | param_type, _ = t 541 | type_map[name] = param_type 542 | 543 | values_map = parse_values(values, type_map) 544 | return self.override_from_dict(values_map) 545 | 546 | def override_from_dict(self, values_dict): 547 | """Override hyperparameter values, parsing new values from a dictionary. 548 | 549 | Args: 550 | values_dict: Dictionary of name:value pairs. 551 | 552 | Returns: 553 | The `HParams` instance. 554 | 555 | Raises: 556 | ValueError: If `values_dict` cannot be parsed. 557 | """ 558 | for name, value in values_dict.items(): 559 | self.set_hparam(name, value) 560 | return self 561 | 562 | ## @deprecation.deprecated(None, 'Use `override_from_dict`.') 563 | def set_from_map(self, values_map): 564 | """DEPRECATED. Use override_from_dict.""" 565 | return self.override_from_dict(values_dict=values_map) 566 | 567 | def set_model_structure(self, model_structure): 568 | self._model_structure = model_structure 569 | 570 | def get_model_structure(self): 571 | return self._model_structure 572 | 573 | def to_json(self, indent=None, separators=None, sort_keys=False): 574 | """Serializes the hyperparameters into JSON. 575 | 576 | Args: 577 | indent: If a non-negative integer, JSON array elements and object members 578 | will be pretty-printed with that indent level. An indent level of 0, or 579 | negative, will only insert newlines. `None` (the default) selects the 580 | most compact representation. 581 | separators: Optional `(item_separator, key_separator)` tuple. Default is 582 | `(', ', ': ')`. 583 | sort_keys: If `True`, the output dictionaries will be sorted by key. 584 | 585 | Returns: 586 | A JSON string. 587 | """ 588 | return json.dumps( 589 | self.values(), 590 | indent=indent, 591 | separators=separators, 592 | sort_keys=sort_keys) 593 | 594 | def parse_json(self, values_json): 595 | """Override hyperparameter values, parsing new values from a json object. 596 | 597 | Args: 598 | values_json: String containing a json object of name:value pairs. 599 | 600 | Returns: 601 | The `HParams` instance. 602 | 603 | Raises: 604 | ValueError: If `values_json` cannot be parsed. 605 | """ 606 | values_map = json.loads(values_json) 607 | return self.override_from_dict(values_map) 608 | 609 | def values(self): 610 | """Return the hyperparameter values as a Python dictionary. 611 | 612 | Returns: 613 | A dictionary with hyperparameter names as keys. The values are the 614 | hyperparameter values. 615 | """ 616 | return {n: getattr(self, n) for n in self._hparam_types.keys()} 617 | 618 | def get(self, key, default=None): 619 | """Returns the value of `key` if it exists, else `default`.""" 620 | if key in self._hparam_types: 621 | # Ensure that default is compatible with the parameter type. 622 | if default is not None: 623 | param_type, is_param_list = self._hparam_types[key] 624 | type_str = 'list<%s>' % param_type if is_param_list else str(param_type) 625 | fail_msg = ("Hparam '%s' of type '%s' is incompatible with " 626 | 'default=%s' % (key, type_str, default)) 627 | 628 | is_default_list = isinstance(default, list) 629 | if is_param_list != is_default_list: 630 | raise ValueError(fail_msg) 631 | 632 | try: 633 | if is_default_list: 634 | for value in default: 635 | _cast_to_type_if_compatible(key, param_type, value) 636 | else: 637 | _cast_to_type_if_compatible(key, param_type, default) 638 | except ValueError as e: 639 | raise ValueError('%s. %s' % (fail_msg, e)) 640 | 641 | return getattr(self, key) 642 | 643 | return default 644 | 645 | def __contains__(self, key): 646 | return key in self._hparam_types 647 | 648 | def __str__(self): 649 | return str(sorted(self.values().items())) 650 | 651 | def __repr__(self): 652 | return '%s(%s)' % (type(self).__name__, self.__str__()) 653 | 654 | @staticmethod 655 | def _get_kind_name(param_type, is_list): 656 | """Returns the field name given parameter type and is_list. 657 | 658 | Args: 659 | param_type: Data type of the hparam. 660 | is_list: Whether this is a list. 661 | 662 | Returns: 663 | A string representation of the field name. 664 | 665 | Raises: 666 | ValueError: If parameter type is not recognized. 667 | """ 668 | if issubclass(param_type, bool): 669 | # This check must happen before issubclass(param_type, six.integer_types), 670 | # since Python considers bool to be a subclass of int. 671 | typename = 'bool' 672 | elif issubclass(param_type, six.integer_types): 673 | # Setting 'int' and 'long' types to be 'int64' to ensure the type is 674 | # compatible with both Python2 and Python3. 675 | typename = 'int64' 676 | elif issubclass(param_type, (six.string_types, six.binary_type)): 677 | # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is 678 | # compatible with both Python2 and Python3. 679 | typename = 'bytes' 680 | elif issubclass(param_type, float): 681 | typename = 'float' 682 | else: 683 | raise ValueError('Unsupported parameter type: %s' % str(param_type)) 684 | 685 | suffix = 'list' if is_list else 'value' 686 | return '_'.join([typename, suffix]) 687 | 688 | ## def to_proto(self, export_scope=None): # pylint: disable=unused-argument 689 | ## """Converts a `HParams` object to a `HParamDef` protocol buffer. 690 | ## 691 | ## Args: 692 | ## export_scope: Optional `string`. Name scope to remove. 693 | ## 694 | ## Returns: 695 | ## A `HParamDef` protocol buffer. 696 | ## """ 697 | ## hparam_proto = hparam_pb2.HParamDef() 698 | ## for name in self._hparam_types: 699 | ## # Parse the values. 700 | ## param_type, is_list = self._hparam_types.get(name, (None, None)) 701 | ## kind = HParams._get_kind_name(param_type, is_list) 702 | ## 703 | ## if is_list: 704 | ## if kind.startswith('bytes'): 705 | ## v_list = [compat.as_bytes(v) for v in getattr(self, name)] 706 | ## else: 707 | ## v_list = [v for v in getattr(self, name)] 708 | ## getattr(hparam_proto.hparam[name], kind).value.extend(v_list) 709 | ## else: 710 | ## v = getattr(self, name) 711 | ## if kind.startswith('bytes'): 712 | ## v = compat.as_bytes(getattr(self, name)) 713 | ## setattr(hparam_proto.hparam[name], kind, v) 714 | ## 715 | ## return hparam_proto 716 | 717 | ## @staticmethod 718 | ## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument 719 | ## return HParams(hparam_def=hparam_def) 720 | 721 | 722 | ## ops.register_proto_function( 723 | ## 'hparams', 724 | ## proto_type=hparam_pb2.HParamDef, 725 | ## to_proto=HParams.to_proto, 726 | ## from_proto=HParams.from_proto) 727 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | from scipy import signal 5 | from librosa.filters import mel 6 | from scipy.signal import get_window 7 | 8 | 9 | 10 | def butter_highpass(cutoff, fs, order=5): 11 | nyq = 0.5 * fs 12 | normal_cutoff = cutoff / nyq 13 | b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) 14 | return b, a 15 | 16 | 17 | 18 | def pySTFT(x, fft_length=1024, hop_length=256): 19 | 20 | x = np.pad(x, int(fft_length//2), mode='reflect') 21 | 22 | noverlap = fft_length - hop_length 23 | shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length) 24 | strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1]) 25 | result = np.lib.stride_tricks.as_strided(x, shape=shape, 26 | strides=strides) 27 | 28 | fft_window = get_window('hann', fft_length, fftbins=True) 29 | result = np.fft.rfft(fft_window * result, n=fft_length).T 30 | 31 | return np.abs(result) 32 | 33 | 34 | 35 | def speaker_normalization(f0, index_nonzero, mean_f0, std_f0): 36 | # f0 is logf0 37 | f0 = f0.astype(float).copy() 38 | #index_nonzero = f0 != 0 39 | f0[index_nonzero] = (f0[index_nonzero] - mean_f0) / std_f0 / 4.0 40 | f0[index_nonzero] = np.clip(f0[index_nonzero], -1, 1) 41 | f0[index_nonzero] = (f0[index_nonzero] + 1) / 2.0 42 | return f0 43 | 44 | 45 | 46 | def quantize_f0_numpy(x, num_bins=256): 47 | # x is logf0 48 | assert x.ndim==1 49 | x = x.astype(float).copy() 50 | uv = (x<=0) 51 | x[uv] = 0.0 52 | assert (x >= 0).all() and (x <= 1).all() 53 | x = np.round(x * (num_bins-1)) 54 | x = x + 1 55 | x[uv] = 0.0 56 | enc = np.zeros((len(x), num_bins+1), dtype=np.float32) 57 | enc[np.arange(len(x)), x.astype(np.int32)] = 1.0 58 | return enc, x.astype(np.int64) 59 | 60 | 61 | 62 | def quantize_f0_torch(x, num_bins=256): 63 | # x is logf0 64 | B = x.size(0) 65 | x = x.view(-1).clone() 66 | uv = (x<=0) 67 | x[uv] = 0 68 | assert (x >= 0).all() and (x <= 1).all() 69 | x = torch.round(x * (num_bins-1)) 70 | x = x + 1 71 | x[uv] = 0 72 | enc = torch.zeros((x.size(0), num_bins+1), device=x.device) 73 | enc[torch.arange(x.size(0)), x.long()] = 1 74 | return enc.view(B, -1, num_bins+1), x.view(B, -1).long() 75 | 76 | 77 | 78 | def get_mask_from_lengths(lengths, max_len): 79 | ids = torch.arange(0, max_len, device=lengths.device) 80 | mask = (ids >= lengths.unsqueeze(1)).bool() 81 | return mask 82 | 83 | 84 | 85 | def pad_seq_to_2(x, len_out=128): 86 | len_pad = (len_out - x.shape[1]) 87 | assert len_pad >= 0 88 | return np.pad(x, ((0,0),(0,len_pad),(0,0)), 'constant'), len_pad --------------------------------------------------------------------------------