├── png └── model.png ├── .gitignore ├── LICENSE ├── README.md ├── preprocessing.py ├── synthesize.py ├── modules.py ├── data.py ├── model.py ├── train.py └── train_apex.py /png/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ksw0306/FloWaveNet/HEAD/png/model.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DATASETS 2 | ljspeech 3 | log 4 | loss 5 | params 6 | samples 7 | 8 | 9 | __pycache__ 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sungwon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FloWaveNet : A Generative Flow for Raw Audio 2 | 3 | This is a PyTorch implementation of our work ["FloWaveNet : A Generative Flow for Raw Audio".](https://arxiv.org/abs/1811.02155) (We'll update soon.) 4 | 5 | 6 | 7 | For a purpose of parallel sampling, we propose FloWaveNet, a flow-based generative model for raw audio synthesis. 8 | FloWaveNet can generate audio samples as fast as ClariNet and Parallel WaveNet, while the training procedure is really easy and stable with a single-stage pipeline. Our generated audio samples are available at [https://ksw0306.github.io/flowavenet-demo/](https://ksw0306.github.io/flowavenet-demo/). Also, our implementation of ClariNet (Gaussian WaveNet and Gaussian IAF) is available at [https://github.com/ksw0306/ClariNet](https://github.com/ksw0306/ClariNet) 9 | 10 | 11 | # Requirements 12 | 13 | - PyTorch 0.4.1 14 | - Python 3.6 15 | - Librosa 16 | 17 | # Examples 18 | 19 | #### Step 1. Download Dataset 20 | 21 | - LJSpeech : [https://keithito.com/LJ-Speech-Dataset/](https://keithito.com/LJ-Speech-Dataset/) 22 | 23 | #### Step 2. Preprocessing (Preparing Mel Spectrogram) 24 | 25 | `python preprocessing.py --in_dir ljspeech --out_dir DATASETS/ljspeech` 26 | 27 | #### Step 3. Train 28 | 29 | ##### Single-GPU training 30 | 31 | `python train.py --model_name flowavenet --batch_size 2 --n_block 8 --n_flow 6 --n_layer 2 --block_per_split 4` 32 | 33 | ##### Multi-GPU training 34 | 35 | `python train.py --model_name flowavenet --batch_size 8 --n_block 8 --n_flow 6 --n_layer 2 --block_per_split 4 --num_gpu 4` 36 | 37 | 38 | NVIDIA TITAN V (12GB VRAM) : batch size 2 per GPU 39 | 40 | NVIDIA Tesla V100 (32GB VRAM) : batch size 8 per GPU 41 | 42 | 43 | #### Step 4. Synthesize 44 | 45 | `--load_step CHECKPOINT` : the # of the pre-trained model's global training step (also depicted in the trained weight file) 46 | 47 | `--temp`: Temperature (standard deviation) value implemented as z ~ N(0, 1 * TEMPERATURE^2) 48 | 49 | ex) `python synthesize.py --model_name flowavenet --n_block 8 --n_flow 6 --n_layer 2 --load_step 100000 --temp 0.8 --num_samples 10 --block_per_split 4` 50 | 51 | 52 | 53 | # Sample Link 54 | 55 | Sample Link : [https://ksw0306.github.io/flowavenet-demo/](https://ksw0306.github.io/flowavenet-demo/) 56 | 57 | Our implementation of ClariNet (Gaussian WaveNet, Gaussian IAF) : [https://github.com/ksw0306/ClariNet](https://github.com/ksw0306/ClariNet) 58 | 59 | - Results 1 : Model Comparisons (WaveNet (MoL, Gaussian), ClariNet and FloWaveNet) 60 | 61 | - Results 2 : Temperature effect on Audio Quality Trade-off (Temperature T : 0.0 ~ 1.0, Model : FloWaveNet) 62 | 63 | - Results 3 : Analysis of ClariNet Loss Terms (Loss functions : 1. Only KL 2. KL + Frame 3. Only Frame) 64 | 65 | - Results 4 : Causality of WaveNet Dilated Convolutions (FloWaveNet : Non-causal WaveNet Affine Coupling Layers, FloWaveNet_causal : Causal WaveNet Affine Coupling Layers) 66 | 67 | 68 | # Reference 69 | 70 | - WaveNet vocoder : [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder) 71 | - glow-pytorch : [https://github.com/rosinality/glow-pytorch](https://github.com/rosinality/glow-pytorch) 72 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | import librosa 6 | from multiprocessing import cpu_count 7 | import argparse 8 | 9 | 10 | def build_from_path(in_dir, out_dir, num_workers=1): 11 | executor = ProcessPoolExecutor(max_workers=num_workers) 12 | futures = [] 13 | index = 1 14 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 15 | for line in f: 16 | parts = line.strip().split('|') 17 | wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0]) 18 | text = parts[2] 19 | futures.append(executor.submit( 20 | partial(_process_utterance, out_dir, index, wav_path, text))) 21 | index += 1 22 | return [future.result() for future in futures] 23 | 24 | 25 | def _process_utterance(out_dir, index, wav_path, text): 26 | # Load the audio to a numpy array: 27 | wav, sr = librosa.load(wav_path, sr=22050) 28 | 29 | wav = wav / np.abs(wav).max() * 0.999 30 | out = wav 31 | constant_values = 0.0 32 | out_dtype = np.float32 33 | n_fft = 1024 34 | hop_length = 256 35 | reference = 20.0 36 | min_db = -100 37 | 38 | # Compute a mel-scale spectrogram from the trimmed wav: 39 | # (N, D) 40 | mel_spectrogram = librosa.feature.melspectrogram(wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=80, 41 | fmin=125, fmax=7600).T 42 | 43 | # mel_spectrogram = np.round(mel_spectrogram, decimals=2) 44 | mel_spectrogram = 20 * np.log10(np.maximum(1e-4, mel_spectrogram)) - reference 45 | mel_spectrogram = np.clip((mel_spectrogram - min_db) / (-min_db), 0, 1) 46 | 47 | pad = (out.shape[0] // hop_length + 1) * hop_length - out.shape[0] 48 | pad_l = pad // 2 49 | pad_r = pad // 2 + pad % 2 50 | 51 | # zero pad for quantized signal 52 | out = np.pad(out, (pad_l, pad_r), mode="constant", constant_values=constant_values) 53 | N = mel_spectrogram.shape[0] 54 | assert len(out) >= N * hop_length 55 | 56 | # time resolution adjustment 57 | # ensure length of raw audio is multiple of hop_size so that we can use 58 | # transposed convolution to upsample 59 | out = out[:N * hop_length] 60 | assert len(out) % hop_length == 0 61 | 62 | timesteps = len(out) 63 | 64 | # Write the spectrograms to disk: 65 | audio_filename = 'ljspeech-audio-%05d.npy' % index 66 | mel_filename = 'ljspeech-mel-%05d.npy' % index 67 | np.save(os.path.join(out_dir, audio_filename), 68 | out.astype(out_dtype), allow_pickle=False) 69 | np.save(os.path.join(out_dir, mel_filename), 70 | mel_spectrogram.astype(np.float32), allow_pickle=False) 71 | 72 | # Return a tuple describing this training example: 73 | return audio_filename, mel_filename, timesteps, text 74 | 75 | 76 | def preprocess(in_dir, out_dir, num_workers): 77 | os.makedirs(out_dir, exist_ok=True) 78 | metadata = build_from_path(in_dir, out_dir, num_workers) 79 | write_metadata(metadata, out_dir) 80 | 81 | 82 | def write_metadata(metadata, out_dir): 83 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: 84 | for m in metadata: 85 | f.write('|'.join([str(x) for x in m]) + '\n') 86 | frames = sum([m[2] for m in metadata]) 87 | sr = 22050 88 | hours = frames / sr / 3600 89 | print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) 90 | print('Max input length: %d' % max(len(m[3]) for m in metadata)) 91 | print('Max output length: %d' % max(m[2] for m in metadata)) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser(description='Preprocessing', 96 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 97 | parser.add_argument('--in_dir', '-i', type=str, default='./', help='In Directory') 98 | parser.add_argument('--out_dir', '-o', type=str, default='./', help='Out Directory') 99 | args = parser.parse_args() 100 | 101 | num_workers = cpu_count() 102 | preprocess(args.in_dir, args.out_dir, num_workers) 103 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from data import LJspeechDataset, collate_fn_synthesize 4 | from model import Flowavenet 5 | from torch.distributions.normal import Normal 6 | import numpy as np 7 | import librosa 8 | import os 9 | import argparse 10 | import time 11 | 12 | torch.backends.cudnn.benchmark = False 13 | np.set_printoptions(precision=4) 14 | parser = argparse.ArgumentParser(description='Train FloWaveNet of LJSpeech', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--data_path', type=str, default='./DATASETS/ljspeech/', help='Dataset Path') 17 | parser.add_argument('--sample_path', type=str, default='./samples', help='Sample Path') 18 | parser.add_argument('--model_name', type=str, default='flowavenet', help='Model Name') 19 | parser.add_argument('--num_samples', type=int, default=10, help='# of audio samples') 20 | parser.add_argument('--load_step', type=int, default=0, help='Load Step') 21 | parser.add_argument('--temp', type=float, default=0.8, help='Temperature') 22 | parser.add_argument('--load', '-l', type=str, default='./params', help='Checkpoint path to resume / test.') 23 | parser.add_argument('--n_layer', type=int, default=2, help='Number of layers') 24 | parser.add_argument('--n_flow', type=int, default=6, help='Number of layers') 25 | parser.add_argument('--n_block', type=int, default=8, help='Number of layers') 26 | parser.add_argument('--cin_channels', type=int, default=80, help='Cin Channels') 27 | parser.add_argument('--block_per_split', type=int, default=4, help='Block per split') 28 | parser.add_argument('--num_workers', type=int, default=0, help='Number of workers') 29 | parser.add_argument('--log', type=str, default='./log', help='Log folder.') 30 | args = parser.parse_args() 31 | 32 | if not os.path.isdir(args.sample_path): 33 | os.makedirs(args.sample_path) 34 | if not os.path.isdir(os.path.join(args.sample_path, args.model_name)): 35 | os.makedirs(os.path.join(args.sample_path, args.model_name)) 36 | 37 | use_cuda = torch.cuda.is_available() 38 | device = torch.device("cuda" if use_cuda else "cpu") 39 | 40 | # LOAD DATASETS 41 | test_dataset = LJspeechDataset(args.data_path, False, 0.1) 42 | synth_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize, 43 | num_workers=args.num_workers, pin_memory=True) 44 | 45 | 46 | def build_model(): 47 | model = Flowavenet(in_channel=1, 48 | cin_channel=args.cin_channels, 49 | n_block=args.n_block, 50 | n_flow=args.n_flow, 51 | n_layer=args.n_layer, 52 | affine=True, 53 | pretrained=True, 54 | block_per_split=args.block_per_split) 55 | return model 56 | 57 | 58 | def synthesize(model): 59 | global global_step 60 | for batch_idx, (x, c) in enumerate(synth_loader): 61 | if batch_idx < args.num_samples: 62 | x, c = x.to(device), c.to(device) 63 | 64 | q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size())) 65 | z = q_0.sample() * args.temp 66 | torch.cuda.synchronize() 67 | start_time = time.time() 68 | 69 | with torch.no_grad(): 70 | y_gen = model.reverse(z, c).squeeze() 71 | torch.cuda.synchronize() 72 | print('{} seconds'.format(time.time() - start_time)) 73 | wav = y_gen.to(torch.device("cpu")).data.numpy() 74 | wav_name = '{}/{}/generate_{}_{}_{}.wav'.format(args.sample_path, args.model_name, 75 | global_step, batch_idx, args.temp) 76 | librosa.output.write_wav(wav_name, wav, sr=22050) 77 | print('{} Saved!'.format(wav_name)) 78 | 79 | 80 | def load_checkpoint(step, model): 81 | checkpoint_path = os.path.join(args.load, args.model_name, "checkpoint_step{:09d}.pth".format(step)) 82 | print("Load checkpoint from: {}".format(checkpoint_path)) 83 | checkpoint = torch.load(checkpoint_path) 84 | # generalized load procedure for both single-gpu and DataParallel models 85 | # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3 86 | try: 87 | model.load_state_dict(checkpoint["state_dict"]) 88 | except RuntimeError: 89 | print("INFO: this model is trained with DataParallel. Creating new state_dict without module...") 90 | state_dict = checkpoint["state_dict"] 91 | from collections import OrderedDict 92 | new_state_dict = OrderedDict() 93 | for k, v in state_dict.items(): 94 | name = k[7:] # remove `module.` 95 | new_state_dict[name] = v 96 | model.load_state_dict(new_state_dict) 97 | return model 98 | 99 | 100 | if __name__ == "__main__": 101 | step = args.load_step 102 | global_step = step 103 | model = build_model() 104 | model = load_checkpoint(step, model) 105 | model = model.to(device) 106 | model.eval() 107 | 108 | with torch.no_grad(): 109 | synthesize(model) 110 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class Conv(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, causal=True): 8 | super(Conv, self).__init__() 9 | 10 | self.causal = causal 11 | if self.causal: 12 | self.padding = dilation * (kernel_size - 1) 13 | else: 14 | self.padding = dilation * (kernel_size - 1) // 2 15 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) 16 | self.conv = nn.utils.weight_norm(self.conv) 17 | nn.init.kaiming_normal_(self.conv.weight) 18 | 19 | def forward(self, tensor): 20 | out = self.conv(tensor) 21 | if self.causal and self.padding is not 0: 22 | out = out[:, :, :-self.padding] 23 | return out 24 | 25 | 26 | class ZeroConv1d(nn.Module): 27 | def __init__(self, in_channel, out_channel): 28 | super().__init__() 29 | 30 | self.conv = nn.Conv1d(in_channel, out_channel, 1, padding=0) 31 | self.conv.weight.data.zero_() 32 | self.conv.bias.data.zero_() 33 | self.scale = nn.Parameter(torch.zeros(1, out_channel, 1)) 34 | 35 | def forward(self, x): 36 | out = self.conv(x) 37 | out = out * torch.exp(self.scale * 3) 38 | return out 39 | 40 | 41 | class ResBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, skip_channels, kernel_size, dilation, 43 | cin_channels=None, local_conditioning=True, causal=False): 44 | super(ResBlock, self).__init__() 45 | self.causal = causal 46 | self.local_conditioning = local_conditioning 47 | self.cin_channels = cin_channels 48 | self.skip = True if skip_channels is not None else False 49 | 50 | self.filter_conv = Conv(in_channels, out_channels, kernel_size, dilation, causal) 51 | self.gate_conv = Conv(in_channels, out_channels, kernel_size, dilation, causal) 52 | self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1) 53 | self.res_conv = nn.utils.weight_norm(self.res_conv) 54 | nn.init.kaiming_normal_(self.res_conv.weight) 55 | if self.skip: 56 | self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1) 57 | self.skip_conv = nn.utils.weight_norm(self.skip_conv) 58 | nn.init.kaiming_normal_(self.skip_conv.weight) 59 | 60 | if self.local_conditioning: 61 | self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1) 62 | self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1) 63 | self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c) 64 | self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c) 65 | nn.init.kaiming_normal_(self.filter_conv_c.weight) 66 | nn.init.kaiming_normal_(self.gate_conv_c.weight) 67 | 68 | def forward(self, tensor, c=None): 69 | h_filter = self.filter_conv(tensor) 70 | h_gate = self.gate_conv(tensor) 71 | 72 | if self.local_conditioning: 73 | h_filter += self.filter_conv_c(c) 74 | h_gate += self.gate_conv_c(c) 75 | 76 | out = torch.tanh(h_filter) * torch.sigmoid(h_gate) 77 | 78 | res = self.res_conv(out) 79 | skip = self.skip_conv(out) if self.skip else None 80 | return (tensor + res) * math.sqrt(0.5), skip 81 | 82 | 83 | class Wavenet(nn.Module): 84 | def __init__(self, in_channels=1, out_channels=2, num_blocks=1, num_layers=6, 85 | residual_channels=256, gate_channels=256, skip_channels=256, 86 | kernel_size=3, cin_channels=80, causal=True): 87 | super(Wavenet, self).__init__() 88 | 89 | self.skip = True if skip_channels is not None else False 90 | self.front_conv = nn.Sequential( 91 | Conv(in_channels, residual_channels, 3, causal=causal), 92 | nn.ReLU() 93 | ) 94 | 95 | self.res_blocks = nn.ModuleList() 96 | for b in range(num_blocks): 97 | for n in range(num_layers): 98 | self.res_blocks.append(ResBlock(residual_channels, gate_channels, skip_channels, 99 | kernel_size, dilation=2**n, 100 | cin_channels=cin_channels, local_conditioning=True, 101 | causal=causal)) 102 | 103 | last_channels = skip_channels if self.skip else residual_channels 104 | self.final_conv = nn.Sequential( 105 | nn.ReLU(), 106 | Conv(last_channels, last_channels, 1, causal=causal), 107 | nn.ReLU(), 108 | ZeroConv1d(last_channels, out_channels) 109 | ) 110 | 111 | def forward(self, x, c=None): 112 | h = self.front_conv(x) 113 | skip = 0 114 | for i, f in enumerate(self.res_blocks): 115 | if self.skip: 116 | h, s = f(h, c) 117 | skip += s 118 | else: 119 | h, _ = f(h, c) 120 | if self.skip: 121 | out = self.final_conv(skip) 122 | else: 123 | out = self.final_conv(h) 124 | return out 125 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | 6 | max_time_steps = 16000 7 | upsample_conditional_features = True 8 | hop_length = 256 9 | 10 | 11 | class LJspeechDataset(Dataset): 12 | def __init__(self, data_root, train=True, test_size=0.05): 13 | self.data_root = data_root 14 | self.lengths = [] 15 | self.train = train 16 | self.test_size = test_size 17 | 18 | self.paths = [self.collect_files(0), self.collect_files(1)] 19 | 20 | def __len__(self): 21 | return len(self.paths[0]) 22 | 23 | def __getitem__(self, idx): 24 | wav = np.load(self.paths[0][idx]) 25 | mel = np.load(self.paths[1][idx]) 26 | return wav, mel 27 | 28 | def interest_indices(self, paths): 29 | test_num_samples = int(self.test_size * len(paths)) 30 | train_indices, test_indices = range(0, len(paths) - test_num_samples), \ 31 | range(len(paths) - test_num_samples, len(paths)) 32 | return train_indices if self.train else test_indices 33 | 34 | def collect_files(self, col): 35 | meta = os.path.join(self.data_root, "train.txt") 36 | with open(meta, "rb") as f: 37 | lines = f.readlines() 38 | l = lines[0].decode("utf-8").split("|") 39 | assert len(l) == 4 40 | self.lengths = list( 41 | map(lambda l: int(l.decode("utf-8").split("|")[2]), lines)) 42 | 43 | paths = list(map(lambda l: l.decode("utf-8").split("|")[col], lines)) 44 | paths = list(map(lambda f: os.path.join(self.data_root, f), paths)) 45 | 46 | # Filter by train/test 47 | indices = self.interest_indices(paths) 48 | paths = list(np.array(paths)[indices]) 49 | self.lengths = list(np.array(self.lengths)[indices]) 50 | self.lengths = list(map(int, self.lengths)) 51 | return paths 52 | 53 | 54 | def _pad(seq, max_len, constant_values=0): 55 | return np.pad(seq, (0, max_len - len(seq)), 56 | mode='constant', constant_values=constant_values) 57 | 58 | 59 | def _pad_2d(x, max_len, b_pad=0): 60 | x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)], 61 | mode="constant", constant_values=0) 62 | return x 63 | 64 | 65 | def collate_fn(batch): 66 | """ 67 | Create batch 68 | 69 | Args : batch(tuple) : List of tuples / (x, c) x : list of (T,) c : list of (T, D) 70 | 71 | Returns : Tuple of batch / Network inputs x (B, C, T), Network targets (B, T, 1) 72 | """ 73 | 74 | local_conditioning = len(batch[0]) >= 2 75 | 76 | if local_conditioning: 77 | new_batch = [] 78 | for idx in range(len(batch)): 79 | x, c = batch[idx] 80 | if upsample_conditional_features: 81 | assert len(x) % len(c) == 0 and len(x) // len(c) == hop_length 82 | 83 | max_steps = max_time_steps - max_time_steps % hop_length # To ensure Divisibility 84 | 85 | if len(x) > max_steps: 86 | max_time_frames = max_steps // hop_length 87 | s = np.random.randint(0, len(c) - max_time_frames) 88 | ts = s * hop_length 89 | x = x[ts:ts + hop_length * max_time_frames] 90 | c = c[s:s + max_time_frames] 91 | assert len(x) % len(c) == 0 and len(x) // len(c) == hop_length 92 | else: 93 | pass 94 | new_batch.append((x, c)) 95 | batch = new_batch 96 | else: 97 | pass 98 | 99 | input_lengths = [len(x[0]) for x in batch] 100 | max_input_len = max(input_lengths) 101 | 102 | # x_batch : [B, T, 1] 103 | x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32) 104 | assert len(x_batch.shape) == 3 105 | if local_conditioning: 106 | max_len = max([len(x[1]) for x in batch]) 107 | c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) 108 | assert len(c_batch.shape) == 3 109 | # (B x C x T') 110 | c_batch = torch.tensor(c_batch).transpose(1, 2).contiguous() 111 | del max_len 112 | else: 113 | c_batch = None 114 | 115 | # Convert to channel first i.e., (B, C, T) / C = 1 116 | x_batch = torch.tensor(x_batch).transpose(1, 2).contiguous() 117 | return x_batch, c_batch 118 | 119 | 120 | def collate_fn_synthesize(batch): 121 | """ 122 | Create batch 123 | 124 | Args : batch(tuple) : List of tuples / (x, c) x : list of (T,) c : list of (T, D) 125 | 126 | Returns : Tuple of batch / Network inputs x (B, C, T), Network targets (B, T, 1) 127 | """ 128 | 129 | local_conditioning = len(batch[0]) >= 2 130 | 131 | if local_conditioning: 132 | new_batch = [] 133 | for idx in range(len(batch)): 134 | x, c = batch[idx] 135 | if upsample_conditional_features: 136 | assert len(x) % len(c) == 0 and len(x) // len(c) == hop_length 137 | new_batch.append((x, c)) 138 | batch = new_batch 139 | else: 140 | pass 141 | 142 | input_lengths = [len(x[0]) for x in batch] 143 | max_input_len = max(input_lengths) 144 | 145 | x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32) 146 | assert len(x_batch.shape) == 3 147 | 148 | if local_conditioning: 149 | max_len = max([len(x[1]) for x in batch]) 150 | c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) 151 | assert len(c_batch.shape) == 3 152 | # (B x C x T') 153 | c_batch = torch.tensor(c_batch).transpose(1, 2).contiguous() 154 | else: 155 | c_batch = None 156 | 157 | # Convert to channel first i.e., (B, C, T) / C = 1 158 | x_batch = torch.tensor(x_batch).transpose(1, 2).contiguous() 159 | return x_batch, c_batch 160 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from math import log, pi 4 | from modules import Wavenet 5 | import math 6 | 7 | logabs = lambda x: torch.log(torch.abs(x)) 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, in_channel, logdet=True, pretrained=False): 12 | super().__init__() 13 | 14 | self.loc = nn.Parameter(torch.zeros(1, in_channel, 1)) 15 | self.scale = nn.Parameter(torch.ones(1, in_channel, 1)) 16 | 17 | self.initialized = pretrained 18 | self.logdet = logdet 19 | 20 | def initialize(self, x): 21 | with torch.no_grad(): 22 | flatten = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1) 23 | mean = ( 24 | flatten.mean(1) 25 | .unsqueeze(1) 26 | .unsqueeze(2) 27 | .permute(1, 0, 2) 28 | ) 29 | std = ( 30 | flatten.std(1) 31 | .unsqueeze(1) 32 | .unsqueeze(2) 33 | .permute(1, 0, 2) 34 | ) 35 | 36 | self.loc.data.copy_(-mean) 37 | self.scale.data.copy_(1 / (std + 1e-6)) 38 | 39 | def forward(self, x): 40 | B, _, T = x.size() 41 | 42 | if not self.initialized: 43 | self.initialize(x) 44 | self.initialized = True 45 | 46 | log_abs = logabs(self.scale) 47 | 48 | logdet = torch.sum(log_abs) * B * T 49 | 50 | if self.logdet: 51 | return self.scale * (x + self.loc), logdet 52 | 53 | else: 54 | return self.scale * (x + self.loc) 55 | 56 | def reverse(self, output): 57 | return output / self.scale - self.loc 58 | 59 | 60 | class AffineCoupling(nn.Module): 61 | def __init__(self, in_channel, cin_channel, filter_size=256, num_layer=6, affine=True): 62 | super().__init__() 63 | 64 | self.affine = affine 65 | self.net = Wavenet(in_channels=in_channel//2, out_channels=in_channel if self.affine else in_channel//2, 66 | num_blocks=1, num_layers=num_layer, residual_channels=filter_size, 67 | gate_channels=filter_size, skip_channels=filter_size, 68 | kernel_size=3, cin_channels=cin_channel//2, causal=False) 69 | 70 | def forward(self, x, c=None): 71 | in_a, in_b = x.chunk(2, 1) 72 | c_a, c_b = c.chunk(2, 1) 73 | 74 | if self.affine: 75 | log_s, t = self.net(in_a, c_a).chunk(2, 1) 76 | 77 | out_b = (in_b - t) * torch.exp(-log_s) 78 | logdet = torch.sum(-log_s) 79 | else: 80 | net_out = self.net(in_a, c_a) 81 | out_b = in_b + net_out 82 | logdet = None 83 | return torch.cat([in_a, out_b], 1), logdet 84 | 85 | def reverse(self, output, c=None): 86 | out_a, out_b = output.chunk(2, 1) 87 | c_a, c_b = c.chunk(2, 1) 88 | 89 | if self.affine: 90 | log_s, t = self.net(out_a, c_a).chunk(2, 1) 91 | in_b = out_b * torch.exp(log_s) + t 92 | else: 93 | net_out = self.net(out_a, c_a) 94 | in_b = out_b - net_out 95 | 96 | return torch.cat([out_a, in_b], 1) 97 | 98 | 99 | def change_order(x, c=None): 100 | x_a, x_b = x.chunk(2, 1) 101 | c_a, c_b = c.chunk(2, 1) 102 | return torch.cat([x_b, x_a], 1), torch.cat([c_b, c_a], 1) 103 | 104 | 105 | class Flow(nn.Module): 106 | def __init__(self, in_channel, cin_channel, filter_size, num_layer, affine=True, pretrained=False): 107 | super().__init__() 108 | 109 | self.actnorm = ActNorm(in_channel, pretrained=pretrained) 110 | self.coupling = AffineCoupling(in_channel, cin_channel, filter_size=filter_size, 111 | num_layer=num_layer, affine=affine) 112 | 113 | def forward(self, x, c=None): 114 | out, logdet = self.actnorm(x) 115 | out, det = self.coupling(out, c) 116 | out, c = change_order(out, c) 117 | 118 | if det is not None: 119 | logdet = logdet + det 120 | 121 | return out, c, logdet 122 | 123 | def reverse(self, output, c=None): 124 | output, c = change_order(output, c) 125 | x = self.coupling.reverse(output, c) 126 | x = self.actnorm.reverse(x) 127 | return x, c 128 | 129 | 130 | def gaussian_log_p(x, mean, log_sd): 131 | return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd) 132 | 133 | 134 | def gaussian_sample(eps, mean, log_sd): 135 | return mean + torch.exp(log_sd) * eps 136 | 137 | 138 | class Block(nn.Module): 139 | def __init__(self, in_channel, cin_channel, n_flow, n_layer, affine=True, pretrained=False, split=False): 140 | super().__init__() 141 | 142 | self.split = split 143 | squeeze_dim = in_channel * 2 144 | squeeze_dim_c = cin_channel * 2 145 | 146 | self.flows = nn.ModuleList() 147 | for i in range(n_flow): 148 | self.flows.append(Flow(squeeze_dim, squeeze_dim_c, filter_size=256, num_layer=n_layer, affine=affine, 149 | pretrained=pretrained)) 150 | if self.split: 151 | self.prior = Wavenet(in_channels=squeeze_dim // 2, out_channels=squeeze_dim, 152 | num_blocks=1, num_layers=2, residual_channels=256, 153 | gate_channels=256, skip_channels=256, 154 | kernel_size=3, cin_channels=squeeze_dim_c, causal=False) 155 | 156 | def forward(self, x, c): 157 | b_size, n_channel, T = x.size() 158 | squeezed_x = x.view(b_size, n_channel, T // 2, 2).permute(0, 1, 3, 2) 159 | out = squeezed_x.contiguous().view(b_size, n_channel * 2, T // 2) 160 | squeezed_c = c.view(b_size, -1, T // 2, 2).permute(0, 1, 3, 2) 161 | c = squeezed_c.contiguous().view(b_size, -1, T // 2) 162 | logdet, log_p = 0, 0 163 | 164 | for flow in self.flows: 165 | out, c, det = flow(out, c) 166 | logdet = logdet + det 167 | if self.split: 168 | out, z = out.chunk(2, 1) 169 | # WaveNet prior 170 | mean, log_sd = self.prior(out, c).chunk(2, 1) 171 | log_p = gaussian_log_p(z, mean, log_sd).sum() 172 | return out, c, logdet, log_p 173 | 174 | def reverse(self, output, c, eps=None): 175 | if self.split: 176 | mean, log_sd = self.prior(output, c).chunk(2, 1) 177 | z_new = gaussian_sample(eps, mean, log_sd) 178 | 179 | x = torch.cat([output, z_new], 1) 180 | else: 181 | x = output 182 | 183 | for flow in self.flows[::-1]: 184 | x, c = flow.reverse(x, c) 185 | 186 | b_size, n_channel, T = x.size() 187 | 188 | unsqueezed_x = x.view(b_size, n_channel // 2, 2, T).permute(0, 1, 3, 2) 189 | unsqueezed_x = unsqueezed_x.contiguous().view(b_size, n_channel // 2, T * 2) 190 | unsqueezed_c = c.view(b_size, -1, 2, T).permute(0, 1, 3, 2) 191 | unsqueezed_c = unsqueezed_c.contiguous().view(b_size, -1, T * 2) 192 | 193 | return unsqueezed_x, unsqueezed_c 194 | 195 | 196 | class Flowavenet(nn.Module): 197 | def __init__(self, in_channel, cin_channel, n_block, n_flow, n_layer, affine=True, pretrained=False, 198 | block_per_split=8): 199 | super().__init__() 200 | self.block_per_split = block_per_split 201 | 202 | self.blocks = nn.ModuleList() 203 | self.n_block = n_block 204 | for i in range(self.n_block): 205 | split = False if (i + 1) % self.block_per_split or i == self.n_block - 1 else True 206 | self.blocks.append(Block(in_channel, cin_channel, n_flow, n_layer, affine=affine, 207 | pretrained=pretrained, split=split)) 208 | cin_channel *= 2 209 | if not split: 210 | in_channel *= 2 211 | 212 | self.upsample_conv = nn.ModuleList() 213 | for s in [16, 16]: 214 | convt = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s)) 215 | convt = nn.utils.weight_norm(convt) 216 | nn.init.kaiming_normal_(convt.weight) 217 | self.upsample_conv.append(convt) 218 | self.upsample_conv.append(nn.LeakyReLU(0.4)) 219 | 220 | def forward(self, x, c): 221 | B, _, T = x.size() 222 | logdet, log_p_sum = 0, 0 223 | out = x 224 | c = self.upsample(c) 225 | for block in self.blocks: 226 | out, c, logdet_new, logp_new = block(out, c) 227 | logdet = logdet + logdet_new 228 | log_p_sum = log_p_sum + logp_new 229 | log_p_sum += 0.5 * (- log(2.0 * pi) - out.pow(2)).sum() 230 | logdet = logdet / (B * T) 231 | log_p = log_p_sum / (B * T) 232 | return log_p, logdet 233 | 234 | def reverse(self, z, c): 235 | _, _, T = z.size() 236 | _, _, t_c = c.size() 237 | if T != t_c: 238 | c = self.upsample(c) 239 | z_list = [] 240 | x = z 241 | for i in range(self.n_block): 242 | b_size, _, T = x.size() 243 | squeezed_x = x.view(b_size, -1, T // 2, 2).permute(0, 1, 3, 2) 244 | x = squeezed_x.contiguous().view(b_size, -1, T // 2) 245 | squeezed_c = c.view(b_size, -1, T // 2, 2).permute(0, 1, 3, 2) 246 | c = squeezed_c.contiguous().view(b_size, -1, T // 2) 247 | if not ((i + 1) % self.block_per_split or i == self.n_block - 1): 248 | x, z = x.chunk(2, 1) 249 | z_list.append(z) 250 | 251 | for i, block in enumerate(self.blocks[::-1]): 252 | index = self.n_block - i 253 | if not (index % self.block_per_split or index == self.n_block): 254 | x, c = block.reverse(x, c, z_list[index // self.block_per_split - 1]) 255 | else: 256 | x, c = block.reverse(x, c) 257 | return x 258 | 259 | def upsample(self, c): 260 | c = c.unsqueeze(1) 261 | for f in self.upsample_conv: 262 | c = f(c) 263 | c = c.squeeze(1) 264 | return c 265 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from data import LJspeechDataset, collate_fn, collate_fn_synthesize 6 | from model import Flowavenet 7 | from torch.distributions.normal import Normal 8 | import numpy as np 9 | import librosa 10 | import os 11 | import argparse 12 | import time 13 | import json 14 | import gc 15 | 16 | torch.backends.cudnn.benchmark = True 17 | np.set_printoptions(precision=4) 18 | torch.manual_seed(1111) 19 | 20 | parser = argparse.ArgumentParser(description='Train FloWaveNet of LJSpeech', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--data_path', type=str, default='./DATASETS/ljspeech/', help='Dataset Path') 23 | parser.add_argument('--sample_path', type=str, default='./samples', help='Sample Path') 24 | parser.add_argument('--save', '-s', type=str, default='./params', help='Folder to save checkpoints.') 25 | parser.add_argument('--load', '-l', type=str, default='./params', help='Checkpoint path') 26 | parser.add_argument('--log', type=str, default='./log', help='Log folder.') 27 | parser.add_argument('--model_name', type=str, default='flowavenet', help='Model Name') 28 | parser.add_argument('--load_step', type=int, default=0, help='Load Step') 29 | parser.add_argument('--epochs', '-e', type=int, default=5000, help='Number of epochs to train.') 30 | parser.add_argument('--batch_size', '-b', type=int, default=2, help='Batch size.') 31 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001, help='The Learning Rate.') 32 | parser.add_argument('--loss', type=str, default='./loss', help='Folder to save loss') 33 | parser.add_argument('--n_layer', type=int, default=2, help='Number of layers') 34 | parser.add_argument('--n_flow', type=int, default=6, help='Number of layers') 35 | parser.add_argument('--n_block', type=int, default=8, help='Number of layers') 36 | parser.add_argument('--cin_channels', type=int, default=80, help='Cin Channels') 37 | parser.add_argument('--block_per_split', type=int, default=4, help='Block per split') 38 | parser.add_argument('--num_workers', type=int, default=2, help='Number of workers') 39 | parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use. >1 uses DataParallel') 40 | args = parser.parse_args() 41 | 42 | # Init logger 43 | if not os.path.isdir(args.log): 44 | os.makedirs(args.log) 45 | 46 | # Checkpoint dir 47 | if not os.path.isdir(args.save): 48 | os.makedirs(args.save) 49 | if not os.path.isdir(args.loss): 50 | os.makedirs(args.loss) 51 | if not os.path.isdir(args.sample_path): 52 | os.makedirs(args.sample_path) 53 | if not os.path.isdir(os.path.join(args.sample_path, args.model_name)): 54 | os.makedirs(os.path.join(args.sample_path, args.model_name)) 55 | if not os.path.isdir(os.path.join(args.save, args.model_name)): 56 | os.makedirs(os.path.join(args.save, args.model_name)) 57 | 58 | use_cuda = torch.cuda.is_available() 59 | device = torch.device("cuda" if use_cuda else "cpu") 60 | 61 | # LOAD DATASETS 62 | train_dataset = LJspeechDataset(args.data_path, True, 0.1) 63 | test_dataset = LJspeechDataset(args.data_path, False, 0.1) 64 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, 65 | num_workers=args.num_workers, pin_memory=True) 66 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, 67 | num_workers=args.num_workers, pin_memory=True) 68 | synth_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize, 69 | num_workers=args.num_workers, pin_memory=True) 70 | 71 | 72 | def build_model(): 73 | pretrained = True if args.load_step > 0 else False 74 | model = Flowavenet(in_channel=1, 75 | cin_channel=args.cin_channels, 76 | n_block=args.n_block, 77 | n_flow=args.n_flow, 78 | n_layer=args.n_layer, 79 | affine=True, 80 | pretrained=pretrained, 81 | block_per_split=args.block_per_split) 82 | return model 83 | 84 | 85 | def train(epoch, model, optimizer, scheduler): 86 | global global_step 87 | epoch_loss = 0.0 88 | running_loss = [0., 0., 0.] 89 | model.train() 90 | display_step = 100 91 | for batch_idx, (x, c) in enumerate(train_loader): 92 | scheduler.step() 93 | global_step += 1 94 | 95 | x, c = x.to(device), c.to(device) 96 | 97 | optimizer.zero_grad() 98 | log_p, logdet = model(x, c) 99 | log_p, logdet = torch.mean(log_p), torch.mean(logdet) 100 | 101 | loss = -(log_p + logdet) 102 | loss.backward() 103 | 104 | nn.utils.clip_grad_norm_(model.parameters(), 1.) 105 | optimizer.step() 106 | 107 | running_loss[0] += loss.item() / display_step 108 | running_loss[1] += log_p.item() / display_step 109 | running_loss[2] += logdet.item() / display_step 110 | 111 | epoch_loss += loss.item() 112 | if (batch_idx + 1) % display_step == 0: 113 | print('Global Step : {}, [{}, {}] [Log pdf, Log p(z), Log Det] : {}' 114 | .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) 115 | running_loss = [0., 0., 0.] 116 | del x, c, log_p, logdet, loss 117 | del running_loss 118 | gc.collect() 119 | print('{} Epoch Training Loss : {:.4f}'.format(epoch, epoch_loss / (len(train_loader)))) 120 | return epoch_loss / len(train_loader) 121 | 122 | 123 | def evaluate(model): 124 | model.eval() 125 | running_loss = [0., 0., 0.] 126 | epoch_loss = 0. 127 | display_step = 100 128 | for batch_idx, (x, c) in enumerate(test_loader): 129 | x, c = x.to(device), c.to(device) 130 | log_p, logdet = model(x, c) 131 | log_p, logdet = torch.mean(log_p), torch.mean(logdet) 132 | loss = -(log_p + logdet) 133 | 134 | running_loss[0] += loss.item() / display_step 135 | running_loss[1] += log_p.item() / display_step 136 | running_loss[2] += logdet.item() / display_step 137 | epoch_loss += loss.item() 138 | 139 | if (batch_idx + 1) % 100 == 0: 140 | print('Global Step : {}, [{}, {}] [Log pdf, Log p(z), Log Det] : {}' 141 | .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) 142 | running_loss = [0., 0., 0.] 143 | del x, c, log_p, logdet, loss 144 | del running_loss 145 | epoch_loss /= len(test_loader) 146 | print('Evaluation Loss : {:.4f}'.format(epoch_loss)) 147 | return epoch_loss 148 | 149 | 150 | def synthesize(model): 151 | global global_step 152 | model.eval() 153 | for batch_idx, (x, c) in enumerate(synth_loader): 154 | if batch_idx == 0: 155 | x, c = x.to(device), c.to(device) 156 | 157 | q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size())) 158 | z = q_0.sample() 159 | 160 | start_time = time.time() 161 | with torch.no_grad(): 162 | if args.num_gpu == 1: 163 | y_gen = model.reverse(z, c).squeeze() 164 | else: 165 | y_gen = model.module.reverse(z, c).squeeze() 166 | wav = y_gen.to(torch.device("cpu")).data.numpy() 167 | wav_name = '{}/{}/generate_{}_{}.wav'.format(args.sample_path, args.model_name, global_step, batch_idx) 168 | print('{} seconds'.format(time.time() - start_time)) 169 | librosa.output.write_wav(wav_name, wav, sr=22050) 170 | print('{} Saved!'.format(wav_name)) 171 | del x, c, z, q_0, y_gen, wav 172 | 173 | 174 | def save_checkpoint(model, optimizer, scheduler, global_step, global_epoch): 175 | checkpoint_path = os.path.join(args.save, args.model_name, "checkpoint_step{:09d}.pth".format(global_step)) 176 | optimizer_state = optimizer.state_dict() 177 | scheduler_state = scheduler.state_dict() 178 | torch.save({"state_dict": model.state_dict(), 179 | "optimizer": optimizer_state, 180 | "scheduler": scheduler_state, 181 | "global_step": global_step, 182 | "global_epoch": global_epoch}, checkpoint_path) 183 | 184 | 185 | def load_checkpoint(step, model, optimizer, scheduler): 186 | global global_step 187 | global global_epoch 188 | 189 | checkpoint_path = os.path.join(args.save, args.model_name, "checkpoint_step{:09d}.pth".format(step)) 190 | print("Load checkpoint from: {}".format(checkpoint_path)) 191 | checkpoint = torch.load(checkpoint_path) 192 | 193 | # generalized load procedure for both single-gpu and DataParallel models 194 | # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3 195 | try: 196 | model.load_state_dict(checkpoint["state_dict"]) 197 | except RuntimeError: 198 | print("INFO: this model is trained with DataParallel. Creating new state_dict without module...") 199 | state_dict = checkpoint["state_dict"] 200 | from collections import OrderedDict 201 | new_state_dict = OrderedDict() 202 | for k, v in state_dict.items(): 203 | name = k[7:] # remove `module.` 204 | new_state_dict[name] = v 205 | model.load_state_dict(new_state_dict) 206 | 207 | optimizer.load_state_dict(checkpoint["optimizer"]) 208 | scheduler.load_state_dict(checkpoint["scheduler"]) 209 | global_step = checkpoint["global_step"] 210 | global_epoch = checkpoint["global_epoch"] 211 | 212 | return model, optimizer, scheduler 213 | 214 | 215 | if __name__ == "__main__": 216 | model = build_model() 217 | model.to(device) 218 | 219 | pretrained = True if args.load_step > 0 else False 220 | if pretrained is False: 221 | # do ActNorm initialization first (if model.pretrained is True, this does nothing so no worries) 222 | x_seed, c_seed = next(iter(train_loader)) 223 | x_seed, c_seed = x_seed.to(device), c_seed.to(device) 224 | with torch.no_grad(): 225 | _, _ = model(x_seed, c_seed) 226 | del x_seed, c_seed, _ 227 | # then convert the model to DataParallel later (since ActNorm init from the DataParallel is wacky) 228 | 229 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 230 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200000, gamma=0.5) 231 | criterion_frame = nn.MSELoss() 232 | 233 | global_step = 0 234 | global_epoch = 0 235 | load_step = args.load_step 236 | 237 | log = open(os.path.join(args.log, '{}.txt'.format(args.model_name)), 'w') 238 | state = {k: v for k, v in args._get_kwargs()} 239 | 240 | if load_step == 0: 241 | list_train_loss, list_loss = [], [] 242 | log.write(json.dumps(state) + '\n') 243 | test_loss = 100.0 244 | else: 245 | model, optimizer, scheduler = load_checkpoint(load_step, model, optimizer, scheduler) 246 | list_train_loss = np.load('{}/{}_train.npy'.format(args.loss, args.model_name)).tolist() 247 | list_loss = np.load('{}/{}.npy'.format(args.loss, args.model_name)).tolist() 248 | list_train_loss = list_train_loss[:global_epoch] 249 | list_loss = list_loss[:global_epoch] 250 | test_loss = np.min(list_loss) 251 | 252 | if args.num_gpu > 1: 253 | print("num_gpu > 1 detected. converting the model to DataParallel...") 254 | model = torch.nn.DataParallel(model) 255 | 256 | for epoch in range(global_epoch + 1, args.epochs + 1): 257 | training_epoch_loss = train(epoch, model, optimizer, scheduler) 258 | with torch.no_grad(): 259 | test_epoch_loss = evaluate(model) 260 | 261 | state['training_loss'] = training_epoch_loss 262 | state['eval_loss'] = test_epoch_loss 263 | state['epoch'] = epoch 264 | list_train_loss.append(training_epoch_loss) 265 | list_loss.append(test_epoch_loss) 266 | 267 | if test_loss > test_epoch_loss: 268 | test_loss = test_epoch_loss 269 | save_checkpoint(model, optimizer, scheduler, global_step, epoch) 270 | print('Epoch {} Model Saved! Loss : {:.4f}'.format(epoch, test_loss)) 271 | with torch.no_grad(): 272 | synthesize(model) 273 | np.save('{}/{}_train.npy'.format(args.loss, args.model_name), list_train_loss) 274 | np.save('{}/{}.npy'.format(args.loss, args.model_name), list_loss) 275 | 276 | log.write('%s\n' % json.dumps(state)) 277 | log.flush() 278 | print(state) 279 | gc.collect() 280 | 281 | log.close() 282 | -------------------------------------------------------------------------------- /train_apex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from data import LJspeechDataset, collate_fn, collate_fn_synthesize 6 | from model import Flowavenet 7 | from torch.distributions.normal import Normal 8 | import numpy as np 9 | import librosa 10 | import argparse 11 | import time 12 | import json 13 | import gc 14 | import os 15 | from tqdm import tqdm 16 | from apex import amp 17 | from apex.parallel import DistributedDataParallel 18 | 19 | # Distributed Training implemented with Apex utilities https://github.com/NVIDIA/apex, 20 | # which handle some issues with specific nodes in the FloWaveNet architecture. 21 | 22 | # List of changes made in train.py: 23 | # 1. Determine local_rank and world_size for torch.distributed.init_process_group 24 | # 2. Set a current device with torch.cuda.set_device 25 | # 3. Wrap dataset with torch.utils.data.distributed.DistributedSampler 26 | # 4. Apply amp.scale_loss at each backward pass 27 | # 5. Clip gradient with amp.master_params 28 | # 6. Divide step_size by world_size (not sure if this is necessary) 29 | # 7. Initialize model and optimizer with amp.initialize 30 | # 8. Wrap model with apex.parallel.DistributedDataParallel 31 | # 9. Handle evaluation and messages on the first node using args.local_rank 32 | 33 | # For example, to run on 4 GPUs, use the following command: 34 | # python -m torch.distributed.launch --nproc_per_node=4 train_apex.py --num_workers 2 --epochs 1000 35 | 36 | torch.backends.cudnn.benchmark = True 37 | np.set_printoptions(precision=4) 38 | torch.manual_seed(1111) 39 | 40 | parser = argparse.ArgumentParser(description='Train FloWaveNet of LJSpeech on multiple GPUs with Apex', 41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | parser.add_argument('--local_rank', default=0, type=int) 43 | parser.add_argument('--data_path', type=str, default='./DATASETS/ljspeech/', help='Dataset Path') 44 | parser.add_argument('--sample_path', type=str, default='./samples', help='Sample Path') 45 | parser.add_argument('--save', '-s', type=str, default='./params', help='Folder to save checkpoints.') 46 | parser.add_argument('--load_step', '-l', type=int, default=0, help='Load Step') 47 | parser.add_argument('--log', type=str, default='./log', help='Log folder.') 48 | parser.add_argument('--model_name', type=str, default='flowavenet', help='Model Name') 49 | parser.add_argument('--epochs', '-e', type=int, default=5000, help='Number of epochs to train.') 50 | parser.add_argument('--batch_size', '-b', type=int, default=2, help='Batch size.') 51 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001, help='The Learning Rate.') 52 | parser.add_argument('--loss', type=str, default='./loss', help='Folder to save loss') 53 | parser.add_argument('--n_layer', type=int, default=2, help='Number of layers') 54 | parser.add_argument('--n_flow', type=int, default=6, help='Number of layers') 55 | parser.add_argument('--n_block', type=int, default=8, help='Number of layers') 56 | parser.add_argument('--cin_channels', type=int, default=80, help='Cin Channels') 57 | parser.add_argument('--block_per_split', type=int, default=4, help='Block per split') 58 | parser.add_argument('--num_workers', type=int, default=2, help='Number of workers') 59 | args = parser.parse_args() 60 | 61 | current_env = os.environ.copy() 62 | world_size = int(current_env['WORLD_SIZE']) 63 | 64 | torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=args.local_rank) 65 | torch.cuda.set_device(args.local_rank) 66 | 67 | if args.local_rank == 0: 68 | # Init logger 69 | if not os.path.isdir(args.log): 70 | os.makedirs(args.log) 71 | # Checkpoint dir 72 | if not os.path.isdir(args.save): 73 | os.makedirs(args.save) 74 | if not os.path.isdir(args.loss): 75 | os.makedirs(args.loss) 76 | if not os.path.isdir(args.sample_path): 77 | os.makedirs(args.sample_path) 78 | if not os.path.isdir(os.path.join(args.sample_path, args.model_name)): 79 | os.makedirs(os.path.join(args.sample_path, args.model_name)) 80 | if not os.path.isdir(os.path.join(args.save, args.model_name)): 81 | os.makedirs(os.path.join(args.save, args.model_name)) 82 | 83 | use_cuda = torch.cuda.is_available() 84 | device = torch.device("cuda" if use_cuda else "cpu") 85 | 86 | # LOAD DATASETS 87 | train_dataset = LJspeechDataset(args.data_path, True, 0.1) 88 | test_dataset = LJspeechDataset(args.data_path, False, 0.1) 89 | 90 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 91 | 92 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, drop_last=True, collate_fn=collate_fn, 93 | num_workers=args.num_workers, pin_memory=True) 94 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, 95 | num_workers=args.num_workers, pin_memory=True) 96 | synth_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize, 97 | num_workers=args.num_workers, pin_memory=True) 98 | 99 | 100 | def build_model(): 101 | pretrained = True if args.load_step > 0 else False 102 | model = Flowavenet(in_channel=1, 103 | cin_channel=args.cin_channels, 104 | n_block=args.n_block, 105 | n_flow=args.n_flow, 106 | n_layer=args.n_layer, 107 | affine=True, 108 | pretrained=pretrained, 109 | block_per_split=args.block_per_split) 110 | return model 111 | 112 | 113 | def train(epoch, model, optimizer, scheduler): 114 | global global_step 115 | 116 | epoch_loss = 0.0 117 | running_num = 0 118 | running_loss = np.zeros(3) 119 | 120 | train_sampler.set_epoch(epoch) 121 | model.train() 122 | 123 | bar = tqdm(train_loader) if args.local_rank == 0 else train_loader 124 | 125 | for batch_idx, (x, c) in enumerate(bar): 126 | 127 | scheduler.step() 128 | global_step += 1 129 | 130 | x, c = x.to(device, non_blocking=True), c.to(device, non_blocking=True) 131 | 132 | optimizer.zero_grad() 133 | 134 | log_p, logdet = model(x, c) 135 | log_p, logdet = torch.mean(log_p), torch.mean(logdet) 136 | 137 | loss = -(log_p + logdet) 138 | 139 | with amp.scale_loss(loss, optimizer) as scaled_loss: 140 | scaled_loss.backward() 141 | 142 | nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.) 143 | 144 | optimizer.step() 145 | 146 | running_num += 1 147 | running_loss[0] += loss.item() 148 | running_loss[1] += log_p.item() 149 | running_loss[2] += logdet.item() 150 | 151 | epoch_loss += loss.item() 152 | 153 | if args.local_rank == 0: 154 | bar.set_description('{}/{}, [Log pdf, Log p(z), Log Det] : {}' 155 | .format(epoch, global_step, running_loss / running_num)) 156 | if (batch_idx + 1) % 100 == 0: 157 | running_num = 0 158 | running_loss = np.zeros(3) 159 | 160 | del x, c, log_p, logdet, loss 161 | del running_loss 162 | gc.collect() 163 | print('{}/{}/{} Training Loss : {:.4f}'.format(epoch, global_step, args.local_rank, epoch_loss / (len(train_loader)))) 164 | return epoch_loss / len(train_loader) 165 | 166 | 167 | def evaluate(model): 168 | model.eval() 169 | running_loss = [0., 0., 0.] 170 | epoch_loss = 0. 171 | display_step = 100 172 | for batch_idx, (x, c) in enumerate(test_loader): 173 | x, c = x.to(device), c.to(device) 174 | log_p, logdet = model(x, c) 175 | log_p, logdet = torch.mean(log_p), torch.mean(logdet) 176 | loss = -(log_p + logdet) 177 | 178 | running_loss[0] += loss.item() / display_step 179 | running_loss[1] += log_p.item() / display_step 180 | running_loss[2] += logdet.item() / display_step 181 | epoch_loss += loss.item() 182 | 183 | if (batch_idx + 1) % 100 == 0: 184 | print('Global Step : {}, [{}, {}] [Log pdf, Log p(z), Log Det] : {}' 185 | .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) 186 | running_loss = [0., 0., 0.] 187 | del x, c, log_p, logdet, loss 188 | del running_loss 189 | epoch_loss /= len(test_loader) 190 | print('Evaluation Loss : {:.4f}'.format(epoch_loss)) 191 | return epoch_loss 192 | 193 | 194 | def synthesize(model): 195 | global global_step 196 | model.eval() 197 | for batch_idx, (x, c) in enumerate(synth_loader): 198 | if batch_idx == 0: 199 | x, c = x.to(device), c.to(device) 200 | 201 | q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size())) 202 | z = q_0.sample() 203 | 204 | start_time = time.time() 205 | with torch.no_grad(): 206 | y_gen = model.module.reverse(z, c).squeeze() 207 | wav = y_gen.to(torch.device("cpu")).data.numpy() 208 | wav_name = '{}/{}/generate_{}_{}.wav'.format(args.sample_path, args.model_name, global_step, batch_idx) 209 | print('{} seconds'.format(time.time() - start_time)) 210 | librosa.output.write_wav(wav_name, wav, sr=22050) 211 | print('{} Saved!'.format(wav_name)) 212 | del x, c, z, q_0, y_gen, wav 213 | 214 | 215 | def save_checkpoint(model, optimizer, scheduler, global_step, global_epoch): 216 | checkpoint_path = os.path.join(args.save, args.model_name, "checkpoint_step{:09d}.pth".format(global_step)) 217 | optimizer_state = optimizer.state_dict() 218 | scheduler_state = scheduler.state_dict() 219 | torch.save({"state_dict": model.state_dict(), 220 | "optimizer": optimizer_state, 221 | "scheduler": scheduler_state, 222 | "global_step": global_step, 223 | "global_epoch": global_epoch}, checkpoint_path) 224 | 225 | 226 | def load_checkpoint(step, model, optimizer, scheduler): 227 | global global_step 228 | global global_epoch 229 | 230 | checkpoint_path = os.path.join(args.save, args.model_name, "checkpoint_step{:09d}.pth".format(step)) 231 | print("Rank {} load checkpoint from: {}".format(args.local_rank, checkpoint_path)) 232 | checkpoint = torch.load(checkpoint_path) 233 | 234 | # generalized load procedure for both single-gpu and DataParallel models 235 | # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3 236 | try: 237 | model.load_state_dict(checkpoint["state_dict"]) 238 | except RuntimeError: 239 | print("INFO: this model is trained with DataParallel. Creating new state_dict without module...") 240 | state_dict = checkpoint["state_dict"] 241 | from collections import OrderedDict 242 | new_state_dict = OrderedDict() 243 | for k, v in state_dict.items(): 244 | name = k[7:] # remove `module.` 245 | new_state_dict[name] = v 246 | model.load_state_dict(new_state_dict) 247 | 248 | optimizer.load_state_dict(checkpoint["optimizer"]) 249 | scheduler.load_state_dict(checkpoint["scheduler"]) 250 | global_step = checkpoint["global_step"] 251 | global_epoch = checkpoint["global_epoch"] 252 | 253 | return model, optimizer, scheduler 254 | 255 | 256 | if __name__ == "__main__": 257 | 258 | model = build_model() 259 | model.to(device) 260 | 261 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 262 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200000 // world_size, gamma=0.5) 263 | 264 | pretrained = True if args.load_step > 0 else False 265 | if pretrained is False: 266 | # do ActNorm initialization first (if model.pretrained is True, this does nothing so no worries) 267 | x_seed, c_seed = next(iter(train_loader)) 268 | x_seed, c_seed = x_seed.to(device), c_seed.to(device) 269 | with torch.no_grad(): 270 | _, _ = model(x_seed, c_seed) 271 | del x_seed, c_seed, _ 272 | # then convert the model to DataParallel later (since ActNorm init from the DataParallel is wacky) 273 | 274 | model, optimizer = amp.initialize(model, optimizer, opt_level="O0") 275 | model = DistributedDataParallel(model) 276 | 277 | global_step = 0 278 | global_epoch = 0 279 | 280 | if args.load_step == 0: 281 | list_train_loss, list_loss = [], [] 282 | test_loss = 100.0 283 | else: 284 | model, optimizer, scheduler = load_checkpoint(args.load_step, model, optimizer, scheduler) 285 | list_train_loss = np.load('{}/{}_train.npy'.format(args.loss, args.model_name)).tolist() 286 | list_loss = np.load('{}/{}.npy'.format(args.loss, args.model_name)).tolist() 287 | list_train_loss = list_train_loss[:global_epoch] 288 | list_loss = list_loss[:global_epoch] 289 | test_loss = np.min(list_loss) 290 | 291 | for epoch in range(global_epoch + 1, args.epochs + 1): 292 | 293 | training_epoch_loss = train(epoch, model, optimizer, scheduler) 294 | 295 | if args.local_rank > 0: 296 | gc.collect() 297 | continue 298 | 299 | with torch.no_grad(): 300 | test_epoch_loss = evaluate(model) 301 | 302 | if test_loss > test_epoch_loss: 303 | test_loss = test_epoch_loss 304 | save_checkpoint(model, optimizer, scheduler, global_step, epoch) 305 | print('Epoch {} Model Saved! Loss : {:.4f}'.format(epoch, test_loss)) 306 | with torch.no_grad(): 307 | synthesize(model) 308 | 309 | list_train_loss.append(training_epoch_loss) 310 | list_loss.append(test_epoch_loss) 311 | 312 | np.save('{}/{}_train.npy'.format(args.loss, args.model_name), list_train_loss) 313 | np.save('{}/{}.npy'.format(args.loss, args.model_name), list_loss) 314 | 315 | state = {k: v for k, v in args._get_kwargs()} 316 | state['training_loss'] = training_epoch_loss 317 | state['eval_loss'] = test_epoch_loss 318 | state['epoch'] = epoch 319 | 320 | with open(os.path.join(args.log, '%s.txt' % args.model_name), 'a') as log: 321 | log.write('%s\n' % json.dumps(state)) 322 | 323 | gc.collect() 324 | --------------------------------------------------------------------------------