├── stylemelgan ├── __init__.py ├── configs │ ├── __init__.py │ └── melgan_config.yaml ├── generator │ ├── __init__.py │ └── melgan.py ├── common.py ├── utils.py ├── inference.py ├── audio.py ├── preprocess.py ├── losses.py ├── discriminator.py ├── dataset.py └── train.py ├── requirements.txt ├── README.md └── .gitignore /stylemelgan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stylemelgan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stylemelgan/generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.7 2 | PyYAML >= 5.1 3 | tensorboard 4 | librosa 5 | tqdm 6 | matplotlib -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | -- Under Construction -- 2 | 3 | This will become an unofficial implementation of StyleMelGAN. Currently implemented is an unofficial version of FB-MelGAN: 4 | https://arxiv.org/pdf/2005.05106.pdf. -------------------------------------------------------------------------------- /stylemelgan/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Conv1d, ConvTranspose1d 3 | from torch.nn.utils import weight_norm 4 | 5 | 6 | class WNConv1d(Module): 7 | 8 | def __init__(self, *args, **kwargs,) -> None: 9 | super().__init__() 10 | conv = Conv1d(*args, **kwargs) 11 | self.conv = weight_norm(conv) 12 | 13 | def forward(self, x: torch.Tensor) -> torch.Tensor: 14 | return self.conv(x) 15 | 16 | 17 | class WNConvTranspose1d(Module): 18 | 19 | def __init__(self, *args, **kwargs,) -> None: 20 | super().__init__() 21 | conv = ConvTranspose1d(*args, **kwargs) 22 | self.conv = weight_norm(conv) 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | return self.conv(x) -------------------------------------------------------------------------------- /stylemelgan/configs/melgan_config.yaml: -------------------------------------------------------------------------------- 1 | model_type: 'melgan' 2 | model_name: 'bild_neurips_nostft' 3 | 4 | paths: 5 | train_dir: '/Users/cschaefe/datasets/asvoice2_splitted_train' 6 | val_dir: '/Users/cschaefe/datasets/asvoice2_splitted_val' 7 | checkpoints: 'stylemelgan/checkpoints' 8 | 9 | audio: 10 | sample_rate: 22050 11 | n_mels: 80 12 | n_fft: 1024 13 | win_length: 1024 14 | hop_length: 256 15 | fmin: 0 16 | fmax: 8000 17 | 18 | training: 19 | batch_size: 16 20 | pretraining_steps: 200000 21 | segment_len: 16000 22 | g_lr: 0.0001 23 | d_lr: 0.0001 24 | eval_steps: 10000 25 | epochs: 100000 26 | num_workers: 6 27 | 28 | model: 29 | channels: [512, 256, 128, 64, 32] 30 | res_layers: [5, 7, 8, 9] 31 | relu_slope: 0.2 32 | padding_val: -11.5129 -------------------------------------------------------------------------------- /stylemelgan/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | from typing import Dict, Any 4 | 5 | from torch.nn import Module 6 | from torch.nn.utils import remove_weight_norm 7 | 8 | 9 | def get_children(model: Module): 10 | children = list(model.children()) 11 | flat_children = [] 12 | if not children: 13 | return model 14 | else: 15 | for child in children: 16 | try: 17 | flat_children.extend(get_children(child)) 18 | except TypeError: 19 | flat_children.append(get_children(child)) 20 | return flat_children 21 | 22 | 23 | def remove_weight_norm_recursively(model: Module) -> None: 24 | layers = get_children(model) 25 | for l in layers: 26 | try: 27 | remove_weight_norm(l) 28 | except Exception as e: 29 | pass 30 | 31 | 32 | def read_config(path: str) -> Dict[str, Any]: 33 | with open(path, 'r') as stream: 34 | config = yaml.load(stream, Loader=yaml.FullLoader) 35 | return config 36 | 37 | 38 | def save_config(config: Dict[str, Any], path: str) -> None: 39 | with open(path, 'w+', encoding='utf-8') as stream: 40 | yaml.dump(config, stream, default_flow_style=False) -------------------------------------------------------------------------------- /stylemelgan/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import tqdm 3 | import torch 4 | import argparse 5 | 6 | from stylemelgan.audio import Audio 7 | from stylemelgan.generator.melgan import MelganGenerator 8 | from stylemelgan.utils import remove_weight_norm_recursively 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--path', type=str, required=True, help='directory containing .mel files to vocode') 13 | parser.add_argument('--checkpoint', type=str, required=True, help='model checkpoint') 14 | args = parser.parse_args() 15 | 16 | checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu')) 17 | step = checkpoint['step'] 18 | g_model = MelganGenerator.from_config(checkpoint['config']) 19 | g_model.eval() 20 | remove_weight_norm_recursively(g_model) 21 | audio = Audio.from_config(checkpoint['config']) 22 | print(f'Loaded melgan with step {step}') 23 | 24 | mel_files = list(Path(args.path).glob('**/*.mel')) 25 | 26 | for mel_file in tqdm.tqdm(mel_files, total=len(mel_files)): 27 | mel = torch.load(mel_file) 28 | wav = g_model.inference(mel) 29 | wav = wav.squeeze().cpu().numpy() 30 | save_path = str(mel_file).replace('.mel', f'_voc_step_{step//1000}k.wav') 31 | audio.save_wav(wav, save_path) 32 | -------------------------------------------------------------------------------- /stylemelgan/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import soundfile as sf 4 | from pathlib import Path 5 | from typing import Dict, Any, Union 6 | 7 | 8 | class Audio: 9 | 10 | def __init__(self, 11 | n_mels: int, 12 | sample_rate: int, 13 | hop_length: int, 14 | win_length: int, 15 | n_fft: int, 16 | fmin: float, 17 | fmax: float) -> None: 18 | 19 | self.n_mels = n_mels 20 | self.sample_rate = sample_rate 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.n_fft = n_fft 24 | self.fmin = fmin 25 | self.fmax = fmax 26 | 27 | @classmethod 28 | def from_config(cls, config: Dict[str, Any]) -> 'Audio': 29 | return Audio(**config['audio']) 30 | 31 | def load_wav(self, path: Union[str, Path]) -> np.array: 32 | wav, _ = librosa.load(path, sr=self.sample_rate) 33 | return wav 34 | 35 | def save_wav(self, wav: np.array, path: Union[str, Path]) -> None: 36 | wav = wav.astype(np.float32) 37 | sf.write(str(path), wav, samplerate=self.sample_rate) 38 | 39 | def wav_to_mel(self, y: np.array, normalize=True) -> np.array: 40 | spec = librosa.stft( 41 | y=y, 42 | n_fft=self.n_fft, 43 | hop_length=self.hop_length, 44 | win_length=self.win_length) 45 | spec = np.abs(spec) 46 | mel = librosa.feature.melspectrogram( 47 | S=spec, 48 | sr=self.sample_rate, 49 | n_fft=self.n_fft, 50 | n_mels=self.n_mels, 51 | fmin=self.fmin, 52 | fmax=self.fmax) 53 | if normalize: 54 | mel = self.normalize(mel) 55 | return mel 56 | 57 | def normalize(self, mel: np.array) -> np.array: 58 | mel = np.clip(mel, a_min=1.e-5, a_max=None) 59 | return np.log(mel) 60 | 61 | def denormalize(self, mel: np.array) -> np.array: 62 | return np.exp(mel) -------------------------------------------------------------------------------- /stylemelgan/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from multiprocessing import cpu_count 3 | from multiprocessing.pool import Pool 4 | from pathlib import Path 5 | 6 | import torch 7 | import tqdm 8 | 9 | from stylemelgan.audio import Audio 10 | from stylemelgan.utils import read_config 11 | 12 | 13 | def valid_n_workers(num): 14 | n = int(num) 15 | if n < 1: 16 | raise argparse.ArgumentTypeError('%r must be an integer greater than 0' % num) 17 | return n 18 | 19 | 20 | class Preprocessor: 21 | 22 | def __init__(self, audio: Audio) -> None: 23 | self.audio = audio 24 | 25 | def __call__(self, file: Path) -> None: 26 | wav = self.audio.load_wav(file) 27 | mel = self.audio.wav_to_mel(wav) 28 | mel = torch.from_numpy(mel).unsqueeze(0).float() 29 | save_path = str(file).replace('.wav', '.mel') 30 | torch.save(mel, save_path) 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description='Arguments for mel preprocessing.') 35 | parser.add_argument('--num_workers', '-w', metavar='N', type=valid_n_workers, default=cpu_count() - 1, 36 | help='The number of worker threads to use for preprocessing.') 37 | parser.add_argument('--config', metavar='FILE', default='config.yaml', 38 | help='The config containing all hyperparams.') 39 | args = parser.parse_args() 40 | 41 | config = read_config(args.config) 42 | audio = Audio.from_config(config) 43 | preprocessor = Preprocessor(audio) 44 | train_data_path = Path(config['paths']['train_dir']) 45 | val_data_path = Path(config['paths']['val_dir']) 46 | train_files = list(train_data_path.glob('**/*.wav')) 47 | val_files = list(val_data_path.glob('**/*.wav')) 48 | all_files = train_files + val_files 49 | n_workers = max(1, args.num_workers) 50 | 51 | pool = Pool(processes=n_workers) 52 | 53 | for _ in tqdm.tqdm(pool.imap_unordered(preprocessor, all_files), total=len(all_files)): 54 | pass 55 | 56 | print('Preprocessing done.') 57 | 58 | -------------------------------------------------------------------------------- /stylemelgan/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Tomoki Hayashi 2 | # MIT License (https://opensource.org/licenses/MIT) 3 | # adapted from https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/losses/stft_loss.py 4 | from typing import Tuple 5 | import torch.nn.functional as F 6 | import torch 7 | from torch.nn import Module 8 | from distutils.version import LooseVersion 9 | 10 | is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") 11 | 12 | 13 | def stft(x: torch.Tensor, 14 | n_fft: int, 15 | hop_length: int, 16 | win_length: int) -> torch.Tensor: 17 | window = torch.hann_window(win_length, device=x.device) 18 | if is_pytorch_17plus: 19 | x_stft = torch.stft( 20 | input=x, n_fft=n_fft, hop_length=hop_length, 21 | win_length=win_length, window=window, return_complex=False) 22 | else: 23 | x_stft = torch.stft( 24 | input=x, n_fft=n_fft, hop_length=hop_length, 25 | win_length=win_length, window=window) 26 | 27 | real = x_stft[..., 0] 28 | imag = x_stft[..., 1] 29 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 30 | 31 | 32 | class MultiResStftLoss(Module): 33 | 34 | def __init__(self) -> None: 35 | super().__init__() 36 | self.n_ffts = [1024, 2048, 512] 37 | self.hop_sizes = [120, 240, 50] 38 | self.win_lengths = [600, 1200, 240] 39 | 40 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 41 | norm_loss = 0. 42 | spec_loss = 0. 43 | for n_fft, hop_length, win_length in zip(self.n_ffts, self.hop_sizes, self.win_lengths): 44 | x_stft = stft(x=x, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 45 | y_stft = stft(x=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 46 | norm_loss += F.l1_loss(torch.log(x_stft), torch.log(y_stft)) 47 | spec_loss += torch.norm(y_stft - x_stft, p="fro") / torch.norm(y_stft, p="fro") 48 | return norm_loss / len(self.n_ffts), spec_loss / len(self.n_ffts) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ -------------------------------------------------------------------------------- /stylemelgan/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Sequential, LeakyReLU 4 | 5 | from stylemelgan.common import WNConv1d 6 | 7 | 8 | class Identity(nn.Module): 9 | def __init__(self): 10 | super(Identity, self).__init__() 11 | 12 | def forward(self, x): 13 | return x 14 | 15 | 16 | class Discriminator(nn.Module): 17 | 18 | def __init__(self, relu_slope: float = 0.2): 19 | super(Discriminator, self).__init__() 20 | 21 | self.discriminator = nn.ModuleList([ 22 | Sequential( 23 | WNConv1d(1, 16, kernel_size=15, stride=1, padding=7, padding_mode='reflect'), 24 | LeakyReLU(relu_slope, inplace=True) 25 | ), 26 | Sequential( 27 | WNConv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4), 28 | LeakyReLU(relu_slope, inplace=True) 29 | ), 30 | Sequential( 31 | WNConv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16), 32 | LeakyReLU(relu_slope, inplace=True) 33 | ), 34 | Sequential( 35 | WNConv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64), 36 | LeakyReLU(relu_slope, inplace=True) 37 | ), 38 | Sequential( 39 | WNConv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256), 40 | LeakyReLU(relu_slope, inplace=True) 41 | ), 42 | Sequential( 43 | WNConv1d(1024, 1024, kernel_size=5, stride=1, padding=2), 44 | LeakyReLU(relu_slope, inplace=True) 45 | ), 46 | WNConv1d(1024, 1, kernel_size=3, stride=1, padding=1) 47 | ]) 48 | 49 | def forward(self, x): 50 | features = [] 51 | for module in self.discriminator: 52 | x = module(x) 53 | features.append(x) 54 | return features[:-1], features[-1] 55 | 56 | 57 | class MultiScaleDiscriminator(nn.Module): 58 | 59 | def __init__(self): 60 | super(MultiScaleDiscriminator, self).__init__() 61 | 62 | self.discriminators = nn.ModuleList( 63 | [Discriminator() for _ in range(3)] 64 | ) 65 | 66 | self.pooling = nn.ModuleList( 67 | [Identity()] + 68 | [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)] 69 | ) 70 | 71 | def forward(self, x): 72 | ret = list() 73 | 74 | for pool, disc in zip(self.pooling, self.discriminators): 75 | x = pool(x) 76 | ret.append(disc(x)) 77 | 78 | return ret # [(feat, score), (feat, score), (feat, score)] 79 | 80 | 81 | 82 | if __name__ == '__main__': 83 | model = Discriminator() 84 | 85 | x = torch.randn(3, 1, 22050) 86 | print(x.shape) 87 | 88 | features, score = model(x) 89 | for feat in features: 90 | print(feat.shape) 91 | print(score.shape) 92 | 93 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 94 | print(pytorch_total_params) -------------------------------------------------------------------------------- /stylemelgan/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Dict, Union 4 | 5 | import librosa 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class AudioDataset(Dataset): 11 | 12 | def __init__(self, 13 | data_path: Path, 14 | hop_len: int, 15 | segment_len: Union[int, None], 16 | sample_rate: int, 17 | padding_val: float = -11.5129) -> None: 18 | mel_names = list(data_path.glob('**/*.mel')) 19 | self.data_path = data_path 20 | self.hop_len = hop_len 21 | self.segment_len = segment_len 22 | self.padding_val = padding_val 23 | self.sample_rate = sample_rate 24 | self.file_ids = [n.stem for n in mel_names] 25 | if segment_len is not None: 26 | self.mel_segment_len = segment_len // hop_len + 2 27 | 28 | def __len__(self): 29 | return len(self.file_ids) 30 | 31 | def __getitem__(self, item_id: int) -> Dict[str, torch.Tensor]: 32 | file_id = self.file_ids[item_id] 33 | mel_path = self.data_path / f'{file_id}.mel' 34 | wav_path = self.data_path / f'{file_id}.wav' 35 | wav, _ = librosa.load(wav_path, sr=self.sample_rate) 36 | wav = torch.tensor(wav).float() 37 | mel = torch.load(mel_path).squeeze(0) 38 | if self.segment_len is not None: 39 | mel_pad_len = 2 * self.mel_segment_len - mel.size(-1) 40 | if mel_pad_len > 0: 41 | mel_pad = torch.full((mel.size(0), mel_pad_len), fill_value=self.padding_val) 42 | mel = torch.cat([mel, mel_pad], dim=-1) 43 | wav_pad_len = mel.size(-1) * self.hop_len - wav.size(0) 44 | if wav_pad_len > 0: 45 | wav_pad = torch.zeros((wav_pad_len, )) 46 | wav = torch.cat([wav, wav_pad], dim=0) 47 | max_mel_start = mel.size(-1) - self.mel_segment_len 48 | mel_start = random.randint(0, max_mel_start) 49 | mel_end = mel_start + self.mel_segment_len 50 | mel = mel[:, mel_start:mel_end] 51 | wav_start = mel_start * self.hop_len 52 | wav_end = wav_start + self.segment_len 53 | wav = wav[wav_start:wav_end] 54 | wav = wav + (1 / 32768) * torch.randn_like(wav) 55 | wav = wav.unsqueeze(0) 56 | return {'mel': mel, 'wav': wav} 57 | 58 | 59 | def new_dataloader(data_path: Path, 60 | segment_len: int, 61 | hop_len: int, 62 | batch_size: int, 63 | sample_rate: int, 64 | num_workers: int = 0) -> DataLoader: 65 | 66 | dataset = AudioDataset(data_path=data_path, segment_len=segment_len, hop_len=hop_len, sample_rate=sample_rate) 67 | dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, 68 | num_workers=num_workers, pin_memory=True, drop_last=True) 69 | return dataloader 70 | 71 | 72 | if __name__ == '__main__': 73 | data_path = Path('/Users/cschaefe/datasets/asvoice2_splitted_train') 74 | dataloader = new_dataloader(data_path=data_path, segment_len=16000, hop_len=256, batch_size=2) 75 | for item in dataloader: 76 | print(item['mel'].size()) 77 | print(item['wav'].size()) 78 | -------------------------------------------------------------------------------- /stylemelgan/generator/melgan.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any 2 | 3 | import torch 4 | from torch.nn import Module, ModuleList, Sequential, LeakyReLU, Tanh 5 | 6 | from stylemelgan.common import WNConv1d, WNConvTranspose1d 7 | from stylemelgan.utils import read_config 8 | 9 | 10 | class ResBlock(Module): 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | out_channels: int, 15 | dilation: int, 16 | relu_slope: float = 0.2): 17 | super().__init__() 18 | self.conv_block = Sequential( 19 | LeakyReLU(relu_slope), 20 | WNConv1d(in_channels=in_channels, out_channels=out_channels, 21 | kernel_size=3, dilation=dilation, padding=dilation, 22 | padding_mode='reflect'), 23 | LeakyReLU(relu_slope), 24 | WNConv1d(in_channels=out_channels, out_channels=out_channels, 25 | kernel_size=1) 26 | ) 27 | self.residual = WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | return self.residual(x) + self.conv_block(x) 31 | 32 | 33 | class ResStack(Module): 34 | 35 | def __init__(self, 36 | in_channels: int, 37 | out_channels: int, 38 | num_layers: int, 39 | relu_slope: float = 0.2): 40 | super().__init__() 41 | 42 | self.res_blocks = ModuleList([ 43 | ResBlock(in_channels=in_channels if i == 0 else out_channels, 44 | out_channels=out_channels, 45 | dilation=3 ** i, 46 | relu_slope=relu_slope) 47 | for i in range(num_layers) 48 | ]) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | for res_block in self.res_blocks: 52 | x = res_block(x) 53 | return x 54 | 55 | 56 | class MelganGenerator(Module): 57 | 58 | def __init__(self, 59 | mel_channels: int, 60 | channels: Tuple = (512, 256, 128, 64, 32), 61 | res_layers: Tuple = (5, 7, 8, 9), 62 | relu_slope: float = 0.2, 63 | padding_val: float = -11.5129) -> None: 64 | super().__init__() 65 | 66 | self.padding_val = padding_val 67 | self.mel_channels = mel_channels 68 | self.hop_length = 256 69 | c_0, c_1, c_2, c_3, c_4 = channels 70 | r_0, r_1, r_2, r_3 = res_layers 71 | 72 | self.blocks = Sequential( 73 | WNConv1d(mel_channels, c_0, kernel_size=7, stride=1, padding=3, padding_mode='reflect'), 74 | LeakyReLU(relu_slope), 75 | WNConvTranspose1d(c_0, c_1, kernel_size=16, stride=8, padding=4), 76 | ResStack(c_1, c_1, num_layers=r_0), 77 | LeakyReLU(relu_slope), 78 | WNConvTranspose1d(c_1, c_2, kernel_size=16, stride=8, padding=4), 79 | ResStack(c_2, c_2, num_layers=r_1), 80 | LeakyReLU(relu_slope), 81 | WNConvTranspose1d(c_2, c_3, kernel_size=4, stride=2, padding=1), 82 | ResStack(c_3, c_3, num_layers=r_2), 83 | LeakyReLU(relu_slope), 84 | WNConvTranspose1d(c_3, c_4, kernel_size=4, stride=2, padding=1), 85 | ResStack(c_4, c_4, num_layers=r_3), 86 | LeakyReLU(relu_slope), 87 | WNConv1d(c_4, 1, kernel_size=7, padding=3, padding_mode='reflect'), 88 | Tanh() 89 | ) 90 | 91 | def forward(self, mel: torch.Tensor) -> torch.Tensor: 92 | mel = (mel.detach() + 5.0) / 5.0 93 | return self.blocks(mel) 94 | 95 | def inference(self, 96 | mel: torch.Tensor, 97 | pad_steps: int = 10) -> torch.Tensor: 98 | with torch.no_grad(): 99 | pad = torch.full((1, self.mel_channels, pad_steps), 100 | self.padding_val).to(mel.device) 101 | mel = torch.cat((mel, pad), dim=2) 102 | audio = self.forward(mel).squeeze() 103 | audio = audio[:-(self.hop_length * pad_steps)] 104 | return audio 105 | 106 | @classmethod 107 | def from_config(cls, config: Dict[str, Any]) -> 'MelganGenerator': 108 | return MelganGenerator(mel_channels=config['audio']['n_mels'], 109 | **config['model']) 110 | 111 | @classmethod 112 | def from_checkpoint(cls, file: str) -> 'MelganGenerator': 113 | checkpoint = torch.load(file, map_location=torch.device('cpu')) 114 | config = checkpoint['config'] 115 | model = MelganGenerator.from_config(config) 116 | model.load_state_dict(config['g_model']) 117 | return model 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | config = read_config('../stylemelgan/configs/melgan_config.yaml') 123 | model = MelganGenerator.from_config(config) 124 | x = torch.randn(3, 80, 1000) 125 | print(x.shape) 126 | 127 | y = model(x) 128 | print(y.shape) 129 | #assert y.shape == torch.Size([3, 1, 2560]) 130 | 131 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 132 | print(pytorch_total_params) 133 | -------------------------------------------------------------------------------- /stylemelgan/train.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | import argparse 4 | import matplotlib as mpl 5 | import torch 6 | import torch.nn.functional as F 7 | import tqdm 8 | from matplotlib.figure import Figure 9 | from torch.cuda import is_available 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from stylemelgan.audio import Audio 13 | from stylemelgan.dataset import new_dataloader, AudioDataset 14 | from stylemelgan.discriminator import MultiScaleDiscriminator 15 | from stylemelgan.generator.melgan import MelganGenerator 16 | from stylemelgan.losses import stft, MultiResStftLoss 17 | from stylemelgan.utils import read_config 18 | 19 | mpl.use('agg') # Use non-interactive backend by default 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | def plot_mel(mel: np.array) -> Figure: 25 | mel = np.flip(mel, axis=0) 26 | fig = plt.figure(figsize=(12, 6), dpi=150) 27 | plt.imshow(mel, interpolation='nearest', aspect='auto') 28 | return fig 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--config', type=str, 34 | default='stylemelgan/configs/melgan_config.yaml', help='points to config.yaml') 35 | args = parser.parse_args() 36 | 37 | config = read_config(args.config) 38 | model_name = config['model_name'] 39 | audio = Audio.from_config(config) 40 | train_data_path = Path(config['paths']['train_dir']) 41 | val_data_path = Path(config['paths']['val_dir']) 42 | 43 | device = torch.device('cuda') if is_available() else torch.device('cpu') 44 | torch.backends.cudnn.benchmark = True 45 | 46 | step = 0 47 | 48 | g_model = MelganGenerator(audio.n_mels).to(device) 49 | d_model = MultiScaleDiscriminator().to(device) 50 | train_cfg = config['training'] 51 | g_optim = torch.optim.Adam(g_model.parameters(), lr=train_cfg['g_lr'], betas=(0.5, 0.9)) 52 | d_optim = torch.optim.Adam(d_model.parameters(), lr=train_cfg['d_lr'], betas=(0.5, 0.9)) 53 | 54 | multires_stft_loss = MultiResStftLoss().to(device) 55 | 56 | try: 57 | checkpoint = torch.load(f'checkpoints/latest_model__{model_name}.pt', map_location=device) 58 | g_model.load_state_dict(checkpoint['g_model']) 59 | g_optim.load_state_dict(checkpoint['g_optim']) 60 | d_model.load_state_dict(checkpoint['d_model']) 61 | d_optim.load_state_dict(checkpoint['d_optim']) 62 | step = checkpoint['step'] 63 | print(f'Loaded model with step {step}') 64 | except Exception as e: 65 | 'Initializing model from scratch.' 66 | 67 | train_cfg = config['training'] 68 | dataloader = new_dataloader(data_path=train_data_path, segment_len=train_cfg['segment_len'], 69 | hop_len=audio.hop_length, batch_size=train_cfg['batch_size'], 70 | num_workers=train_cfg['num_workers'], sample_rate=audio.sample_rate) 71 | val_dataset = AudioDataset(data_path=val_data_path, segment_len=None, hop_len=audio.hop_length, 72 | sample_rate=audio.sample_rate) 73 | 74 | stft = partial(stft, n_fft=1024, hop_length=256, win_length=1024) 75 | 76 | pretraining_steps = train_cfg['pretraining_steps'] 77 | 78 | summary_writer = SummaryWriter(log_dir=f'checkpoints/logs_{model_name}') 79 | 80 | best_stft = 9999 81 | 82 | for epoch in range(train_cfg['epochs']): 83 | pbar = tqdm.tqdm(enumerate(dataloader, 1), total=len(dataloader)) 84 | for i, data in pbar: 85 | step += 1 86 | mel = data['mel'].to(device) 87 | wav_real = data['wav'].to(device) 88 | 89 | wav_fake = g_model(mel)[:, :, :16000] 90 | 91 | d_loss = 0.0 92 | g_loss = 0.0 93 | stft_norm_loss = 0.0 94 | stft_spec_loss = 0.0 95 | 96 | if step > pretraining_steps: 97 | # discriminator 98 | d_fake = d_model(wav_fake.detach()) 99 | d_real = d_model(wav_real) 100 | for (_, score_fake), (_, score_real) in zip(d_fake, d_real): 101 | d_loss += F.relu(1.0 - score_real).mean() 102 | d_loss += F.relu(1.0 + score_fake).mean() 103 | d_optim.zero_grad() 104 | d_loss.backward() 105 | d_optim.step() 106 | 107 | # generator 108 | d_fake = d_model(wav_fake) 109 | for (feat_fake, score_fake), (feat_real, _) in zip(d_fake, d_real): 110 | g_loss += -score_fake.mean() 111 | for feat_fake_i, feat_real_i in zip(feat_fake, feat_real): 112 | g_loss += 10. * F.l1_loss(feat_fake_i, feat_real_i.detach()) 113 | 114 | factor = 1. if step < pretraining_steps else 0. 115 | stft_norm_loss, stft_spec_loss = multires_stft_loss(wav_fake.squeeze(1), wav_real.squeeze(1)) 116 | g_loss_all = g_loss + factor * (stft_norm_loss + stft_spec_loss) 117 | 118 | g_optim.zero_grad() 119 | g_loss_all.backward() 120 | g_optim.step() 121 | 122 | pbar.set_description(desc=f'Epoch: {epoch} | Step {step} ' 123 | f'| g_loss: {g_loss:#.4} ' 124 | f'| d_loss: {d_loss:#.4} ' 125 | f'| stft_norm_loss {stft_norm_loss:#.4} ' 126 | f'| stft_spec_loss {stft_spec_loss:#.4} ', refresh=True) 127 | 128 | summary_writer.add_scalar('generator_loss', g_loss, global_step=step) 129 | summary_writer.add_scalar('stft_norm_loss', stft_norm_loss, global_step=step) 130 | summary_writer.add_scalar('stft_spec_loss', stft_spec_loss, global_step=step) 131 | summary_writer.add_scalar('discriminator_loss', d_loss, global_step=step) 132 | 133 | if step % train_cfg['eval_steps'] == 0: 134 | g_model.eval() 135 | val_norm_loss = 0 136 | val_spec_loss = 0 137 | val_wavs = [] 138 | 139 | for i, val_data in enumerate(val_dataset): 140 | val_mel = val_data['mel'].to(device) 141 | val_mel = val_mel.unsqueeze(0) 142 | wav_fake = g_model.inference(val_mel, pad_steps=80).squeeze().cpu().numpy() 143 | wav_real = val_data['wav'].detach().squeeze().cpu().numpy() 144 | wav_f = torch.tensor(wav_fake).unsqueeze(0).to(device) 145 | wav_r = torch.tensor(wav_real).unsqueeze(0).to(device) 146 | val_wavs.append((wav_fake, wav_real)) 147 | size = min(wav_r.size(-1), wav_f.size(-1)) 148 | val_n, val_s = multires_stft_loss(wav_f[..., :size], wav_r[..., :size]) 149 | val_norm_loss += val_n 150 | val_spec_loss += val_s 151 | 152 | val_norm_loss /= len(val_dataset) 153 | val_spec_loss /= len(val_dataset) 154 | summary_writer.add_scalar('val_stft_norm_loss', val_norm_loss, global_step=step) 155 | summary_writer.add_scalar('val_stft_spec_loss', val_spec_loss, global_step=step) 156 | val_wavs.sort(key=lambda x: x[1].shape[0]) 157 | wav_fake, wav_real = val_wavs[-1] 158 | if val_norm_loss + val_spec_loss < best_stft: 159 | best_stft = val_norm_loss + val_spec_loss 160 | print(f'\nnew best stft: {best_stft}') 161 | torch.save({ 162 | 'g_model': g_model.state_dict(), 163 | 'g_optim': g_optim.state_dict(), 164 | 'd_model': d_model.state_dict(), 165 | 'd_optim': d_optim.state_dict(), 166 | 'config': config, 167 | 'step': step 168 | }, f'checkpoints/best_model_{model_name}.pt') 169 | summary_writer.add_audio('best_generated', wav_fake, sample_rate=audio.sample_rate, global_step=step) 170 | 171 | g_model.train() 172 | summary_writer.add_audio('generated', wav_fake, sample_rate=audio.sample_rate, global_step=step) 173 | summary_writer.add_audio('target', wav_real, sample_rate=audio.sample_rate, global_step=step) 174 | mel_fake = audio.wav_to_mel(wav_fake) 175 | mel_real = audio.wav_to_mel(wav_real) 176 | mel_fake_plot = plot_mel(mel_fake) 177 | mel_real_plot = plot_mel(mel_real) 178 | summary_writer.add_figure('mel_generated', mel_fake_plot, global_step=step) 179 | summary_writer.add_figure('mel_target', mel_real_plot, global_step=step) 180 | 181 | # epoch end 182 | torch.save({ 183 | 'g_model': g_model.state_dict(), 184 | 'g_optim': g_optim.state_dict(), 185 | 'd_model': d_model.state_dict(), 186 | 'd_optim': d_optim.state_dict(), 187 | 'config': config, 188 | 'step': step 189 | }, f'checkpoints/latest_model__{model_name}.pt') --------------------------------------------------------------------------------