├── .gitignore ├── figure ├── compare.png ├── overall.png └── compare_table.png ├── requirements.txt ├── env.py ├── LICENSE ├── config.json ├── config2.json ├── config2_pghi.json ├── utils.py ├── config_pghi.json ├── README.md ├── inference.py ├── dataset.py ├── train.py ├── train2.py ├── train2_pghi.py ├── train_pghi.py ├── models.py ├── models2.py ├── models2_pghi.py └── models_pghi.py /.gitignore: -------------------------------------------------------------------------------- 1 | cp_* 2 | *.pyc 3 | \.vscode 4 | test.ipynb -------------------------------------------------------------------------------- /figure/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BakerBunker/FreeV/HEAD/figure/compare.png -------------------------------------------------------------------------------- /figure/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BakerBunker/FreeV/HEAD/figure/overall.png -------------------------------------------------------------------------------- /figure/compare_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BakerBunker/FreeV/HEAD/figure/compare_table.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1+cu111 2 | numpy==1.21.6 3 | librosa==0.9.1 4 | tensorboard==2.8.0 5 | soundfile==0.10.3 6 | matplotlib==3.1.3 7 | # for pghi 8 | pghipy -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 redmist 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_training_wav_list": "/path/to/train_set", 3 | "input_validation_wav_list": "/path/to/test_set", 4 | "test_input_wavs_dir":"../../datasets/LJ_22050/LJ_test", 5 | "test_input_mels_dir":"./", 6 | "test_mel_load": 0, 7 | "test_output_dir": "/path/to/output", 8 | 9 | "batch_size": 16, 10 | "learning_rate": 0.0002, 11 | "adam_b1": 0.8, 12 | "adam_b2": 0.99, 13 | "lr_decay": 0.999, 14 | "seed": 1234, 15 | "training_epochs": 3100, 16 | "stdout_interval":20, 17 | "checkpoint_interval": 5000, 18 | "summary_interval": 100, 19 | "validation_interval": 1000, 20 | "checkpoint_path": "cp_APNet_22k", 21 | "checkpoint_file_load": "cp_APNet_22k/g_01000000", 22 | 23 | "ASP_channel": 512, 24 | "ASP_resblock_kernel_sizes": [3,7,11], 25 | "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 26 | "ASP_input_conv_kernel_size": 7, 27 | "ASP_output_conv_kernel_size": 7, 28 | 29 | "PSP_channel": 512, 30 | "PSP_resblock_kernel_sizes": [3,7,11], 31 | "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 32 | "PSP_input_conv_kernel_size": 7, 33 | "PSP_output_R_conv_kernel_size": 7, 34 | "PSP_output_I_conv_kernel_size": 7, 35 | 36 | "segment_size": 8192, 37 | "num_mels": 80, 38 | "n_fft": 1024, 39 | "hop_size": 256, 40 | "win_size": 1024, 41 | 42 | "sampling_rate": 22050, 43 | 44 | "fmin": 0, 45 | "fmax": 8000, 46 | "meloss":null, 47 | "num_workers": 4 48 | } 49 | -------------------------------------------------------------------------------- /config2.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_training_wav_list": "/path/to/train_set", 3 | "input_validation_wav_list": "/path/to/test_set", 4 | "test_input_wavs_dir":"../../datasets/LJ_22050/LJ_test", 5 | "test_input_mels_dir":"./", 6 | "test_mel_load": 0, 7 | "test_output_dir": "/path/to/output", 8 | 9 | "batch_size": 16, 10 | "learning_rate": 0.0002, 11 | "adam_b1": 0.8, 12 | "adam_b2": 0.99, 13 | "lr_decay": 0.999, 14 | "seed": 1234, 15 | "training_epochs": 3100, 16 | "stdout_interval":20, 17 | "checkpoint_interval": 5000, 18 | "summary_interval": 100, 19 | "validation_interval": 1000, 20 | "checkpoint_path": "cp_APNet_revabs_22k", 21 | "checkpoint_file_load": "cp_APNet_revabs_22k/g_01000000", 22 | 23 | "ASP_channel": 513, 24 | "ASP_resblock_kernel_sizes": [3,7,11], 25 | "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 26 | "ASP_input_conv_kernel_size": 7, 27 | "ASP_output_conv_kernel_size": 7, 28 | 29 | "PSP_channel": 512, 30 | "PSP_resblock_kernel_sizes": [3,7,11], 31 | "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 32 | "PSP_input_conv_kernel_size": 7, 33 | "PSP_output_R_conv_kernel_size": 7, 34 | "PSP_output_I_conv_kernel_size": 7, 35 | 36 | "segment_size": 8192, 37 | "num_mels": 80, 38 | "n_fft": 1024, 39 | "hop_size": 256, 40 | "win_size": 1024, 41 | 42 | "sampling_rate": 22050, 43 | 44 | "fmin": 0, 45 | "fmax": 8000, 46 | "meloss":null, 47 | "num_workers": 4 48 | } 49 | -------------------------------------------------------------------------------- /config2_pghi.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_training_wav_list": "/path/to/train_set", 3 | "input_validation_wav_list": "/path/to/test_set", 4 | "test_input_wavs_dir":"../../datasets/LJ_22050/LJ_test", 5 | "test_input_mels_dir":"./", 6 | "test_mel_load": 0, 7 | "test_output_dir": "/path/to/output", 8 | 9 | "batch_size": 16, 10 | "learning_rate": 0.0002, 11 | "adam_b1": 0.8, 12 | "adam_b2": 0.99, 13 | "lr_decay": 0.999, 14 | "seed": 1234, 15 | "training_epochs": 3100, 16 | "stdout_interval":20, 17 | "checkpoint_interval": 5000, 18 | "summary_interval": 100, 19 | "validation_interval": 1000, 20 | "checkpoint_path": "cp_APNet_revabs_22k_pghi", 21 | "checkpoint_file_load": "cp_APNet_revabs_22k/g_01000000", 22 | 23 | "ASP_channel": 513, 24 | "ASP_resblock_kernel_sizes": [3,7,11], 25 | "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 26 | "ASP_input_conv_kernel_size": 7, 27 | "ASP_output_conv_kernel_size": 7, 28 | 29 | "PSP_channel": 512, 30 | "PSP_resblock_kernel_sizes": [3,7,11], 31 | "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 32 | "PSP_input_conv_kernel_size": 7, 33 | "PSP_output_R_conv_kernel_size": 7, 34 | "PSP_output_I_conv_kernel_size": 7, 35 | 36 | "segment_size": 8192, 37 | "num_mels": 80, 38 | "n_fft": 1024, 39 | "hop_size": 256, 40 | "win_size": 1024, 41 | 42 | "sampling_rate": 22050, 43 | 44 | "fmin": 0, 45 | "fmax": 8000, 46 | "meloss":null, 47 | "num_workers": 4 48 | } 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | import shutil 9 | 10 | class AttrDict(dict): 11 | def __init__(self, *args, **kwargs): 12 | super(AttrDict, self).__init__(*args, **kwargs) 13 | self.__dict__ = self 14 | 15 | 16 | def build_env(config, config_name, path): 17 | t_path = os.path.join(path, config_name) 18 | if config != t_path: 19 | os.makedirs(path, exist_ok=True) 20 | shutil.copyfile(config, os.path.join(path, config_name)) 21 | 22 | def plot_spectrogram(spectrogram): 23 | fig, ax = plt.subplots(figsize=(10, 2)) 24 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 25 | interpolation='none') 26 | plt.colorbar(im, ax=ax) 27 | 28 | fig.canvas.draw() 29 | plt.close() 30 | 31 | return fig 32 | 33 | 34 | def init_weights(m, mean=0.0, std=0.01): 35 | classname = m.__class__.__name__ 36 | if classname.find("Conv") != -1: 37 | m.weight.data.normal_(mean, std) 38 | 39 | 40 | def apply_weight_norm(m): 41 | classname = m.__class__.__name__ 42 | if classname.find("Conv") != -1: 43 | weight_norm(m) 44 | 45 | 46 | def get_padding(kernel_size, dilation=1): 47 | return int((kernel_size*dilation - dilation)/2) 48 | 49 | 50 | def load_checkpoint(filepath, device): 51 | assert os.path.isfile(filepath) 52 | print("Loading '{}'".format(filepath)) 53 | checkpoint_dict = torch.load(filepath, map_location=device) 54 | print("Complete.") 55 | return checkpoint_dict 56 | 57 | 58 | def save_checkpoint(filepath, obj): 59 | print("Saving checkpoint to {}".format(filepath)) 60 | torch.save(obj, filepath) 61 | print("Complete.") 62 | 63 | 64 | def scan_checkpoint(cp_dir, prefix): 65 | pattern = os.path.join(cp_dir, prefix + '????????') 66 | cp_list = glob.glob(pattern) 67 | if len(cp_list) == 0: 68 | return None 69 | return sorted(cp_list)[-1] 70 | 71 | -------------------------------------------------------------------------------- /config_pghi.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_training_wav_list": "/path/to/train_set", 3 | "input_validation_wav_list": "/path/to/test_set", 4 | "test_input_wavs_dir": "../../datasets/LJ_22050/LJ_test", 5 | "test_input_mels_dir": "./", 6 | "test_mel_load": 0, 7 | "test_output_dir": "/path/to/output", 8 | "batch_size": 64, 9 | "learning_rate": 0.0004, 10 | "adam_b1": 0.8, 11 | "adam_b2": 0.99, 12 | "lr_decay": 0.999, 13 | "seed": 1234, 14 | "training_epochs": 3100, 15 | "stdout_interval": 20, 16 | "checkpoint_interval": 5000, 17 | "summary_interval": 100, 18 | "validation_interval": 1000, 19 | "checkpoint_path": "cp_APNet_revabs_pghi", 20 | "checkpoint_file_load": "cp_APNet_revabs_pghi/g_01000000", 21 | "ASP_channel": 513, 22 | "ASP_resblock_kernel_sizes": [ 23 | 3, 24 | 7, 25 | 11 26 | ], 27 | "ASP_resblock_dilation_sizes": [ 28 | [ 29 | 1, 30 | 3, 31 | 5 32 | ], 33 | [ 34 | 1, 35 | 3, 36 | 5 37 | ], 38 | [ 39 | 1, 40 | 3, 41 | 5 42 | ] 43 | ], 44 | "ASP_input_conv_kernel_size": 7, 45 | "ASP_output_conv_kernel_size": 7, 46 | "PSP_channel": 513, 47 | "PSP_resblock_kernel_sizes": [ 48 | 3, 49 | 7, 50 | 11 51 | ], 52 | "PSP_resblock_dilation_sizes": [ 53 | [ 54 | 1, 55 | 3, 56 | 5 57 | ], 58 | [ 59 | 1, 60 | 3, 61 | 5 62 | ], 63 | [ 64 | 1, 65 | 3, 66 | 5 67 | ] 68 | ], 69 | "PSP_input_conv_kernel_size": 7, 70 | "PSP_output_R_conv_kernel_size": 7, 71 | "PSP_output_I_conv_kernel_size": 7, 72 | "segment_size": 8192, 73 | "num_mels": 80, 74 | "n_fft": 1024, 75 | "hop_size": 256, 76 | "win_size": 1024, 77 | "sampling_rate": 22050, 78 | "fmin": 0, 79 | "fmax": 8000, 80 | "meloss": null, 81 | "num_workers": 4 82 | } 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FreeV: Free Lunch For Vocoders Through Pseudo Inversed Mel Filter 2 | 3 | One liner code: 4 | ```python 5 | model_input = (mel_spec @ mel_filter.pinverse()).abs().clamp_min(1e-5) 6 | ``` 7 | 8 | Official Repository of the paper: [FreeV: Free Lunch For Vocoders Through Pseudo Inversed Mel Filter](https://arxiv.org/abs/2406.08196) 9 | 10 | **Audio samples** at: [https://bakerbunker.github.io/FreeV/](https://bakerbunker.github.io/FreeV/) 11 | 12 | **Model checkpoints** and **tensorboard training logs** available at: [huggingface](https://huggingface.co/Bakerbunker/FreeV_Model_Logs) 13 | 14 | ## Requirements 15 | ```bash 16 | git clone https://github.com/BakerBunker/FreeV.git 17 | cd FreeV 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Configs 22 | 23 | I tried using [PGHI(Phase Gradient Heap Integration)](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7890450) as phase spec initialization. But sadly it didn't work. 24 | 25 | Here is the config and train script of different settings, `diff ` to see the differences. 26 | 27 | | Model | Config File | Train Script | 28 | | --- | ---| ---| 29 | |APNet2|config.json|train.py| 30 | |APNet2 w/pghi|config_pghi.json|train_pghi.py| 31 | |FreeV | config2.json|train2.py| 32 | |FreeV w/pghi| config2_pghi.json|train2_pghi.py| 33 | 34 | ## Training 35 | ``` 36 | python 37 | ``` 38 | Checkpoints and copy of the configuration file are saved in the `checkpoint_path` directory in `config.json`. 39 | 40 | Modify the training and inference configuration by modifying the parameters in the `config.json`. 41 | 42 | ## Inference 43 | Download pretrained model on LJSpeech dataset at [huggingface](https://huggingface.co/Bakerbunker/FreeV_Model_Logs). 44 | 45 | Modify the `inference.py` to inference. 46 | 47 | 48 | ## Model Structure 49 | ![model](./figure/overall.png) 50 | 51 | ## Comparison with other models 52 | ![compare](./figure/compare.png) 53 | 54 | ![compare_table](./figure/compare_table.png) 55 | 56 | ## Acknowledgements 57 | We referred to [APNet2](https://github.com/redmist328/APNet2) to implement this. 58 | 59 | See the code changes at this [commit](https://github.com/BakerBunker/FreeV/commit/95e1e5cb3fe2b0360a30f39167e3e3ffd8097980) 60 | 61 | ## Citation 62 | ```bibtex 63 | @misc{lv2024freevfreelunchvocoders, 64 | title={FreeV: Free Lunch For Vocoders Through Pseudo Inversed Mel Filter}, 65 | author={Yuanjun Lv and Hai Li and Ying Yan and Junhui Liu and Danming Xie and Lei Xie}, 66 | year={2024}, 67 | eprint={2406.08196}, 68 | archivePrefix={arXiv}, 69 | primaryClass={cs.SD}, 70 | url={https://arxiv.org/abs/2406.08196}, 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | import torch 8 | from utils import AttrDict 9 | from dataset import mel_spectrogram, load_wav 10 | from models import Generator 11 | import soundfile as sf 12 | import librosa 13 | import numpy as np 14 | import time 15 | h = None 16 | device = None 17 | 18 | 19 | def load_checkpoint(filepath, device): 20 | assert os.path.isfile(filepath) 21 | print("Loading '{}'".format(filepath)) 22 | checkpoint_dict = torch.load(filepath, map_location=device) 23 | print("Complete.") 24 | return checkpoint_dict 25 | 26 | 27 | def get_mel(x): 28 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 29 | 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + '*') 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return '' 36 | return sorted(cp_list)[-1] 37 | 38 | 39 | def inference(h): 40 | generator = Generator(h).to(device) 41 | 42 | state_dict_g = load_checkpoint(h.checkpoint_file_load, device) 43 | generator.load_state_dict(state_dict_g['generator']) 44 | 45 | filelist = sorted(os.listdir(h.test_input_mels_dir if h.test_mel_load else h.test_input_wavs_dir)) 46 | 47 | os.makedirs(h.test_output_dir, exist_ok=True) 48 | 49 | generator.eval() 50 | l=0 51 | with torch.no_grad(): 52 | starttime = time.time() 53 | for i, filename in enumerate(filelist): 54 | 55 | # if h.test_mel_load: 56 | if 1: 57 | mel = np.load(os.path.join(h.test_input_wavs_dir, filename)) 58 | x = torch.FloatTensor(mel).to(device) 59 | x=x.transpose(1,2) 60 | else: 61 | raw_wav, _ = librosa.load(os.path.join(h.test_input_wavs_dir, filename), sr=h.sampling_rate, mono=True) 62 | raw_wav = torch.FloatTensor(raw_wav).to(device) 63 | x = get_mel(raw_wav.unsqueeze(0)) 64 | 65 | logamp_g, pha_g, _, _, y_g = generator(x) 66 | audio = y_g.squeeze() 67 | # logamp = logamp_g.squeeze() 68 | # pha = pha_g.squeeze() 69 | audio = audio.cpu().numpy() 70 | # logamp = logamp.cpu().numpy() 71 | # pha = pha.cpu().numpy() 72 | audiolen=len(audio) 73 | sf.write(os.path.join(h.test_output_dir, filename.split('.')[0]+'.wav'), audio, h.sampling_rate,'PCM_16') 74 | 75 | # print(pp) 76 | l+=audiolen 77 | 78 | # write(output_file, h.sampling_rate, audio) 79 | # print(output_file) 80 | end=time.time() 81 | print(end-starttime) 82 | print(l/22050) 83 | print(l/22050/(end-starttime)) 84 | 85 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_logamp.npy'), logamp) 86 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_pha.npy'), pha) 87 | # if i==9: 88 | # break 89 | 90 | def main(): 91 | print('Initializing Inference Process..') 92 | 93 | config_file = 'config.json' 94 | 95 | with open(config_file) as f: 96 | data = f.read() 97 | 98 | global h 99 | json_config = json.loads(data) 100 | h = AttrDict(json_config) 101 | 102 | torch.manual_seed(h.seed) 103 | global device 104 | if torch.cuda.is_available(): 105 | torch.cuda.manual_seed(h.seed) 106 | device = torch.device('cuda') 107 | else: 108 | device = torch.device('cpu') 109 | device = torch.device('cpu') 110 | inference(h) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | 116 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from librosa.filters import mel as librosa_mel_fn 9 | import librosa 10 | import torchaudio 11 | import torch.nn as nn 12 | from pghipy import pghi 13 | 14 | 15 | def load_wav(full_path, sample_rate): 16 | data, _ = librosa.load(full_path, sr=sample_rate, mono=True) 17 | return data 18 | 19 | 20 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 21 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 22 | 23 | 24 | def dynamic_range_decompression(x, C=1): 25 | return np.exp(x) / C 26 | 27 | 28 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 29 | return torch.log(torch.clamp(x, min=clip_val) * C) 30 | 31 | 32 | def dynamic_range_decompression_torch(x, C=1): 33 | return torch.exp(x) / C 34 | 35 | 36 | def spectral_normalize_torch(magnitudes): 37 | output = dynamic_range_compression_torch(magnitudes) 38 | return output 39 | 40 | 41 | def spectral_de_normalize_torch(magnitudes): 42 | output = dynamic_range_decompression_torch(magnitudes) 43 | return output 44 | 45 | 46 | mel_window = {} 47 | inv_mel_window = {} 48 | 49 | 50 | def param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device): 51 | return f"{sampling_rate}-{n_fft}-{num_mels}-{fmin}-{fmax}-{win_size}-{device}" 52 | 53 | 54 | def mel_spectrogram( 55 | y, 56 | n_fft, 57 | num_mels, 58 | sampling_rate, 59 | hop_size, 60 | win_size, 61 | fmin, 62 | fmax, 63 | center=True, 64 | in_dataset=False, 65 | ): 66 | global mel_window 67 | device = torch.device("cpu") if in_dataset else y.device 68 | ps = param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device) 69 | if ps in mel_window: 70 | mel_basis, hann_window = mel_window[ps] 71 | # print(mel_basis, hann_window) 72 | # mel_basis, hann_window = mel_basis.to(y.device), hann_window.to(y.device) 73 | else: 74 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 75 | mel_basis = torch.from_numpy(mel).float().to(device) 76 | hann_window = torch.hann_window(win_size).to(device) 77 | mel_window[ps] = (mel_basis.clone(), hann_window.clone()) 78 | 79 | spec = torch.stft( 80 | y.to(device), 81 | n_fft, 82 | hop_length=hop_size, 83 | win_length=win_size, 84 | window=hann_window.to(device), 85 | center=True, 86 | return_complex=True, 87 | ) 88 | 89 | spec = mel_basis.to(device) @ spec.abs() 90 | spec = spectral_normalize_torch(spec) 91 | 92 | return spec # [batch_size,n_fft/2+1,frames] 93 | 94 | 95 | def inverse_mel( 96 | mel, 97 | n_fft, 98 | num_mels, 99 | sampling_rate, 100 | hop_size, 101 | win_size, 102 | fmin, 103 | fmax, 104 | in_dataset=False, 105 | ): 106 | global inv_mel_window, mel_window 107 | device = torch.device("cpu") if in_dataset else mel.device 108 | ps = param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device) 109 | if ps in inv_mel_window: 110 | inv_basis = inv_mel_window[ps] 111 | else: 112 | if ps in mel_window: 113 | mel_basis, _ = mel_window[ps] 114 | else: 115 | mel_np = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 116 | mel_basis = torch.from_numpy(mel_np).float().to(device) 117 | hann_window = torch.hann_window(win_size).to(device) 118 | mel_window[ps] = (mel_basis.clone(), hann_window.clone()) 119 | inv_basis = mel_basis.pinverse() 120 | inv_mel_window[ps] = inv_basis.clone() 121 | return inv_basis.to(device) @ spectral_de_normalize_torch(mel.to(device)) 122 | 123 | 124 | def amp_pha_specturm(y, n_fft, hop_size, win_size): 125 | hann_window = torch.hann_window(win_size).to(y.device) 126 | 127 | stft_spec = torch.stft( 128 | y, 129 | n_fft, 130 | hop_length=hop_size, 131 | win_length=win_size, 132 | window=hann_window, 133 | center=True, 134 | return_complex=True, 135 | ) # [batch_size, n_fft//2+1, frames, 2] 136 | 137 | log_amplitude = torch.log( 138 | stft_spec.abs() + 1e-5 139 | ) # [batch_size, n_fft//2+1, frames] 140 | phase = stft_spec.angle() # [batch_size, n_fft//2+1, frames] 141 | 142 | return log_amplitude, phase, stft_spec.real, stft_spec.imag 143 | 144 | 145 | def get_dataset_filelist(input_training_wav_list, input_validation_wav_list): 146 | training_files = [] 147 | filelist = os.listdir(input_training_wav_list) 148 | for files in filelist: 149 | src = os.path.join(input_training_wav_list, files) 150 | training_files.append(src) 151 | 152 | validation_files = [] 153 | filelist = os.listdir(input_validation_wav_list) 154 | for files in filelist: 155 | src = os.path.join(input_validation_wav_list, files) 156 | validation_files.append(src) 157 | 158 | return training_files, validation_files 159 | 160 | 161 | class Dataset(torch.utils.data.Dataset): 162 | def __init__( 163 | self, 164 | training_files, 165 | segment_size, 166 | n_fft, 167 | num_mels, 168 | hop_size, 169 | win_size, 170 | sampling_rate, 171 | fmin, 172 | fmax, 173 | meloss, 174 | split=True, 175 | shuffle=True, 176 | n_cache_reuse=1, 177 | device=None, 178 | inv_mel=False, 179 | use_pghi=False, 180 | ): 181 | self.audio_files = training_files 182 | random.seed(1234) 183 | if shuffle: 184 | random.shuffle(self.audio_files) 185 | self.segment_size = segment_size 186 | self.sampling_rate = sampling_rate 187 | self.split = split 188 | self.n_fft = n_fft 189 | self.num_mels = num_mels 190 | self.hop_size = hop_size 191 | self.win_size = win_size 192 | self.fmin = fmin 193 | self.fmax = fmax 194 | self.cached_wav = None 195 | self.n_cache_reuse = n_cache_reuse 196 | self._cache_ref_count = 0 197 | self.device = device 198 | self.meloss = meloss 199 | self.inv_mel = inv_mel 200 | self.pghi = use_pghi 201 | 202 | def __getitem__(self, index): 203 | filename = self.audio_files[index] 204 | if self._cache_ref_count == 0: 205 | audio = load_wav(filename, self.sampling_rate) 206 | self.cached_wav = audio 207 | self._cache_ref_count = self.n_cache_reuse 208 | else: 209 | audio = self.cached_wav 210 | self._cache_ref_count -= 1 211 | 212 | audio = torch.FloatTensor(audio) # [T] 213 | audio = audio.unsqueeze(0) # [1,T] 214 | 215 | if self.split: 216 | if audio.size(1) >= self.segment_size: 217 | max_audio_start = audio.size(1) - self.segment_size 218 | audio_start = random.randint(0, max_audio_start) 219 | audio = audio[:, audio_start : audio_start + self.segment_size] # [1,T] 220 | else: 221 | audio = torch.nn.functional.pad( 222 | audio, (0, self.segment_size - audio.size(1)), "constant" 223 | ) 224 | 225 | mel = mel_spectrogram( 226 | audio, 227 | self.n_fft, 228 | self.num_mels, 229 | self.sampling_rate, 230 | self.hop_size, 231 | self.win_size, 232 | self.fmin, 233 | self.fmax, 234 | center=True, 235 | in_dataset=True, 236 | ) 237 | meloss1 = mel_spectrogram( 238 | audio, 239 | self.n_fft, 240 | self.num_mels, 241 | self.sampling_rate, 242 | self.hop_size, 243 | self.win_size, 244 | self.fmin, 245 | self.meloss, 246 | center=True, 247 | in_dataset=True, 248 | ) 249 | log_amplitude, phase, rea, imag = amp_pha_specturm( 250 | audio, self.n_fft, self.hop_size, self.win_size 251 | ) # [1,n_fft/2+1,frames] 252 | inv_mel = ( 253 | inverse_mel( 254 | mel, 255 | self.n_fft, 256 | self.num_mels, 257 | self.sampling_rate, 258 | self.hop_size, 259 | self.win_size, 260 | self.fmin, 261 | self.fmax, 262 | ) 263 | .abs() 264 | .clamp_min(1e-5) 265 | .squeeze() 266 | if self.inv_mel 267 | else torch.tensor([0]) 268 | ) 269 | if self.pghi: 270 | pghid = torch.tensor( 271 | pghi(inv_mel.squeeze(0).T.numpy(), self.win_size, self.hop_size) 272 | ).T 273 | pghid = torch.polar(torch.ones_like(pghid), pghid).angle() 274 | else: 275 | pghid = torch.tensor([0]) 276 | 277 | # print(pghid) 278 | return ( 279 | mel.squeeze(), 280 | log_amplitude.squeeze(), 281 | phase.squeeze(), 282 | rea.squeeze(), 283 | imag.squeeze(), 284 | audio.squeeze(0), 285 | meloss1.squeeze(), 286 | inv_mel, 287 | pghid, 288 | ) 289 | 290 | def __len__(self): 291 | return len(self.audio_files) 292 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action="ignore", category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist 17 | from models import ( 18 | Generator, 19 | MultiPeriodDiscriminator, 20 | feature_loss, 21 | generator_loss, 22 | discriminator_loss, 23 | amplitude_loss, 24 | phase_loss, 25 | STFT_consistency_loss, 26 | MultiResolutionDiscriminator, 27 | ) 28 | from utils import ( 29 | AttrDict, 30 | build_env, 31 | plot_spectrogram, 32 | scan_checkpoint, 33 | load_checkpoint, 34 | save_checkpoint, 35 | ) 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def train(h): 41 | torch.cuda.manual_seed(h.seed) 42 | device = torch.device("cuda:{:d}".format(0)) 43 | 44 | generator = Generator(h).to(device) 45 | mpd = MultiPeriodDiscriminator().to(device) 46 | mrd = MultiResolutionDiscriminator().to(device) 47 | 48 | print(generator) 49 | os.makedirs(h.checkpoint_path, exist_ok=True) 50 | print("checkpoints directory : ", h.checkpoint_path) 51 | 52 | if os.path.isdir(h.checkpoint_path): 53 | cp_g = scan_checkpoint(h.checkpoint_path, "g_") 54 | cp_do = scan_checkpoint(h.checkpoint_path, "do_") 55 | 56 | steps = 0 57 | if cp_g is None or cp_do is None: 58 | state_dict_do = None 59 | last_epoch = -1 60 | else: 61 | state_dict_g = load_checkpoint(cp_g, device) 62 | state_dict_do = load_checkpoint(cp_do, device) 63 | generator.load_state_dict(state_dict_g["generator"]) 64 | mpd.load_state_dict(state_dict_do["mpd"]) 65 | mrd.load_state_dict(state_dict_do["mrd"]) 66 | steps = state_dict_do["steps"] + 1 67 | last_epoch = state_dict_do["epoch"] 68 | 69 | optim_g = torch.optim.AdamW( 70 | generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2] 71 | ) 72 | optim_d = torch.optim.AdamW( 73 | itertools.chain(mrd.parameters(), mpd.parameters()), 74 | h.learning_rate, 75 | betas=[h.adam_b1, h.adam_b2], 76 | ) 77 | 78 | if state_dict_do is not None: 79 | optim_g.load_state_dict(state_dict_do["optim_g"]) 80 | optim_d.load_state_dict(state_dict_do["optim_d"]) 81 | 82 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 83 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch 84 | ) 85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 86 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch 87 | ) 88 | 89 | training_filelist, validation_filelist = get_dataset_filelist( 90 | h.input_training_wav_list, h.input_validation_wav_list 91 | ) 92 | 93 | trainset = Dataset( 94 | training_filelist, 95 | h.segment_size, 96 | h.n_fft, 97 | h.num_mels, 98 | h.hop_size, 99 | h.win_size, 100 | h.sampling_rate, 101 | h.fmin, 102 | h.fmax, 103 | h.meloss, 104 | n_cache_reuse=0, 105 | shuffle=True, 106 | device=device, 107 | ) 108 | 109 | train_loader = DataLoader( 110 | trainset, 111 | num_workers=h.num_workers, 112 | shuffle=True, 113 | sampler=None, 114 | batch_size=h.batch_size, 115 | pin_memory=True, 116 | drop_last=True, 117 | ) 118 | 119 | validset = Dataset( 120 | validation_filelist, 121 | h.segment_size, 122 | h.n_fft, 123 | h.num_mels, 124 | h.hop_size, 125 | h.win_size, 126 | h.sampling_rate, 127 | h.fmin, 128 | h.fmax, 129 | h.meloss, 130 | False, 131 | False, 132 | n_cache_reuse=0, 133 | device=device, 134 | ) 135 | validation_loader = DataLoader( 136 | validset, 137 | num_workers=1, 138 | shuffle=False, 139 | sampler=None, 140 | batch_size=1, 141 | pin_memory=True, 142 | drop_last=True, 143 | ) 144 | 145 | sw = SummaryWriter(os.path.join(h.checkpoint_path, "logs")) 146 | 147 | generator.train() 148 | mpd.train() 149 | mrd.train() 150 | 151 | for epoch in range(max(0, last_epoch), h.training_epochs): 152 | start = time.time() 153 | print("Epoch: {}".format(epoch + 1)) 154 | 155 | for i, batch in enumerate(train_loader): 156 | start_b = time.time() 157 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 158 | lambda x: x.to(device, non_blocking=True), batch 159 | ) 160 | y = y.unsqueeze(1) 161 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x) 162 | y_g_mel = mel_spectrogram( 163 | y_g.squeeze(1), 164 | h.n_fft, 165 | h.num_mels, 166 | h.sampling_rate, 167 | h.hop_size, 168 | h.win_size, 169 | h.fmin, 170 | h.meloss, 171 | ) 172 | 173 | optim_d.zero_grad() 174 | 175 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach()) 176 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 177 | y_df_hat_r, y_df_hat_g 178 | ) 179 | 180 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach()) 181 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 182 | y_ds_hat_r, y_ds_hat_g 183 | ) 184 | 185 | L_D = loss_disc_s * 0.1 + loss_disc_f 186 | 187 | L_D.backward() 188 | optim_d.step() 189 | 190 | # Generator 191 | optim_g.zero_grad() 192 | 193 | # Losses defined on log amplitude spectra 194 | L_A = amplitude_loss(logamp, logamp_g) 195 | 196 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 197 | # Losses defined on phase spectra 198 | L_P = L_IP + L_GD + L_PTD 199 | 200 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 201 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 202 | ) 203 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final) 204 | L_R = F.l1_loss(rea, rea_g) 205 | L_I = F.l1_loss(imag, imag_g) 206 | # Losses defined on reconstructed STFT spectra 207 | L_S = L_C + 2.25 * (L_R + L_I) 208 | 209 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g) 210 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g) 211 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 212 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 213 | loss_gen_f, losses_gen_f = generator_loss(y_df_g) 214 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g) 215 | L_GAN_G = loss_gen_s * 0.1 + loss_gen_f 216 | L_FM = loss_fm_s * 0.1 + loss_fm_f 217 | L_Mel = F.l1_loss(meloss, y_g_mel) 218 | # Losses defined on final waveforms 219 | L_W = L_GAN_G + L_FM + 45 * L_Mel 220 | 221 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W 222 | 223 | L_G.backward() 224 | optim_g.step() 225 | 226 | # STDOUT logging 227 | if steps % h.stdout_interval == 0: 228 | with torch.no_grad(): 229 | A_error = amplitude_loss(logamp, logamp_g).item() 230 | IP_error, GD_error, PTD_error = phase_loss( 231 | pha, pha_g, h.n_fft, pha.size()[-1] 232 | ) 233 | IP_error = IP_error.item() 234 | GD_error = GD_error.item() 235 | PTD_error = PTD_error.item() 236 | C_error = STFT_consistency_loss( 237 | rea_g, rea_g_final, imag_g, imag_g_final 238 | ).item() 239 | R_error = F.l1_loss(rea, rea_g).item() 240 | I_error = F.l1_loss(imag, imag_g).item() 241 | Mel_error = F.l1_loss(x, y_g_mel).item() 242 | 243 | print( 244 | "Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}".format( 245 | steps, 246 | L_G, 247 | A_error, 248 | IP_error, 249 | GD_error, 250 | PTD_error, 251 | C_error, 252 | R_error, 253 | I_error, 254 | Mel_error, 255 | time.time() - start_b, 256 | ) 257 | ) 258 | 259 | # checkpointing 260 | if steps % h.checkpoint_interval == 0 and steps != 0: 261 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps) 262 | save_checkpoint(checkpoint_path, {"generator": generator.state_dict()}) 263 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps) 264 | save_checkpoint( 265 | checkpoint_path, 266 | { 267 | "mpd": mpd.state_dict(), 268 | "mrd": mrd.state_dict(), 269 | "optim_g": optim_g.state_dict(), 270 | "optim_d": optim_d.state_dict(), 271 | "steps": steps, 272 | "epoch": epoch, 273 | }, 274 | ) 275 | 276 | # Tensorboard summary logging 277 | if steps % h.summary_interval == 0: 278 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps) 279 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps) 280 | 281 | # Validation 282 | if steps % h.validation_interval == 0: # and steps != 0: 283 | generator.eval() 284 | torch.cuda.empty_cache() 285 | val_A_err_tot = 0 286 | val_IP_err_tot = 0 287 | val_GD_err_tot = 0 288 | val_PTD_err_tot = 0 289 | val_C_err_tot = 0 290 | val_R_err_tot = 0 291 | val_I_err_tot = 0 292 | val_Mel_err_tot = 0 293 | with torch.no_grad(): 294 | for j, batch in enumerate(validation_loader): 295 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 296 | lambda x: x.to(device, non_blocking=True), batch 297 | ) 298 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x.to(device)) 299 | y_g_mel = mel_spectrogram( 300 | y_g.squeeze(1), 301 | h.n_fft, 302 | h.num_mels, 303 | h.sampling_rate, 304 | h.hop_size, 305 | h.win_size, 306 | h.fmin, 307 | h.meloss, 308 | ) 309 | 310 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 311 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 312 | ) 313 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item() 314 | val_IP_err, val_GD_err, val_PTD_err = phase_loss( 315 | pha, pha_g, h.n_fft, pha.size()[-1] 316 | ) 317 | val_IP_err_tot += val_IP_err.item() 318 | val_GD_err_tot += val_GD_err.item() 319 | val_PTD_err_tot += val_PTD_err.item() 320 | val_C_err_tot += STFT_consistency_loss( 321 | rea_g, rea_g_final, imag_g, imag_g_final 322 | ).item() 323 | val_R_err_tot += F.l1_loss(rea, rea_g).item() 324 | val_I_err_tot += F.l1_loss(imag, imag_g).item() 325 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item() 326 | 327 | if j <= 4: 328 | if steps == 0: 329 | sw.add_audio( 330 | "gt/y_{}".format(j), y[0], steps, h.sampling_rate 331 | ) 332 | sw.add_figure( 333 | "gt/y_spec_{}".format(j), 334 | plot_spectrogram(x[0].cpu()), 335 | steps, 336 | ) 337 | 338 | sw.add_audio( 339 | "generated/y_g_{}".format(j), 340 | y_g[0], 341 | steps, 342 | h.sampling_rate, 343 | ) 344 | y_g_spec = mel_spectrogram( 345 | y_g.squeeze(1), 346 | h.n_fft, 347 | h.num_mels, 348 | h.sampling_rate, 349 | h.hop_size, 350 | h.win_size, 351 | h.fmin, 352 | h.fmax, 353 | ) 354 | sw.add_figure( 355 | "generated/y_g_spec_{}".format(j), 356 | plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), 357 | steps, 358 | ) 359 | 360 | val_A_err = val_A_err_tot / (j + 1) 361 | val_IP_err = val_IP_err_tot / (j + 1) 362 | val_GD_err = val_GD_err_tot / (j + 1) 363 | val_PTD_err = val_PTD_err_tot / (j + 1) 364 | val_C_err = val_C_err_tot / (j + 1) 365 | val_R_err = val_R_err_tot / (j + 1) 366 | val_I_err = val_I_err_tot / (j + 1) 367 | val_Mel_err = val_Mel_err_tot / (j + 1) 368 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps) 369 | sw.add_scalar( 370 | "Validation/Instantaneous_Phase_Loss", val_IP_err, steps 371 | ) 372 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps) 373 | sw.add_scalar( 374 | "Validation/Phase_Time_Difference_Loss", val_PTD_err, steps 375 | ) 376 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps) 377 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps) 378 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps) 379 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps) 380 | 381 | generator.train() 382 | 383 | steps += 1 384 | 385 | scheduler_g.step() 386 | scheduler_d.step() 387 | 388 | print( 389 | "Time taken for epoch {} is {} sec\n".format( 390 | epoch + 1, int(time.time() - start) 391 | ) 392 | ) 393 | 394 | 395 | def main(): 396 | print("Initializing Training Process..") 397 | 398 | config_file = "config.json" 399 | 400 | with open(config_file) as f: 401 | data = f.read() 402 | 403 | json_config = json.loads(data) 404 | h = AttrDict(json_config) 405 | build_env(config_file, "config.json", h.checkpoint_path) 406 | 407 | torch.manual_seed(h.seed) 408 | if torch.cuda.is_available(): 409 | torch.cuda.manual_seed(h.seed) 410 | else: 411 | pass 412 | 413 | train(h) 414 | 415 | 416 | if __name__ == "__main__": 417 | main() 418 | -------------------------------------------------------------------------------- /train2.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action="ignore", category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist 17 | from models2 import ( 18 | Generator, 19 | MultiPeriodDiscriminator, 20 | feature_loss, 21 | generator_loss, 22 | discriminator_loss, 23 | amplitude_loss, 24 | phase_loss, 25 | STFT_consistency_loss, 26 | MultiResolutionDiscriminator, 27 | ) 28 | from utils import ( 29 | AttrDict, 30 | build_env, 31 | plot_spectrogram, 32 | scan_checkpoint, 33 | load_checkpoint, 34 | save_checkpoint, 35 | ) 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def train(h): 41 | torch.cuda.manual_seed(h.seed) 42 | device = torch.device("cuda:{:d}".format(0)) 43 | 44 | generator = Generator(h).to(device) 45 | mpd = MultiPeriodDiscriminator().to(device) 46 | mrd = MultiResolutionDiscriminator().to(device) 47 | 48 | print(generator) 49 | os.makedirs(h.checkpoint_path, exist_ok=True) 50 | print("checkpoints directory : ", h.checkpoint_path) 51 | 52 | if os.path.isdir(h.checkpoint_path): 53 | cp_g = scan_checkpoint(h.checkpoint_path, "g_") 54 | cp_do = scan_checkpoint(h.checkpoint_path, "do_") 55 | 56 | steps = 0 57 | if cp_g is None or cp_do is None: 58 | state_dict_do = None 59 | last_epoch = -1 60 | else: 61 | state_dict_g = load_checkpoint(cp_g, device) 62 | state_dict_do = load_checkpoint(cp_do, device) 63 | generator.load_state_dict(state_dict_g["generator"]) 64 | mpd.load_state_dict(state_dict_do["mpd"]) 65 | mrd.load_state_dict(state_dict_do["mrd"]) 66 | steps = state_dict_do["steps"] + 1 67 | last_epoch = state_dict_do["epoch"] 68 | 69 | optim_g = torch.optim.AdamW( 70 | generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2] 71 | ) 72 | optim_d = torch.optim.AdamW( 73 | itertools.chain(mrd.parameters(), mpd.parameters()), 74 | h.learning_rate, 75 | betas=[h.adam_b1, h.adam_b2], 76 | ) 77 | 78 | if state_dict_do is not None: 79 | optim_g.load_state_dict(state_dict_do["optim_g"]) 80 | optim_d.load_state_dict(state_dict_do["optim_d"]) 81 | 82 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 83 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch 84 | ) 85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 86 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch 87 | ) 88 | 89 | training_filelist, validation_filelist = get_dataset_filelist( 90 | h.input_training_wav_list, h.input_validation_wav_list 91 | ) 92 | 93 | trainset = Dataset( 94 | training_filelist, 95 | h.segment_size, 96 | h.n_fft, 97 | h.num_mels, 98 | h.hop_size, 99 | h.win_size, 100 | h.sampling_rate, 101 | h.fmin, 102 | h.fmax, 103 | h.meloss, 104 | n_cache_reuse=0, 105 | shuffle=True, 106 | inv_mel=True, 107 | device=device, 108 | ) 109 | 110 | train_loader = DataLoader( 111 | trainset, 112 | num_workers=h.num_workers, 113 | shuffle=True, 114 | sampler=None, 115 | batch_size=h.batch_size, 116 | pin_memory=True, 117 | drop_last=True, 118 | ) 119 | 120 | validset = Dataset( 121 | validation_filelist, 122 | h.segment_size, 123 | h.n_fft, 124 | h.num_mels, 125 | h.hop_size, 126 | h.win_size, 127 | h.sampling_rate, 128 | h.fmin, 129 | h.fmax, 130 | h.meloss, 131 | False, 132 | False, 133 | n_cache_reuse=0, 134 | device=device, 135 | inv_mel=True, 136 | ) 137 | validation_loader = DataLoader( 138 | validset, 139 | num_workers=1, 140 | shuffle=False, 141 | sampler=None, 142 | batch_size=1, 143 | pin_memory=True, 144 | drop_last=True, 145 | ) 146 | 147 | sw = SummaryWriter(os.path.join(h.checkpoint_path, "logs")) 148 | 149 | generator.train() 150 | mpd.train() 151 | mrd.train() 152 | 153 | for epoch in range(max(0, last_epoch), h.training_epochs): 154 | start = time.time() 155 | print("Epoch: {}".format(epoch + 1)) 156 | 157 | for i, batch in enumerate(train_loader): 158 | start_b = time.time() 159 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 160 | lambda x: x.to(device, non_blocking=True), batch 161 | ) 162 | y = y.unsqueeze(1) 163 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x) 164 | y_g_mel = mel_spectrogram( 165 | y_g.squeeze(1), 166 | h.n_fft, 167 | h.num_mels, 168 | h.sampling_rate, 169 | h.hop_size, 170 | h.win_size, 171 | h.fmin, 172 | h.meloss, 173 | ) 174 | 175 | optim_d.zero_grad() 176 | 177 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach()) 178 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 179 | y_df_hat_r, y_df_hat_g 180 | ) 181 | 182 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach()) 183 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 184 | y_ds_hat_r, y_ds_hat_g 185 | ) 186 | 187 | L_D = loss_disc_s * 0.1 + loss_disc_f 188 | 189 | L_D.backward() 190 | optim_d.step() 191 | 192 | # Generator 193 | optim_g.zero_grad() 194 | 195 | # Losses defined on log amplitude spectra 196 | L_A = amplitude_loss(logamp, logamp_g) 197 | 198 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 199 | # Losses defined on phase spectra 200 | L_P = L_IP + L_GD + L_PTD 201 | 202 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 203 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 204 | ) 205 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final) 206 | L_R = F.l1_loss(rea, rea_g) 207 | L_I = F.l1_loss(imag, imag_g) 208 | # Losses defined on reconstructed STFT spectra 209 | L_S = L_C + 2.25 * (L_R + L_I) 210 | 211 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g) 212 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g) 213 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 214 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 215 | loss_gen_f, losses_gen_f = generator_loss(y_df_g) 216 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g) 217 | L_GAN_G = loss_gen_s * 0.1 + loss_gen_f 218 | L_FM = loss_fm_s * 0.1 + loss_fm_f 219 | L_Mel = F.l1_loss(meloss, y_g_mel) 220 | # Losses defined on final waveforms 221 | L_W = L_GAN_G + L_FM + 45 * L_Mel 222 | 223 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W 224 | 225 | L_G.backward() 226 | optim_g.step() 227 | 228 | # STDOUT logging 229 | if steps % h.stdout_interval == 0: 230 | with torch.no_grad(): 231 | A_error = amplitude_loss(logamp, logamp_g).item() 232 | IP_error, GD_error, PTD_error = phase_loss( 233 | pha, pha_g, h.n_fft, pha.size()[-1] 234 | ) 235 | IP_error = IP_error.item() 236 | GD_error = GD_error.item() 237 | PTD_error = PTD_error.item() 238 | C_error = STFT_consistency_loss( 239 | rea_g, rea_g_final, imag_g, imag_g_final 240 | ).item() 241 | R_error = F.l1_loss(rea, rea_g).item() 242 | I_error = F.l1_loss(imag, imag_g).item() 243 | Mel_error = F.l1_loss(x, y_g_mel).item() 244 | 245 | print( 246 | "Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}".format( 247 | steps, 248 | L_G, 249 | A_error, 250 | IP_error, 251 | GD_error, 252 | PTD_error, 253 | C_error, 254 | R_error, 255 | I_error, 256 | Mel_error, 257 | time.time() - start_b, 258 | ) 259 | ) 260 | 261 | # checkpointing 262 | if steps % h.checkpoint_interval == 0 and steps != 0: 263 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps) 264 | save_checkpoint(checkpoint_path, {"generator": generator.state_dict()}) 265 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps) 266 | save_checkpoint( 267 | checkpoint_path, 268 | { 269 | "mpd": mpd.state_dict(), 270 | "mrd": mrd.state_dict(), 271 | "optim_g": optim_g.state_dict(), 272 | "optim_d": optim_d.state_dict(), 273 | "steps": steps, 274 | "epoch": epoch, 275 | }, 276 | ) 277 | 278 | # Tensorboard summary logging 279 | if steps % h.summary_interval == 0: 280 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps) 281 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps) 282 | 283 | # Validation 284 | if steps % h.validation_interval == 0: # and steps != 0: 285 | generator.eval() 286 | torch.cuda.empty_cache() 287 | val_A_err_tot = 0 288 | val_IP_err_tot = 0 289 | val_GD_err_tot = 0 290 | val_PTD_err_tot = 0 291 | val_C_err_tot = 0 292 | val_R_err_tot = 0 293 | val_I_err_tot = 0 294 | val_Mel_err_tot = 0 295 | with torch.no_grad(): 296 | for j, batch in enumerate(validation_loader): 297 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 298 | lambda x: x.to(device, non_blocking=True), batch 299 | ) 300 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x.to(device)) 301 | y_g_mel = mel_spectrogram( 302 | y_g.squeeze(1), 303 | h.n_fft, 304 | h.num_mels, 305 | h.sampling_rate, 306 | h.hop_size, 307 | h.win_size, 308 | h.fmin, 309 | h.meloss, 310 | ) 311 | 312 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 313 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 314 | ) 315 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item() 316 | val_IP_err, val_GD_err, val_PTD_err = phase_loss( 317 | pha, pha_g, h.n_fft, pha.size()[-1] 318 | ) 319 | val_IP_err_tot += val_IP_err.item() 320 | val_GD_err_tot += val_GD_err.item() 321 | val_PTD_err_tot += val_PTD_err.item() 322 | val_C_err_tot += STFT_consistency_loss( 323 | rea_g, rea_g_final, imag_g, imag_g_final 324 | ).item() 325 | val_R_err_tot += F.l1_loss(rea, rea_g).item() 326 | val_I_err_tot += F.l1_loss(imag, imag_g).item() 327 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item() 328 | 329 | if j <= 4: 330 | if steps == 0: 331 | sw.add_audio( 332 | "gt/y_{}".format(j), y[0], steps, h.sampling_rate 333 | ) 334 | sw.add_figure( 335 | "gt/y_spec_{}".format(j), 336 | plot_spectrogram(x[0].cpu()), 337 | steps, 338 | ) 339 | 340 | sw.add_audio( 341 | "generated/y_g_{}".format(j), 342 | y_g[0], 343 | steps, 344 | h.sampling_rate, 345 | ) 346 | y_g_spec = mel_spectrogram( 347 | y_g.squeeze(1), 348 | h.n_fft, 349 | h.num_mels, 350 | h.sampling_rate, 351 | h.hop_size, 352 | h.win_size, 353 | h.fmin, 354 | h.fmax, 355 | ) 356 | sw.add_figure( 357 | "generated/y_g_spec_{}".format(j), 358 | plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), 359 | steps, 360 | ) 361 | 362 | val_A_err = val_A_err_tot / (j + 1) 363 | val_IP_err = val_IP_err_tot / (j + 1) 364 | val_GD_err = val_GD_err_tot / (j + 1) 365 | val_PTD_err = val_PTD_err_tot / (j + 1) 366 | val_C_err = val_C_err_tot / (j + 1) 367 | val_R_err = val_R_err_tot / (j + 1) 368 | val_I_err = val_I_err_tot / (j + 1) 369 | val_Mel_err = val_Mel_err_tot / (j + 1) 370 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps) 371 | sw.add_scalar( 372 | "Validation/Instantaneous_Phase_Loss", val_IP_err, steps 373 | ) 374 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps) 375 | sw.add_scalar( 376 | "Validation/Phase_Time_Difference_Loss", val_PTD_err, steps 377 | ) 378 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps) 379 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps) 380 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps) 381 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps) 382 | 383 | generator.train() 384 | 385 | steps += 1 386 | 387 | scheduler_g.step() 388 | scheduler_d.step() 389 | 390 | print( 391 | "Time taken for epoch {} is {} sec\n".format( 392 | epoch + 1, int(time.time() - start) 393 | ) 394 | ) 395 | 396 | 397 | def main(): 398 | print("Initializing Training Process..") 399 | 400 | config_file = "config2.json" 401 | 402 | with open(config_file) as f: 403 | data = f.read() 404 | 405 | json_config = json.loads(data) 406 | h = AttrDict(json_config) 407 | build_env(config_file, "config2.json", h.checkpoint_path) 408 | 409 | torch.manual_seed(h.seed) 410 | if torch.cuda.is_available(): 411 | torch.cuda.manual_seed(h.seed) 412 | else: 413 | pass 414 | 415 | train(h) 416 | 417 | 418 | if __name__ == "__main__": 419 | main() 420 | -------------------------------------------------------------------------------- /train2_pghi.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action="ignore", category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist 17 | from models2_pghi import ( 18 | Generator, 19 | MultiPeriodDiscriminator, 20 | feature_loss, 21 | generator_loss, 22 | discriminator_loss, 23 | amplitude_loss, 24 | phase_loss, 25 | STFT_consistency_loss, 26 | MultiResolutionDiscriminator, 27 | ) 28 | from utils import ( 29 | AttrDict, 30 | build_env, 31 | plot_spectrogram, 32 | scan_checkpoint, 33 | load_checkpoint, 34 | save_checkpoint, 35 | ) 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def train(h): 41 | torch.cuda.manual_seed(h.seed) 42 | device = torch.device("cuda:{:d}".format(0)) 43 | 44 | generator = Generator(h).to(device) 45 | mpd = MultiPeriodDiscriminator().to(device) 46 | mrd = MultiResolutionDiscriminator().to(device) 47 | 48 | print(generator) 49 | os.makedirs(h.checkpoint_path, exist_ok=True) 50 | print("checkpoints directory : ", h.checkpoint_path) 51 | 52 | if os.path.isdir(h.checkpoint_path): 53 | cp_g = scan_checkpoint(h.checkpoint_path, "g_") 54 | cp_do = scan_checkpoint(h.checkpoint_path, "do_") 55 | 56 | steps = 0 57 | if cp_g is None or cp_do is None: 58 | state_dict_do = None 59 | last_epoch = -1 60 | else: 61 | state_dict_g = load_checkpoint(cp_g, device) 62 | state_dict_do = load_checkpoint(cp_do, device) 63 | generator.load_state_dict(state_dict_g["generator"]) 64 | mpd.load_state_dict(state_dict_do["mpd"]) 65 | mrd.load_state_dict(state_dict_do["mrd"]) 66 | steps = state_dict_do["steps"] + 1 67 | last_epoch = state_dict_do["epoch"] 68 | 69 | optim_g = torch.optim.AdamW( 70 | generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2] 71 | ) 72 | optim_d = torch.optim.AdamW( 73 | itertools.chain(mrd.parameters(), mpd.parameters()), 74 | h.learning_rate, 75 | betas=[h.adam_b1, h.adam_b2], 76 | ) 77 | 78 | if state_dict_do is not None: 79 | optim_g.load_state_dict(state_dict_do["optim_g"]) 80 | optim_d.load_state_dict(state_dict_do["optim_d"]) 81 | 82 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 83 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch 84 | ) 85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 86 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch 87 | ) 88 | 89 | training_filelist, validation_filelist = get_dataset_filelist( 90 | h.input_training_wav_list, h.input_validation_wav_list 91 | ) 92 | 93 | trainset = Dataset( 94 | training_filelist, 95 | h.segment_size, 96 | h.n_fft, 97 | h.num_mels, 98 | h.hop_size, 99 | h.win_size, 100 | h.sampling_rate, 101 | h.fmin, 102 | h.fmax, 103 | h.meloss, 104 | n_cache_reuse=0, 105 | shuffle=True, 106 | inv_mel=True, 107 | use_pghi=True, 108 | device=device, 109 | ) 110 | 111 | train_loader = DataLoader( 112 | trainset, 113 | num_workers=h.num_workers, 114 | shuffle=True, 115 | sampler=None, 116 | batch_size=h.batch_size, 117 | pin_memory=True, 118 | drop_last=True, 119 | ) 120 | 121 | validset = Dataset( 122 | validation_filelist, 123 | h.segment_size, 124 | h.n_fft, 125 | h.num_mels, 126 | h.hop_size, 127 | h.win_size, 128 | h.sampling_rate, 129 | h.fmin, 130 | h.fmax, 131 | h.meloss, 132 | False, 133 | False, 134 | n_cache_reuse=0, 135 | device=device, 136 | inv_mel=True, 137 | use_pghi=True, 138 | ) 139 | validation_loader = DataLoader( 140 | validset, 141 | num_workers=1, 142 | shuffle=False, 143 | sampler=None, 144 | batch_size=1, 145 | pin_memory=True, 146 | drop_last=True, 147 | ) 148 | 149 | sw = SummaryWriter(os.path.join(h.checkpoint_path, "logs")) 150 | 151 | generator.train() 152 | mpd.train() 153 | mrd.train() 154 | 155 | for epoch in range(max(0, last_epoch), h.training_epochs): 156 | start = time.time() 157 | print("Epoch: {}".format(epoch + 1)) 158 | 159 | for i, batch in enumerate(train_loader): 160 | start_b = time.time() 161 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 162 | lambda x: x.to(device, non_blocking=True), batch 163 | ) 164 | y = y.unsqueeze(1) 165 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x, pghi=pghid) 166 | y_g_mel = mel_spectrogram( 167 | y_g.squeeze(1), 168 | h.n_fft, 169 | h.num_mels, 170 | h.sampling_rate, 171 | h.hop_size, 172 | h.win_size, 173 | h.fmin, 174 | h.meloss, 175 | ) 176 | 177 | optim_d.zero_grad() 178 | 179 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach()) 180 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 181 | y_df_hat_r, y_df_hat_g 182 | ) 183 | 184 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach()) 185 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 186 | y_ds_hat_r, y_ds_hat_g 187 | ) 188 | 189 | L_D = loss_disc_s * 0.1 + loss_disc_f 190 | 191 | L_D.backward() 192 | optim_d.step() 193 | 194 | # Generator 195 | optim_g.zero_grad() 196 | 197 | # Losses defined on log amplitude spectra 198 | L_A = amplitude_loss(logamp, logamp_g) 199 | 200 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 201 | # Losses defined on phase spectra 202 | L_P = L_IP + L_GD + L_PTD 203 | 204 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 205 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 206 | ) 207 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final) 208 | L_R = F.l1_loss(rea, rea_g) 209 | L_I = F.l1_loss(imag, imag_g) 210 | # Losses defined on reconstructed STFT spectra 211 | L_S = L_C + 2.25 * (L_R + L_I) 212 | 213 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g) 214 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g) 215 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 216 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 217 | loss_gen_f, losses_gen_f = generator_loss(y_df_g) 218 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g) 219 | L_GAN_G = loss_gen_s * 0.1 + loss_gen_f 220 | L_FM = loss_fm_s * 0.1 + loss_fm_f 221 | L_Mel = F.l1_loss(meloss, y_g_mel) 222 | # Losses defined on final waveforms 223 | L_W = L_GAN_G + L_FM + 45 * L_Mel 224 | 225 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W 226 | 227 | L_G.backward() 228 | optim_g.step() 229 | 230 | # STDOUT logging 231 | if steps % h.stdout_interval == 0: 232 | with torch.no_grad(): 233 | A_error = amplitude_loss(logamp, logamp_g).item() 234 | IP_error, GD_error, PTD_error = phase_loss( 235 | pha, pha_g, h.n_fft, pha.size()[-1] 236 | ) 237 | IP_error = IP_error.item() 238 | GD_error = GD_error.item() 239 | PTD_error = PTD_error.item() 240 | C_error = STFT_consistency_loss( 241 | rea_g, rea_g_final, imag_g, imag_g_final 242 | ).item() 243 | R_error = F.l1_loss(rea, rea_g).item() 244 | I_error = F.l1_loss(imag, imag_g).item() 245 | Mel_error = F.l1_loss(x, y_g_mel).item() 246 | 247 | print( 248 | "Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}".format( 249 | steps, 250 | L_G, 251 | A_error, 252 | IP_error, 253 | GD_error, 254 | PTD_error, 255 | C_error, 256 | R_error, 257 | I_error, 258 | Mel_error, 259 | time.time() - start_b, 260 | ) 261 | ) 262 | 263 | # checkpointing 264 | if steps % h.checkpoint_interval == 0 and steps != 0: 265 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps) 266 | save_checkpoint(checkpoint_path, {"generator": generator.state_dict()}) 267 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps) 268 | save_checkpoint( 269 | checkpoint_path, 270 | { 271 | "mpd": mpd.state_dict(), 272 | "mrd": mrd.state_dict(), 273 | "optim_g": optim_g.state_dict(), 274 | "optim_d": optim_d.state_dict(), 275 | "steps": steps, 276 | "epoch": epoch, 277 | }, 278 | ) 279 | 280 | # Tensorboard summary logging 281 | if steps % h.summary_interval == 0: 282 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps) 283 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps) 284 | 285 | # Validation 286 | if steps % h.validation_interval == 0: # and steps != 0: 287 | generator.eval() 288 | torch.cuda.empty_cache() 289 | val_A_err_tot = 0 290 | val_IP_err_tot = 0 291 | val_GD_err_tot = 0 292 | val_PTD_err_tot = 0 293 | val_C_err_tot = 0 294 | val_R_err_tot = 0 295 | val_I_err_tot = 0 296 | val_Mel_err_tot = 0 297 | with torch.no_grad(): 298 | for j, batch in enumerate(validation_loader): 299 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 300 | lambda x: x.to(device, non_blocking=True), batch 301 | ) 302 | logamp_g, pha_g, rea_g, imag_g, y_g = generator( 303 | x.to(device), pghi=pghid.to(device) 304 | ) 305 | y_g_mel = mel_spectrogram( 306 | y_g.squeeze(1), 307 | h.n_fft, 308 | h.num_mels, 309 | h.sampling_rate, 310 | h.hop_size, 311 | h.win_size, 312 | h.fmin, 313 | h.meloss, 314 | ) 315 | 316 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 317 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 318 | ) 319 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item() 320 | val_IP_err, val_GD_err, val_PTD_err = phase_loss( 321 | pha, pha_g, h.n_fft, pha.size()[-1] 322 | ) 323 | val_IP_err_tot += val_IP_err.item() 324 | val_GD_err_tot += val_GD_err.item() 325 | val_PTD_err_tot += val_PTD_err.item() 326 | val_C_err_tot += STFT_consistency_loss( 327 | rea_g, rea_g_final, imag_g, imag_g_final 328 | ).item() 329 | val_R_err_tot += F.l1_loss(rea, rea_g).item() 330 | val_I_err_tot += F.l1_loss(imag, imag_g).item() 331 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item() 332 | 333 | if j <= 4: 334 | if steps == 0: 335 | sw.add_audio( 336 | "gt/y_{}".format(j), y[0], steps, h.sampling_rate 337 | ) 338 | sw.add_figure( 339 | "gt/y_spec_{}".format(j), 340 | plot_spectrogram(x[0].cpu()), 341 | steps, 342 | ) 343 | 344 | sw.add_audio( 345 | "generated/y_g_{}".format(j), 346 | y_g[0], 347 | steps, 348 | h.sampling_rate, 349 | ) 350 | y_g_spec = mel_spectrogram( 351 | y_g.squeeze(1), 352 | h.n_fft, 353 | h.num_mels, 354 | h.sampling_rate, 355 | h.hop_size, 356 | h.win_size, 357 | h.fmin, 358 | h.fmax, 359 | ) 360 | sw.add_figure( 361 | "generated/y_g_spec_{}".format(j), 362 | plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), 363 | steps, 364 | ) 365 | 366 | val_A_err = val_A_err_tot / (j + 1) 367 | val_IP_err = val_IP_err_tot / (j + 1) 368 | val_GD_err = val_GD_err_tot / (j + 1) 369 | val_PTD_err = val_PTD_err_tot / (j + 1) 370 | val_C_err = val_C_err_tot / (j + 1) 371 | val_R_err = val_R_err_tot / (j + 1) 372 | val_I_err = val_I_err_tot / (j + 1) 373 | val_Mel_err = val_Mel_err_tot / (j + 1) 374 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps) 375 | sw.add_scalar( 376 | "Validation/Instantaneous_Phase_Loss", val_IP_err, steps 377 | ) 378 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps) 379 | sw.add_scalar( 380 | "Validation/Phase_Time_Difference_Loss", val_PTD_err, steps 381 | ) 382 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps) 383 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps) 384 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps) 385 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps) 386 | 387 | generator.train() 388 | 389 | steps += 1 390 | 391 | scheduler_g.step() 392 | scheduler_d.step() 393 | 394 | print( 395 | "Time taken for epoch {} is {} sec\n".format( 396 | epoch + 1, int(time.time() - start) 397 | ) 398 | ) 399 | 400 | 401 | def main(): 402 | print("Initializing Training Process..") 403 | 404 | config_file = "config2_pghi.json" 405 | 406 | with open(config_file) as f: 407 | data = f.read() 408 | 409 | json_config = json.loads(data) 410 | h = AttrDict(json_config) 411 | build_env(config_file, "config2.json", h.checkpoint_path) 412 | 413 | torch.manual_seed(h.seed) 414 | if torch.cuda.is_available(): 415 | torch.cuda.manual_seed(h.seed) 416 | else: 417 | pass 418 | 419 | train(h) 420 | 421 | 422 | if __name__ == "__main__": 423 | main() 424 | -------------------------------------------------------------------------------- /train_pghi.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action="ignore", category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist 17 | from models_pghi import ( 18 | Generator, 19 | MultiPeriodDiscriminator, 20 | feature_loss, 21 | generator_loss, 22 | discriminator_loss, 23 | amplitude_loss, 24 | phase_loss, 25 | STFT_consistency_loss, 26 | MultiResolutionDiscriminator, 27 | ) 28 | from utils import ( 29 | AttrDict, 30 | build_env, 31 | plot_spectrogram, 32 | scan_checkpoint, 33 | load_checkpoint, 34 | save_checkpoint, 35 | ) 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def train(h): 41 | torch.cuda.manual_seed(h.seed) 42 | device = torch.device("cuda:{:d}".format(0)) 43 | 44 | generator = Generator(h).to(device) 45 | mpd = MultiPeriodDiscriminator().to(device) 46 | mrd = MultiResolutionDiscriminator().to(device) 47 | 48 | print(generator) 49 | os.makedirs(h.checkpoint_path, exist_ok=True) 50 | print("checkpoints directory : ", h.checkpoint_path) 51 | 52 | if os.path.isdir(h.checkpoint_path): 53 | cp_g = scan_checkpoint(h.checkpoint_path, "g_") 54 | cp_do = scan_checkpoint(h.checkpoint_path, "do_") 55 | 56 | steps = 0 57 | if cp_g is None or cp_do is None: 58 | state_dict_do = None 59 | last_epoch = -1 60 | else: 61 | state_dict_g = load_checkpoint(cp_g, device) 62 | state_dict_do = load_checkpoint(cp_do, device) 63 | generator.load_state_dict(state_dict_g["generator"]) 64 | mpd.load_state_dict(state_dict_do["mpd"]) 65 | mrd.load_state_dict(state_dict_do["mrd"]) 66 | steps = state_dict_do["steps"] + 1 67 | last_epoch = state_dict_do["epoch"] 68 | 69 | optim_g = torch.optim.AdamW( 70 | generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2] 71 | ) 72 | optim_d = torch.optim.AdamW( 73 | itertools.chain(mrd.parameters(), mpd.parameters()), 74 | h.learning_rate, 75 | betas=[h.adam_b1, h.adam_b2], 76 | ) 77 | 78 | if state_dict_do is not None: 79 | optim_g.load_state_dict(state_dict_do["optim_g"]) 80 | optim_d.load_state_dict(state_dict_do["optim_d"]) 81 | 82 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 83 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch 84 | ) 85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 86 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch 87 | ) 88 | 89 | training_filelist, validation_filelist = get_dataset_filelist( 90 | h.input_training_wav_list, h.input_validation_wav_list 91 | ) 92 | 93 | trainset = Dataset( 94 | training_filelist, 95 | h.segment_size, 96 | h.n_fft, 97 | h.num_mels, 98 | h.hop_size, 99 | h.win_size, 100 | h.sampling_rate, 101 | h.fmin, 102 | h.fmax, 103 | h.meloss, 104 | n_cache_reuse=0, 105 | shuffle=True, 106 | device=device, 107 | inv_mel=True, 108 | use_pghi=True, 109 | ) 110 | 111 | train_loader = DataLoader( 112 | trainset, 113 | num_workers=h.num_workers, 114 | shuffle=True, 115 | sampler=None, 116 | batch_size=h.batch_size, 117 | pin_memory=True, 118 | drop_last=True, 119 | ) 120 | 121 | validset = Dataset( 122 | validation_filelist, 123 | h.segment_size, 124 | h.n_fft, 125 | h.num_mels, 126 | h.hop_size, 127 | h.win_size, 128 | h.sampling_rate, 129 | h.fmin, 130 | h.fmax, 131 | h.meloss, 132 | False, 133 | False, 134 | n_cache_reuse=0, 135 | device=device, 136 | inv_mel=True, 137 | use_pghi=True, 138 | ) 139 | validation_loader = DataLoader( 140 | validset, 141 | num_workers=1, 142 | shuffle=False, 143 | sampler=None, 144 | batch_size=1, 145 | pin_memory=True, 146 | drop_last=True, 147 | ) 148 | 149 | sw = SummaryWriter(os.path.join(h.checkpoint_path, "logs")) 150 | 151 | generator.train() 152 | mpd.train() 153 | mrd.train() 154 | 155 | for epoch in range(max(0, last_epoch), h.training_epochs): 156 | start = time.time() 157 | print("Epoch: {}".format(epoch + 1)) 158 | 159 | for i, batch in enumerate(train_loader): 160 | start_b = time.time() 161 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 162 | lambda x: x.to(device, non_blocking=True), batch 163 | ) 164 | y = y.unsqueeze(1) 165 | logamp_g, pha_g, rea_g, imag_g, y_g = generator( 166 | x, inv_mel=inv_mel, pghi=pghid 167 | ) 168 | y_g_mel = mel_spectrogram( 169 | y_g.squeeze(1), 170 | h.n_fft, 171 | h.num_mels, 172 | h.sampling_rate, 173 | h.hop_size, 174 | h.win_size, 175 | h.fmin, 176 | h.meloss, 177 | ) 178 | 179 | optim_d.zero_grad() 180 | 181 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach()) 182 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 183 | y_df_hat_r, y_df_hat_g 184 | ) 185 | 186 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach()) 187 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 188 | y_ds_hat_r, y_ds_hat_g 189 | ) 190 | 191 | L_D = loss_disc_s * 0.1 + loss_disc_f 192 | 193 | L_D.backward() 194 | optim_d.step() 195 | 196 | # Generator 197 | optim_g.zero_grad() 198 | 199 | # Losses defined on log amplitude spectra 200 | L_A = amplitude_loss(logamp, logamp_g) 201 | 202 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 203 | # Losses defined on phase spectra 204 | L_P = L_IP + L_GD + L_PTD 205 | 206 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 207 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 208 | ) 209 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final) 210 | L_R = F.l1_loss(rea, rea_g) 211 | L_I = F.l1_loss(imag, imag_g) 212 | # Losses defined on reconstructed STFT spectra 213 | L_S = L_C + 2.25 * (L_R + L_I) 214 | 215 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g) 216 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g) 217 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 218 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 219 | loss_gen_f, losses_gen_f = generator_loss(y_df_g) 220 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g) 221 | L_GAN_G = loss_gen_s * 0.1 + loss_gen_f 222 | L_FM = loss_fm_s * 0.1 + loss_fm_f 223 | L_Mel = F.l1_loss(meloss, y_g_mel) 224 | # Losses defined on final waveforms 225 | L_W = L_GAN_G + L_FM + 45 * L_Mel 226 | 227 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W 228 | 229 | L_G.backward() 230 | optim_g.step() 231 | 232 | # STDOUT logging 233 | if steps % h.stdout_interval == 0: 234 | with torch.no_grad(): 235 | A_error = amplitude_loss(logamp, logamp_g).item() 236 | IP_error, GD_error, PTD_error = phase_loss( 237 | pha, pha_g, h.n_fft, pha.size()[-1] 238 | ) 239 | IP_error = IP_error.item() 240 | GD_error = GD_error.item() 241 | PTD_error = PTD_error.item() 242 | C_error = STFT_consistency_loss( 243 | rea_g, rea_g_final, imag_g, imag_g_final 244 | ).item() 245 | R_error = F.l1_loss(rea, rea_g).item() 246 | I_error = F.l1_loss(imag, imag_g).item() 247 | Mel_error = F.l1_loss(x, y_g_mel).item() 248 | 249 | print( 250 | "Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}".format( 251 | steps, 252 | L_G, 253 | A_error, 254 | IP_error, 255 | GD_error, 256 | PTD_error, 257 | C_error, 258 | R_error, 259 | I_error, 260 | Mel_error, 261 | time.time() - start_b, 262 | ) 263 | ) 264 | 265 | # checkpointing 266 | if steps % h.checkpoint_interval == 0 and steps != 0: 267 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps) 268 | save_checkpoint(checkpoint_path, {"generator": generator.state_dict()}) 269 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps) 270 | save_checkpoint( 271 | checkpoint_path, 272 | { 273 | "mpd": mpd.state_dict(), 274 | "mrd": mrd.state_dict(), 275 | "optim_g": optim_g.state_dict(), 276 | "optim_d": optim_d.state_dict(), 277 | "steps": steps, 278 | "epoch": epoch, 279 | }, 280 | ) 281 | 282 | # Tensorboard summary logging 283 | if steps % h.summary_interval == 0: 284 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps) 285 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps) 286 | 287 | # Validation 288 | if steps % h.validation_interval == 0: # and steps != 0: 289 | generator.eval() 290 | torch.cuda.empty_cache() 291 | val_A_err_tot = 0 292 | val_IP_err_tot = 0 293 | val_GD_err_tot = 0 294 | val_PTD_err_tot = 0 295 | val_C_err_tot = 0 296 | val_R_err_tot = 0 297 | val_I_err_tot = 0 298 | val_Mel_err_tot = 0 299 | with torch.no_grad(): 300 | for j, batch in enumerate(validation_loader): 301 | x, logamp, pha, rea, imag, y, meloss, inv_mel, pghid = map( 302 | lambda x: x.to(device, non_blocking=True), batch 303 | ) 304 | logamp_g, pha_g, rea_g, imag_g, y_g = generator( 305 | x, inv_mel=inv_mel, pghi=pghid 306 | ) 307 | x = x.cpu() 308 | y_g_mel = mel_spectrogram( 309 | y_g.squeeze(1), 310 | h.n_fft, 311 | h.num_mels, 312 | h.sampling_rate, 313 | h.hop_size, 314 | h.win_size, 315 | h.fmin, 316 | h.meloss, 317 | ) 318 | 319 | _, _, rea_g_final, imag_g_final = amp_pha_specturm( 320 | y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size 321 | ) 322 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item() 323 | val_IP_err, val_GD_err, val_PTD_err = phase_loss( 324 | pha, pha_g, h.n_fft, pha.size()[-1] 325 | ) 326 | val_IP_err_tot += val_IP_err.item() 327 | val_GD_err_tot += val_GD_err.item() 328 | val_PTD_err_tot += val_PTD_err.item() 329 | val_C_err_tot += STFT_consistency_loss( 330 | rea_g, rea_g_final, imag_g, imag_g_final 331 | ).item() 332 | val_R_err_tot += F.l1_loss(rea, rea_g).item() 333 | val_I_err_tot += F.l1_loss(imag, imag_g).item() 334 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item() 335 | 336 | if j <= 4: 337 | if steps == 0: 338 | sw.add_audio( 339 | "gt/y_{}".format(j), y[0], steps, h.sampling_rate 340 | ) 341 | sw.add_figure( 342 | "gt/y_spec_{}".format(j), 343 | plot_spectrogram(x[0]), 344 | steps, 345 | ) 346 | 347 | sw.add_audio( 348 | "generated/y_g_{}".format(j), 349 | y_g[0], 350 | steps, 351 | h.sampling_rate, 352 | ) 353 | y_g_spec = mel_spectrogram( 354 | y_g.squeeze(1), 355 | h.n_fft, 356 | h.num_mels, 357 | h.sampling_rate, 358 | h.hop_size, 359 | h.win_size, 360 | h.fmin, 361 | h.fmax, 362 | ) 363 | sw.add_figure( 364 | "generated/y_g_spec_{}".format(j), 365 | plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), 366 | steps, 367 | ) 368 | 369 | val_A_err = val_A_err_tot / (j + 1) 370 | val_IP_err = val_IP_err_tot / (j + 1) 371 | val_GD_err = val_GD_err_tot / (j + 1) 372 | val_PTD_err = val_PTD_err_tot / (j + 1) 373 | val_C_err = val_C_err_tot / (j + 1) 374 | val_R_err = val_R_err_tot / (j + 1) 375 | val_I_err = val_I_err_tot / (j + 1) 376 | val_Mel_err = val_Mel_err_tot / (j + 1) 377 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps) 378 | sw.add_scalar( 379 | "Validation/Instantaneous_Phase_Loss", val_IP_err, steps 380 | ) 381 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps) 382 | sw.add_scalar( 383 | "Validation/Phase_Time_Difference_Loss", val_PTD_err, steps 384 | ) 385 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps) 386 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps) 387 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps) 388 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps) 389 | 390 | generator.train() 391 | 392 | steps += 1 393 | 394 | scheduler_g.step() 395 | scheduler_d.step() 396 | 397 | print( 398 | "Time taken for epoch {} is {} sec\n".format( 399 | epoch + 1, int(time.time() - start) 400 | ) 401 | ) 402 | 403 | 404 | def main(): 405 | print("Initializing Training Process..") 406 | 407 | config_file = "config_pghi.json" 408 | 409 | with open(config_file) as f: 410 | data = f.read() 411 | 412 | json_config = json.loads(data) 413 | h = AttrDict(json_config) 414 | build_env(config_file, "config_pghi.json", h.checkpoint_path) 415 | 416 | torch.manual_seed(h.seed) 417 | if torch.cuda.is_available(): 418 | torch.cuda.manual_seed(h.seed) 419 | else: 420 | pass 421 | 422 | train(h) 423 | 424 | 425 | if __name__ == "__main__": 426 | main() 427 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | from dataset import inverse_mel 8 | import numpy as np 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class GRN(nn.Module): 14 | """GRN (Global Response Normalization) layer""" 15 | 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 19 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 20 | 21 | def forward(self, x): 22 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 23 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 24 | return self.gamma * (x * Nx) + self.beta + x 25 | 26 | 27 | class ConvNeXtBlock(nn.Module): 28 | def __init__( 29 | self, 30 | dim: int, 31 | intermediate_dim: int, 32 | layer_scale_init_value=None, 33 | adanorm_num_embeddings=None, 34 | ): 35 | super().__init__() 36 | self.dwconv = nn.Conv1d( 37 | dim, dim, kernel_size=7, padding=3, groups=dim 38 | ) # depthwise conv 39 | self.adanorm = adanorm_num_embeddings is not None 40 | 41 | self.norm = nn.LayerNorm(dim, eps=1e-6) 42 | self.pwconv1 = nn.Linear( 43 | dim, intermediate_dim 44 | ) # pointwise/1x1 convs, implemented with linear layers 45 | self.act = nn.GELU() 46 | self.grn = GRN(intermediate_dim) 47 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 48 | 49 | def forward(self, x, cond_embedding_id=None): 50 | residual = x 51 | x = self.dwconv(x) 52 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 53 | if self.adanorm: 54 | assert cond_embedding_id is not None 55 | x = self.norm(x, cond_embedding_id) 56 | else: 57 | x = self.norm(x) 58 | x = self.pwconv1(x) 59 | x = self.act(x) 60 | x = self.grn(x) 61 | x = self.pwconv2(x) 62 | 63 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 64 | 65 | x = residual + x 66 | return x 67 | 68 | 69 | class Generator(torch.nn.Module): 70 | def __init__(self, h): 71 | super(Generator, self).__init__() 72 | self.h = h 73 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes) 74 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes) 75 | 76 | self.ASP_input_conv = Conv1d( 77 | h.num_mels, 78 | h.ASP_channel, 79 | h.ASP_input_conv_kernel_size, 80 | 1, 81 | padding=get_padding(h.ASP_input_conv_kernel_size, 1), 82 | ) 83 | self.PSP_input_conv = Conv1d( 84 | h.num_mels, 85 | h.PSP_channel, 86 | h.PSP_input_conv_kernel_size, 87 | 1, 88 | padding=get_padding(h.PSP_input_conv_kernel_size, 1), 89 | ) 90 | 91 | self.ASP_output_conv = Conv1d( 92 | h.ASP_channel, 93 | h.n_fft // 2 + 1, 94 | h.ASP_output_conv_kernel_size, 95 | 1, 96 | padding=get_padding(h.ASP_output_conv_kernel_size, 1), 97 | ) 98 | self.PSP_output_R_conv = Conv1d( 99 | 512, 100 | h.n_fft // 2 + 1, 101 | h.PSP_output_R_conv_kernel_size, 102 | 1, 103 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1), 104 | ) 105 | self.PSP_output_I_conv = Conv1d( 106 | 512, 107 | h.n_fft // 2 + 1, 108 | h.PSP_output_I_conv_kernel_size, 109 | 1, 110 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1), 111 | ) 112 | 113 | self.dim = 512 114 | self.num_layers = 8 115 | self.adanorm_num_embeddings = None 116 | self.intermediate_dim = 1536 117 | self.norm = nn.LayerNorm(self.dim, eps=1e-6) 118 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6) 119 | layer_scale_init_value = 1 / self.num_layers 120 | self.convnext = nn.ModuleList( 121 | [ 122 | ConvNeXtBlock( 123 | dim=self.dim, 124 | intermediate_dim=self.intermediate_dim, 125 | layer_scale_init_value=layer_scale_init_value, 126 | adanorm_num_embeddings=self.adanorm_num_embeddings, 127 | ) 128 | for _ in range(self.num_layers) 129 | ] 130 | ) 131 | self.convnext2 = nn.ModuleList( 132 | [ 133 | ConvNeXtBlock( 134 | dim=self.dim, 135 | intermediate_dim=self.intermediate_dim, 136 | layer_scale_init_value=layer_scale_init_value, 137 | adanorm_num_embeddings=self.adanorm_num_embeddings, 138 | ) 139 | for _ in range(self.num_layers) 140 | ] 141 | ) 142 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6) 143 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6) 144 | self.apply(self._init_weights) 145 | 146 | def _init_weights(self, m): 147 | if isinstance(m, (nn.Conv1d, nn.Linear)): 148 | nn.init.trunc_normal_(m.weight, std=0.02) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | def forward(self, mel): 152 | logamp = self.ASP_input_conv(mel) 153 | logamp = self.norm2(logamp.transpose(1, 2)) 154 | logamp = logamp.transpose(1, 2) 155 | for conv_block in self.convnext2: 156 | logamp = conv_block(logamp, cond_embedding_id=None) 157 | logamp = self.final_layer_norm2(logamp.transpose(1, 2)) 158 | logamp = logamp.transpose(1, 2) 159 | logamp = self.ASP_output_conv(logamp) 160 | 161 | pha = self.PSP_input_conv(mel) 162 | pha = self.norm(pha.transpose(1, 2)) 163 | pha = pha.transpose(1, 2) 164 | for conv_block in self.convnext: 165 | pha = conv_block(pha, cond_embedding_id=None) 166 | pha = self.final_layer_norm(pha.transpose(1, 2)) 167 | pha = pha.transpose(1, 2) 168 | R = self.PSP_output_R_conv(pha) 169 | I = self.PSP_output_I_conv(pha) 170 | 171 | pha = torch.atan2(I, R) 172 | 173 | rea = torch.exp(logamp) * torch.cos(pha) 174 | imag = torch.exp(logamp) * torch.sin(pha) 175 | 176 | spec = torch.complex(rea, imag) 177 | # spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) 178 | 179 | audio = torch.istft( 180 | spec, 181 | self.h.n_fft, 182 | hop_length=self.h.hop_size, 183 | win_length=self.h.win_size, 184 | window=torch.hann_window(self.h.win_size).to(mel.device), 185 | center=True, 186 | ) 187 | 188 | return logamp, pha, rea, imag, audio.unsqueeze(1) 189 | 190 | 191 | class DiscriminatorP(torch.nn.Module): 192 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 193 | super(DiscriminatorP, self).__init__() 194 | self.period = period 195 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 196 | self.convs = nn.ModuleList( 197 | [ 198 | norm_f( 199 | Conv2d( 200 | 1, 201 | 32, 202 | (kernel_size, 1), 203 | (stride, 1), 204 | padding=(get_padding(5, 1), 0), 205 | ) 206 | ), 207 | norm_f( 208 | Conv2d( 209 | 32, 210 | 128, 211 | (kernel_size, 1), 212 | (stride, 1), 213 | padding=(get_padding(5, 1), 0), 214 | ) 215 | ), 216 | norm_f( 217 | Conv2d( 218 | 128, 219 | 512, 220 | (kernel_size, 1), 221 | (stride, 1), 222 | padding=(get_padding(5, 1), 0), 223 | ) 224 | ), 225 | norm_f( 226 | Conv2d( 227 | 512, 228 | 1024, 229 | (kernel_size, 1), 230 | (stride, 1), 231 | padding=(get_padding(5, 1), 0), 232 | ) 233 | ), 234 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 235 | ] 236 | ) 237 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 238 | 239 | def forward(self, x): 240 | fmap = [] 241 | 242 | # 1d to 2d 243 | b, c, t = x.shape 244 | if t % self.period != 0: # pad first 245 | n_pad = self.period - (t % self.period) 246 | x = F.pad(x, (0, n_pad), "reflect") 247 | t = t + n_pad 248 | x = x.view(b, c, t // self.period, self.period) 249 | 250 | for l in self.convs: 251 | x = l(x) 252 | x = F.leaky_relu(x, LRELU_SLOPE) 253 | fmap.append(x) 254 | x = self.conv_post(x) 255 | fmap.append(x) 256 | x = torch.flatten(x, 1, -1) 257 | 258 | return x, fmap 259 | 260 | 261 | class MultiPeriodDiscriminator(torch.nn.Module): 262 | def __init__(self): 263 | super(MultiPeriodDiscriminator, self).__init__() 264 | self.discriminators = nn.ModuleList( 265 | [ 266 | DiscriminatorP(2), 267 | DiscriminatorP(3), 268 | DiscriminatorP(5), 269 | DiscriminatorP(7), 270 | DiscriminatorP(11), 271 | ] 272 | ) 273 | 274 | def forward(self, y, y_hat): 275 | y_d_rs = [] 276 | y_d_gs = [] 277 | fmap_rs = [] 278 | fmap_gs = [] 279 | for i, d in enumerate(self.discriminators): 280 | y_d_r, fmap_r = d(y) 281 | y_d_g, fmap_g = d(y_hat) 282 | y_d_rs.append(y_d_r) 283 | fmap_rs.append(fmap_r) 284 | y_d_gs.append(y_d_g) 285 | fmap_gs.append(fmap_g) 286 | 287 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 288 | 289 | 290 | def phase_loss(phase_r, phase_g, n_fft, frames): 291 | GD_matrix = ( 292 | torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) 293 | - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) 294 | - torch.eye(n_fft // 2 + 1) 295 | ) 296 | GD_matrix = GD_matrix.to(phase_g.device) 297 | 298 | GD_r = torch.matmul(phase_r.permute(0, 2, 1), GD_matrix) 299 | GD_g = torch.matmul(phase_g.permute(0, 2, 1), GD_matrix) 300 | 301 | PTD_matrix = ( 302 | torch.triu(torch.ones(frames, frames), diagonal=1) 303 | - torch.triu(torch.ones(frames, frames), diagonal=2) 304 | - torch.eye(frames) 305 | ) 306 | PTD_matrix = PTD_matrix.to(phase_g.device) 307 | 308 | PTD_r = torch.matmul(phase_r, PTD_matrix) 309 | PTD_g = torch.matmul(phase_g, PTD_matrix) 310 | 311 | IP_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 312 | GD_loss = torch.mean(anti_wrapping_function(GD_r - GD_g)) 313 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r - PTD_g)) 314 | 315 | return IP_loss, GD_loss, PTD_loss 316 | 317 | 318 | class MultiResolutionDiscriminator(nn.Module): 319 | def __init__( 320 | self, 321 | resolutions=((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), 322 | num_embeddings: int = None, 323 | ): 324 | super().__init__() 325 | self.discriminators = nn.ModuleList( 326 | [ 327 | DiscriminatorR(resolution=r, num_embeddings=num_embeddings) 328 | for r in resolutions 329 | ] 330 | ) 331 | 332 | def forward( 333 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 334 | ): 335 | y_d_rs = [] 336 | y_d_gs = [] 337 | fmap_rs = [] 338 | fmap_gs = [] 339 | 340 | for d in self.discriminators: 341 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 342 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 343 | y_d_rs.append(y_d_r) 344 | fmap_rs.append(fmap_r) 345 | y_d_gs.append(y_d_g) 346 | fmap_gs.append(fmap_g) 347 | 348 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 349 | 350 | 351 | class DiscriminatorR(nn.Module): 352 | def __init__( 353 | self, 354 | resolution, 355 | channels: int = 64, 356 | in_channels: int = 1, 357 | num_embeddings: int = None, 358 | lrelu_slope: float = 0.1, 359 | ): 360 | super().__init__() 361 | self.resolution = resolution 362 | self.in_channels = in_channels 363 | self.lrelu_slope = lrelu_slope 364 | self.convs = nn.ModuleList( 365 | [ 366 | weight_norm( 367 | nn.Conv2d( 368 | in_channels, 369 | channels, 370 | kernel_size=(7, 5), 371 | stride=(2, 2), 372 | padding=(3, 2), 373 | ) 374 | ), 375 | weight_norm( 376 | nn.Conv2d( 377 | channels, 378 | channels, 379 | kernel_size=(5, 3), 380 | stride=(2, 1), 381 | padding=(2, 1), 382 | ) 383 | ), 384 | weight_norm( 385 | nn.Conv2d( 386 | channels, 387 | channels, 388 | kernel_size=(5, 3), 389 | stride=(2, 2), 390 | padding=(2, 1), 391 | ) 392 | ), 393 | weight_norm( 394 | nn.Conv2d( 395 | channels, channels, kernel_size=3, stride=(2, 1), padding=1 396 | ) 397 | ), 398 | weight_norm( 399 | nn.Conv2d( 400 | channels, channels, kernel_size=3, stride=(2, 2), padding=1 401 | ) 402 | ), 403 | ] 404 | ) 405 | if num_embeddings is not None: 406 | self.emb = torch.nn.Embedding( 407 | num_embeddings=num_embeddings, embedding_dim=channels 408 | ) 409 | torch.nn.init.zeros_(self.emb.weight) 410 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 411 | 412 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 413 | fmap = [] 414 | x = x.squeeze(1) 415 | 416 | x = self.spectrogram(x) 417 | x = x.unsqueeze(1) 418 | for l in self.convs: 419 | x = l(x) 420 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 421 | fmap.append(x) 422 | if cond_embedding_id is not None: 423 | emb = self.emb(cond_embedding_id) 424 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 425 | else: 426 | h = 0 427 | x = self.conv_post(x) 428 | fmap.append(x) 429 | x += h 430 | x = torch.flatten(x, 1, -1) 431 | 432 | return x, fmap 433 | 434 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 435 | n_fft, hop_length, win_length = self.resolution 436 | magnitude_spectrogram = torch.stft( 437 | x, 438 | n_fft=n_fft, 439 | hop_length=hop_length, 440 | win_length=win_length, 441 | window=None, # interestingly rectangular window kind of works here 442 | center=True, 443 | return_complex=True, 444 | ).abs() 445 | 446 | return magnitude_spectrogram 447 | 448 | 449 | def anti_wrapping_function(x): 450 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 451 | 452 | 453 | def amplitude_loss(log_amplitude_r, log_amplitude_g): 454 | MSELoss = torch.nn.MSELoss() 455 | 456 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g) 457 | 458 | return amplitude_loss 459 | 460 | 461 | def feature_loss(fmap_r, fmap_g): 462 | loss = 0 463 | for dr, dg in zip(fmap_r, fmap_g): 464 | for rl, gl in zip(dr, dg): 465 | loss += torch.mean(torch.abs(rl - gl)) 466 | 467 | return loss 468 | 469 | 470 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 471 | loss = 0 472 | r_losses = [] 473 | g_losses = [] 474 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 475 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 476 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 477 | loss += r_loss + g_loss 478 | r_losses.append(r_loss.item()) 479 | g_losses.append(g_loss.item()) 480 | 481 | return loss, r_losses, g_losses 482 | 483 | 484 | def generator_loss(disc_outputs): 485 | loss = 0 486 | gen_losses = [] 487 | for dg in disc_outputs: 488 | l = torch.mean(torch.clamp(1 - dg, min=0)) 489 | gen_losses.append(l) 490 | loss += l 491 | 492 | return loss, gen_losses 493 | 494 | 495 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g): 496 | C_loss = torch.mean( 497 | torch.mean((rea_r - rea_g) ** 2 + (imag_r - imag_g) ** 2, (1, 2)) 498 | ) 499 | 500 | return C_loss 501 | -------------------------------------------------------------------------------- /models2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | from dataset import inverse_mel 8 | import numpy as np 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class GRN(nn.Module): 14 | """GRN (Global Response Normalization) layer""" 15 | 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 19 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 20 | 21 | def forward(self, x): 22 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 23 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 24 | return self.gamma * (x * Nx) + self.beta + x 25 | 26 | 27 | class ConvNeXtBlock(nn.Module): 28 | def __init__( 29 | self, 30 | dim: int, 31 | intermediate_dim: int, 32 | layer_scale_init_value=None, 33 | adanorm_num_embeddings=None, 34 | ): 35 | super().__init__() 36 | self.dwconv = nn.Conv1d( 37 | dim, dim, kernel_size=7, padding=3, groups=dim 38 | ) # depthwise conv 39 | self.adanorm = adanorm_num_embeddings is not None 40 | 41 | self.norm = nn.LayerNorm(dim, eps=1e-6) 42 | self.pwconv1 = nn.Linear( 43 | dim, intermediate_dim 44 | ) # pointwise/1x1 convs, implemented with linear layers 45 | self.act = nn.GELU() 46 | self.grn = GRN(intermediate_dim) 47 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 48 | 49 | def forward(self, x, cond_embedding_id=None): 50 | residual = x 51 | x = self.dwconv(x) 52 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 53 | if self.adanorm: 54 | assert cond_embedding_id is not None 55 | x = self.norm(x, cond_embedding_id) 56 | else: 57 | x = self.norm(x) 58 | x = self.pwconv1(x) 59 | x = self.act(x) 60 | x = self.grn(x) 61 | x = self.pwconv2(x) 62 | 63 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 64 | 65 | x = residual + x 66 | return x 67 | 68 | 69 | class Generator(torch.nn.Module): 70 | def __init__(self, h): 71 | super(Generator, self).__init__() 72 | self.h = h 73 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes) 74 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes) 75 | 76 | # self.ASP_input_conv = Conv1d( 77 | # h.num_mels, 78 | # h.ASP_channel, 79 | # h.ASP_input_conv_kernel_size, 80 | # 1, 81 | # padding=get_padding(h.ASP_input_conv_kernel_size, 1), 82 | # ) 83 | self.PSP_input_conv = Conv1d( 84 | h.num_mels, 85 | h.PSP_channel, 86 | h.PSP_input_conv_kernel_size, 87 | 1, 88 | padding=get_padding(h.PSP_input_conv_kernel_size, 1), 89 | ) 90 | 91 | # self.ASP_output_conv = Conv1d( 92 | # h.ASP_channel, 93 | # h.n_fft // 2 + 1, 94 | # h.ASP_output_conv_kernel_size, 95 | # 1, 96 | # padding=get_padding(h.ASP_output_conv_kernel_size, 1), 97 | # ) 98 | self.PSP_output_R_conv = Conv1d( 99 | 512, 100 | h.n_fft // 2 + 1, 101 | h.PSP_output_R_conv_kernel_size, 102 | 1, 103 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1), 104 | ) 105 | self.PSP_output_I_conv = Conv1d( 106 | 512, 107 | h.n_fft // 2 + 1, 108 | h.PSP_output_I_conv_kernel_size, 109 | 1, 110 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1), 111 | ) 112 | 113 | self.dim = 512 114 | self.num_layers = 8 115 | self.adanorm_num_embeddings = None 116 | self.intermediate_dim = 1536 117 | self.norm = nn.LayerNorm(self.dim, eps=1e-6) 118 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6) 119 | layer_scale_init_value = 1 / self.num_layers 120 | self.convnext = nn.ModuleList( 121 | [ 122 | ConvNeXtBlock( 123 | dim=self.dim, 124 | intermediate_dim=self.intermediate_dim, 125 | layer_scale_init_value=layer_scale_init_value, 126 | adanorm_num_embeddings=self.adanorm_num_embeddings, 127 | ) 128 | for _ in range(self.num_layers) 129 | ] 130 | ) 131 | self.convnext2 = nn.ModuleList( 132 | [ 133 | ConvNeXtBlock( 134 | dim=self.h.ASP_channel, 135 | intermediate_dim=self.intermediate_dim, 136 | layer_scale_init_value=layer_scale_init_value, 137 | adanorm_num_embeddings=self.adanorm_num_embeddings, 138 | ) 139 | # for _ in range(self.num_layers) 140 | for _ in range(1) 141 | ] 142 | ) 143 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6) 144 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6) 145 | self.apply(self._init_weights) 146 | 147 | def _init_weights(self, m): 148 | if isinstance(m, (nn.Conv1d, nn.Linear)): 149 | nn.init.trunc_normal_(m.weight, std=0.02) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | def forward(self, mel, inv_mel=None): 153 | if inv_mel is None: 154 | inv_amp = ( 155 | inverse_mel( 156 | mel, 157 | self.h.n_fft, 158 | self.h.num_mels, 159 | self.h.sampling_rate, 160 | self.h.hop_size, 161 | self.h.win_size, 162 | self.h.fmin, 163 | self.h.fmax, 164 | ) 165 | .abs() 166 | .clamp_min(1e-5) 167 | ) 168 | else: 169 | inv_amp = inv_mel 170 | logamp = inv_amp.log() 171 | # logamp = self.ASP_input_conv(logamp) 172 | for conv_block in self.convnext2: 173 | logamp = conv_block(logamp, cond_embedding_id=None) 174 | # logamp = self.final_layer_norm2(logamp.transpose(1, 2)) 175 | # logamp = logamp.transpose(1, 2) 176 | # logamp = self.ASP_output_conv(logamp) 177 | 178 | pha = self.PSP_input_conv(mel) 179 | pha = self.norm(pha.transpose(1, 2)) 180 | pha = pha.transpose(1, 2) 181 | for conv_block in self.convnext: 182 | pha = conv_block(pha, cond_embedding_id=None) 183 | pha = self.final_layer_norm(pha.transpose(1, 2)) 184 | pha = pha.transpose(1, 2) 185 | R = self.PSP_output_R_conv(pha) 186 | I = self.PSP_output_I_conv(pha) 187 | 188 | pha = torch.atan2(I, R) 189 | 190 | rea = torch.exp(logamp) * torch.cos(pha) 191 | imag = torch.exp(logamp) * torch.sin(pha) 192 | 193 | spec = torch.complex(rea, imag) 194 | # spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) 195 | 196 | audio = torch.istft( 197 | spec, 198 | self.h.n_fft, 199 | hop_length=self.h.hop_size, 200 | win_length=self.h.win_size, 201 | window=torch.hann_window(self.h.win_size).to(mel.device), 202 | center=True, 203 | ) 204 | 205 | return logamp, pha, rea, imag, audio.unsqueeze(1) 206 | 207 | 208 | class DiscriminatorP(torch.nn.Module): 209 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 210 | super(DiscriminatorP, self).__init__() 211 | self.period = period 212 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 213 | self.convs = nn.ModuleList( 214 | [ 215 | norm_f( 216 | Conv2d( 217 | 1, 218 | 32, 219 | (kernel_size, 1), 220 | (stride, 1), 221 | padding=(get_padding(5, 1), 0), 222 | ) 223 | ), 224 | norm_f( 225 | Conv2d( 226 | 32, 227 | 128, 228 | (kernel_size, 1), 229 | (stride, 1), 230 | padding=(get_padding(5, 1), 0), 231 | ) 232 | ), 233 | norm_f( 234 | Conv2d( 235 | 128, 236 | 512, 237 | (kernel_size, 1), 238 | (stride, 1), 239 | padding=(get_padding(5, 1), 0), 240 | ) 241 | ), 242 | norm_f( 243 | Conv2d( 244 | 512, 245 | 1024, 246 | (kernel_size, 1), 247 | (stride, 1), 248 | padding=(get_padding(5, 1), 0), 249 | ) 250 | ), 251 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 252 | ] 253 | ) 254 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 255 | 256 | def forward(self, x): 257 | fmap = [] 258 | 259 | # 1d to 2d 260 | b, c, t = x.shape 261 | if t % self.period != 0: # pad first 262 | n_pad = self.period - (t % self.period) 263 | x = F.pad(x, (0, n_pad), "reflect") 264 | t = t + n_pad 265 | x = x.view(b, c, t // self.period, self.period) 266 | 267 | for l in self.convs: 268 | x = l(x) 269 | x = F.leaky_relu(x, LRELU_SLOPE) 270 | fmap.append(x) 271 | x = self.conv_post(x) 272 | fmap.append(x) 273 | x = torch.flatten(x, 1, -1) 274 | 275 | return x, fmap 276 | 277 | 278 | class MultiPeriodDiscriminator(torch.nn.Module): 279 | def __init__(self): 280 | super(MultiPeriodDiscriminator, self).__init__() 281 | self.discriminators = nn.ModuleList( 282 | [ 283 | DiscriminatorP(2), 284 | DiscriminatorP(3), 285 | DiscriminatorP(5), 286 | DiscriminatorP(7), 287 | DiscriminatorP(11), 288 | ] 289 | ) 290 | 291 | def forward(self, y, y_hat): 292 | y_d_rs = [] 293 | y_d_gs = [] 294 | fmap_rs = [] 295 | fmap_gs = [] 296 | for i, d in enumerate(self.discriminators): 297 | y_d_r, fmap_r = d(y) 298 | y_d_g, fmap_g = d(y_hat) 299 | y_d_rs.append(y_d_r) 300 | fmap_rs.append(fmap_r) 301 | y_d_gs.append(y_d_g) 302 | fmap_gs.append(fmap_g) 303 | 304 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 305 | 306 | 307 | def phase_loss(phase_r, phase_g, n_fft, frames): 308 | GD_matrix = ( 309 | torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) 310 | - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) 311 | - torch.eye(n_fft // 2 + 1) 312 | ) 313 | GD_matrix = GD_matrix.to(phase_g.device) 314 | 315 | GD_r = torch.matmul(phase_r.permute(0, 2, 1), GD_matrix) 316 | GD_g = torch.matmul(phase_g.permute(0, 2, 1), GD_matrix) 317 | 318 | PTD_matrix = ( 319 | torch.triu(torch.ones(frames, frames), diagonal=1) 320 | - torch.triu(torch.ones(frames, frames), diagonal=2) 321 | - torch.eye(frames) 322 | ) 323 | PTD_matrix = PTD_matrix.to(phase_g.device) 324 | 325 | PTD_r = torch.matmul(phase_r, PTD_matrix) 326 | PTD_g = torch.matmul(phase_g, PTD_matrix) 327 | 328 | IP_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 329 | GD_loss = torch.mean(anti_wrapping_function(GD_r - GD_g)) 330 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r - PTD_g)) 331 | 332 | return IP_loss, GD_loss, PTD_loss 333 | 334 | 335 | class MultiResolutionDiscriminator(nn.Module): 336 | def __init__( 337 | self, 338 | resolutions=((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), 339 | num_embeddings: int = None, 340 | ): 341 | super().__init__() 342 | self.discriminators = nn.ModuleList( 343 | [ 344 | DiscriminatorR(resolution=r, num_embeddings=num_embeddings) 345 | for r in resolutions 346 | ] 347 | ) 348 | 349 | def forward( 350 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 351 | ): 352 | y_d_rs = [] 353 | y_d_gs = [] 354 | fmap_rs = [] 355 | fmap_gs = [] 356 | 357 | for d in self.discriminators: 358 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 359 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 360 | y_d_rs.append(y_d_r) 361 | fmap_rs.append(fmap_r) 362 | y_d_gs.append(y_d_g) 363 | fmap_gs.append(fmap_g) 364 | 365 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 366 | 367 | 368 | class DiscriminatorR(nn.Module): 369 | def __init__( 370 | self, 371 | resolution, 372 | channels: int = 64, 373 | in_channels: int = 1, 374 | num_embeddings: int = None, 375 | lrelu_slope: float = 0.1, 376 | ): 377 | super().__init__() 378 | self.resolution = resolution 379 | self.in_channels = in_channels 380 | self.lrelu_slope = lrelu_slope 381 | self.convs = nn.ModuleList( 382 | [ 383 | weight_norm( 384 | nn.Conv2d( 385 | in_channels, 386 | channels, 387 | kernel_size=(7, 5), 388 | stride=(2, 2), 389 | padding=(3, 2), 390 | ) 391 | ), 392 | weight_norm( 393 | nn.Conv2d( 394 | channels, 395 | channels, 396 | kernel_size=(5, 3), 397 | stride=(2, 1), 398 | padding=(2, 1), 399 | ) 400 | ), 401 | weight_norm( 402 | nn.Conv2d( 403 | channels, 404 | channels, 405 | kernel_size=(5, 3), 406 | stride=(2, 2), 407 | padding=(2, 1), 408 | ) 409 | ), 410 | weight_norm( 411 | nn.Conv2d( 412 | channels, channels, kernel_size=3, stride=(2, 1), padding=1 413 | ) 414 | ), 415 | weight_norm( 416 | nn.Conv2d( 417 | channels, channels, kernel_size=3, stride=(2, 2), padding=1 418 | ) 419 | ), 420 | ] 421 | ) 422 | if num_embeddings is not None: 423 | self.emb = torch.nn.Embedding( 424 | num_embeddings=num_embeddings, embedding_dim=channels 425 | ) 426 | torch.nn.init.zeros_(self.emb.weight) 427 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 428 | 429 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 430 | fmap = [] 431 | x = x.squeeze(1) 432 | 433 | x = self.spectrogram(x) 434 | x = x.unsqueeze(1) 435 | for l in self.convs: 436 | x = l(x) 437 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 438 | fmap.append(x) 439 | if cond_embedding_id is not None: 440 | emb = self.emb(cond_embedding_id) 441 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 442 | else: 443 | h = 0 444 | x = self.conv_post(x) 445 | fmap.append(x) 446 | x += h 447 | x = torch.flatten(x, 1, -1) 448 | 449 | return x, fmap 450 | 451 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 452 | n_fft, hop_length, win_length = self.resolution 453 | magnitude_spectrogram = torch.stft( 454 | x, 455 | n_fft=n_fft, 456 | hop_length=hop_length, 457 | win_length=win_length, 458 | window=None, # interestingly rectangular window kind of works here 459 | center=True, 460 | return_complex=True, 461 | ).abs() 462 | 463 | return magnitude_spectrogram 464 | 465 | 466 | def anti_wrapping_function(x): 467 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 468 | 469 | 470 | def amplitude_loss(log_amplitude_r, log_amplitude_g): 471 | MSELoss = torch.nn.MSELoss() 472 | 473 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g) 474 | 475 | return amplitude_loss 476 | 477 | 478 | def feature_loss(fmap_r, fmap_g): 479 | loss = 0 480 | for dr, dg in zip(fmap_r, fmap_g): 481 | for rl, gl in zip(dr, dg): 482 | loss += torch.mean(torch.abs(rl - gl)) 483 | 484 | return loss 485 | 486 | 487 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 488 | loss = 0 489 | r_losses = [] 490 | g_losses = [] 491 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 492 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 493 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 494 | loss += r_loss + g_loss 495 | r_losses.append(r_loss.item()) 496 | g_losses.append(g_loss.item()) 497 | 498 | return loss, r_losses, g_losses 499 | 500 | 501 | def generator_loss(disc_outputs): 502 | loss = 0 503 | gen_losses = [] 504 | for dg in disc_outputs: 505 | l = torch.mean(torch.clamp(1 - dg, min=0)) 506 | gen_losses.append(l) 507 | loss += l 508 | 509 | return loss, gen_losses 510 | 511 | 512 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g): 513 | C_loss = torch.mean( 514 | torch.mean((rea_r - rea_g) ** 2 + (imag_r - imag_g) ** 2, (1, 2)) 515 | ) 516 | 517 | return C_loss 518 | -------------------------------------------------------------------------------- /models2_pghi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | from dataset import inverse_mel 8 | import numpy as np 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class GRN(nn.Module): 14 | """GRN (Global Response Normalization) layer""" 15 | 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 19 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 20 | 21 | def forward(self, x): 22 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 23 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 24 | return self.gamma * (x * Nx) + self.beta + x 25 | 26 | 27 | class ConvNeXtBlock(nn.Module): 28 | def __init__( 29 | self, 30 | dim: int, 31 | intermediate_dim: int, 32 | layer_scale_init_value=None, 33 | adanorm_num_embeddings=None, 34 | ): 35 | super().__init__() 36 | self.dwconv = nn.Conv1d( 37 | dim, dim, kernel_size=7, padding=3, groups=dim 38 | ) # depthwise conv 39 | self.adanorm = adanorm_num_embeddings is not None 40 | 41 | self.norm = nn.LayerNorm(dim, eps=1e-6) 42 | self.pwconv1 = nn.Linear( 43 | dim, intermediate_dim 44 | ) # pointwise/1x1 convs, implemented with linear layers 45 | self.act = nn.GELU() 46 | self.grn = GRN(intermediate_dim) 47 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 48 | 49 | def forward(self, x, cond_embedding_id=None): 50 | residual = x 51 | x = self.dwconv(x) 52 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 53 | if self.adanorm: 54 | assert cond_embedding_id is not None 55 | x = self.norm(x, cond_embedding_id) 56 | else: 57 | x = self.norm(x) 58 | x = self.pwconv1(x) 59 | x = self.act(x) 60 | x = self.grn(x) 61 | x = self.pwconv2(x) 62 | 63 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 64 | 65 | x = residual + x 66 | return x 67 | 68 | 69 | class Generator(torch.nn.Module): 70 | def __init__(self, h): 71 | super(Generator, self).__init__() 72 | self.h = h 73 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes) 74 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes) 75 | 76 | # self.ASP_input_conv = Conv1d( 77 | # h.num_mels, 78 | # h.ASP_channel, 79 | # h.ASP_input_conv_kernel_size, 80 | # 1, 81 | # padding=get_padding(h.ASP_input_conv_kernel_size, 1), 82 | # ) 83 | self.PSP_input_conv = Conv1d( 84 | 2 * self.h.ASP_channel, 85 | h.PSP_channel, 86 | 1, 87 | ) 88 | # self.PSP_input_conv2 = Conv1d( 89 | # h.PSP_channel, 90 | # h.PSP_channel, 91 | # h.PSP_input_conv_kernel_size, 92 | # 1, 93 | # padding=get_padding(h.PSP_input_conv_kernel_size, 1), 94 | # ) 95 | 96 | # self.ASP_output_conv = Conv1d( 97 | # h.ASP_channel, 98 | # h.n_fft // 2 + 1, 99 | # h.ASP_output_conv_kernel_size, 100 | # 1, 101 | # padding=get_padding(h.ASP_output_conv_kernel_size, 1), 102 | # ) 103 | self.PSP_output_R_conv = Conv1d( 104 | 512, 105 | h.n_fft // 2 + 1, 106 | h.PSP_output_R_conv_kernel_size, 107 | 1, 108 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1), 109 | ) 110 | self.PSP_output_I_conv = Conv1d( 111 | 512, 112 | h.n_fft // 2 + 1, 113 | h.PSP_output_I_conv_kernel_size, 114 | 1, 115 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1), 116 | ) 117 | 118 | self.dim = 512 119 | self.num_layers = 8 120 | self.adanorm_num_embeddings = None 121 | self.intermediate_dim = 1536 122 | self.norm = nn.LayerNorm(self.dim, eps=1e-6) 123 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6) 124 | layer_scale_init_value = 1 / self.num_layers 125 | self.convnext = nn.ModuleList( 126 | [ 127 | ConvNeXtBlock( 128 | dim=self.dim, 129 | intermediate_dim=self.intermediate_dim, 130 | layer_scale_init_value=layer_scale_init_value, 131 | adanorm_num_embeddings=self.adanorm_num_embeddings, 132 | ) 133 | for _ in range(self.num_layers) 134 | ] 135 | ) 136 | self.convnext2 = nn.ModuleList( 137 | [ 138 | ConvNeXtBlock( 139 | dim=self.h.ASP_channel, 140 | intermediate_dim=self.intermediate_dim, 141 | layer_scale_init_value=layer_scale_init_value, 142 | adanorm_num_embeddings=self.adanorm_num_embeddings, 143 | ) 144 | # for _ in range(self.num_layers) 145 | for _ in range(1) 146 | ] 147 | ) 148 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6) 149 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6) 150 | self.apply(self._init_weights) 151 | 152 | def _init_weights(self, m): 153 | if isinstance(m, (nn.Conv1d, nn.Linear)): 154 | nn.init.trunc_normal_(m.weight, std=0.02) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | def forward(self, mel, inv_mel=None, pghi=None): 158 | if inv_mel is None: 159 | inv_amp = ( 160 | inverse_mel( 161 | mel, 162 | self.h.n_fft, 163 | self.h.num_mels, 164 | self.h.sampling_rate, 165 | self.h.hop_size, 166 | self.h.win_size, 167 | self.h.fmin, 168 | self.h.fmax, 169 | ) 170 | .abs() 171 | .clamp_min(1e-5) 172 | ) 173 | else: 174 | inv_amp = inv_mel 175 | logamp = inv_amp.log() 176 | # logamp = self.ASP_input_conv(logamp) 177 | for conv_block in self.convnext2: 178 | logamp = conv_block(logamp, cond_embedding_id=None) 179 | # logamp = self.final_layer_norm2(logamp.transpose(1, 2)) 180 | # logamp = logamp.transpose(1, 2) 181 | # logamp = self.ASP_output_conv(logamp) 182 | 183 | pha = self.PSP_input_conv(torch.cat((inv_amp, pghi), dim=1)) 184 | # pha = self.norm(pha.transpose(1, 2)) 185 | # pha = pha.transpose(1, 2) 186 | for conv_block in self.convnext: 187 | pha = conv_block(pha, cond_embedding_id=None) 188 | pha = self.final_layer_norm(pha.transpose(1, 2)) 189 | pha = pha.transpose(1, 2) 190 | R = self.PSP_output_R_conv(pha) 191 | I = self.PSP_output_I_conv(pha) 192 | 193 | pha = torch.atan2(I, R) 194 | 195 | rea = torch.exp(logamp) * torch.cos(pha) 196 | imag = torch.exp(logamp) * torch.sin(pha) 197 | 198 | spec = torch.complex(rea, imag) 199 | # spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) 200 | 201 | audio = torch.istft( 202 | spec, 203 | self.h.n_fft, 204 | hop_length=self.h.hop_size, 205 | win_length=self.h.win_size, 206 | window=torch.hann_window(self.h.win_size).to(mel.device), 207 | center=True, 208 | ) 209 | 210 | return logamp, pha, rea, imag, audio.unsqueeze(1) 211 | 212 | 213 | class DiscriminatorP(torch.nn.Module): 214 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 215 | super(DiscriminatorP, self).__init__() 216 | self.period = period 217 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 218 | self.convs = nn.ModuleList( 219 | [ 220 | norm_f( 221 | Conv2d( 222 | 1, 223 | 32, 224 | (kernel_size, 1), 225 | (stride, 1), 226 | padding=(get_padding(5, 1), 0), 227 | ) 228 | ), 229 | norm_f( 230 | Conv2d( 231 | 32, 232 | 128, 233 | (kernel_size, 1), 234 | (stride, 1), 235 | padding=(get_padding(5, 1), 0), 236 | ) 237 | ), 238 | norm_f( 239 | Conv2d( 240 | 128, 241 | 512, 242 | (kernel_size, 1), 243 | (stride, 1), 244 | padding=(get_padding(5, 1), 0), 245 | ) 246 | ), 247 | norm_f( 248 | Conv2d( 249 | 512, 250 | 1024, 251 | (kernel_size, 1), 252 | (stride, 1), 253 | padding=(get_padding(5, 1), 0), 254 | ) 255 | ), 256 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 257 | ] 258 | ) 259 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 260 | 261 | def forward(self, x): 262 | fmap = [] 263 | 264 | # 1d to 2d 265 | b, c, t = x.shape 266 | if t % self.period != 0: # pad first 267 | n_pad = self.period - (t % self.period) 268 | x = F.pad(x, (0, n_pad), "reflect") 269 | t = t + n_pad 270 | x = x.view(b, c, t // self.period, self.period) 271 | 272 | for l in self.convs: 273 | x = l(x) 274 | x = F.leaky_relu(x, LRELU_SLOPE) 275 | fmap.append(x) 276 | x = self.conv_post(x) 277 | fmap.append(x) 278 | x = torch.flatten(x, 1, -1) 279 | 280 | return x, fmap 281 | 282 | 283 | class MultiPeriodDiscriminator(torch.nn.Module): 284 | def __init__(self): 285 | super(MultiPeriodDiscriminator, self).__init__() 286 | self.discriminators = nn.ModuleList( 287 | [ 288 | DiscriminatorP(2), 289 | DiscriminatorP(3), 290 | DiscriminatorP(5), 291 | DiscriminatorP(7), 292 | DiscriminatorP(11), 293 | ] 294 | ) 295 | 296 | def forward(self, y, y_hat): 297 | y_d_rs = [] 298 | y_d_gs = [] 299 | fmap_rs = [] 300 | fmap_gs = [] 301 | for i, d in enumerate(self.discriminators): 302 | y_d_r, fmap_r = d(y) 303 | y_d_g, fmap_g = d(y_hat) 304 | y_d_rs.append(y_d_r) 305 | fmap_rs.append(fmap_r) 306 | y_d_gs.append(y_d_g) 307 | fmap_gs.append(fmap_g) 308 | 309 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 310 | 311 | 312 | def phase_loss(phase_r, phase_g, n_fft, frames): 313 | GD_matrix = ( 314 | torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) 315 | - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) 316 | - torch.eye(n_fft // 2 + 1) 317 | ) 318 | GD_matrix = GD_matrix.to(phase_g.device) 319 | 320 | GD_r = torch.matmul(phase_r.permute(0, 2, 1), GD_matrix) 321 | GD_g = torch.matmul(phase_g.permute(0, 2, 1), GD_matrix) 322 | 323 | PTD_matrix = ( 324 | torch.triu(torch.ones(frames, frames), diagonal=1) 325 | - torch.triu(torch.ones(frames, frames), diagonal=2) 326 | - torch.eye(frames) 327 | ) 328 | PTD_matrix = PTD_matrix.to(phase_g.device) 329 | 330 | PTD_r = torch.matmul(phase_r, PTD_matrix) 331 | PTD_g = torch.matmul(phase_g, PTD_matrix) 332 | 333 | IP_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 334 | GD_loss = torch.mean(anti_wrapping_function(GD_r - GD_g)) 335 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r - PTD_g)) 336 | 337 | return IP_loss, GD_loss, PTD_loss 338 | 339 | 340 | class MultiResolutionDiscriminator(nn.Module): 341 | def __init__( 342 | self, 343 | resolutions=((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), 344 | num_embeddings: int = None, 345 | ): 346 | super().__init__() 347 | self.discriminators = nn.ModuleList( 348 | [ 349 | DiscriminatorR(resolution=r, num_embeddings=num_embeddings) 350 | for r in resolutions 351 | ] 352 | ) 353 | 354 | def forward( 355 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 356 | ): 357 | y_d_rs = [] 358 | y_d_gs = [] 359 | fmap_rs = [] 360 | fmap_gs = [] 361 | 362 | for d in self.discriminators: 363 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 364 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 365 | y_d_rs.append(y_d_r) 366 | fmap_rs.append(fmap_r) 367 | y_d_gs.append(y_d_g) 368 | fmap_gs.append(fmap_g) 369 | 370 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 371 | 372 | 373 | class DiscriminatorR(nn.Module): 374 | def __init__( 375 | self, 376 | resolution, 377 | channels: int = 64, 378 | in_channels: int = 1, 379 | num_embeddings: int = None, 380 | lrelu_slope: float = 0.1, 381 | ): 382 | super().__init__() 383 | self.resolution = resolution 384 | self.in_channels = in_channels 385 | self.lrelu_slope = lrelu_slope 386 | self.convs = nn.ModuleList( 387 | [ 388 | weight_norm( 389 | nn.Conv2d( 390 | in_channels, 391 | channels, 392 | kernel_size=(7, 5), 393 | stride=(2, 2), 394 | padding=(3, 2), 395 | ) 396 | ), 397 | weight_norm( 398 | nn.Conv2d( 399 | channels, 400 | channels, 401 | kernel_size=(5, 3), 402 | stride=(2, 1), 403 | padding=(2, 1), 404 | ) 405 | ), 406 | weight_norm( 407 | nn.Conv2d( 408 | channels, 409 | channels, 410 | kernel_size=(5, 3), 411 | stride=(2, 2), 412 | padding=(2, 1), 413 | ) 414 | ), 415 | weight_norm( 416 | nn.Conv2d( 417 | channels, channels, kernel_size=3, stride=(2, 1), padding=1 418 | ) 419 | ), 420 | weight_norm( 421 | nn.Conv2d( 422 | channels, channels, kernel_size=3, stride=(2, 2), padding=1 423 | ) 424 | ), 425 | ] 426 | ) 427 | if num_embeddings is not None: 428 | self.emb = torch.nn.Embedding( 429 | num_embeddings=num_embeddings, embedding_dim=channels 430 | ) 431 | torch.nn.init.zeros_(self.emb.weight) 432 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 433 | 434 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 435 | fmap = [] 436 | x = x.squeeze(1) 437 | 438 | x = self.spectrogram(x) 439 | x = x.unsqueeze(1) 440 | for l in self.convs: 441 | x = l(x) 442 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 443 | fmap.append(x) 444 | if cond_embedding_id is not None: 445 | emb = self.emb(cond_embedding_id) 446 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 447 | else: 448 | h = 0 449 | x = self.conv_post(x) 450 | fmap.append(x) 451 | x += h 452 | x = torch.flatten(x, 1, -1) 453 | 454 | return x, fmap 455 | 456 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 457 | n_fft, hop_length, win_length = self.resolution 458 | magnitude_spectrogram = torch.stft( 459 | x, 460 | n_fft=n_fft, 461 | hop_length=hop_length, 462 | win_length=win_length, 463 | window=None, # interestingly rectangular window kind of works here 464 | center=True, 465 | return_complex=True, 466 | ).abs() 467 | 468 | return magnitude_spectrogram 469 | 470 | 471 | def anti_wrapping_function(x): 472 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 473 | 474 | 475 | def amplitude_loss(log_amplitude_r, log_amplitude_g): 476 | MSELoss = torch.nn.MSELoss() 477 | 478 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g) 479 | 480 | return amplitude_loss 481 | 482 | 483 | def feature_loss(fmap_r, fmap_g): 484 | loss = 0 485 | for dr, dg in zip(fmap_r, fmap_g): 486 | for rl, gl in zip(dr, dg): 487 | loss += torch.mean(torch.abs(rl - gl)) 488 | 489 | return loss 490 | 491 | 492 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 493 | loss = 0 494 | r_losses = [] 495 | g_losses = [] 496 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 497 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 498 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 499 | loss += r_loss + g_loss 500 | r_losses.append(r_loss.item()) 501 | g_losses.append(g_loss.item()) 502 | 503 | return loss, r_losses, g_losses 504 | 505 | 506 | def generator_loss(disc_outputs): 507 | loss = 0 508 | gen_losses = [] 509 | for dg in disc_outputs: 510 | l = torch.mean(torch.clamp(1 - dg, min=0)) 511 | gen_losses.append(l) 512 | loss += l 513 | 514 | return loss, gen_losses 515 | 516 | 517 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g): 518 | C_loss = torch.mean( 519 | torch.mean((rea_r - rea_g) ** 2 + (imag_r - imag_g) ** 2, (1, 2)) 520 | ) 521 | 522 | return C_loss 523 | -------------------------------------------------------------------------------- /models_pghi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | from dataset import inverse_mel 8 | import numpy as np 9 | from einops import rearrange 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class GRN(nn.Module): 15 | """GRN (Global Response Normalization) layer""" 16 | 17 | def __init__(self, dim): 18 | super().__init__() 19 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 20 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 21 | 22 | def forward(self, x): 23 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 24 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 25 | return self.gamma * (x * Nx) + self.beta + x 26 | 27 | 28 | class ConvNeXtBlock(nn.Module): 29 | def __init__( 30 | self, 31 | dim: int, 32 | intermediate_dim: int, 33 | layer_scale_init_value=None, 34 | adanorm_num_embeddings=None, 35 | ): 36 | super().__init__() 37 | self.dwconv = nn.Conv1d( 38 | dim, dim, kernel_size=7, padding=3, groups=dim 39 | ) # depthwise conv 40 | self.adanorm = adanorm_num_embeddings is not None 41 | 42 | self.norm = nn.LayerNorm(dim, eps=1e-6) 43 | self.pwconv1 = nn.Linear( 44 | dim, intermediate_dim 45 | ) # pointwise/1x1 convs, implemented with linear layers 46 | self.act = nn.GELU() 47 | self.grn = GRN(intermediate_dim) 48 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 49 | 50 | def forward(self, x, cond_embedding_id=None): 51 | residual = x 52 | x = self.dwconv(x) 53 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 54 | if self.adanorm: 55 | assert cond_embedding_id is not None 56 | x = self.norm(x, cond_embedding_id) 57 | else: 58 | x = self.norm(x) 59 | x = self.pwconv1(x) 60 | x = self.act(x) 61 | x = self.grn(x) 62 | x = self.pwconv2(x) 63 | 64 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 65 | 66 | x = residual + x 67 | return x 68 | 69 | 70 | class Generator(torch.nn.Module): 71 | def __init__(self, h): 72 | super(Generator, self).__init__() 73 | self.h = h 74 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes) 75 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes) 76 | 77 | # self.ASP_input_conv = Conv1d( 78 | # h.num_mels, 79 | # h.ASP_channel, 80 | # h.ASP_input_conv_kernel_size, 81 | # 1, 82 | # padding=get_padding(h.ASP_input_conv_kernel_size, 1), 83 | # ) 84 | self.PSP_input_conv = Conv1d( 85 | 513 * 2, 86 | h.PSP_channel, 87 | h.PSP_input_conv_kernel_size, 88 | 1, 89 | padding=get_padding(h.PSP_input_conv_kernel_size, 1), 90 | ) 91 | 92 | # self.ASP_output_conv = Conv1d( 93 | # h.ASP_channel, 94 | # h.n_fft // 2 + 1, 95 | # h.ASP_output_conv_kernel_size, 96 | # 1, 97 | # padding=get_padding(h.ASP_output_conv_kernel_size, 1), 98 | # ) 99 | self.PSP_output_R_conv = Conv1d( 100 | 513, 101 | h.n_fft // 2 + 1, 102 | h.PSP_output_R_conv_kernel_size, 103 | 1, 104 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1), 105 | ) 106 | self.PSP_output_I_conv = Conv1d( 107 | 513, 108 | h.n_fft // 2 + 1, 109 | h.PSP_output_I_conv_kernel_size, 110 | 1, 111 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1), 112 | ) 113 | 114 | self.dim = 513 115 | self.num_layers = 8 116 | self.adanorm_num_embeddings = None 117 | self.intermediate_dim = 1536 118 | self.norm = nn.LayerNorm(self.dim, eps=1e-6) 119 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6) 120 | layer_scale_init_value = 1 / self.num_layers 121 | self.convnext = nn.ModuleList( 122 | [ 123 | ConvNeXtBlock( 124 | dim=self.h.PSP_channel, 125 | intermediate_dim=self.intermediate_dim, 126 | layer_scale_init_value=layer_scale_init_value, 127 | adanorm_num_embeddings=self.adanorm_num_embeddings, 128 | ) 129 | # for _ in range(self.num_layers) 130 | for _ in range(1) 131 | ] 132 | ) 133 | self.convnext2 = nn.ModuleList( 134 | [ 135 | ConvNeXtBlock( 136 | dim=self.h.ASP_channel, 137 | intermediate_dim=self.intermediate_dim, 138 | layer_scale_init_value=layer_scale_init_value, 139 | adanorm_num_embeddings=self.adanorm_num_embeddings, 140 | ) 141 | # for _ in range(self.num_layers) 142 | for _ in range(1) 143 | ] 144 | ) 145 | # self.convnext_mix = nn.ModuleList( 146 | # [ 147 | # ConvNeXtBlock( 148 | # dim=self.h.ASP_channel * 2, 149 | # intermediate_dim=self.intermediate_dim, 150 | # layer_scale_init_value=layer_scale_init_value, 151 | # adanorm_num_embeddings=self.adanorm_num_embeddings, 152 | # ) 153 | # # for _ in range(self.num_layers) 154 | # for _ in range(1) 155 | # ] 156 | # ) 157 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6) 158 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6) 159 | self.apply(self._init_weights) 160 | 161 | def _init_weights(self, m): 162 | if isinstance(m, (nn.Conv1d, nn.Linear)): 163 | nn.init.trunc_normal_(m.weight, std=0.02) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | def forward(self, mel, inv_mel=None, pghi=None): 167 | if inv_mel is None: 168 | inv_amp = ( 169 | inverse_mel( 170 | mel, 171 | self.h.n_fft, 172 | self.h.num_mels, 173 | self.h.sampling_rate, 174 | self.h.hop_size, 175 | self.h.win_size, 176 | self.h.fmin, 177 | self.h.fmax, 178 | ) 179 | .abs() 180 | .clamp_min(1e-5) 181 | ) 182 | else: 183 | inv_amp = inv_mel 184 | logamp = inv_amp.log() 185 | # logamp = self.ASP_input_conv(logamp) 186 | for conv_block in self.convnext2: 187 | logamp = conv_block(logamp, cond_embedding_id=None) 188 | # logamp = self.final_layer_norm2(logamp.transpose(1, 2)) 189 | # logamp = logamp.transpose(1, 2) 190 | # logamp = self.ASP_output_conv(logamp) 191 | 192 | # pha = self.PSP_input_conv(mel) 193 | # pha = self.norm(pha.transpose(1, 2)) 194 | # pha = pha.transpose(1, 2) 195 | # pha = self.PSP_input_conv(torch.cat((pghi, mel), dim=-2)) 196 | # pha = pghi 197 | # pha = logamp.detach() 198 | # for conv_block in self.convnext: 199 | # pha = conv_block(pha, cond_embedding_id=None) 200 | # pha = self.final_layer_norm(pha.transpose(1, 2)) 201 | # pha = pha.transpose(1, 2) 202 | # R = self.PSP_output_R_conv(pha) 203 | # I = self.PSP_output_I_conv(pha) 204 | 205 | # pha = torch.atan2(I, R) 206 | _spec = torch.cat((inv_mel, pghi), dim=-2) 207 | # _spec = torch.polar(inv_mel, pghi) 208 | # _spec = rearrange(torch.view_as_real(_spec), "b c t d -> b (c d) t", d=2) 209 | _spec = self.PSP_input_conv(_spec) 210 | for conv_block in self.convnext: 211 | _spec = conv_block(_spec, cond_embedding_id=None) 212 | pha = self.final_layer_norm(_spec.transpose(1, 2)) 213 | pha = pha.transpose(1, 2) 214 | R = self.PSP_output_R_conv(pha) 215 | I = self.PSP_output_I_conv(pha) 216 | pha = torch.atan2(I, R) 217 | 218 | spec = torch.polar(logamp.exp(), pha) 219 | # spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) 220 | 221 | audio = torch.istft( 222 | spec, 223 | self.h.n_fft, 224 | hop_length=self.h.hop_size, 225 | win_length=self.h.win_size, 226 | window=torch.hann_window(self.h.win_size).to(mel.device), 227 | center=True, 228 | ) 229 | 230 | return logamp, pha, spec.real, spec.imag, audio.unsqueeze(1) 231 | 232 | 233 | class DiscriminatorP(torch.nn.Module): 234 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 235 | super(DiscriminatorP, self).__init__() 236 | self.period = period 237 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 238 | self.convs = nn.ModuleList( 239 | [ 240 | norm_f( 241 | Conv2d( 242 | 1, 243 | 32, 244 | (kernel_size, 1), 245 | (stride, 1), 246 | padding=(get_padding(5, 1), 0), 247 | ) 248 | ), 249 | norm_f( 250 | Conv2d( 251 | 32, 252 | 128, 253 | (kernel_size, 1), 254 | (stride, 1), 255 | padding=(get_padding(5, 1), 0), 256 | ) 257 | ), 258 | norm_f( 259 | Conv2d( 260 | 128, 261 | 512, 262 | (kernel_size, 1), 263 | (stride, 1), 264 | padding=(get_padding(5, 1), 0), 265 | ) 266 | ), 267 | norm_f( 268 | Conv2d( 269 | 512, 270 | 1024, 271 | (kernel_size, 1), 272 | (stride, 1), 273 | padding=(get_padding(5, 1), 0), 274 | ) 275 | ), 276 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 277 | ] 278 | ) 279 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 280 | 281 | def forward(self, x): 282 | fmap = [] 283 | 284 | # 1d to 2d 285 | b, c, t = x.shape 286 | if t % self.period != 0: # pad first 287 | n_pad = self.period - (t % self.period) 288 | x = F.pad(x, (0, n_pad), "reflect") 289 | t = t + n_pad 290 | x = x.view(b, c, t // self.period, self.period) 291 | 292 | for l in self.convs: 293 | x = l(x) 294 | x = F.leaky_relu(x, LRELU_SLOPE) 295 | fmap.append(x) 296 | x = self.conv_post(x) 297 | fmap.append(x) 298 | x = torch.flatten(x, 1, -1) 299 | 300 | return x, fmap 301 | 302 | 303 | class MultiPeriodDiscriminator(torch.nn.Module): 304 | def __init__(self): 305 | super(MultiPeriodDiscriminator, self).__init__() 306 | self.discriminators = nn.ModuleList( 307 | [ 308 | DiscriminatorP(2), 309 | DiscriminatorP(3), 310 | DiscriminatorP(5), 311 | DiscriminatorP(7), 312 | DiscriminatorP(11), 313 | ] 314 | ) 315 | 316 | def forward(self, y, y_hat): 317 | y_d_rs = [] 318 | y_d_gs = [] 319 | fmap_rs = [] 320 | fmap_gs = [] 321 | for i, d in enumerate(self.discriminators): 322 | y_d_r, fmap_r = d(y) 323 | y_d_g, fmap_g = d(y_hat) 324 | y_d_rs.append(y_d_r) 325 | fmap_rs.append(fmap_r) 326 | y_d_gs.append(y_d_g) 327 | fmap_gs.append(fmap_g) 328 | 329 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 330 | 331 | 332 | def phase_loss(phase_r, phase_g, n_fft, frames): 333 | GD_matrix = ( 334 | torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) 335 | - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) 336 | - torch.eye(n_fft // 2 + 1) 337 | ) 338 | GD_matrix = GD_matrix.to(phase_g.device) 339 | 340 | GD_r = torch.matmul(phase_r.permute(0, 2, 1), GD_matrix) 341 | GD_g = torch.matmul(phase_g.permute(0, 2, 1), GD_matrix) 342 | 343 | PTD_matrix = ( 344 | torch.triu(torch.ones(frames, frames), diagonal=1) 345 | - torch.triu(torch.ones(frames, frames), diagonal=2) 346 | - torch.eye(frames) 347 | ) 348 | PTD_matrix = PTD_matrix.to(phase_g.device) 349 | 350 | PTD_r = torch.matmul(phase_r, PTD_matrix) 351 | PTD_g = torch.matmul(phase_g, PTD_matrix) 352 | 353 | IP_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 354 | GD_loss = torch.mean(anti_wrapping_function(GD_r - GD_g)) 355 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r - PTD_g)) 356 | 357 | return IP_loss, GD_loss, PTD_loss 358 | 359 | 360 | class MultiResolutionDiscriminator(nn.Module): 361 | def __init__( 362 | self, 363 | resolutions=((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), 364 | num_embeddings: int = None, 365 | ): 366 | super().__init__() 367 | self.discriminators = nn.ModuleList( 368 | [ 369 | DiscriminatorR(resolution=r, num_embeddings=num_embeddings) 370 | for r in resolutions 371 | ] 372 | ) 373 | 374 | def forward( 375 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 376 | ): 377 | y_d_rs = [] 378 | y_d_gs = [] 379 | fmap_rs = [] 380 | fmap_gs = [] 381 | 382 | for d in self.discriminators: 383 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 384 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 385 | y_d_rs.append(y_d_r) 386 | fmap_rs.append(fmap_r) 387 | y_d_gs.append(y_d_g) 388 | fmap_gs.append(fmap_g) 389 | 390 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 391 | 392 | 393 | class DiscriminatorR(nn.Module): 394 | def __init__( 395 | self, 396 | resolution, 397 | channels: int = 64, 398 | in_channels: int = 1, 399 | num_embeddings: int = None, 400 | lrelu_slope: float = 0.1, 401 | ): 402 | super().__init__() 403 | self.resolution = resolution 404 | self.in_channels = in_channels 405 | self.lrelu_slope = lrelu_slope 406 | self.convs = nn.ModuleList( 407 | [ 408 | weight_norm( 409 | nn.Conv2d( 410 | in_channels, 411 | channels, 412 | kernel_size=(7, 5), 413 | stride=(2, 2), 414 | padding=(3, 2), 415 | ) 416 | ), 417 | weight_norm( 418 | nn.Conv2d( 419 | channels, 420 | channels, 421 | kernel_size=(5, 3), 422 | stride=(2, 1), 423 | padding=(2, 1), 424 | ) 425 | ), 426 | weight_norm( 427 | nn.Conv2d( 428 | channels, 429 | channels, 430 | kernel_size=(5, 3), 431 | stride=(2, 2), 432 | padding=(2, 1), 433 | ) 434 | ), 435 | weight_norm( 436 | nn.Conv2d( 437 | channels, channels, kernel_size=3, stride=(2, 1), padding=1 438 | ) 439 | ), 440 | weight_norm( 441 | nn.Conv2d( 442 | channels, channels, kernel_size=3, stride=(2, 2), padding=1 443 | ) 444 | ), 445 | ] 446 | ) 447 | if num_embeddings is not None: 448 | self.emb = torch.nn.Embedding( 449 | num_embeddings=num_embeddings, embedding_dim=channels 450 | ) 451 | torch.nn.init.zeros_(self.emb.weight) 452 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 453 | 454 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 455 | fmap = [] 456 | x = x.squeeze(1) 457 | 458 | x = self.spectrogram(x) 459 | x = x.unsqueeze(1) 460 | for l in self.convs: 461 | x = l(x) 462 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 463 | fmap.append(x) 464 | if cond_embedding_id is not None: 465 | emb = self.emb(cond_embedding_id) 466 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 467 | else: 468 | h = 0 469 | x = self.conv_post(x) 470 | fmap.append(x) 471 | x += h 472 | x = torch.flatten(x, 1, -1) 473 | 474 | return x, fmap 475 | 476 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 477 | n_fft, hop_length, win_length = self.resolution 478 | magnitude_spectrogram = torch.stft( 479 | x, 480 | n_fft=n_fft, 481 | hop_length=hop_length, 482 | win_length=win_length, 483 | window=None, # interestingly rectangular window kind of works here 484 | center=True, 485 | return_complex=True, 486 | ).abs() 487 | 488 | return magnitude_spectrogram 489 | 490 | 491 | def anti_wrapping_function(x): 492 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 493 | 494 | 495 | def amplitude_loss(log_amplitude_r, log_amplitude_g): 496 | MSELoss = torch.nn.MSELoss() 497 | 498 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g) 499 | 500 | return amplitude_loss 501 | 502 | 503 | def feature_loss(fmap_r, fmap_g): 504 | loss = 0 505 | for dr, dg in zip(fmap_r, fmap_g): 506 | for rl, gl in zip(dr, dg): 507 | loss += torch.mean(torch.abs(rl - gl)) 508 | 509 | return loss 510 | 511 | 512 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 513 | loss = 0 514 | r_losses = [] 515 | g_losses = [] 516 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 517 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 518 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 519 | loss += r_loss + g_loss 520 | r_losses.append(r_loss.item()) 521 | g_losses.append(g_loss.item()) 522 | 523 | return loss, r_losses, g_losses 524 | 525 | 526 | def generator_loss(disc_outputs): 527 | loss = 0 528 | gen_losses = [] 529 | for dg in disc_outputs: 530 | l = torch.mean(torch.clamp(1 - dg, min=0)) 531 | gen_losses.append(l) 532 | loss += l 533 | 534 | return loss, gen_losses 535 | 536 | 537 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g): 538 | C_loss = torch.mean( 539 | torch.mean((rea_r - rea_g) ** 2 + (imag_r - imag_g) ** 2, (1, 2)) 540 | ) 541 | 542 | return C_loss 543 | --------------------------------------------------------------------------------