├── README.md ├── configs ├── audio_config.yaml └── training_config.yaml ├── data ├── __init__.py ├── utils.py └── vocoder_dataset.py ├── models ├── __init__.py └── universal_vocoder.py ├── preprocess.py ├── reconstruct.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Universal Vocoder 2 | 3 | This is a restructured and rewritten version of [bshall/UniversalVocoding](https://github.com/bshall/UniversalVocoding). 4 | The main difference here is that the model is turned into a [TorchScript](https://pytorch.org/docs/stable/jit.html) module during training and can be loaded for inferencing anywhere without Python dependencies. 5 | 6 | ## Generate waveforms using pretrained models 7 | 8 | Since the pretrained models were turned to TorchScript, you can load a trained model anywhere. 9 | Also you can generate multiple waveforms parallelly, e.g. 10 | 11 | ```python 12 | import torch 13 | 14 | vocoder = torch.jit.load("vocoder.pt") 15 | 16 | mels = [ 17 | torch.randn(100, 80), 18 | torch.randn(200, 80), 19 | torch.randn(300, 80), 20 | ] # (length, mel_dim) 21 | 22 | with torch.no_grad(): 23 | wavs = vocoder.generate(mels) 24 | ``` 25 | 26 | Emperically, if you're using the default architecture, you can generate 30 samples at the same time on an GTX 1080 Ti. 27 | 28 | ## Train from scratch 29 | 30 | Multiple directories containing audio files can be processed at the same time, e.g. 31 | 32 | ```bash 33 | python preprocess.py \ 34 | VCTK-Corpus \ 35 | LibriTTS/train-clean-100 \ 36 | preprocessed # the output directory of preprocessed data 37 | ``` 38 | 39 | And train the model with the preprocessed data, e.g. 40 | 41 | ```bash 42 | python train.py preprocessed 43 | ``` 44 | 45 | With the default settings, it would take around 12 hr to train to 100K steps on an RTX 2080 Ti. 46 | 47 | ## References 48 | 49 | - [Towards achieving robust universal neural vocoding](https://arxiv.org/abs/1811.06292) 50 | -------------------------------------------------------------------------------- /configs/audio_config.yaml: -------------------------------------------------------------------------------- 1 | sample_rate: 16000 2 | preemph: 0.97 3 | hop_len: 200 4 | win_len: 800 5 | n_fft: 2048 6 | n_mels: 80 7 | f_min: 50 8 | -------------------------------------------------------------------------------- /configs/training_config.yaml: -------------------------------------------------------------------------------- 1 | frames_per_sample: 40 2 | frames_per_slice: 8 3 | bits: 9 4 | conditioning_channels: 128 5 | embedding_dim: 256 6 | rnn_channels: 896 7 | fc_channels: 512 8 | batch_size: 32 9 | n_steps: 100000 10 | valid_every: 1000 11 | valid_ratio: 0.05 12 | save_every: 10000 13 | learning_rate: 0.0005 14 | decay_every: 20000 15 | decay_gamma: 0.5 16 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .vocoder_dataset import VocoderDataset 3 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for data manipulation.""" 2 | 3 | from typing import Union 4 | from pathlib import Path 5 | 6 | import librosa 7 | import numpy as np 8 | from scipy.signal import lfilter 9 | 10 | 11 | def load_wav(audio_path: Union[str, Path], sample_rate: int) -> np.ndarray: 12 | """Load and preprocess waveform.""" 13 | wav = librosa.load(audio_path, sr=sample_rate)[0] 14 | wav = wav / (np.abs(wav).max() + 1e-6) 15 | return wav 16 | 17 | 18 | def mulaw_encode(x: np.ndarray, n_channels: int) -> np.ndarray: 19 | """Encode signal based on mu-law companding.""" 20 | assert x.max() < 1.0 and x.min() > -1.0 21 | mu = n_channels - 1 22 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 23 | x_mu = np.floor((x_mu + 1) / 2 * mu + 0.5).astype(np.int64) 24 | return x_mu 25 | 26 | 27 | def mulaw_decode(x_mu: np.ndarray, n_channels: int) -> np.ndarray: 28 | """Decode mu-law encoded signal.""" 29 | mu = n_channels - 1 30 | x = np.sign(x_mu) / mu * ((1 + mu) ** np.abs(x_mu) - 1) 31 | return x 32 | 33 | 34 | def log_mel_spectrogram( 35 | x: np.ndarray, 36 | preemph: float, 37 | sample_rate: int, 38 | n_mels: int, 39 | n_fft: int, 40 | hop_length: int, 41 | win_length: int, 42 | f_min: int, 43 | ) -> np.ndarray: 44 | """Create a log Mel spectrogram from a raw audio signal.""" 45 | x = lfilter([1, -preemph], [1], x) 46 | magnitude = np.abs( 47 | librosa.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 48 | ) 49 | mel_fb = librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=f_min) 50 | mel_spec = np.dot(mel_fb, magnitude) 51 | log_mel_spec = np.log(mel_spec + 1e-9) 52 | return log_mel_spec.T 53 | -------------------------------------------------------------------------------- /data/vocoder_dataset.py: -------------------------------------------------------------------------------- 1 | """Vocoder dataset.""" 2 | 3 | import json 4 | from random import randint 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from .utils import mulaw_encode 12 | 13 | 14 | class VocoderDataset(Dataset): 15 | """Sample a segment of utterance for training vocoder.""" 16 | 17 | def __init__( 18 | self, data_dir, metadata_path, frames_per_sample, frames_per_slice, bits 19 | ): 20 | 21 | with open(metadata_path, "r") as f: 22 | metadata = json.load(f) 23 | 24 | self.data_dir = Path(data_dir) 25 | self.sample_rate = metadata["sample_rate"] 26 | self.hop_len = metadata["hop_len"] 27 | self.n_mels = metadata["n_mels"] 28 | self.n_pad = (frames_per_sample - frames_per_slice) // 2 29 | self.frames_per_sample = frames_per_sample 30 | self.frames_per_slice = frames_per_slice 31 | self.bits = bits 32 | self.uttr_infos = [ 33 | uttr_info 34 | for uttr_info in metadata["utterances"] 35 | if uttr_info["mel_len"] > frames_per_sample 36 | ] 37 | 38 | def __len__(self): 39 | return len(self.uttr_infos) 40 | 41 | def __getitem__(self, index): 42 | uttr_info = self.uttr_infos[index] 43 | features = np.load(self.data_dir / uttr_info["feature_path"]) 44 | wav = features["wav"] 45 | mel = features["mel"] 46 | 47 | wav = np.pad(wav, (0, (len(mel) * self.hop_len - len(wav))), "constant") 48 | mel = np.pad(mel, ((self.n_pad,), (0,)), "constant") 49 | wav = np.pad(wav, (self.n_pad * self.hop_len,), "constant") 50 | wav = mulaw_encode(wav, 2 ** self.bits) 51 | 52 | pos = randint(0, len(mel) - self.frames_per_sample) 53 | mel_seg = mel[pos : pos + self.frames_per_sample, :] 54 | 55 | pos1 = pos + self.n_pad 56 | pos2 = pos1 + self.frames_per_slice 57 | wav_seg = wav[pos1 * self.hop_len : pos2 * self.hop_len + 1] 58 | 59 | return torch.FloatTensor(mel_seg), torch.LongTensor(wav_seg) 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .universal_vocoder import * 2 | -------------------------------------------------------------------------------- /models/universal_vocoder.py: -------------------------------------------------------------------------------- 1 | """Universal vocoder""" 2 | 3 | from typing import List 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | class UniversalVocoder(nn.Module): 13 | """Universal vocoding""" 14 | 15 | def __init__( 16 | self, 17 | sample_rate, 18 | frames_per_sample, 19 | frames_per_slice, 20 | mel_dim, 21 | mel_rnn_dim, 22 | emb_dim, 23 | wav_rnn_dim, 24 | affine_dim, 25 | bits, 26 | hop_length, 27 | ): 28 | super().__init__() 29 | 30 | self.sample_rate = sample_rate 31 | self.frames_per_slice = frames_per_slice 32 | self.pad = (frames_per_sample - frames_per_slice) // 2 33 | self.wav_rnn_dim = wav_rnn_dim 34 | self.quant_dim = 2 ** bits 35 | self.hop_len = hop_length 36 | 37 | self.mel_rnn = nn.GRU( 38 | mel_dim, mel_rnn_dim, num_layers=2, batch_first=True, bidirectional=True 39 | ) 40 | self.embedding = nn.Embedding(self.quant_dim, emb_dim) 41 | self.wav_rnn = nn.GRU(emb_dim + 2 * mel_rnn_dim, wav_rnn_dim, batch_first=True) 42 | self.affine = nn.Sequential( 43 | nn.Linear(wav_rnn_dim, affine_dim), 44 | nn.ReLU(), 45 | nn.Linear(affine_dim, self.quant_dim), 46 | ) 47 | 48 | def forward(self, wavs, mels): 49 | """Generate waveform from mel spectrogram with teacher-forcing.""" 50 | mel_embs, _ = self.mel_rnn(mels) 51 | mel_embs = mel_embs.transpose(1, 2) 52 | mel_embs = mel_embs[:, :, self.pad : self.pad + self.frames_per_slice] 53 | 54 | conditions = F.interpolate(mel_embs, scale_factor=float(self.hop_len)) 55 | conditions = conditions.transpose(1, 2) 56 | 57 | wav_embs = self.embedding(wavs) 58 | wav_outs, _ = self.wav_rnn(torch.cat((wav_embs, conditions), dim=2)) 59 | 60 | return self.affine(wav_outs) 61 | 62 | @torch.jit.export 63 | def generate(self, mels: List[Tensor]) -> List[Tensor]: 64 | """Generate waveform from mel spectrogram. 65 | 66 | Args: 67 | mels: list of tensor of shape (mel_len, mel_dim) 68 | 69 | Returns: 70 | wavs: list of tensor of shape (wav_len) 71 | """ 72 | 73 | batch_size = len(mels) 74 | device = mels[0].device 75 | 76 | mel_lens = [len(mel) for mel in mels] 77 | wav_lens = [mel_len * self.hop_len for mel_len in mel_lens] 78 | max_mel_len = max(mel_lens) 79 | max_wav_len = max_mel_len * self.hop_len 80 | 81 | pad_mels = pad_sequence(mels, batch_first=True) 82 | pack_mels = pack_padded_sequence( 83 | pad_mels, torch.tensor(mel_lens), batch_first=True, enforce_sorted=False 84 | ) 85 | pack_mel_embs, _ = self.mel_rnn(pack_mels) 86 | mel_embs, _ = pad_packed_sequence( 87 | pack_mel_embs, batch_first=True 88 | ) # (batch, max_mel_len, emb_dim) 89 | 90 | mel_embs = mel_embs.transpose(1, 2) 91 | conditions = F.interpolate(mel_embs, scale_factor=float(self.hop_len)) 92 | conditions = conditions.transpose(1, 2) # (batch, max_wav_len, emb_dim) 93 | 94 | hid = torch.zeros(1, batch_size, self.wav_rnn_dim, device=device) 95 | wav = torch.full( 96 | (batch_size,), self.quant_dim // 2, dtype=torch.long, device=device, 97 | ) 98 | wavs = torch.empty(batch_size, max_wav_len, dtype=torch.float, device=device,) 99 | 100 | for i, condition in enumerate(torch.unbind(conditions, dim=1)): 101 | wav_emb = self.embedding(wav) 102 | wav_rnn_input = torch.cat((wav_emb, condition), dim=1).unsqueeze(1) 103 | _, hid = self.wav_rnn(wav_rnn_input, hid) 104 | logit = self.affine(hid.squeeze(0)) 105 | posterior = F.softmax(logit, dim=1) 106 | wav = torch.multinomial(posterior, 1).squeeze(1) 107 | wavs[:, i] = 2 * wav / (self.quant_dim - 1.0) - 1.0 108 | 109 | mu = self.quant_dim - 1 110 | wavs = torch.true_divide(torch.sign(wavs), mu) * ( 111 | (1 + mu) ** torch.abs(wavs) - 1 112 | ) 113 | wavs = [ 114 | wav[:length] for wav, length in zip(torch.unbind(wavs, dim=0), wav_lens) 115 | ] 116 | 117 | return wavs 118 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Preprocess script""" 3 | 4 | import os 5 | import json 6 | from concurrent.futures import ProcessPoolExecutor 7 | from multiprocessing import cpu_count 8 | from itertools import chain 9 | from pathlib import Path 10 | from tempfile import mkstemp 11 | 12 | import numpy as np 13 | from jsonargparse import ArgumentParser, ActionConfigFile 14 | from librosa.util import find_files 15 | from tqdm import tqdm 16 | 17 | from data import load_wav, log_mel_spectrogram 18 | 19 | 20 | def parse_args(): 21 | """Parse command-line arguments.""" 22 | parser = ArgumentParser() 23 | parser.add_argument("data_dirs", type=str, nargs="+") 24 | parser.add_argument("out_dir", type=str) 25 | parser.add_argument("-w", "--n_workers", type=int, default=cpu_count()) 26 | 27 | parser.add_argument("--sample_rate", type=int, default=16000) 28 | parser.add_argument("--preemph", type=float, default=0.97) 29 | parser.add_argument("--hop_len", type=int, default=200) 30 | parser.add_argument("--win_len", type=int, default=800) 31 | parser.add_argument("--n_fft", type=int, default=2048) 32 | parser.add_argument("--n_mels", type=int, default=80) 33 | parser.add_argument("--f_min", type=int, default=50) 34 | parser.add_argument("--audio_config", action=ActionConfigFile) 35 | 36 | args = vars(parser.parse_args()) 37 | 38 | return args 39 | 40 | 41 | def load_process_save( 42 | audio_path, save_dir, sample_rate, preemph, hop_len, win_len, n_fft, n_mels, f_min, 43 | ): 44 | """Load an audio file, process, and save npz object.""" 45 | 46 | wav = load_wav(audio_path, sample_rate) 47 | mel = log_mel_spectrogram( 48 | wav, preemph, sample_rate, n_mels, n_fft, hop_len, win_len, f_min 49 | ) 50 | 51 | fd, temp_file_path = mkstemp(suffix=".npz", prefix="utterance-", dir=save_dir) 52 | np.savez_compressed(temp_file_path, wav=wav, mel=mel) 53 | os.close(fd) 54 | 55 | return { 56 | "feature_path": Path(temp_file_path).name, 57 | "audio_path": audio_path, 58 | "wav_len": len(wav), 59 | "mel_len": len(mel), 60 | } 61 | 62 | 63 | def main( 64 | data_dirs, 65 | out_dir, 66 | n_workers, 67 | sample_rate, 68 | preemph, 69 | hop_len, 70 | win_len, 71 | n_fft, 72 | n_mels, 73 | f_min, 74 | **kwargs, 75 | ): 76 | """Preprocess audio files into features for training.""" 77 | 78 | audio_paths = chain.from_iterable([find_files(data_dir) for data_dir in data_dirs]) 79 | 80 | save_dir = Path(out_dir) 81 | save_dir.mkdir(parents=True, exist_ok=True) 82 | 83 | executor = ProcessPoolExecutor(max_workers=n_workers) 84 | 85 | futures = [] 86 | for audio_path in audio_paths: 87 | futures.append( 88 | executor.submit( 89 | load_process_save, 90 | audio_path, 91 | save_dir, 92 | sample_rate, 93 | preemph, 94 | hop_len, 95 | win_len, 96 | n_fft, 97 | n_mels, 98 | f_min, 99 | ) 100 | ) 101 | 102 | infos = { 103 | "sample_rate": sample_rate, 104 | "preemph": preemph, 105 | "hop_len": hop_len, 106 | "win_len": win_len, 107 | "n_fft": n_fft, 108 | "n_mels": n_mels, 109 | "f_min": f_min, 110 | "utterances": [future.result() for future in tqdm(futures, ncols=0)], 111 | } 112 | 113 | with open(save_dir / "metadata.json", "w") as f: 114 | json.dump(infos, f, indent=2) 115 | 116 | 117 | if __name__ == "__main__": 118 | main(**parse_args()) 119 | -------------------------------------------------------------------------------- /reconstruct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Reconstruct waveform from log mel spectrogram.""" 3 | 4 | from warnings import filterwarnings 5 | from pathlib import Path 6 | from functools import partial 7 | from multiprocessing import Pool, cpu_count 8 | 9 | import torch 10 | import soundfile as sf 11 | from jsonargparse import ArgumentParser, ActionConfigFile 12 | 13 | from data import load_wav, log_mel_spectrogram 14 | 15 | 16 | def parse_args(): 17 | """Parse command-line arguments.""" 18 | parser = ArgumentParser() 19 | parser.add_argument("ckpt_path", type=str) 20 | parser.add_argument("audio_paths", type=str, nargs="+") 21 | parser.add_argument("-o", "--output_dir", type=str, default=".") 22 | 23 | parser.add_argument("--sample_rate", type=int, default=16000) 24 | parser.add_argument("--preemph", type=float, default=0.97) 25 | parser.add_argument("--hop_len", type=int, default=200) 26 | parser.add_argument("--win_len", type=int, default=800) 27 | parser.add_argument("--n_fft", type=int, default=2048) 28 | parser.add_argument("--n_mels", type=int, default=80) 29 | parser.add_argument("--f_min", type=int, default=50) 30 | parser.add_argument("--audio_config", action=ActionConfigFile) 31 | return vars(parser.parse_args()) 32 | 33 | 34 | def main( 35 | ckpt_path, 36 | audio_paths, 37 | output_dir, 38 | sample_rate, 39 | preemph, 40 | hop_len, 41 | win_len, 42 | n_fft, 43 | n_mels, 44 | f_min, 45 | **kwargs, 46 | ): 47 | """Main function.""" 48 | 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | 51 | model = torch.jit.load(ckpt_path) 52 | model.to(device) 53 | 54 | path2wav = partial(load_wav, sample_rate=sample_rate) 55 | wav2mel = partial( 56 | log_mel_spectrogram, 57 | preemph=preemph, 58 | sample_rate=sample_rate, 59 | n_mels=n_mels, 60 | n_fft=n_fft, 61 | hop_length=hop_len, 62 | win_length=win_len, 63 | f_min=f_min, 64 | ) 65 | 66 | with Pool(cpu_count()) as pool: 67 | wavs = pool.map(path2wav, audio_paths) 68 | mels = pool.map(wav2mel, wavs) 69 | 70 | print("mels length:", [len(mel) for mel in mels]) 71 | 72 | mel_tensors = [torch.FloatTensor(mel).to(device) for mel in mels] 73 | 74 | with torch.no_grad(): 75 | wavs = model.generate(mel_tensors) 76 | wavs = [wav.detach().cpu().numpy() for wav in wavs] 77 | 78 | for wav, audio_path in zip(wavs, audio_paths): 79 | wav_path_name = Path(audio_path).name 80 | wav_path = Path(output_dir, wav_path_name).with_suffix(".rec.wav") 81 | sf.write(wav_path, wav, sample_rate) 82 | 83 | 84 | if __name__ == "__main__": 85 | filterwarnings("ignore") 86 | main(**parse_args()) 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Train universal vocoder.""" 3 | 4 | from datetime import datetime 5 | from pathlib import Path 6 | from warnings import filterwarnings 7 | 8 | import tqdm 9 | import torch 10 | from torch.nn.functional import cross_entropy 11 | from torch.optim import Adam 12 | from torch.optim.lr_scheduler import StepLR 13 | from torch.utils.data import DataLoader, random_split 14 | from torch.utils.tensorboard import SummaryWriter 15 | from jsonargparse import ArgumentParser, ActionConfigFile 16 | 17 | from data import VocoderDataset 18 | 19 | from models import UniversalVocoder 20 | 21 | 22 | def parse_args(): 23 | """Parse command-line arguments.""" 24 | parser = ArgumentParser() 25 | parser.add_argument("data_dir", type=str) 26 | parser.add_argument("--n_workers", type=int, default=8) 27 | parser.add_argument("--save_dir", type=str, default=".") 28 | parser.add_argument("--comment", type=str) 29 | 30 | parser.add_argument("--frames_per_sample", type=int, default=40) 31 | parser.add_argument("--frames_per_slice", type=int, default=8) 32 | parser.add_argument("--bits", type=int, default=9) 33 | parser.add_argument("--conditioning_channels", type=int, default=128) 34 | parser.add_argument("--embedding_dim", type=int, default=256) 35 | parser.add_argument("--rnn_channels", type=int, default=896) 36 | parser.add_argument("--fc_channels", type=int, default=512) 37 | parser.add_argument("--batch_size", type=int, default=32) 38 | parser.add_argument("--n_steps", type=int, default=100000) 39 | parser.add_argument("--valid_every", type=int, default=1000) 40 | parser.add_argument("--valid_ratio", type=float, default=0.1) 41 | parser.add_argument("--save_every", type=int, default=10000) 42 | parser.add_argument("--learning_rate", type=float, default=4e-4) 43 | parser.add_argument("--decay_every", type=int, default=20000) 44 | parser.add_argument("--decay_gamma", type=float, default=0.5) 45 | parser.add_argument("--training_config", action=ActionConfigFile) 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def main( 51 | data_dir, 52 | n_workers, 53 | save_dir, 54 | comment, 55 | frames_per_sample, 56 | frames_per_slice, 57 | bits, 58 | conditioning_channels, 59 | embedding_dim, 60 | rnn_channels, 61 | fc_channels, 62 | batch_size, 63 | n_steps, 64 | valid_every, 65 | valid_ratio, 66 | save_every, 67 | learning_rate, 68 | decay_every, 69 | decay_gamma, 70 | **kwargs, 71 | ): 72 | """Main function.""" 73 | 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | dataset = VocoderDataset( 76 | data_dir, 77 | Path(data_dir) / "metadata.json", 78 | frames_per_sample, 79 | frames_per_slice, 80 | bits, 81 | ) 82 | lengths = [ 83 | trainlen := int((1 - valid_ratio) * len(dataset)), 84 | len(dataset) - trainlen, 85 | ] 86 | trainset, validset = random_split(dataset, lengths) 87 | train_loader = DataLoader( 88 | trainset, 89 | batch_size=batch_size, 90 | shuffle=True, 91 | drop_last=True, 92 | num_workers=n_workers, 93 | pin_memory=True, 94 | ) 95 | valid_loader = DataLoader( 96 | validset, 97 | batch_size=batch_size, 98 | shuffle=False, 99 | drop_last=False, 100 | num_workers=n_workers, 101 | pin_memory=True, 102 | ) 103 | 104 | model = UniversalVocoder( 105 | sample_rate=dataset.sample_rate, 106 | frames_per_sample=frames_per_sample, 107 | frames_per_slice=frames_per_slice, 108 | mel_dim=dataset.n_mels, 109 | mel_rnn_dim=conditioning_channels, 110 | emb_dim=embedding_dim, 111 | wav_rnn_dim=rnn_channels, 112 | affine_dim=fc_channels, 113 | bits=bits, 114 | hop_length=dataset.hop_len, 115 | ) 116 | model.to(device) 117 | model = torch.jit.script(model) 118 | 119 | optimizer = Adam(model.parameters(), lr=learning_rate) 120 | scheduler = StepLR(optimizer, decay_every, decay_gamma) 121 | 122 | if comment is not None: 123 | log_dir = "logs/" 124 | log_dir += datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 125 | log_dir += "_" + comment 126 | writer = SummaryWriter(log_dir) 127 | 128 | train_iterator = iter(train_loader) 129 | losses = [] 130 | pbar = tqdm.tqdm(total=valid_every * train_loader.batch_size, ncols=0, desc="Train") 131 | 132 | for step in range(n_steps): 133 | try: 134 | mels, wavs = next(train_iterator) 135 | except StopIteration: 136 | train_iterator = iter(train_loader) 137 | mels, wavs = next(train_iterator) 138 | 139 | mels = mels.to(device) 140 | wavs = wavs.to(device) 141 | 142 | outs = model(wavs[:, :-1], mels) 143 | 144 | loss = cross_entropy(outs.transpose(1, 2), wavs[:, 1:]) 145 | 146 | losses.append(loss.item()) 147 | pbar.set_postfix(step=step + 1, loss=loss.item()) 148 | 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | scheduler.step() 153 | 154 | pbar.update(train_loader.batch_size) 155 | 156 | if (step + 1) % valid_every == 0: 157 | pbar.close() 158 | 159 | train_loss = sum(losses) / len(losses) 160 | print(f"[train] loss = {train_loss:.4f}") 161 | losses = [] 162 | 163 | pbar = tqdm.tqdm( 164 | total=len(valid_loader.dataset), ncols=0, leave=False, desc="Valid" 165 | ) 166 | for mels, wavs in valid_loader: 167 | mels = mels.to(device) 168 | wavs = wavs.to(device) 169 | with torch.no_grad(): 170 | outs = model(wavs[:, :-1], mels) 171 | loss = cross_entropy(outs.transpose(1, 2), wavs[:, 1:]) 172 | losses.append(loss.item()) 173 | pbar.update(valid_loader.batch_size) 174 | pbar.close() 175 | 176 | valid_loss = sum(losses) / len(losses) 177 | print(f"[valid] loss = {valid_loss:.4f}") 178 | losses = [] 179 | 180 | if comment is not None: 181 | writer.add_scalar("Loss/train", train_loss, step + 1) 182 | writer.add_scalar("Loss/valid", valid_loss, step + 1) 183 | 184 | pbar = tqdm.tqdm( 185 | total=valid_every * train_loader.batch_size, ncols=0, desc="Train" 186 | ) 187 | 188 | if (step + 1) % save_every == 0: 189 | save_dir_path = Path(save_dir) 190 | save_dir_path.mkdir(parents=True, exist_ok=True) 191 | checkpoint_path = save_dir_path / f"vocoder-ckpt-{step+1}.pt" 192 | torch.jit.save(model.cpu(), str(checkpoint_path)) 193 | model.to(device) 194 | 195 | 196 | if __name__ == "__main__": 197 | filterwarnings("ignore") 198 | main(**vars(parse_args())) 199 | --------------------------------------------------------------------------------