├── data └── VCTK │ ├── txt │ ├── p301 │ │ ├── p301_001.txt │ │ ├── p301_010.txt │ │ ├── p301_002.txt │ │ ├── p301_004.txt │ │ ├── p301_009.txt │ │ ├── p301_007.txt │ │ ├── p301_006.txt │ │ ├── p301_003.txt │ │ ├── p301_005.txt │ │ └── p301_008.txt │ └── p226 │ │ ├── p226_001.txt │ │ ├── p226_010.txt │ │ ├── p226_002.txt │ │ ├── p226_009.txt │ │ ├── p226_004.txt │ │ ├── p226_007.txt │ │ ├── p226_006.txt │ │ ├── p226_003.txt │ │ ├── p226_005.txt │ │ └── p226_008.txt │ └── wav48 │ ├── p226 │ ├── p226_001.wav │ ├── p226_002.wav │ ├── p226_003.wav │ ├── p226_004.wav │ ├── p226_005.wav │ ├── p226_006.wav │ ├── p226_007.wav │ ├── p226_008.wav │ ├── p226_009.wav │ └── p226_010.wav │ └── p301 │ ├── p301_001.wav │ ├── p301_002.wav │ ├── p301_003.wav │ ├── p301_004.wav │ ├── p301_005.wav │ ├── p301_006.wav │ ├── p301_007.wav │ ├── p301_008.wav │ ├── p301_009.wav │ └── p301_010.wav ├── assets └── spk2gen_voicesplit.pkl ├── .gitignore ├── make_spk2gen.py ├── logger.py ├── nikl_spk.txt ├── task_launcher.py ├── plotting_utils.py ├── README.md ├── preprocess.py ├── hparams.py ├── conversion.py ├── utils.py ├── datasets └── voicesplit.py ├── solver ├── base.py ├── autovc_f0.py └── autovc.py ├── data_utils.py └── architectures ├── arch_autovc.py └── arch_autovc_f0.py /data/VCTK/txt/p301/p301_001.txt: -------------------------------------------------------------------------------- 1 | Please call Stella. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_001.txt: -------------------------------------------------------------------------------- 1 | Please call Stella. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_010.txt: -------------------------------------------------------------------------------- 1 | People look, but no one ever finds it. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_010.txt: -------------------------------------------------------------------------------- 1 | People look, but no one ever finds it. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_002.txt: -------------------------------------------------------------------------------- 1 | Ask her to bring these things with her from the store. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_002.txt: -------------------------------------------------------------------------------- 1 | Ask her to bring these things with her from the store. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_004.txt: -------------------------------------------------------------------------------- 1 | We also need a small plastic snake and a big toy frog for the kids. -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_009.txt: -------------------------------------------------------------------------------- 1 | There is , according to legend, a boiling pot of gold at one end. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_009.txt: -------------------------------------------------------------------------------- 1 | There is , according to legend, a boiling pot of gold at one end. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_007.txt: -------------------------------------------------------------------------------- 1 | The rainbow is a division of white light into many beautiful colors. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_004.txt: -------------------------------------------------------------------------------- 1 | We also need a small plastic snake and a big toy frog for the kids. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_007.txt: -------------------------------------------------------------------------------- 1 | The rainbow is a division of white light into many beautiful colors. 2 | -------------------------------------------------------------------------------- /assets/spk2gen_voicesplit.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/assets/spk2gen_voicesplit.pkl -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_006.txt: -------------------------------------------------------------------------------- 1 | When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_006.txt: -------------------------------------------------------------------------------- 1 | When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 2 | -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_001.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_002.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_003.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_003.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_004.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_005.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_006.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_006.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_007.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_007.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_008.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_009.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_009.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p226/p226_010.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p226/p226_010.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_001.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_002.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_003.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_003.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_004.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_005.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_006.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_006.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_007.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_007.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_008.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_009.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_009.wav -------------------------------------------------------------------------------- /data/VCTK/wav48/p301/p301_010.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/f0-auto-vc/HEAD/data/VCTK/wav48/p301/p301_010.wav -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_003.txt: -------------------------------------------------------------------------------- 1 | Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob. -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_005.txt: -------------------------------------------------------------------------------- 1 | She can scoop these things into three red bags, and we will go meet her Wednesday at the train station. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_003.txt: -------------------------------------------------------------------------------- 1 | Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_005.txt: -------------------------------------------------------------------------------- 1 | She can scoop these things into three red bags, and we will go meet her Wednesday at the train station. 2 | -------------------------------------------------------------------------------- /data/VCTK/txt/p301/p301_008.txt: -------------------------------------------------------------------------------- 1 | These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. -------------------------------------------------------------------------------- /data/VCTK/txt/p226/p226_008.txt: -------------------------------------------------------------------------------- 1 | These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/deployment.xml 2 | .idea/f0-autovc.iml 3 | .idea/misc.xml 4 | .idea/modules.xml 5 | .idea/remote-mappings.xml 6 | .idea/vcs.xml 7 | .idea/workspace.xml 8 | .idea/libraries/R_User_Library.xml 9 | -------------------------------------------------------------------------------- /make_spk2gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Making pickle file as follows. 3 | 4 | spk2gen = { 5 | 'spk1' : 'M' 6 | 'spk2' : 'F', 7 | 'spk3' : 'M', 8 | . 9 | . 10 | . 11 | } 12 | """ 13 | 14 | import pickle 15 | 16 | spk2gen_file = "assets/spk2gen_nikl.pkl" 17 | 18 | with open("nikl_spk.txt", 'rt') as f: 19 | spkinfo = [l.rstrip() for l in f.readlines()] 20 | 21 | spk2gen = {} 22 | for info in spkinfo: 23 | spk, gender = info.split() 24 | spk2gen[spk] = gender 25 | 26 | with open(spk2gen_file, 'wb') as f: 27 | pickle.dump(spk2gen, f) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # import tensorflow as tf 2 | from tensorboardX import SummaryWriter 3 | from plotting_utils import plot_spectrogram_to_numpy 4 | 5 | class Logger(object): 6 | """Using tensorboardX such that need no dependency on tensorflow.""" 7 | 8 | def __init__(self, log_dir): 9 | """Initialize summary writer.""" 10 | self.writer = SummaryWriter(log_dir) 11 | 12 | def scalar_summary(self, tag, value, step): 13 | self.writer.add_scalar(tag, value, step) 14 | 15 | def image_summary(self, tag, image, step): 16 | self.writer.add_image( 17 | tag, 18 | image, 19 | #plot_spectrogram_to_numpy(image.T), 20 | step, 21 | dataformats='HWC') 22 | 23 | def dist_summary(self, tag, weights, step): 24 | # # plot distribution of parameters 25 | # for tag, value in model.named_parameters(): 26 | # tag = tag.replace('.', '/') 27 | # self.add_histogram(tag, value.data.cpu().numpy(), iteration) 28 | self.writer.add_histogram(tag, weights, step) -------------------------------------------------------------------------------- /nikl_spk.txt: -------------------------------------------------------------------------------- 1 | fv18 F 2 | fv16 F 3 | fv17 F 4 | fv08 F 5 | fx07 F 6 | fx04 F 7 | fx15 F 8 | fx14 F 9 | fx12 F 10 | fy06 F 11 | fv05 F 12 | fy03 F 13 | fv02 F 14 | fx06 F 15 | fy07 F 16 | fx17 F 17 | fx05 F 18 | fy05 F 19 | fx16 F 20 | fz05 F 21 | fx01 F 22 | fx13 F 23 | fy14 F 24 | fy09 F 25 | fy02 F 26 | fv07 F 27 | fv10 F 28 | fv01 F 29 | fy13 F 30 | fx02 F 31 | fv19 F 32 | fv20 F 33 | fv12 F 34 | fy18 F 35 | fx11 F 36 | fy11 F 37 | fx18 F 38 | fy12 F 39 | fv13 F 40 | fv15 F 41 | fy08 F 42 | fx10 F 43 | fy17 F 44 | fx08 F 45 | fy04 F 46 | fy01 F 47 | fy16 F 48 | fv06 F 49 | fv04 F 50 | fx03 F 51 | fy10 F 52 | fv14 F 53 | fv11 F 54 | fx19 F 55 | fz06 F 56 | fv09 F 57 | fx09 F 58 | fv03 F 59 | fx20 F 60 | mw13 M 61 | my06 M 62 | mw01 M 63 | mv18 M 64 | my09 M 65 | mv02 M 66 | mv01 M 67 | mv10 M 68 | mz01 M 69 | mv08 M 70 | mz06 M 71 | mw08 M 72 | mw11 M 73 | mv03 M 74 | mw10 M 75 | mv07 M 76 | mz09 M 77 | mw09 M 78 | mw06 M 79 | mv04 M 80 | mv05 M 81 | mw16 M 82 | mw19 M 83 | my04 M 84 | mv15 M 85 | my05 M 86 | mv17 M 87 | mv06 M 88 | mv12 M 89 | mv09 M 90 | mz03 M 91 | mw02 M 92 | mw18 M 93 | my08 M 94 | mw20 M 95 | my10 M 96 | my11 M 97 | mv19 M 98 | mz07 M 99 | mv16 M 100 | my01 M 101 | mz04 M 102 | mz05 M 103 | mw05 M 104 | mw17 M 105 | mv20 M 106 | mw04 M 107 | my07 M 108 | mz02 M 109 | mw14 M 110 | mw03 M 111 | mw07 M 112 | mz08 M 113 | mv13 M 114 | my02 M 115 | mw15 M 116 | my03 M 117 | mv14 M 118 | mv11 M 119 | -------------------------------------------------------------------------------- /task_launcher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from hparams import hparams 3 | from utils import prepare_dirs 4 | import os 5 | from data_utils import prepare_dataloaders 6 | import importlib 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description="") 10 | parser.add_argument('--name', type=str, default="autovc_voicesplit_freq16_seqlen128_l1_nof0") 11 | parser.add_argument('--save_path', type=str, default="/hd0/f0-autovc/exp") 12 | parser.add_argument('--data_dir', type=str, default='/hd0/f0-autovc/preprocessed/sr16000_npz') 13 | parser.add_argument('--checkpoint', type=str, default=None) 14 | parser.add_argument('--architecture', type=str, default='architectures/arch_autovc.py') 15 | parser.add_argument('--solver', type=str, default='solver/autovc.py') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | args = parse_args() 21 | 22 | # prepare directory 23 | model_dir, log_dir, sample_dir = prepare_dirs(args) 24 | 25 | # data loaders 26 | train_loader, val_loader, testset = prepare_dataloaders(args.data_dir, hparams) 27 | 28 | # architecture 29 | arch = os.path.splitext(args.architecture)[0].replace("/", ".") 30 | print(" [*] Load architecture : {}".format(arch)) 31 | 32 | solver_mod = importlib.import_module(os.path.splitext(args.solver)[0].replace("/", ".")) 33 | print(" [*] Load solver : {}".format(solver_mod)) 34 | 35 | # solver 36 | solver = solver_mod.AutoVC(arch, 37 | model_dir, 38 | log_dir, 39 | sample_dir) 40 | 41 | if args.checkpoint: 42 | solver.load(args.checkpoint) 43 | 44 | solver.train( 45 | train_loader, 46 | val_loader, 47 | testset, 48 | hparams.nepochs, 49 | hparams.save_every, 50 | verbose=True) -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | def save_figure_to_numpy(fig): 7 | # save it to a numpy array. 8 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 9 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 10 | return data 11 | 12 | def plot_spectrogram_to_numpy(title, spectrogram): 13 | fig, ax = plt.subplots(figsize=(12, 3)) 14 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 15 | interpolation='none') 16 | plt.colorbar(im, ax=ax) 17 | plt.title(title) 18 | plt.xlabel("Frames") 19 | plt.ylabel("Channels") 20 | plt.tight_layout() 21 | 22 | fig.canvas.draw() 23 | data = save_figure_to_numpy(fig) 24 | plt.close() 25 | return data 26 | 27 | def plot_f0_to_numpy(title, f0): 28 | fig, ax = plt.subplots(figsize=(12, 3)) 29 | im = ax.plot(range(len(f0)), f0, color='green') 30 | plt.title(title) 31 | plt.xlabel("Frames") 32 | plt.ylabel("F0") 33 | plt.tight_layout() 34 | 35 | fig.canvas.draw() 36 | data = save_figure_to_numpy(fig) 37 | plt.close() 38 | return data 39 | 40 | def plot_f0_outputs_to_numpy(title, f0_target, f0_predicted): 41 | fig, ax = plt.subplots(figsize=(12, 3)) 42 | ax.plot(range(len(f0_target)), f0_target, alpha=0.5, 43 | color='green', label='target') 44 | ax.plot(range(len(f0_predicted)), f0_predicted, alpha=0.5, 45 | color='red', label='predicted') 46 | 47 | plt.title(title) 48 | plt.xlabel("Frames (Green target, Red predicted)") 49 | plt.ylabel("F0") 50 | plt.tight_layout() 51 | 52 | fig.canvas.draw() 53 | data = save_figure_to_numpy(fig) 54 | plt.close() 55 | return data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## F0-AUTOVC: F0-Consistent Many-to-Many Non-Parallel Voice Conversion via Conditional Autoencoder 2 | This repository provides a PyTorch implementation of the paper [F0-AUTOVC](https://arxiv.org/abs/2004.07370). 3 | 4 | Based on 5 | - https://github.com/auspicious3000/autovc 6 | - https://github.com/auspicious3000/SpeechSplit 7 | - https://github.com/christopher-beckham/amr 8 | ## Dependencies 9 | - Python 3.7 10 | - Pytorch 1.6.0 11 | - TensorFlow 12 | - Numpy 13 | - librosa 14 | - tqdm 15 | 16 | ## Usage 17 | 1. Prepare dataset
18 | we used the [VCTK dataset](http://www.udialogue.org/download/cstr-vctk-corpus.html) as used in original paper. 19 | But, you can use your own dataset. 20 | 21 | 2. Prepare the speaker to gender file as shown in nikl_spk.txt and run ```make_spk2gen.py``` 22 | * Format 23 | speaker1 gender1 24 | speaker2 gender2 25 | 26 | * Example: 27 | p225 W 28 | p226 M 29 | p301 W 30 | p302 W 31 | . 32 | . 33 | 34 | 3. Preprocess data using ```preprocess.py``` 35 | 36 | 4. Run ```task_launcher.py``` 37 | This is the linke of the main project which I'm working on and explaining about it :https://github.com/hrnoh/f0-autovc 38 | The following links are articles that will help you understand the subject better 39 | [2002.00198] Transforming Spectrum and Prosody for Emotional Voice Conversion with Non-Parallel Training Data (arxiv.org) 40 | Blow: a single-scale hyperconditioned flow for non-parallel raw-audio voice conversion (neurips.cc) 41 | [1907.10185] Non-Parallel Voice Conversion with Cyclic Variational Autoencoder (arxiv.org) 42 | The link below contains a video where I explain the summary of the article:https://drive.google.com/file/d/1QUo5kBkf8QxXhK-y53c4CvqTs0YDxM6z/view?usp=drivesdk 43 | [2108.04395] StarGAN-VC+ASR: StarGAN-based Non-Parallel Voice Conversion Regularized by Automatic Speech Recognition (arxiv.org) 44 | The last link contains my explanation about the main project:https://drive.google.com/file/d/1YxCwQQt4UT36T5Hz4iYW0nKdxEijgDy0/view?usp=drivesdk 45 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from multiprocessing import cpu_count 4 | from tqdm import tqdm 5 | import importlib 6 | from hparams import hparams, hparams_debug_string 7 | import warnings 8 | warnings.simplefilter(action='ignore', category=FutureWarning) 9 | 10 | 11 | def preprocess(mod, in_dir, out_dir, spk_emb, spk2gen, num_workers): 12 | os.makedirs(out_dir, exist_ok=True) 13 | mod.build_from_path(hparams, in_dir, out_dir, spk_emb, spk2gen, num_workers=num_workers) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--name', type=str, default='voicesplit') 19 | parser.add_argument('--in_dir', type=str, default='/hd0/dataset/voicesplit/') 20 | parser.add_argument('--out_dir', type=str, default='/hd0/f0-autovc/preprocessed/sr16000_npz') 21 | parser.add_argument('--num_workers', type=str, default=None) 22 | parser.add_argument('--spk_emb', type=str, default=None, help='speaker embedding path (default: onehot)') 23 | parser.add_argument('--spk2gen', type=str, default="assets/spk2gen_voicesplit.pkl", help="pickle file path for converting speaker to gender") 24 | args = parser.parse_args() 25 | print(hparams_debug_string()) 26 | 27 | if not os.path.exists(args.out_dir): 28 | try: 29 | os.mkdir(args.out_dir) 30 | except FileExistsError: 31 | print(args.out_dir, "exists") 32 | 33 | if not args.num_workers: 34 | args.num_workers = cpu_count() 35 | 36 | assert args.name in ["VCTK", "NIKL", "voicesplit"] 37 | mod = importlib.import_module('datasets.{}'.format(args.name)) 38 | 39 | print("---------------------------------- Preprecessing starts! ----------------------------------") 40 | print("dataset: {}".format(args.name)) 41 | print("load directory: {}".format(args.in_dir)) 42 | print("output directory: {}".format(args.out_dir)) 43 | 44 | preprocess(mod, args.in_dir, args.out_dir, args.spk_emb, args.spk2gen, args.num_workers) 45 | print("---------------------------------- Preprecessing is done! ----------------------------------") -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import torch 5 | 6 | # NOTE: If you want full control for model architecture. please take a look 7 | # at the code and change whatever you want. Some hyper parameters are hardcoded. 8 | 9 | # Default hyperparameters: 10 | hparams = tf.contrib.training.HParams( 11 | exp='AutoVC', 12 | n_speakers=12, 13 | 14 | device= 'cuda' if torch.cuda.is_available() else 'cpu', 15 | #device= 'cpu', 16 | 17 | # preprocessing 때 사용 18 | used_spks=None, 19 | 20 | ########## Audio ###################################### 21 | sample_rate=16000, # 22 | 23 | # shift can be specified by either hop_size(?곗꽑) or frame_shift_ms 24 | hop_size=256, # frame_shift_ms = 12.5ms 25 | fft_size=1024, 26 | win_size=1024, # 50ms 27 | num_mels=80, 28 | 29 | min_level_db=-100, 30 | ref_level_db=16, 31 | 32 | rescaling=True, 33 | rescaling_max=0.999, 34 | 35 | trim_silence=False, # Whether to clip silence in Audio (at beginning and end of audio only, not the middle) 36 | # M-AILABS (and other datasets) trim params (there parameters are usually correct for any data, but definitely must be tuned for specific speakers) 37 | trim_fft_size=1024, 38 | trim_hop_size=256, 39 | trim_top_db=20, 40 | 41 | # filter parameter 42 | cutoff=30, 43 | order=5, 44 | 45 | # mel-basis parameters 46 | fmin=90, 47 | fmax=7600, 48 | 49 | ########## Model Parameters ###################################### 50 | # input 51 | seq_len = 128, 52 | 53 | # Model 54 | dim_neck = 32, 55 | dim_emb = 12, 56 | dim_pre = 512, 57 | freq = 16, 58 | pitch_bin = 256, 59 | 60 | ################################################################################ 61 | # Training: 62 | batch_size = 2, # it is equal to N 63 | val_batch_size = 2, 64 | adam_beta1=0.9, 65 | adam_beta2=0.999, 66 | adam_eps=1e-8, 67 | amsgrad=False, 68 | initial_learning_rate= 1e-3, 69 | final_learning_rate = 1e-6, 70 | n_warmup_steps=4000, # ScheduledOptim : 4000, Exponential : 40000 71 | decay_rate = 0.000005, 72 | decay_step = 1000000, 73 | nepochs=500, 74 | ################################################################################ 75 | # Save 76 | # per-epoch interval 77 | save_every=20, 78 | ) 79 | 80 | 81 | def hparams_debug_string(): 82 | values = hparams.values() 83 | hp = [' %s: %s' % (name, values[name]) for name in sorted(values)] 84 | return 'Hyperparameters:\n' + '\n'.join(hp) 85 | -------------------------------------------------------------------------------- /conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import numpy as np 5 | from utils import pad_seq, make_onehot, quantize_f0_numpy 6 | from hparams import hparams 7 | import argparse 8 | import importlib 9 | import glob 10 | import itertools 11 | import torch 12 | from tqdm import tqdm 13 | 14 | def conversion(args, net, device='cuda'): 15 | assert os.path.isdir(args.data_dir), 'Cannot found data dir : {}'.format(args.data_dir) 16 | 17 | all_spk_path = [p for p in glob.glob(os.path.join(args.data_dir, '*')) if os.path.isdir(p)] 18 | all_test_samples = [glob.glob(os.path.join(p, 'test', '*.npz'))[0] for p in all_spk_path] 19 | os.makedirs(args.out_dir, exist_ok=True) 20 | 21 | all_pair = itertools.product(all_test_samples, all_test_samples) 22 | for src, trg in tqdm(all_pair, desc="converting voices"): 23 | src_name = src.split('/')[-3] 24 | trg_name = trg.split('/')[-3] 25 | src_npz = np.load(src) 26 | trg_npz = np.load(trg) 27 | 28 | x = src_npz['mel'] 29 | p = src_npz['f0'][:, np.newaxis] 30 | emb_src_np = make_onehot(src_npz['spk_label'].item(), hparams.n_speakers) 31 | emb_trg_np = make_onehot(trg_npz['spk_label'].item(), hparams.n_speakers) 32 | 33 | x_padded, pad_len = pad_seq(x, base=hparams.freq, constant_values=None) 34 | p_padded, pad_len = pad_seq(p, base=hparams.freq, constant_values=-1e10) 35 | 36 | quantized_p, _ = quantize_f0_numpy(p_padded[:, 0], num_bins=hparams.pitch_bin) 37 | 38 | x_src = torch.from_numpy(x_padded).unsqueeze(0).to(device) 39 | p_src = torch.from_numpy(quantized_p).unsqueeze(0).to(device) 40 | emb_src = torch.from_numpy(emb_src_np).unsqueeze(0).to(device) 41 | emb_trg = torch.from_numpy(emb_trg_np).unsqueeze(0).to(device) 42 | 43 | if args.model == 'autovc': 44 | out, out_psnt, _ = net(x_src, emb_src, emb_trg) 45 | elif args.model == 'autovc-f0': 46 | out, out_psnt, _ = net(x_src, p_src, emb_src, emb_trg) 47 | else: 48 | print("Wrong model name : {}".format(args.model)) 49 | 50 | print(out_psnt) 51 | 52 | if pad_len == 0: 53 | out_mel = out_psnt.squeeze().detach().cpu().numpy()[:, :] 54 | else: 55 | out_mel = out_psnt.squeeze().detach().cpu().numpy()[:-pad_len, :] 56 | src_mel = src_npz['mel'] 57 | trg_mel = trg_npz['mel'] 58 | 59 | np.save(os.path.join(args.out_dir, '{}-{}-feats.npy'.format(src_name, os.path.splitext(src.split('/')[-1])[0])), src_mel) 60 | np.save(os.path.join(args.out_dir, '{}-{}-feats.npy'.format(trg_name, os.path.splitext(trg.split('/')[-1])[0])), trg_mel) 61 | np.save(os.path.join(args.out_dir, '{}-to-{}-{}.npy'.format(src_name, trg_name, os.path.splitext(src.split('/')[-1])[0])), out_mel) 62 | 63 | 64 | if __name__=='__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--model', type=str, default='autovc-f0', help="set 'autovc' or 'autovc-f0'") 67 | parser.add_argument('--data_dir', type=str, default='/hd0/f0-autovc/preprocessed/sr16000_npz/') 68 | parser.add_argument('--out_dir', type=str, default='generated') 69 | parser.add_argument('--checkpoint', type=str, default='/hd0/f0-autovc/exp/autovc_f0_voicesplit_freq8/model/60.pkl') 70 | parser.add_argument('--architecture', type=str, default='architectures/arch_autovc_f0.py') 71 | args = parser.parse_args() 72 | 73 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 74 | 75 | # architecture 76 | arch_name = os.path.splitext(args.architecture)[0].replace("/", ".") 77 | arch = importlib.import_module(arch_name) 78 | print(" [*] Load architecture : {}".format(arch_name)) 79 | 80 | # load model 81 | net = arch.get_network(hparams) 82 | net = net['net'].to(device) 83 | 84 | assert os.path.isfile(args.checkpoint), 'Cannot found model checkpoint : {}'.format(args.checkpoint) 85 | dd = torch.load(args.checkpoint) 86 | net.load_state_dict(dd['net'], strict=True) 87 | net.eval() 88 | 89 | conversion(args, net, device) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | from scipy import signal 5 | from librosa.filters import mel 6 | from scipy.signal import get_window 7 | import os 8 | from shutil import copyfile 9 | from math import ceil 10 | 11 | 12 | def butter_highpass(cutoff, fs, order=5): 13 | nyq = 0.5 * fs 14 | normal_cutoff = cutoff / nyq 15 | b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) 16 | return b, a 17 | 18 | 19 | 20 | def pySTFT(x, fft_length=1024, hop_length=256): 21 | 22 | x = np.pad(x, int(fft_length//2), mode='reflect') 23 | 24 | noverlap = fft_length - hop_length 25 | shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length) 26 | strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1]) 27 | result = np.lib.stride_tricks.as_strided(x, shape=shape, 28 | strides=strides) 29 | 30 | fft_window = get_window('hann', fft_length, fftbins=True) 31 | result = np.fft.rfft(fft_window * result, n=fft_length).T 32 | 33 | return np.abs(result) 34 | 35 | 36 | 37 | def speaker_normalization(f0, index_nonzero, mean_f0, std_f0): 38 | # f0 is logf0 39 | f0 = f0.astype(float).copy() 40 | #index_nonzero = f0 != 0 41 | f0[index_nonzero] = (f0[index_nonzero] - mean_f0) / std_f0 / 4.0 42 | f0[index_nonzero] = np.clip(f0[index_nonzero], -1, 1) 43 | f0[index_nonzero] = (f0[index_nonzero] + 1) / 2.0 44 | return f0 45 | 46 | 47 | 48 | def quantize_f0_numpy(x, num_bins=256): 49 | # x is logf0 50 | assert x.ndim==1 51 | x = x.astype(float).copy() 52 | uv = (x<=0) 53 | x[uv] = 0.0 54 | assert (x >= 0).all() and (x <= 1).all() 55 | x = np.round(x * (num_bins-1)) 56 | x = x + 1 57 | x[uv] = 0.0 58 | enc = np.zeros((len(x), num_bins+1), dtype=np.float32) 59 | enc[np.arange(len(x)), x.astype(np.int32)] = 1.0 60 | return enc, x.astype(np.int64) 61 | 62 | 63 | 64 | def quantize_f0_torch(x, num_bins=256): 65 | # x is logf0 66 | B = x.size(0) 67 | x = x.view(-1).clone() 68 | uv = (x<=0) 69 | x[uv] = 0 70 | assert (x >= 0).all() and (x <= 1).all() 71 | x = torch.round(x * (num_bins-1)) 72 | x = x + 1 73 | x[uv] = 0 74 | enc = torch.zeros((x.size(0), num_bins+1), device=x.device) 75 | enc[torch.arange(x.size(0)), x.long()] = 1 76 | return enc.view(B, -1, num_bins+1), x.view(B, -1).long() 77 | 78 | 79 | 80 | def get_mask_from_lengths(lengths, max_len): 81 | ids = torch.arange(0, max_len, device=lengths.device) 82 | mask = (ids >= lengths.unsqueeze(1)).bool() 83 | return mask 84 | 85 | 86 | 87 | def pad_seq_to_2(x, len_out=128): 88 | len_pad = (len_out - x.shape[1]) 89 | assert len_pad >= 0 90 | return np.pad(x, ((0,0),(0,len_pad),(0,0)), 'constant'), len_pad 91 | 92 | def pad_seq(x, base=32, constant_values=0): 93 | len_out = int(base * ceil(float(x.shape[0])/base)) 94 | len_pad = len_out - x.shape[0] 95 | assert len_pad >= 0 96 | return np.pad(x, ((0,len_pad),(0,0)), 'constant', constant_values=constant_values), len_pad 97 | 98 | def make_onehot(label, n_classes): 99 | speaker_vector = np.zeros(n_classes) 100 | speaker_vector[label] = 1 101 | return speaker_vector.astype(dtype=np.float32) 102 | 103 | def makedirs(path): 104 | if not os.path.exists(path): 105 | print(" [*] Make directories : {}".format(path)) 106 | os.makedirs(path) 107 | 108 | def prepare_dirs(config): 109 | if hasattr(config, 'save_path'): 110 | log_dir = config.save_path 111 | os.makedirs(log_dir, exist_ok=True) 112 | 113 | if hasattr(config, 'name'): 114 | exp_name = config.name 115 | root_path = os.path.join(log_dir, exp_name) 116 | log_path = os.path.join(root_path, 'log') 117 | model_path = os.path.join(root_path, 'model') 118 | sample_path = os.path.join(root_path, 'samples') 119 | 120 | makedirs(log_path) 121 | makedirs(model_path) 122 | makedirs(sample_path) 123 | 124 | copyfile("hparams.py", os.path.join(root_path, "hparams_exp.py")) 125 | copyfile(config.architecture, os.path.join(root_path, os.path.basename(config.architecture))) 126 | copyfile(config.solver, os.path.join(root_path, os.path.basename(config.solver))) 127 | 128 | return model_path, log_path, sample_path -------------------------------------------------------------------------------- /datasets/voicesplit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.simplefilter(action='ignore', category=FutureWarning) 4 | import numpy as np 5 | from tqdm import tqdm 6 | from concurrent.futures import ProcessPoolExecutor 7 | from functools import partial 8 | import pickle 9 | import pkbar 10 | import glob 11 | import torch 12 | 13 | import random 14 | from scipy import signal 15 | from librosa.filters import mel 16 | from numpy.random import RandomState 17 | from pysptk import sptk 18 | import librosa 19 | from utils import butter_highpass 20 | from utils import speaker_normalization 21 | from utils import pySTFT 22 | from hparams import hparams 23 | 24 | mel_basis = mel(hparams.sample_rate, hparams.fft_size, fmin=hparams.fmin, fmax=hparams.fmax, n_mels=hparams.num_mels).T 25 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) 26 | b, a = butter_highpass(hparams.cutoff, hparams.sample_rate, order=hparams.order) 27 | 28 | def build_from_path(hparams, in_dir, out_dir, spk_emb_path, spk2gen_path, num_workers=16): 29 | 30 | executor = ProcessPoolExecutor(max_workers=num_workers) 31 | 32 | # load spk paths 33 | if hparams.used_spks is not None: 34 | spk_paths = [p for p in glob.glob(os.path.join(in_dir, "*")) if os.path.isdir(p) and os.path.basename(p) in hparams.used_spks] 35 | else: 36 | spk_paths = [p for p in glob.glob(os.path.join(in_dir, "*")) if os.path.isdir(p)] 37 | 38 | # load speaker embedding 39 | if spk_emb_path: 40 | spk_embs = pickle.load(open(spk_emb_path, 'rb')) 41 | 42 | # load speaker to gender 43 | if spk2gen_path is not None: 44 | spk2gen = pickle.load(open(spk2gen_path, "rb")) 45 | else: 46 | raise ValueError 47 | 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | # preprocessing per speaker 51 | for i, spk_path in enumerate(spk_paths): 52 | spk_name = os.path.basename(spk_path) 53 | 54 | if spk_emb_path: 55 | emb_idx = -1 56 | for i in range(len(spk_embs)): 57 | if spk_embs[i][0] == spk_name: 58 | emb_idx = i 59 | break 60 | 61 | gender = spk2gen[spk_name] 62 | assert gender == 'M' or gender == 'F' 63 | 64 | # make speaker directory 65 | os.makedirs(os.path.join(out_dir, spk_name), exist_ok=True) 66 | os.makedirs(os.path.join(out_dir, spk_name, 'train'), exist_ok=True) 67 | os.makedirs(os.path.join(out_dir, spk_name, 'val'), exist_ok=True) 68 | os.makedirs(os.path.join(out_dir, spk_name, 'test'), exist_ok=True) 69 | 70 | # glob all samples for a speaker 71 | all_wav_path = glob.glob(os.path.join(spk_path, "*.wav")) 72 | random.shuffle(all_wav_path) 73 | 74 | total_num = len(all_wav_path) 75 | train_num = int(total_num * 0.95) 76 | val_num = total_num - train_num - 1 77 | test_num = 1 78 | 79 | pbar = pkbar.Pbar(name='loading and processing dataset', target=len(all_wav_path)) 80 | 81 | futures = [] 82 | for j, wav_path in tqdm(enumerate(all_wav_path)): 83 | wav_name = os.path.basename(wav_path) 84 | spk_emb = spk_embs[emb_idx][1] if spk_emb_path else None 85 | 86 | if j < train_num: 87 | npz_name = os.path.join(out_dir, spk_name, 'train', wav_name[:-4] + ".npz") 88 | elif j >= train_num and j < train_num + val_num: 89 | npz_name = os.path.join(out_dir, spk_name, 'val', wav_name[:-4] + ".npz") 90 | else: 91 | npz_name = os.path.join(out_dir, spk_name, 'test', wav_name[:-4] + ".npz") 92 | 93 | futures.append(executor.submit(partial(_processing_data, hparams, wav_path, i, spk_emb, gender, npz_name, pbar, i))) 94 | 95 | results = [future.result() for future in futures if future.result() is not None] 96 | 97 | print('Finish Preprocessing') 98 | 99 | 100 | def _processing_data(hparams, full_path, spk_label, spk_emb, gender, npz_name, pbar, i): 101 | if gender == 'M': 102 | lo, hi = 50, 250 103 | elif gender == 'F': 104 | lo, hi = 100, 600 105 | else: 106 | raise ValueError 107 | 108 | prng = RandomState(int(random.random())) 109 | x, fs = librosa.load(full_path, sr=hparams.sample_rate) 110 | assert fs == hparams.sample_rate 111 | if x.shape[0] % hparams.hop_size == 0: 112 | x = np.concatenate((x, np.array([1e-06])), axis=0) 113 | y = signal.filtfilt(b, a, x) 114 | wav = y * 0.96 + (prng.rand(y.shape[0]) - 0.5) * 1e-06 115 | 116 | # compute spectrogram 117 | D = pySTFT(wav).T 118 | D_mel = np.dot(D, mel_basis) 119 | D_db = 20 * np.log10(np.maximum(min_level, D_mel)) - hparams.ref_level_db 120 | S = (D_db + 100) / 100 121 | 122 | # extract f0 123 | f0_rapt = sptk.rapt(wav.astype(np.float32) * 32768, fs, hparams.hop_size, min=lo, max=hi, otype=2) 124 | index_nonzero = (f0_rapt != -1e10) 125 | mean_f0, std_f0 = np.mean(f0_rapt[index_nonzero]), np.std(f0_rapt[index_nonzero]) 126 | f0_norm = speaker_normalization(f0_rapt, index_nonzero, mean_f0, std_f0) 127 | 128 | assert len(S) == len(f0_rapt) 129 | 130 | data = { 131 | 'mel': S.astype(np.float32), 132 | 'f0': f0_norm.astype(np.float32), 133 | 'spk_label': spk_label 134 | } 135 | if spk_emb is not None: 136 | data['spk_emb'] = spk_emb 137 | 138 | np.savez(npz_name, **data) 139 | pbar.update(i) 140 | 141 | -------------------------------------------------------------------------------- /solver/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from tqdm import tqdm 5 | from collections import OrderedDict 6 | from plotting_utils import plot_spectrogram_to_numpy 7 | 8 | class Base: 9 | def __init__(self, 10 | model_dir, 11 | log_dir, 12 | sample_dir): 13 | self.model_dir = model_dir 14 | self.log_dir = log_dir 15 | self.sample_dir = sample_dir 16 | 17 | self.global_epoch = 0 18 | self.global_step = 0 19 | 20 | self.build_tensorboard(self.log_dir) 21 | 22 | def train(self, 23 | itr_train, 24 | itr_valid, 25 | testset, 26 | epochs, 27 | save_every=1, 28 | verbose=True): 29 | 30 | f_mode = 'w' if not os.path.exists("%s/results.txt" % self.log_dir) else 'a' 31 | f = None 32 | if self.log_dir is not None: 33 | f = open("%s/results.txt" % self.log_dir, f_mode) 34 | 35 | try: 36 | for epoch in range(self.global_epoch, epochs): 37 | epoch_start_time = time.time() 38 | # Training. 39 | if verbose: 40 | pbar = tqdm(total=len(itr_train)) 41 | train_dict = OrderedDict({'epoch': epoch+1}) 42 | # item, pose, id 43 | for b, batch in enumerate(itr_train): 44 | self.global_step += 1 45 | batch = self.prepare_batch(batch) 46 | losses, outputs = self.train_on_instance(*batch, 47 | iter=b+1, 48 | global_step = self.global_step) 49 | for key in losses: 50 | this_key = 'train/%s' % key 51 | if this_key not in train_dict: 52 | train_dict[this_key] = [] 53 | train_dict[this_key].append(losses[key]) 54 | self.logger.scalar_summary(this_key, losses[key], self.global_step) 55 | if verbose: 56 | pbar.update(1) 57 | pbar.set_postfix(self._get_stats(train_dict, 'train')) 58 | if verbose: 59 | pbar.close() 60 | valid_dict = {} 61 | # TODO: enable valid 62 | if verbose: 63 | pbar = tqdm(total=len(itr_valid)) 64 | # Validation. 65 | valid_dict = OrderedDict({}) 66 | for b, valid_batch in enumerate(itr_valid): 67 | valid_batch = self.prepare_batch(valid_batch) 68 | valid_losses, valid_outputs = self.eval_on_instance(*valid_batch, 69 | iter=b+1, 70 | global_step = self.global_step) 71 | for key in valid_losses: 72 | this_key = 'valid/%s' % key 73 | if this_key not in valid_dict: 74 | valid_dict[this_key] = [] 75 | valid_dict[this_key].append(valid_losses[key]) 76 | self.logger.scalar_summary(this_key, valid_losses[key], self.global_step) 77 | 78 | self.summary(valid_outputs, epoch) 79 | 80 | 81 | if verbose: 82 | pbar.update(1) 83 | pbar.set_postfix(self._get_stats(valid_dict, 'valid')) 84 | 85 | if verbose: 86 | pbar.close() 87 | # Step learning rates. 88 | # for sched in self.schedulers: 89 | # sched.step(self.global_step) 90 | # Update dictionary of values. 91 | all_dict = train_dict 92 | all_dict.update(valid_dict) 93 | for key in all_dict: 94 | all_dict[key] = np.mean(all_dict[key]) 95 | for key in self.optim: 96 | all_dict["lr_%s" % key] = \ 97 | self.optim[key].state_dict()['param_groups'][0]['lr'] 98 | all_dict['time'] = time.time() - epoch_start_time 99 | str_ = ",".join([str(all_dict[key]) for key in all_dict]) 100 | print(str_) 101 | if self.log_dir is not None: 102 | if (epoch+1) == 1: 103 | f.write(",".join(all_dict.keys()) + "\n") 104 | f.write(str_ + "\n") 105 | f.flush() 106 | if (epoch+1) % save_every == 0 and self.model_dir is not None: 107 | self.save(filename="%s/%i.pkl" % (self.model_dir, epoch+1), 108 | epoch=epoch+1) 109 | #self.summary(testset, epoch=epoch+1) 110 | 111 | self.global_epoch += 1 112 | except KeyboardInterrupt: 113 | self.save(filename="%s/%i.pkl" % (self.model_dir, epoch + 1), 114 | epoch=epoch + 1) 115 | print("%s/%i.pkl is saved!" % (self.model_dir, epoch + 1)) 116 | if f is not None: 117 | f.close() 118 | 119 | def vis_batch(self, batch, outputs): 120 | raise NotImplementedError() 121 | 122 | def build_tensorboard(self, log_dir): 123 | """Build a tensorboard logger.""" 124 | from logger import Logger 125 | self.logger = Logger(log_dir) 126 | 127 | -------------------------------------------------------------------------------- /solver/autovc_f0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from collections import OrderedDict 5 | from torch import optim 6 | from itertools import chain 7 | from solver.base import Base 8 | import argparse 9 | from hparams import hparams 10 | from utils import prepare_dirs 11 | from data_utils import prepare_dataloaders 12 | from plotting_utils import plot_spectrogram_to_numpy, plot_f0_to_numpy 13 | import importlib 14 | import os 15 | 16 | class AutoVC(Base): 17 | def __init__(self, 18 | architecture, 19 | model_dir, 20 | log_dir, 21 | sample_dir): 22 | super(AutoVC, self).__init__(model_dir, log_dir, sample_dir) 23 | 24 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | #self.device = 'cpu' 26 | 27 | # load architecture 28 | arch = importlib.import_module(architecture) 29 | 30 | nets = arch.get_network(hparams) 31 | self.net = nets['net'].to(self.device) 32 | 33 | opt_args = {'lr': hparams.initial_learning_rate, 'betas': (hparams.adam_beta1, hparams.adam_beta2)} 34 | g_params = self.net.parameters() 35 | optim_g = optim.Adam(filter(lambda p: p.requires_grad, g_params), **opt_args) 36 | 37 | self.optim = { 38 | 'g': optim_g 39 | } 40 | 41 | self.last_epoch = 0 42 | self.load_strict = True 43 | 44 | def _get_stats(self, dict_, mode): 45 | stats = OrderedDict({}) 46 | for key in dict_.keys(): 47 | stats[key] = np.mean(dict_[key]) 48 | return stats 49 | 50 | def reset_grad(self): 51 | """Reset the gradient buffers.""" 52 | self.optim['g'].zero_grad() 53 | 54 | def _train(self): 55 | self.net.train() 56 | 57 | def _eval(self): 58 | self.net.eval() 59 | 60 | def train_on_instance(self, 61 | x_real, 62 | p_real, 63 | emb_org, 64 | **kwargs): 65 | self._train() 66 | # Identity mapping loss 67 | x_identic, x_identic_psnt, code_real = self.net(x_real, p_real, emb_org, emb_org) 68 | g_loss_id = F.mse_loss(x_real, x_identic) 69 | g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt) 70 | 71 | # Code semantic loss. 72 | code_reconst = self.net(x_identic_psnt, None, emb_org, None) 73 | g_loss_cd = F.l1_loss(code_real, code_reconst) 74 | 75 | # Backward and optimize. 76 | g_loss = g_loss_id + g_loss_id_psnt + g_loss_cd 77 | self.reset_grad() 78 | g_loss.backward() 79 | self.optim['g'].step() 80 | 81 | ## ---------------------------------------------- 82 | ## Collecting losses and outputs 83 | ## ---------------------------------------------- 84 | losses = { 85 | 'G/loss_id': g_loss_id.item(), 86 | 'G/loss_id_psnt': g_loss_id_psnt.item(), 87 | 'G/loss_cd': g_loss_cd.item() 88 | } 89 | 90 | outputs = { 91 | 'GT': x_real.detach().cpu(), 92 | 'recon': x_identic_psnt.detach().cpu(), 93 | } 94 | 95 | return losses, outputs 96 | 97 | def eval_on_instance(self, 98 | x_real, 99 | p_real, 100 | emb_org, 101 | **kwargs): 102 | self._eval() 103 | with torch.no_grad(): 104 | # Identity mapping loss 105 | x_identic, x_identic_psnt, code_real = self.net(x_real, p_real, emb_org, emb_org) 106 | g_loss_id = F.mse_loss(x_real, x_identic) 107 | g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt) 108 | 109 | # Code semantic loss. 110 | code_reconst = self.net(x_identic_psnt, None, emb_org, None) 111 | g_loss_cd = F.l1_loss(code_real, code_reconst) 112 | 113 | # Backward and optimize. 114 | g_loss = g_loss_id + g_loss_id_psnt + g_loss_cd 115 | 116 | ## ---------------------------------------------- 117 | ## Collecting losses and outputs 118 | ## ---------------------------------------------- 119 | losses = { 120 | 'G/loss_id': g_loss_id.item(), 121 | 'G/loss_id_psnt': g_loss_id_psnt.item(), 122 | 'G/loss_cd': g_loss_cd.item() 123 | } 124 | 125 | outputs = { 126 | 'GT': x_real.detach().cpu(), 127 | 'recon': x_identic_psnt.detach().cpu(), 128 | } 129 | 130 | return losses, outputs 131 | 132 | def prepare_batch(self, batch): 133 | if len(batch) != 3: 134 | raise Exception("Expected batch to eight element: " + 135 | "mel, quantized_p, spk") 136 | 137 | x = batch["mel"].to(self.device) 138 | quantized_p = batch["quantized_p"].to(self.device) 139 | spk = batch["spk"].to(self.device) 140 | 141 | return [x, quantized_p, spk] 142 | 143 | def save(self, filename, epoch): 144 | dd = {} 145 | # Save the models. 146 | dd['net'] = self.net.state_dict() 147 | # Save the models' optim state. 148 | for key in self.optim: 149 | dd['optim_%s' % key] = self.optim[key].state_dict() 150 | dd['epoch'] = epoch 151 | dd['global_epoch'] = self.global_epoch 152 | dd['global_step'] = self.global_step 153 | torch.save(dd, filename) 154 | 155 | def load(self, filename): 156 | # if not self.use_cuda: 157 | # map_location = lambda storage, loc: storage 158 | # else: 159 | # map_location = None 160 | dd = torch.load(filename) 161 | #map_location=map_location) 162 | # Load the models. 163 | self.net.load_state_dict(dd['net'], strict=self.load_strict) 164 | 165 | # Load the models' optim state. 166 | for key in self.optim: 167 | self.optim[key].load_state_dict(dd['optim_%s' % key]) 168 | self.last_epoch = dd['epoch'] 169 | self.global_epoch = dd['global_epoch'] 170 | self.global_step = dd['global_step'] 171 | 172 | def summary(self, outputs, epoch): 173 | self.logger.image_summary("mel-spectrogram", 174 | plot_spectrogram_to_numpy("ground-truth", outputs['GT'][0].numpy().T), 175 | epoch) 176 | self.logger.image_summary("mel-spectrogram", 177 | plot_spectrogram_to_numpy("reconstruction", outputs['recon'][0].numpy().T), 178 | epoch) -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import ceil 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | from torch.utils.data import DataLoader 7 | import os 8 | import glob 9 | from hparams import hparams 10 | import pickle 11 | 12 | from utils import quantize_f0_numpy 13 | 14 | 15 | class VCDataset(torch.utils.data.Dataset): 16 | """ 17 | 1) loads audio,text pairs 18 | 2) normalizes text and converts them to sequences of one-hot vectors 19 | 3) computes mel-spectrograms from audio files. 20 | """ 21 | def __init__(self, data_dir, hparams, is_train=True): 22 | self.spk_path = [p for p in glob.glob(os.path.join(data_dir, '*')) if os.path.isdir(p)] 23 | self.is_train = is_train 24 | self.npz_path, self.metadata = self.get_npz_path(self.spk_path) 25 | self.n_speakers = hparams.n_speakers 26 | 27 | assert len(self.npz_path) > 0, "npz 파일 탐색 실패" 28 | random.seed(1234) 29 | random.shuffle(self.npz_path) 30 | 31 | def get_npz_path(self, spk_path): 32 | metadata = {} 33 | npz_path = [] 34 | for spk in spk_path: 35 | if self.is_train: 36 | spk_npz = glob.glob(os.path.join(spk, "train", "*.npz")) 37 | else: 38 | spk_npz = glob.glob(os.path.join(spk, "test", "*.npz")) 39 | npz_path += spk_npz 40 | 41 | return npz_path, metadata 42 | 43 | def get_sample(self, npz_path): 44 | # separate filename and text 45 | npz = np.load(npz_path) 46 | mel = npz['mel'].T if npz['mel'].shape[0] == 80 else npz['mel'] 47 | f0 = npz['f0'] 48 | spk_label = self.get_speaker(npz['spk_label'].item()) 49 | 50 | return (mel, f0, spk_label) 51 | 52 | def get_speaker(self, speaker): 53 | speaker_vector = np.zeros(self.n_speakers) 54 | speaker_vector[int(speaker)] = 1 55 | return speaker_vector.astype(dtype=np.float32) 56 | 57 | 58 | def __getitem__(self, index): 59 | return self.get_sample(self.npz_path[index]) 60 | 61 | def __len__(self): 62 | return len(self.npz_path) 63 | 64 | class AutoVCCollate(): 65 | def __init__(self, hparams): 66 | self.seq_len = hparams.seq_len 67 | 68 | def __call__(self, batch): 69 | # batch : B * (mel, f0, spk) 70 | mels = [b[0] for b in batch] 71 | f0s = [b[1] for b in batch] 72 | spk = [b[2] for b in batch] 73 | 74 | 75 | mel_seg = [] 76 | f0_seg = [] 77 | speaker_embeddings = [] 78 | for mel, f0, spk_emb in zip(mels, f0s, spk): 79 | frame_len = mel.shape[0] 80 | if frame_len < self.seq_len: 81 | len_pad = self.seq_len - frame_len 82 | x = np.pad(mel, ((0, len_pad), (0, 0)), 'constant') 83 | p = np.pad(f0, ((0, len_pad)), 'constant', constant_values=-1e10) 84 | else: 85 | start = np.random.randint(frame_len - self.seq_len + 1) 86 | x = mel[start:start + self.seq_len] 87 | p = f0[start:start + self.seq_len] 88 | 89 | quantized_p, _ = quantize_f0_numpy(p, num_bins=hparams.pitch_bin) 90 | 91 | mel_seg.append(x) 92 | f0_seg.append(quantized_p) 93 | speaker_embeddings.append(spk_emb) 94 | 95 | 96 | out = {"mel": torch.FloatTensor(mel_seg), 97 | "quantized_p": torch.FloatTensor(f0_seg), 98 | "spk": torch.FloatTensor(speaker_embeddings), 99 | } 100 | 101 | return out 102 | 103 | class VCTestSet: 104 | def __init__(self, data_dir, hparams): 105 | self.n_speakers = hparams.n_speakers 106 | self.spk_path = glob.glob(os.path.join(data_dir, '*')) 107 | assert len(self.spk_path) > 0, "speaker 탐색 실패" 108 | 109 | def get_random_pair(self): 110 | random.seed(1234) 111 | random.shuffle(self.spk_path) 112 | src_spk_path = self.spk_path[0] 113 | trg_spk_path = self.spk_path[1] 114 | 115 | src_spk_name = os.path.basename(src_spk_path) 116 | trg_spk_name = os.path.basename(trg_spk_path) 117 | 118 | src_npz_path = self.get_first_npz(src_spk_path) 119 | trg_npz_path = self.get_first_npz(trg_spk_path) 120 | 121 | return (src_spk_name, src_npz_path, trg_spk_name, trg_npz_path) 122 | 123 | def get_random_npz(self, spk_path): 124 | npz_path = glob.glob(os.path.join(spk_path, "test/*.npz")) 125 | idx = np.random.randint(0, len(npz_path)) 126 | 127 | return npz_path[idx] 128 | 129 | def get_first_npz(self, spk_path): 130 | npz_path = glob.glob(os.path.join(spk_path, "test/*.npz")) 131 | 132 | return npz_path[0] 133 | 134 | def parse_npz(self, npz_path): 135 | npz = np.load(npz_path) 136 | mel = torch.from_numpy(npz['mel']) 137 | f0 = torch.from_numpy(npz['f0']) 138 | spk = torch.from_numpy(self.get_speaker(npz['speaker'].item())) 139 | 140 | mel = mel.unsqueeze(0).float() 141 | f0 = f0.unsqueeze(0).float() 142 | spk = spk.unsqueeze(0).float() 143 | 144 | return (mel, f0, spk) 145 | 146 | def get_speaker(self, speaker): 147 | speaker_vector = np.zeros(self.n_speakers) 148 | speaker_vector[int(speaker)] = 1 149 | return speaker_vector.astype(dtype=np.float32) 150 | 151 | 152 | def prepare_dataloaders(data_path, hparams): 153 | # Get data, data loaders and collate function ready 154 | trainset = VCDataset(data_path, hparams, is_train=True) 155 | valset = VCDataset(data_path, hparams, is_train=False) 156 | collate_fn = AutoVCCollate(hparams) 157 | 158 | train_loader = DataLoader(trainset, num_workers=1, shuffle=True, 159 | batch_size=hparams.batch_size, 160 | drop_last=True, collate_fn=collate_fn) 161 | val_loader = DataLoader(valset, num_workers=1, shuffle=False, 162 | batch_size=hparams.val_batch_size, 163 | drop_last=True, collate_fn=collate_fn) 164 | test_set = VCTestSet(data_path, hparams) 165 | 166 | return train_loader, val_loader, test_set 167 | 168 | if __name__ == "__main__": 169 | train_loader, val_loader, test_set = prepare_dataloaders("/hd0/f0-autovc/preprocessed/sr16000_npz", hparams) 170 | train_iter = iter(train_loader) 171 | val_iter = iter(val_loader) 172 | 173 | out = train_iter.__next__() 174 | """ 175 | out = {"mel": torch.FloatTensor(mel_targets), 176 | "phoneme": torch.FloatTensor(phonemes), 177 | "D": torch.FloatTensor(Ds), 178 | "mel_pos": torch.LongTensor(mel_pos), 179 | "mel_max_len": max_mel_len, 180 | "D_max_len": max_D_len 181 | } 182 | """ 183 | 184 | print(out["mel"].size()) 185 | print(out["quantized_p"].size()) 186 | print(out["spk"].size()) 187 | -------------------------------------------------------------------------------- /architectures/arch_autovc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from hparams import hparams 7 | 8 | 9 | class LinearNorm(torch.nn.Module): 10 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 11 | super(LinearNorm, self).__init__() 12 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 13 | 14 | torch.nn.init.xavier_uniform_( 15 | self.linear_layer.weight, 16 | gain=torch.nn.init.calculate_gain(w_init_gain)) 17 | 18 | def forward(self, x): 19 | return self.linear_layer(x) 20 | 21 | 22 | class ConvNorm(torch.nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 24 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 25 | super(ConvNorm, self).__init__() 26 | if padding is None: 27 | assert (kernel_size % 2 == 1) 28 | padding = int(dilation * (kernel_size - 1) / 2) 29 | 30 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 31 | kernel_size=kernel_size, stride=stride, 32 | padding=padding, dilation=dilation, 33 | bias=bias) 34 | 35 | torch.nn.init.xavier_uniform_( 36 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 37 | 38 | def forward(self, signal): 39 | conv_signal = self.conv(signal) 40 | return conv_signal 41 | 42 | 43 | class Encoder(nn.Module): 44 | """Encoder module: 45 | """ 46 | 47 | def __init__(self, dim_neck, dim_emb, freq): 48 | super(Encoder, self).__init__() 49 | self.dim_neck = dim_neck 50 | self.freq = freq 51 | 52 | convolutions = [] 53 | for i in range(3): 54 | conv_layer = nn.Sequential( 55 | ConvNorm(80 + dim_emb if i == 0 else 512, 56 | 512, 57 | kernel_size=5, stride=1, 58 | padding=2, 59 | dilation=1, w_init_gain='relu'), 60 | nn.BatchNorm1d(512)) 61 | convolutions.append(conv_layer) 62 | self.convolutions = nn.ModuleList(convolutions) 63 | 64 | self.lstm = nn.LSTM(512, dim_neck, 2, batch_first=True, bidirectional=True) 65 | 66 | def forward(self, x, c_org): 67 | x = x.squeeze(1).transpose(2, 1) 68 | c_org = c_org.unsqueeze(-1).expand(-1, -1, x.size(-1)) 69 | x = torch.cat((x, c_org), dim=1) 70 | 71 | for conv in self.convolutions: 72 | x = F.relu(conv(x)) 73 | x = x.transpose(1, 2) 74 | 75 | self.lstm.flatten_parameters() 76 | outputs, _ = self.lstm(x) 77 | out_forward = outputs[:, :, :self.dim_neck] 78 | out_backward = outputs[:, :, self.dim_neck:] 79 | 80 | codes = [] 81 | for i in range(0, outputs.size(1), self.freq): 82 | codes.append(torch.cat((out_forward[:, i + self.freq - 1, :], out_backward[:, i, :]), dim=-1)) 83 | 84 | return codes 85 | 86 | 87 | class Decoder(nn.Module): 88 | """Decoder module: 89 | """ 90 | 91 | def __init__(self, dim_neck, dim_emb, dim_pre): 92 | super(Decoder, self).__init__() 93 | 94 | self.lstm1 = nn.LSTM(dim_neck * 2 + dim_emb, dim_pre, 1, batch_first=True) 95 | 96 | convolutions = [] 97 | for i in range(3): 98 | conv_layer = nn.Sequential( 99 | ConvNorm(dim_pre, 100 | dim_pre, 101 | kernel_size=5, stride=1, 102 | padding=2, 103 | dilation=1, w_init_gain='relu'), 104 | nn.BatchNorm1d(dim_pre)) 105 | convolutions.append(conv_layer) 106 | self.convolutions = nn.ModuleList(convolutions) 107 | 108 | self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True) 109 | 110 | self.linear_projection = LinearNorm(1024, 80) 111 | 112 | def forward(self, x): 113 | 114 | # self.lstm1.flatten_parameters() 115 | x, _ = self.lstm1(x) 116 | x = x.transpose(1, 2) 117 | 118 | for conv in self.convolutions: 119 | x = F.relu(conv(x)) 120 | x = x.transpose(1, 2) 121 | 122 | outputs, _ = self.lstm2(x) 123 | 124 | decoder_output = self.linear_projection(outputs) 125 | 126 | return decoder_output 127 | 128 | 129 | class Postnet(nn.Module): 130 | """Postnet 131 | - Five 1-d convolution with 512 channels and kernel size 5 132 | """ 133 | 134 | def __init__(self): 135 | super(Postnet, self).__init__() 136 | self.convolutions = nn.ModuleList() 137 | 138 | self.convolutions.append( 139 | nn.Sequential( 140 | ConvNorm(80, 512, 141 | kernel_size=5, stride=1, 142 | padding=2, 143 | dilation=1, w_init_gain='tanh'), 144 | nn.BatchNorm1d(512)) 145 | ) 146 | 147 | for i in range(1, 5 - 1): 148 | self.convolutions.append( 149 | nn.Sequential( 150 | ConvNorm(512, 151 | 512, 152 | kernel_size=5, stride=1, 153 | padding=2, 154 | dilation=1, w_init_gain='tanh'), 155 | nn.BatchNorm1d(512)) 156 | ) 157 | 158 | self.convolutions.append( 159 | nn.Sequential( 160 | ConvNorm(512, 80, 161 | kernel_size=5, stride=1, 162 | padding=2, 163 | dilation=1, w_init_gain='linear'), 164 | nn.BatchNorm1d(80)) 165 | ) 166 | 167 | def forward(self, x): 168 | for i in range(len(self.convolutions) - 1): 169 | x = torch.tanh(self.convolutions[i](x)) 170 | 171 | x = self.convolutions[-1](x) 172 | 173 | return x 174 | 175 | 176 | class Generator(nn.Module): 177 | """Generator network.""" 178 | 179 | def __init__(self, hparams): 180 | super(Generator, self).__init__() 181 | self.dim_neck = hparams.dim_neck 182 | self.dim_emb = hparams.dim_emb 183 | self.dim_pre = hparams.dim_pre 184 | self.freq = hparams.freq 185 | 186 | self.encoder = Encoder(self.dim_neck, self.dim_emb, self.freq) 187 | self.decoder = Decoder(self.dim_neck, self.dim_emb, self.dim_pre) 188 | self.postnet = Postnet() 189 | 190 | def forward(self, x, c_org, c_trg): 191 | 192 | codes = self.encoder(x, c_org) 193 | if c_trg is None: 194 | return torch.cat(codes, dim=-1) 195 | 196 | tmp = [] 197 | for code in codes: 198 | tmp.append(code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1)) 199 | code_exp = torch.cat(tmp, dim=1) 200 | 201 | encoder_outputs = torch.cat((code_exp, c_trg.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1) 202 | 203 | mel_outputs = self.decoder(encoder_outputs) 204 | 205 | mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) 206 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) 207 | 208 | mel_outputs = mel_outputs.unsqueeze(1) 209 | mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1) 210 | 211 | return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1) 212 | 213 | def get_network(hparams, **kwargs): 214 | gen = Generator(hparams) 215 | 216 | networks = {'net': gen} 217 | return networks 218 | -------------------------------------------------------------------------------- /solver/autovc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from collections import OrderedDict 5 | from torch import optim 6 | from itertools import chain 7 | from solver.base import Base 8 | import argparse 9 | from hparams import hparams 10 | from utils import prepare_dirs 11 | from data_utils import prepare_dataloaders 12 | from plotting_utils import plot_spectrogram_to_numpy, plot_f0_to_numpy 13 | import importlib 14 | import os 15 | 16 | class AutoVC(Base): 17 | def __init__(self, 18 | architecture, 19 | model_dir, 20 | log_dir, 21 | sample_dir): 22 | super(AutoVC, self).__init__(model_dir, log_dir, sample_dir) 23 | 24 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | #self.device = 'cpu' 26 | 27 | # load architecture 28 | arch = importlib.import_module(architecture) 29 | 30 | nets = arch.get_network(hparams) 31 | self.net = nets['net'].to(self.device) 32 | 33 | opt_args = {'lr': hparams.initial_learning_rate, 'betas': (hparams.adam_beta1, hparams.adam_beta2)} 34 | g_params = self.net.parameters() 35 | optim_g = optim.Adam(filter(lambda p: p.requires_grad, g_params), **opt_args) 36 | 37 | self.optim = { 38 | 'g': optim_g 39 | } 40 | 41 | self.last_epoch = 0 42 | self.load_strict = True 43 | 44 | def _get_stats(self, dict_, mode): 45 | stats = OrderedDict({}) 46 | for key in dict_.keys(): 47 | stats[key] = np.mean(dict_[key]) 48 | return stats 49 | 50 | def reset_grad(self): 51 | """Reset the gradient buffers.""" 52 | self.optim['g'].zero_grad() 53 | 54 | def _train(self): 55 | self.net.train() 56 | 57 | def _eval(self): 58 | self.net.eval() 59 | 60 | def train_on_instance(self, 61 | x_real, 62 | p_real, 63 | emb_org, 64 | **kwargs): 65 | self._train() 66 | # Identity mapping loss 67 | x_identic, x_identic_psnt, code_real = self.net(x_real, emb_org, emb_org) 68 | g_loss_id = F.mse_loss(x_real, x_identic) 69 | g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt) 70 | 71 | # Code semantic loss. 72 | code_reconst = self.net(x_identic_psnt, emb_org, None) 73 | g_loss_cd = F.l1_loss(code_real, code_reconst) 74 | 75 | # Backward and optimize. 76 | g_loss = g_loss_id + g_loss_id_psnt + g_loss_cd 77 | self.reset_grad() 78 | g_loss.backward() 79 | self.optim['g'].step() 80 | 81 | ## ---------------------------------------------- 82 | ## Collecting losses and outputs 83 | ## ---------------------------------------------- 84 | losses = { 85 | 'G/loss_id': g_loss_id.item(), 86 | 'G/loss_id_psnt': g_loss_id_psnt.item(), 87 | 'G/loss_cd': g_loss_cd.item() 88 | } 89 | 90 | outputs = { 91 | 'GT': x_real.detach().cpu(), 92 | 'recon': x_identic_psnt.detach().cpu(), 93 | } 94 | 95 | return losses, outputs 96 | 97 | def eval_on_instance(self, 98 | x_real, 99 | p_real, 100 | emb_org, 101 | **kwargs): 102 | self._eval() 103 | with torch.no_grad(): 104 | # Identity mapping loss 105 | x_identic, x_identic_psnt, code_real = self.net(x_real, emb_org, emb_org) 106 | g_loss_id = F.mse_loss(x_real, x_identic) 107 | g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt) 108 | 109 | # Code semantic loss. 110 | code_reconst = self.net(x_identic_psnt, emb_org, None) 111 | g_loss_cd = F.l1_loss(code_real, code_reconst) 112 | 113 | # Backward and optimize. 114 | g_loss = g_loss_id + g_loss_id_psnt + g_loss_cd 115 | 116 | ## ---------------------------------------------- 117 | ## Collecting losses and outputs 118 | ## ---------------------------------------------- 119 | losses = { 120 | 'G/loss_id': g_loss_id.item(), 121 | 'G/loss_id_psnt': g_loss_id_psnt.item(), 122 | 'G/loss_cd': g_loss_cd.item() 123 | } 124 | 125 | outputs = { 126 | 'GT': x_real.detach().cpu(), 127 | 'recon': x_identic_psnt.detach().cpu(), 128 | } 129 | 130 | return losses, outputs 131 | 132 | def prepare_batch(self, batch): 133 | if len(batch) != 3: 134 | raise Exception("Expected batch to eight element: " + 135 | "mel, quantized_p, spk") 136 | 137 | x = batch["mel"].to(self.device) 138 | quantized_p = batch["quantized_p"].to(self.device) 139 | spk = batch["spk"].to(self.device) 140 | 141 | return [x, quantized_p, spk] 142 | 143 | def save(self, filename, epoch): 144 | dd = {} 145 | # Save the models. 146 | dd['net'] = self.net.state_dict() 147 | # Save the models' optim state. 148 | for key in self.optim: 149 | dd['optim_%s' % key] = self.optim[key].state_dict() 150 | dd['epoch'] = epoch 151 | dd['global_epoch'] = self.global_epoch 152 | dd['global_step'] = self.global_step 153 | torch.save(dd, filename) 154 | 155 | def load(self, filename): 156 | # if not self.use_cuda: 157 | # map_location = lambda storage, loc: storage 158 | # else: 159 | # map_location = None 160 | dd = torch.load(filename) 161 | #map_location=map_location) 162 | # Load the models. 163 | self.net.load_state_dict(dd['net'], strict=self.load_strict) 164 | 165 | # Load the models' optim state. 166 | for key in self.optim: 167 | self.optim[key].load_state_dict(dd['optim_%s' % key]) 168 | self.last_epoch = dd['epoch'] 169 | self.global_epoch = dd['global_epoch'] 170 | self.global_step = dd['global_step'] 171 | 172 | def summary(self, outputs, epoch): 173 | self.logger.image_summary("mel-spectrogram", 174 | plot_spectrogram_to_numpy("ground-truth", outputs['GT'][0].numpy().T), 175 | epoch) 176 | self.logger.image_summary("mel-spectrogram", 177 | plot_spectrogram_to_numpy("reconstruction", outputs['recon'][0].numpy().T), 178 | epoch) 179 | 180 | 181 | if __name__=="__main__": 182 | def parse_args(): 183 | parser = argparse.ArgumentParser(description="") 184 | parser.add_argument('--name', type=str, default="test") 185 | parser.add_argument('--save_path', type=str, default="/hd0/voice_mixer/VAEAutoVC/") 186 | parser.add_argument('--data_dir', type=str, default='/hd0/voice_mixer/preprocessed/VCTK20_f0_norm_all/seen') 187 | args = parser.parse_args() 188 | return args 189 | 190 | args = parse_args() 191 | 192 | model_dir, log_dir, sample_dir = prepare_dirs(args) 193 | 194 | train_loader, val_loader, testset = prepare_dataloaders(args.data_dir, hparams) 195 | 196 | solver = FastVC(model_dir, 197 | log_dir, 198 | sample_dir) 199 | 200 | 201 | solver.train( 202 | train_loader, 203 | val_loader, 204 | testset, 205 | hparams.nepochs, 206 | hparams.save_every, 207 | verbose=True) -------------------------------------------------------------------------------- /architectures/arch_autovc_f0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class LinearNorm(torch.nn.Module): 8 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 9 | super(LinearNorm, self).__init__() 10 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 11 | 12 | torch.nn.init.xavier_uniform_( 13 | self.linear_layer.weight, 14 | gain=torch.nn.init.calculate_gain(w_init_gain)) 15 | 16 | def forward(self, x): 17 | return self.linear_layer(x) 18 | 19 | 20 | class ConvNorm(torch.nn.Module): 21 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 22 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 23 | super(ConvNorm, self).__init__() 24 | if padding is None: 25 | assert (kernel_size % 2 == 1) 26 | padding = int(dilation * (kernel_size - 1) / 2) 27 | 28 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 29 | kernel_size=kernel_size, stride=stride, 30 | padding=padding, dilation=dilation, 31 | bias=bias) 32 | 33 | torch.nn.init.xavier_uniform_( 34 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 35 | 36 | def forward(self, signal): 37 | conv_signal = self.conv(signal) 38 | return conv_signal 39 | 40 | 41 | class Encoder(nn.Module): 42 | """Encoder module: 43 | """ 44 | 45 | def __init__(self, dim_neck, dim_emb, freq): 46 | super(Encoder, self).__init__() 47 | self.dim_neck = dim_neck 48 | self.freq = freq 49 | 50 | convolutions = [] 51 | for i in range(3): 52 | conv_layer = nn.Sequential( 53 | ConvNorm(80 + dim_emb if i == 0 else 512, 54 | 512, 55 | kernel_size=5, stride=1, 56 | padding=2, 57 | dilation=1, w_init_gain='relu'), 58 | nn.BatchNorm1d(512)) 59 | convolutions.append(conv_layer) 60 | self.convolutions = nn.ModuleList(convolutions) 61 | 62 | self.lstm = nn.LSTM(512, dim_neck, 2, batch_first=True, bidirectional=True) 63 | 64 | def forward(self, x, c_org): 65 | x = x.squeeze(1).transpose(2, 1) 66 | c_org = c_org.unsqueeze(-1).expand(-1, -1, x.size(-1)) 67 | x = torch.cat((x, c_org), dim=1) 68 | 69 | for conv in self.convolutions: 70 | x = F.relu(conv(x)) 71 | x = x.transpose(1, 2) 72 | 73 | self.lstm.flatten_parameters() 74 | outputs, _ = self.lstm(x) 75 | out_forward = outputs[:, :, :self.dim_neck] 76 | out_backward = outputs[:, :, self.dim_neck:] 77 | 78 | codes = [] 79 | for i in range(0, outputs.size(1), self.freq): 80 | codes.append(torch.cat((out_forward[:, i + self.freq - 1, :], out_backward[:, i, :]), dim=-1)) 81 | 82 | return codes 83 | 84 | 85 | class Decoder(nn.Module): 86 | """Decoder module: 87 | """ 88 | 89 | def __init__(self, dim_neck, dim_emb, dim_pitch, dim_pre): 90 | super(Decoder, self).__init__() 91 | 92 | self.lstm1 = nn.LSTM(dim_neck * 2 + dim_emb + (dim_pitch+1), dim_pre, 1, batch_first=True) 93 | 94 | convolutions = [] 95 | for i in range(3): 96 | conv_layer = nn.Sequential( 97 | ConvNorm(dim_pre, 98 | dim_pre, 99 | kernel_size=5, stride=1, 100 | padding=2, 101 | dilation=1, w_init_gain='relu'), 102 | nn.BatchNorm1d(dim_pre)) 103 | convolutions.append(conv_layer) 104 | self.convolutions = nn.ModuleList(convolutions) 105 | 106 | self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True) 107 | 108 | self.linear_projection = LinearNorm(1024, 80) 109 | 110 | def forward(self, x): 111 | 112 | # self.lstm1.flatten_parameters() 113 | x, _ = self.lstm1(x) 114 | x = x.transpose(1, 2) 115 | 116 | for conv in self.convolutions: 117 | x = F.relu(conv(x)) 118 | x = x.transpose(1, 2) 119 | 120 | outputs, _ = self.lstm2(x) 121 | 122 | decoder_output = self.linear_projection(outputs) 123 | 124 | return decoder_output 125 | 126 | 127 | class Postnet(nn.Module): 128 | """Postnet 129 | - Five 1-d convolution with 512 channels and kernel size 5 130 | """ 131 | 132 | def __init__(self): 133 | super(Postnet, self).__init__() 134 | self.convolutions = nn.ModuleList() 135 | 136 | self.convolutions.append( 137 | nn.Sequential( 138 | ConvNorm(80, 512, 139 | kernel_size=5, stride=1, 140 | padding=2, 141 | dilation=1, w_init_gain='tanh'), 142 | nn.BatchNorm1d(512)) 143 | ) 144 | 145 | for i in range(1, 5 - 1): 146 | self.convolutions.append( 147 | nn.Sequential( 148 | ConvNorm(512, 149 | 512, 150 | kernel_size=5, stride=1, 151 | padding=2, 152 | dilation=1, w_init_gain='tanh'), 153 | nn.BatchNorm1d(512)) 154 | ) 155 | 156 | self.convolutions.append( 157 | nn.Sequential( 158 | ConvNorm(512, 80, 159 | kernel_size=5, stride=1, 160 | padding=2, 161 | dilation=1, w_init_gain='linear'), 162 | nn.BatchNorm1d(80)) 163 | ) 164 | 165 | def forward(self, x): 166 | for i in range(len(self.convolutions) - 1): 167 | x = torch.tanh(self.convolutions[i](x)) 168 | 169 | x = self.convolutions[-1](x) 170 | 171 | return x 172 | 173 | 174 | class Generator(nn.Module): 175 | """Generator network.""" 176 | 177 | def __init__(self, hparams): 178 | super(Generator, self).__init__() 179 | self.dim_neck = hparams.dim_neck 180 | self.dim_emb = hparams.dim_emb 181 | self.dim_pre = hparams.dim_pre 182 | self.pitch_bin = hparams.pitch_bin 183 | self.freq = hparams.freq 184 | 185 | self.encoder = Encoder(self.dim_neck, self.dim_emb, self.freq) 186 | self.decoder = Decoder(self.dim_neck, self.dim_emb, self.pitch_bin, self.dim_pre) 187 | self.postnet = Postnet() 188 | 189 | def forward(self, x, f0_src, c_org, c_trg): 190 | codes = self.encoder(x, c_org) 191 | if c_trg is None: 192 | return torch.cat(codes, dim=-1) 193 | 194 | tmp = [] 195 | for code in codes: 196 | tmp.append(code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1)) 197 | code_exp = torch.cat(tmp, dim=1) 198 | 199 | encoder_outputs = torch.cat((code_exp, c_trg.unsqueeze(1).expand(-1, x.size(1), -1), f0_src), dim=-1) 200 | 201 | mel_outputs = self.decoder(encoder_outputs) 202 | 203 | mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) 204 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) 205 | 206 | mel_outputs = mel_outputs.unsqueeze(1) 207 | mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1) 208 | 209 | return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1) 210 | 211 | 212 | def get_network(hparams, **kwargs): 213 | gen = Generator(hparams) 214 | 215 | networks = {'net': gen} 216 | return networks --------------------------------------------------------------------------------