├── Configs └── config.yml ├── LICENSE ├── README.md ├── meldataset.py ├── model.py ├── optimizers.py ├── train.py └── trainer.py /Configs/config.yml: -------------------------------------------------------------------------------- 1 | log_dir: "Checkpoint" 2 | save_freq: 10 3 | device: "cuda" 4 | epochs: 100 5 | batch_size: 64 6 | pretrained_model: "" 7 | train_data: "Data/train_list.txt" 8 | val_data: "Data/val_list.txt" 9 | num_workers: 16 10 | 11 | 12 | optimizer_params: 13 | lr: 0.0003 14 | 15 | loss_params: 16 | lambda_f0: 0.1 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aaron (Yinghao) Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JDC-PitchExtractor 2 | This repo contains the training code for deep neural pitch extractor for Voice Conversion (VC) and TTS used in [StarGANv2-VC](https://github.com/yl4579/StarGANv2-VC) and [StyleTTS](https://github.com/yl4579/StyleTTS). This is the F0 network in StarGANv2-VC and pitch extractor in StyleTTS. 3 | 4 | ## Pre-requisites 5 | 1. Python >= 3.7 6 | 2. Clone this repository: 7 | ```bash 8 | git clone https://github.com/yl4579/PitchExtractor.git 9 | cd PitchExtractor 10 | ``` 11 | 3. Install python requirements: 12 | ```bash 13 | pip install SoundFile torchaudio torch pyyaml click matplotlib librosa pyworld 14 | ``` 15 | 4. Prepare your own dataset and put the `train_list.txt` and `val_list.txt` in the `Data` folder (see Training section for more details). 16 | 17 | ## Training 18 | ```bash 19 | python train.py --config_path ./Configs/config.yml 20 | ``` 21 | Please specify the training and validation data in `config.yml` file. The data list format needs to be `filename.wav|anything`, see [train_list.txt](https://github.com/yl4579/StarGANv2-VC/blob/main/Data/train_list.txt) as an example (a subset of VCTK). Note that you can put anything after the filename because the training labels are generated ad-hoc. 22 | 23 | Checkpoints and Tensorboard logs will be saved at `log_dir`. To speed up training, you may want to make `batch_size` as large as your GPU RAM can take. 24 | 25 | ### IMPORTANT: DATA FOLDER NEEDS WRITE PERMISSION 26 | Since both `harvest` and `dio` are relatively slow, we do have to save the computed F0 ground truth for later use. In [meldataset.py](https://github.com/yl4579/PitchExtractor/blob/main/meldataset.py#L77-L89), it will write the computed F0 curve `_f0.npy` for each `.wav` file. This requires write permission in your data folder. 27 | 28 | ### F0 Computation Details 29 | In [meldataset.py](https://github.com/yl4579/PitchExtractor/blob/main/meldataset.py#L83-L87), the F0 curves are computated using [PyWorld](https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder), one with `harvest` and another with `dio`. Both methods are acoustic-based and are unstable under certain conditions. `harvest` is faster but fails more than `dio`, so we first try `harvest`. When `harvest` fails (determined by number of frames with non-zero values), it will compute the ground truth F0 labels with `dio`. If `dio` fails, the computed F0 will have `NaN` and will be replaced with 0. This is supposed to occur only occasionally and should not affect training because these samples are treated as noises by the neural network and deep learning models are kwown to even benefit from slightly noisy datasets. However, if a lot of your samples have this problem (say > 5%), please remove them from the training set so that the model does not learn from the failed samples. 30 | 31 | ### Data Augmentation 32 | Data augmentation is not included in this code. For better voice conversion results, please add your own data augmentation in [meldataset.py](https://github.com/yl4579/PitchExtractor/blob/main/meldataset.py) with [audiomentations](https://github.com/iver56/audiomentations). 33 | 34 | ## References 35 | - [keums/melodyExtraction_JDC](https://github.com/keums/melodyExtraction_JDC) 36 | - [kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) 37 | -------------------------------------------------------------------------------- /meldataset.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | """ 3 | TODO: 4 | - make TestDataset 5 | - separate transforms 6 | """ 7 | 8 | import os 9 | import os.path as osp 10 | import time 11 | import random 12 | import numpy as np 13 | import random 14 | import soundfile as sf 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | import torchaudio 19 | from torch.utils.data import DataLoader 20 | 21 | import pyworld as pw 22 | 23 | import logging 24 | logger = logging.getLogger(__name__) 25 | logger.setLevel(logging.DEBUG) 26 | 27 | np.random.seed(1) 28 | random.seed(1) 29 | 30 | SPECT_PARAMS = { 31 | "n_fft": 2048, 32 | "win_length": 1200, 33 | "hop_length": 300 34 | } 35 | MEL_PARAMS = { 36 | "n_mels": 80, 37 | "n_fft": 2048, 38 | "win_length": 1200, 39 | "hop_length": 300 40 | } 41 | 42 | class MelDataset(torch.utils.data.Dataset): 43 | def __init__(self, 44 | data_list, 45 | sr=24000, 46 | data_augmentation=False, 47 | validation=False, 48 | verbose=True 49 | ): 50 | 51 | _data_list = [l[:-1].split('|') for l in data_list] 52 | self.data_list = [d[0] for d in _data_list] 53 | 54 | self.sr = sr 55 | self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) 56 | 57 | self.mean, self.std = -4, 4 58 | self.data_augmentation = data_augmentation and (not validation) 59 | self.max_mel_length = 192 60 | self.mean, self.std = -4, 4 61 | 62 | self.verbose = verbose 63 | 64 | # for silence detection 65 | self.zero_value = -10 # what the zero value is 66 | self.bad_F0 = 5 # if less than 5 frames are non-zero, it's a bad F0, try another algorithm 67 | 68 | def __len__(self): 69 | return len(self.data_list) 70 | 71 | def path_to_mel_and_label(self, path): 72 | wave_tensor = self._load_tensor(path) 73 | 74 | # use pyworld to get F0 75 | output_file = path + "_f0.npy" 76 | # check if the file exists 77 | if os.path.isfile(output_file): # if exists, load it directly 78 | f0 = np.load(output_file) 79 | else: # if not exist, create F0 file 80 | if self.verbose: 81 | print('Computing F0 for ' + path + '...') 82 | x = wave_tensor.numpy().astype("double") 83 | frame_period = MEL_PARAMS['hop_length'] * 1000 / self.sr 84 | _f0, t = pw.harvest(x, self.sr, frame_period=frame_period) 85 | if sum(_f0 != 0) < self.bad_F0: # this happens when the algorithm fails 86 | _f0, t = pw.dio(x, self.sr, frame_period=frame_period) # if harvest fails, try dio 87 | f0 = pw.stonemask(x, _f0, t, self.sr) 88 | # save the f0 info for later use 89 | np.save(output_file, f0) 90 | 91 | f0 = torch.from_numpy(f0).float() 92 | 93 | if self.data_augmentation: 94 | random_scale = 0.5 + 0.5 * np.random.random() 95 | wave_tensor = random_scale * wave_tensor 96 | 97 | mel_tensor = self.to_melspec(wave_tensor) 98 | mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mean) / self.std 99 | mel_length = mel_tensor.size(1) 100 | 101 | f0_zero = (f0 == 0) 102 | 103 | ####################################### 104 | # You may want your own silence labels here 105 | # The more accurate the label, the better the resultss 106 | is_silence = torch.zeros(f0.shape) 107 | is_silence[f0_zero] = 1 108 | ####################################### 109 | 110 | if mel_length > self.max_mel_length: 111 | random_start = np.random.randint(0, mel_length - self.max_mel_length) 112 | mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length] 113 | f0 = f0[random_start:random_start + self.max_mel_length] 114 | is_silence = is_silence[random_start:random_start + self.max_mel_length] 115 | 116 | if torch.any(torch.isnan(f0)): # failed 117 | f0[torch.isnan(f0)] = self.zero_value # replace nan value with 0 118 | 119 | return mel_tensor, f0, is_silence 120 | 121 | 122 | def __getitem__(self, idx): 123 | data = self.data_list[idx] 124 | mel_tensor, f0, is_silence = self.path_to_mel_and_label(data) 125 | return mel_tensor, f0, is_silence 126 | 127 | def _load_tensor(self, data): 128 | wave_path = data 129 | wave, sr = sf.read(wave_path) 130 | wave_tensor = torch.from_numpy(wave).float() 131 | return wave_tensor 132 | 133 | class Collater(object): 134 | """ 135 | Args: 136 | adaptive_batch_size (bool): if true, decrease batch size when long data comes. 137 | """ 138 | 139 | def __init__(self, return_wave=False): 140 | self.text_pad_index = 0 141 | self.return_wave = return_wave 142 | self.min_mel_length = 192 143 | self.max_mel_length = 192 144 | self.mel_length_step = 16 145 | self.latent_dim = 16 146 | 147 | def __call__(self, batch): 148 | # batch[0] = wave, mel, text, f0, speakerid 149 | batch_size = len(batch) 150 | nmels = batch[0][0].size(0) 151 | mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() 152 | f0s = torch.zeros((batch_size, self.max_mel_length)).float() 153 | is_silences = torch.zeros((batch_size, self.max_mel_length)).float() 154 | 155 | for bid, (mel, f0, is_silence) in enumerate(batch): 156 | mel_size = mel.size(1) 157 | mels[bid, :, :mel_size] = mel 158 | f0s[bid, :mel_size] = f0 159 | is_silences[bid, :mel_size] = is_silence 160 | 161 | if self.max_mel_length > self.min_mel_length: 162 | random_slice = np.random.randint( 163 | self.min_mel_length//self.mel_length_step, 164 | 1+self.max_mel_length//self.mel_length_step) * self.mel_length_step + self.min_mel_length 165 | mels = mels[:, :, :random_slice] 166 | f0 = f0[:, :random_slice] 167 | 168 | mels = mels.unsqueeze(1) 169 | return mels, f0s, is_silences 170 | 171 | 172 | def build_dataloader(path_list, 173 | validation=False, 174 | batch_size=4, 175 | num_workers=1, 176 | device='cpu', 177 | collate_config={}, 178 | dataset_config={}): 179 | 180 | dataset = MelDataset(path_list, validation=validation, **dataset_config) 181 | collate_fn = Collater(**collate_config) 182 | 183 | data_loader = DataLoader(dataset, 184 | batch_size=batch_size, 185 | shuffle=(not validation), 186 | num_workers=num_workers, 187 | drop_last=(not validation), 188 | collate_fn=collate_fn, 189 | pin_memory=(device != 'cpu')) 190 | 191 | return data_loader 192 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of model from: 3 | Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using 4 | Convolutional Recurrent Neural Networks" (2019) 5 | Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class JDCNet(nn.Module): 12 | """ 13 | Joint Detection and Classification Network model for singing voice melody. 14 | """ 15 | def __init__(self, num_class=722, leaky_relu_slope=0.01): 16 | super().__init__() 17 | self.num_class = num_class 18 | 19 | # input = (b, 1, 31, 513), b = batch size 20 | self.conv_block = nn.Sequential( 21 | nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513) 22 | nn.BatchNorm2d(num_features=64), 23 | nn.LeakyReLU(leaky_relu_slope, inplace=True), 24 | nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513) 25 | ) 26 | 27 | # res blocks 28 | self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128) 29 | self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32) 30 | self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8) 31 | 32 | # pool block 33 | self.pool_block = nn.Sequential( 34 | nn.BatchNorm2d(num_features=256), 35 | nn.LeakyReLU(leaky_relu_slope, inplace=True), 36 | nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2) 37 | nn.Dropout(p=0.5), 38 | ) 39 | 40 | # maxpool layers (for auxiliary network inputs) 41 | # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2) 42 | self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40)) 43 | # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2) 44 | self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20)) 45 | # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2) 46 | self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10)) 47 | 48 | # in = (b, 640, 31, 2), out = (b, 256, 31, 2) 49 | self.detector_conv = nn.Sequential( 50 | nn.Conv2d(640, 256, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | nn.LeakyReLU(leaky_relu_slope, inplace=True), 53 | nn.Dropout(p=0.5), 54 | ) 55 | 56 | # input: (b, 31, 512) - resized from (b, 256, 31, 2) 57 | self.bilstm_classifier = nn.LSTM( 58 | input_size=512, hidden_size=256, 59 | batch_first=True, dropout=0.3, bidirectional=True) # (b, 31, 512) 60 | 61 | # input: (b, 31, 512) - resized from (b, 256, 31, 2) 62 | self.bilstm_detector = nn.LSTM( 63 | input_size=512, hidden_size=256, 64 | batch_first=True, dropout=0.3, bidirectional=True) # (b, 31, 512) 65 | 66 | # input: (b * 31, 512) 67 | self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class) 68 | 69 | # input: (b * 31, 512) 70 | self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier 71 | 72 | # initialize weights 73 | self.apply(self.init_weights) 74 | 75 | def forward(self, x): 76 | """ 77 | Returns: 78 | classification_prediction, detection_prediction 79 | sizes: (b, 31, 722), (b, 31, 2) 80 | """ 81 | seq_len = x.shape[-2] 82 | ############################### 83 | # forward pass for classifier # 84 | ############################### 85 | convblock_out = self.conv_block(x) 86 | 87 | resblock1_out = self.res_block1(convblock_out) 88 | resblock2_out = self.res_block2(resblock1_out) 89 | resblock3_out = self.res_block3(resblock2_out) 90 | poolblock_out = self.pool_block(resblock3_out) 91 | 92 | # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512) 93 | classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512)) 94 | classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states 95 | 96 | classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512) 97 | classifier_out = self.classifier(classifier_out) 98 | classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class) 99 | 100 | ############################# 101 | # forward pass for detector # 102 | ############################# 103 | mp1_out = self.maxpool1(convblock_out) 104 | mp2_out = self.maxpool2(resblock1_out) 105 | mp3_out = self.maxpool3(resblock2_out) 106 | 107 | # out = (b, 640, 31, 2) 108 | concat_out = torch.cat((mp1_out, mp2_out, mp3_out, poolblock_out), dim=1) 109 | detector_out = self.detector_conv(concat_out) 110 | 111 | # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512) 112 | detector_out = detector_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512)) 113 | detector_out, _ = self.bilstm_detector(detector_out) # (b, 31, 512) 114 | 115 | detector_out = detector_out.contiguous().view((-1, 512)) 116 | detector_out = self.detector(detector_out) 117 | detector_out = detector_out.view((-1, seq_len, 2)).sum(axis=-1) # binary classifier - (b, 31, 2) 118 | 119 | # sizes: (b, 31, 722), (b, 31, 2) 120 | # classifier output consists of predicted pitch classes per frame 121 | # detector output consists of: (isvoice, notvoice) estimates per frame 122 | return classifier_out, detector_out 123 | 124 | @staticmethod 125 | def init_weights(m): 126 | if isinstance(m, nn.Linear): 127 | nn.init.kaiming_uniform_(m.weight) 128 | if m.bias is not None: 129 | nn.init.constant_(m.bias, 0) 130 | elif isinstance(m, nn.Conv2d): 131 | nn.init.xavier_normal_(m.weight) 132 | elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell): 133 | for p in m.parameters(): 134 | if p.data is None: 135 | continue 136 | 137 | if len(p.shape) >= 2: 138 | nn.init.orthogonal_(p.data) 139 | else: 140 | nn.init.normal_(p.data) 141 | 142 | 143 | class ResBlock(nn.Module): 144 | def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01): 145 | super().__init__() 146 | self.downsample = in_channels != out_channels 147 | 148 | # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper 149 | self.pre_conv = nn.Sequential( 150 | nn.BatchNorm2d(num_features=in_channels), 151 | nn.LeakyReLU(leaky_relu_slope, inplace=True), 152 | nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only 153 | ) 154 | 155 | # conv layers 156 | self.conv = nn.Sequential( 157 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 158 | kernel_size=3, padding=1, bias=False), 159 | nn.BatchNorm2d(out_channels), 160 | nn.LeakyReLU(leaky_relu_slope, inplace=True), 161 | nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), 162 | ) 163 | 164 | # 1 x 1 convolution layer to match the feature dimensions 165 | self.conv1by1 = None 166 | if self.downsample: 167 | self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False) 168 | 169 | def forward(self, x): 170 | x = self.pre_conv(x) 171 | if self.downsample: 172 | x = self.conv(x) + self.conv1by1(x) 173 | else: 174 | x = self.conv(x) + x 175 | return x -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os, sys 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | from functools import reduce 9 | from torch.optim import AdamW 10 | 11 | class MultiOptimizer: 12 | def __init__(self, optimizers={}, schedulers={}): 13 | self.optimizers = optimizers 14 | self.schedulers = schedulers 15 | self.keys = list(optimizers.keys()) 16 | self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()]) 17 | 18 | def state_dict(self): 19 | state_dicts = [(key, self.optimizers[key].state_dict())\ 20 | for key in self.keys] 21 | return state_dicts 22 | 23 | def load_state_dict(self, state_dict): 24 | for key, val in state_dict: 25 | try: 26 | self.optimizers[key].load_state_dict(val) 27 | except: 28 | print("Unloaded %s" % key) 29 | 30 | 31 | def step(self, key=None): 32 | if key is not None: 33 | self.optimizers[key].step() 34 | else: 35 | _ = [self.optimizers[key].step() for key in self.keys] 36 | 37 | def zero_grad(self, key=None): 38 | if key is not None: 39 | self.optimizers[key].zero_grad() 40 | else: 41 | _ = [self.optimizers[key].zero_grad() for key in self.keys] 42 | 43 | def scheduler(self, *args, key=None): 44 | if key is not None: 45 | self.schedulers[key].step(*args) 46 | else: 47 | _ = [self.schedulers[key].step(*args) for key in self.keys] 48 | 49 | 50 | def build_optimizer(parameters): 51 | optimizer, scheduler = _define_optimizer(parameters) 52 | return optimizer, scheduler 53 | 54 | def _define_optimizer(params): 55 | optimizer_params = params['optimizer_params'] 56 | sch_params = params['scheduler_params'] 57 | optimizer = AdamW( 58 | params['params'], 59 | lr=optimizer_params.get('lr', 1e-4), 60 | weight_decay=optimizer_params.get('weight_decay', 5e-4), 61 | betas=(0.9, 0.98), 62 | eps=1e-9) 63 | scheduler = _define_scheduler(optimizer, sch_params) 64 | return optimizer, scheduler 65 | 66 | def _define_scheduler(optimizer, params): 67 | print(params) 68 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 69 | optimizer, 70 | max_lr=params.get('max_lr', 5e-4), 71 | epochs=params.get('epochs', 200), 72 | steps_per_epoch=params.get('steps_per_epoch', 1000), 73 | pct_start=params.get('pct_start', 0.0), 74 | final_div_factor=5) 75 | 76 | return scheduler 77 | 78 | def build_multi_optimizer(parameters_dict, scheduler_params): 79 | optim = dict([(key, AdamW(params, lr=1e-4, weight_decay=1e-6, betas=(0.9, 0.98), eps=1e-9)) 80 | for key, params in parameters_dict.items()]) 81 | 82 | schedulers = dict([(key, _define_scheduler(opt, scheduler_params)) \ 83 | for key, opt in optim.items()]) 84 | 85 | multi_optim = MultiOptimizer(optim, schedulers) 86 | return multi_optim -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import JDCNet 2 | from meldataset import build_dataloader 3 | from optimizers import build_optimizer 4 | from trainer import Trainer 5 | 6 | import time 7 | import os 8 | import os.path as osp 9 | import re 10 | import sys 11 | import yaml 12 | import shutil 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | from torch.utils.tensorboard import SummaryWriter 17 | import click 18 | from tqdm import tqdm 19 | 20 | import logging 21 | from logging import StreamHandler 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.DEBUG) 24 | handler = StreamHandler() 25 | handler.setLevel(logging.DEBUG) 26 | logger.addHandler(handler) 27 | 28 | torch.backends.cudnn.benchmark = True 29 | 30 | def get_data_path_list(train_path=None, val_path=None): 31 | if train_path is None: 32 | train_path = "Data/train_list.txt" 33 | if val_path is None: 34 | val_path = "Data/val_list.txt" 35 | 36 | with open(train_path, 'r') as f: 37 | train_list = f.readlines() 38 | with open(val_path, 'r') as f: 39 | val_list = f.readlines() 40 | 41 | # train_list = train_list[-500:] 42 | # val_list = train_list[:500] 43 | return train_list, val_list 44 | 45 | @click.command() 46 | @click.option('-p', '--config_path', default='./Configs/config.yml', type=str) 47 | def main(config_path): 48 | config = yaml.safe_load(open(config_path)) 49 | log_dir = config['log_dir'] 50 | if not osp.exists(log_dir): os.mkdir(log_dir) 51 | shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) 52 | 53 | writer = SummaryWriter(log_dir + "/tensorboard") 54 | 55 | # write logs 56 | file_handler = logging.FileHandler(osp.join(log_dir, 'train.log')) 57 | file_handler.setLevel(logging.DEBUG) 58 | file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s')) 59 | logger.addHandler(file_handler) 60 | 61 | batch_size = config.get('batch_size', 32) 62 | device = config.get('device', 'cpu') 63 | epochs = config.get('epochs', 100) 64 | save_freq = config.get('save_freq', 10) 65 | train_path = config.get('train_data', None) 66 | val_path = config.get('val_data', None) 67 | num_workers = config.get('num_workers', 8) 68 | 69 | train_list, val_list = get_data_path_list(train_path, val_path) 70 | 71 | train_dataloader = build_dataloader(train_list, 72 | batch_size=batch_size, 73 | num_workers=num_workers, 74 | dataset_config=config.get('dataset_params', {}), 75 | device=device) 76 | 77 | val_dataloader = build_dataloader(val_list, 78 | batch_size=batch_size, 79 | validation=True, 80 | num_workers=num_workers // 2, 81 | device=device, 82 | dataset_config=config.get('dataset_params', {})) 83 | 84 | # define model 85 | model = JDCNet(num_class=1) # num_class = 1 means regression 86 | 87 | scheduler_params = { 88 | "max_lr": float(config['optimizer_params'].get('lr', 5e-4)), 89 | "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), 90 | "epochs": epochs, 91 | "steps_per_epoch": len(train_dataloader), 92 | } 93 | 94 | model.to(device) 95 | optimizer, scheduler = build_optimizer( 96 | {"params": model.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params}) 97 | 98 | criterion = {'l1': nn.SmoothL1Loss(), # F0 loss (regression) 99 | 'ce': nn.BCEWithLogitsLoss() # silence loss (binary classification) 100 | } 101 | 102 | loss_config = config['loss_params'] 103 | 104 | trainer = Trainer(model=model, 105 | criterion=criterion, 106 | optimizer=optimizer, 107 | scheduler=scheduler, 108 | device=device, 109 | train_dataloader=train_dataloader, 110 | val_dataloader=val_dataloader, 111 | loss_config=loss_config, 112 | logger=logger) 113 | 114 | if config.get('pretrained_model', '') != '': 115 | trainer.load_checkpoint(config['pretrained_model'], 116 | load_only_params=config.get('load_only_params', True)) 117 | 118 | # compute all F0 for training and validation data 119 | print('Checking if all F0 data is computed...') 120 | for _ in enumerate(train_dataloader): 121 | continue 122 | for _ in enumerate(val_dataloader): 123 | continue 124 | print('All F0 data is computed.') 125 | 126 | for epoch in range(1, epochs+1): 127 | train_results = trainer._train_epoch() 128 | eval_results = trainer._eval_epoch() 129 | results = train_results.copy() 130 | results.update(eval_results) 131 | logger.info('--- epoch %d ---' % epoch) 132 | for key, value in results.items(): 133 | if isinstance(value, float): 134 | logger.info('%-15s: %.4f' % (key, value)) 135 | writer.add_scalar(key, value, epoch) 136 | else: 137 | writer.add_figure(key, (v), epoch) 138 | if (epoch % save_freq) == 0: 139 | trainer.save_checkpoint(osp.join(log_dir, 'epoch_%05d.pth' % epoch)) 140 | 141 | return 0 142 | 143 | if __name__=="__main__": 144 | main() -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | import matplotlib.pyplot as plt 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | 21 | class Trainer(object): 22 | def __init__(self, 23 | model=None, 24 | criterion=None, 25 | optimizer=None, 26 | scheduler=None, 27 | config={}, 28 | loss_config={}, 29 | device=torch.device("cpu"), 30 | logger=logger, 31 | train_dataloader=None, 32 | val_dataloader=None, 33 | initial_steps=0, 34 | initial_epochs=0): 35 | 36 | self.steps = initial_steps 37 | self.epochs = initial_epochs 38 | self.model = model 39 | self.criterion = criterion 40 | self.optimizer = optimizer 41 | self.scheduler = scheduler 42 | self.train_dataloader = train_dataloader 43 | self.val_dataloader = val_dataloader 44 | self.config = config 45 | self.loss_config = loss_config 46 | self.device = device 47 | self.finish_train = False 48 | self.logger = logger 49 | self.fp16_run = False 50 | 51 | def save_checkpoint(self, checkpoint_path): 52 | """Save checkpoint. 53 | Args: 54 | checkpoint_path (str): Checkpoint path to be saved. 55 | """ 56 | state_dict = { 57 | "optimizer": self.optimizer.state_dict(), 58 | "scheduler": self.scheduler.state_dict(), 59 | "steps": self.steps, 60 | "epochs": self.epochs, 61 | } 62 | state_dict["model"] = self.model.state_dict() 63 | 64 | if not os.path.exists(os.path.dirname(checkpoint_path)): 65 | os.makedirs(os.path.dirname(checkpoint_path)) 66 | torch.save(state_dict, checkpoint_path) 67 | 68 | def load_checkpoint(self, checkpoint_path, load_only_params=False): 69 | """Load checkpoint. 70 | Args: 71 | checkpoint_path (str): Checkpoint path to be loaded. 72 | load_only_params (bool): Whether to load only model parameters. 73 | """ 74 | state_dict = torch.load(checkpoint_path, map_location="cpu") 75 | self._load(state_dict["model"], self.model) 76 | 77 | if not load_only_params: 78 | self.steps = state_dict["steps"] 79 | self.epochs = state_dict["epochs"] 80 | self.optimizer.load_state_dict(state_dict["optimizer"]) 81 | 82 | # overwrite schedular argument parameters 83 | state_dict["scheduler"].update(**self.config.get("scheduler_params", {})) 84 | self.scheduler.load_state_dict(state_dict["scheduler"]) 85 | 86 | def _load(self, states, model, force_load=True): 87 | model_states = model.state_dict() 88 | for key, val in states.items(): 89 | try: 90 | if key not in model_states: 91 | continue 92 | if isinstance(val, nn.Parameter): 93 | val = val.data 94 | 95 | if val.shape != model_states[key].shape: 96 | self.logger.info("%s does not have same shape" % key) 97 | print(val.shape, model_states[key].shape) 98 | if not force_load: 99 | continue 100 | 101 | min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape)) 102 | slices = [slice(0, min_index) for min_index in min_shape] 103 | model_states[key][slices].copy_(val[slices]) 104 | else: 105 | model_states[key].copy_(val) 106 | except: 107 | self.logger.info("not exist :%s" % key) 108 | print("not exist ", key) 109 | 110 | @staticmethod 111 | def get_gradient_norm(model): 112 | total_norm = 0 113 | for p in model.parameters(): 114 | param_norm = p.grad.data.norm(2) 115 | total_norm += param_norm.item() ** 2 116 | 117 | total_norm = np.sqrt(total_norm) 118 | return total_norm 119 | 120 | @staticmethod 121 | def length_to_mask(lengths): 122 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 123 | mask = torch.gt(mask+1, lengths.unsqueeze(1)) 124 | return mask 125 | 126 | def _get_lr(self): 127 | for param_group in self.optimizer.param_groups: 128 | lr = param_group['lr'] 129 | break 130 | return lr 131 | 132 | def run(self, batch): 133 | self.optimizer.zero_grad() 134 | batch = [b.to(self.device) for b in batch] 135 | 136 | x, f0, sil = batch 137 | f0_pred, sil_pred = self.model(x.transpose(-1, -2)) 138 | 139 | loss_f0 = self.loss_config['lambda_f0'] * self.criterion['l1'](f0_pred.squeeze(), f0) 140 | loss_sil = self.criterion['ce'](sil_pred, sil) 141 | loss = loss_f0 + loss_sil 142 | 143 | loss.backward() 144 | self.optimizer.step() 145 | self.scheduler.step() 146 | 147 | return {'loss': loss.item(), 148 | 'f0': loss_f0.item(), 149 | 'sil': loss_sil.item()} 150 | 151 | def _train_epoch(self): 152 | self.epochs += 1 153 | train_losses = defaultdict(list) 154 | self.model.train() 155 | for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1): 156 | losses = self.run(batch) 157 | for key, value in losses.items(): 158 | train_losses["train/%s" % key].append(value) 159 | 160 | train_losses = {key: np.mean(value) for key, value in train_losses.items()} 161 | train_losses['train/learning_rate'] = self._get_lr() 162 | return train_losses 163 | 164 | @torch.no_grad() 165 | def _eval_epoch(self): 166 | self.model.eval() 167 | eval_losses = defaultdict(list) 168 | eval_images = defaultdict(list) 169 | for eval_steps_per_epoch, batch in enumerate(tqdm(self.val_dataloader, desc="[eval]"), 1): 170 | batch = [b.to(self.device) for b in batch] 171 | x, f0, sil = batch 172 | 173 | f0_pred, sil_pred = self.model(x.transpose(-1, -2)) 174 | 175 | loss_f0 = self.loss_config['lambda_f0'] * self.criterion['l1'](f0_pred.squeeze(), f0) 176 | loss_sil = self.criterion['ce'](sil_pred, sil) 177 | loss = loss_f0 + loss_sil 178 | 179 | 180 | eval_losses["eval/loss"].append(loss.item()) 181 | eval_losses["eval/f0"].append(loss_f0.item()) 182 | eval_losses["eval/sil"].append(loss_sil.item()) 183 | 184 | eval_losses = {key: np.mean(value) for key, value in eval_losses.items()} 185 | eval_losses.update(eval_images) 186 | return eval_losses 187 | --------------------------------------------------------------------------------