├── figure ├── 1.png ├── 2.png └── 3.png ├── requirements.txt ├── env.py ├── LICENSE ├── config.json ├── utils.py ├── README.md ├── inference.py ├── dataset.py ├── train.py └── models.py /figure/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/1.png -------------------------------------------------------------------------------- /figure/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/2.png -------------------------------------------------------------------------------- /figure/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/3.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1+cu111 2 | numpy==1.21.6 3 | librosa==0.9.1 4 | tensorboard==2.8.0 5 | soundfile==0.10.3 6 | matplotlib==3.1.3 -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 redmist 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_training_wav_list": "../../datasets/LJ_22050/LJ_train", 3 | "input_validation_wav_list": "../../datasets/LJ_22050/LJ_val", 4 | "test_input_wavs_dir":"../../datasets/LJ_22050/LJ_test", 5 | "test_input_mels_dir":"./", 6 | "test_mel_load": 0, 7 | "test_output_dir":"output", 8 | 9 | "batch_size": 16, 10 | "learning_rate": 0.0002, 11 | "adam_b1": 0.8, 12 | "adam_b2": 0.99, 13 | "lr_decay": 0.999, 14 | "seed": 1234, 15 | "training_epochs": 3100, 16 | "stdout_interval":20, 17 | "checkpoint_interval": 1000, 18 | "summary_interval": 100, 19 | "validation_interval": 250, 20 | "checkpoint_path": "cp_APNet", 21 | "checkpoint_file_load": "cp_APNet/g_01000000", 22 | 23 | "ASP_channel": 512, 24 | "ASP_resblock_kernel_sizes": [3,7,11], 25 | "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 26 | "ASP_input_conv_kernel_size": 7, 27 | "ASP_output_conv_kernel_size": 7, 28 | 29 | "PSP_channel": 512, 30 | "PSP_resblock_kernel_sizes": [3,7,11], 31 | "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 32 | "PSP_input_conv_kernel_size": 7, 33 | "PSP_output_R_conv_kernel_size": 7, 34 | "PSP_output_I_conv_kernel_size": 7, 35 | 36 | "segment_size": 8192, 37 | "num_mels": 80, 38 | "n_fft": 1024, 39 | "hop_size": 256, 40 | "win_size": 1024, 41 | 42 | "sampling_rate": 22050, 43 | 44 | "fmin": 0, 45 | "fmax": 8000, 46 | "meloss":null, 47 | "num_workers": 4 48 | } 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | import shutil 9 | 10 | class AttrDict(dict): 11 | def __init__(self, *args, **kwargs): 12 | super(AttrDict, self).__init__(*args, **kwargs) 13 | self.__dict__ = self 14 | 15 | 16 | def build_env(config, config_name, path): 17 | t_path = os.path.join(path, config_name) 18 | if config != t_path: 19 | os.makedirs(path, exist_ok=True) 20 | shutil.copyfile(config, os.path.join(path, config_name)) 21 | 22 | def plot_spectrogram(spectrogram): 23 | fig, ax = plt.subplots(figsize=(10, 2)) 24 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 25 | interpolation='none') 26 | plt.colorbar(im, ax=ax) 27 | 28 | fig.canvas.draw() 29 | plt.close() 30 | 31 | return fig 32 | 33 | 34 | def init_weights(m, mean=0.0, std=0.01): 35 | classname = m.__class__.__name__ 36 | if classname.find("Conv") != -1: 37 | m.weight.data.normal_(mean, std) 38 | 39 | 40 | def apply_weight_norm(m): 41 | classname = m.__class__.__name__ 42 | if classname.find("Conv") != -1: 43 | weight_norm(m) 44 | 45 | 46 | def get_padding(kernel_size, dilation=1): 47 | return int((kernel_size*dilation - dilation)/2) 48 | 49 | 50 | def load_checkpoint(filepath, device): 51 | assert os.path.isfile(filepath) 52 | print("Loading '{}'".format(filepath)) 53 | checkpoint_dict = torch.load(filepath, map_location=device) 54 | print("Complete.") 55 | return checkpoint_dict 56 | 57 | 58 | def save_checkpoint(filepath, obj): 59 | print("Saving checkpoint to {}".format(filepath)) 60 | torch.save(obj, filepath) 61 | print("Complete.") 62 | 63 | 64 | def scan_checkpoint(cp_dir, prefix): 65 | pattern = os.path.join(cp_dir, prefix + '????????') 66 | cp_list = glob.glob(pattern) 67 | if len(cp_list) == 0: 68 | return None 69 | return sorted(cp_list)[-1] 70 | 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra 2 | ### Hui-Peng Du, Ye-Xin Lu, Yang Ai, Zhen-Hua Ling 3 | In our [paper](https://arxiv.org/pdf/2311.11545.pdf), we proposed APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra.
4 | We provide our implementation as open source in this repository. 5 | 6 | **Abstract:** 7 | In our previous work, we proposed a neural vocoder called APNet, which directly predicts speech amplitude and phase spectra with a 5 ms frame shift in parallel from the input acoustic features, and then reconstructs the 16 kHz speech waveform using inverse short-time Fourier transform (ISTFT). 8 | APNet demonstrates the capability to generate synthesized speech of comparable quality to the HiFi-GAN vocoder but with a considerably improved inference speed. 9 | However, the performance of the APNet vocoder is constrained by the waveform sampling rate and spectral frame shift, limiting its practicality for high-quality speech synthesis. 10 | Therefore, this paper proposes an improved iteration of APNet, named APNet2. 11 | The proposed APNet2 vocoder adopts ConvNeXt v2 as the backbone network for amplitude and phase predictions, expecting to enhance the modeling capability. 12 | Additionally, we introduce a multi-resolution discriminator (MRD) into the GAN-based losses and optimize the form of certain losses. 13 | At a common configuration with a waveform sampling rate of 22.05 kHz and spectral frame shift of 256 points (i.e., approximately 11.6ms), our proposed APNet2 vocoder outperformed the original APNet and Vocos vocoders in terms of synthesized speech quality. 14 | The synthesized speech quality of APNet2 is also comparable to that of HiFi-GAN and iSTFTNet, while offering a significantly faster inference speed. 15 | 16 | Audio samples can be found [here](https://redmist328.github.io/APNet2_demo/).
17 | 18 | ## Requirements 19 | Follow this [txt](https://github.com/redmist328/APNet2/blob/main/requirements.txt). 20 | 21 | ## Training 22 | ``` 23 | python train.py 24 | ``` 25 | Checkpoints and copy of the configuration file are saved in the `cp_APNet` directory by default.
26 | You can modify the training and inference configuration by modifying the parameters in the [config.json](https://github.com/redmist328/APNet2/blob/main/config.json). 27 | ## Inference 28 | You can download pretrained model on LJSpeech dataset at [here](http://home.ustc.edu.cn/~redmist/APNet2/). 29 | ``` 30 | python inference.py 31 | ``` 32 | 33 | ## Model Structure 34 | ![model](./figure/2.png) 35 | 36 | ## Comparison with other models 37 | ![comparison](./figure/3.png) 38 | 39 | ## Acknowledgements 40 | We referred to [HiFiGAN](https://github.com/jik876/hifi-gan), [NSPP](https://github.com/YangAi520/NSPP), [APNet](https://github.com/YangAi520/APNet) 41 | and [Vocos](https://github.com/charactr-platform/vocos) to implement this. 42 | 43 | ## Citation 44 | ``` 45 | @article{du2023apnet2, 46 | title={APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra}, 47 | author={Du, Hui-Peng and Lu, Ye-Xin and Ai, Yang and Ling, Zhen-Hua}, 48 | journal={arXiv preprint arXiv:2311.11545}, 49 | year={2023} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | import torch 8 | from utils import AttrDict 9 | from dataset import mel_spectrogram, load_wav 10 | from models import Generator 11 | import soundfile as sf 12 | import librosa 13 | import numpy as np 14 | import time 15 | h = None 16 | device = None 17 | 18 | 19 | def load_checkpoint(filepath, device): 20 | assert os.path.isfile(filepath) 21 | print("Loading '{}'".format(filepath)) 22 | checkpoint_dict = torch.load(filepath, map_location=device) 23 | print("Complete.") 24 | return checkpoint_dict 25 | 26 | 27 | def get_mel(x): 28 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 29 | 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + '*') 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return '' 36 | return sorted(cp_list)[-1] 37 | 38 | 39 | def inference(h): 40 | generator = Generator(h).to(device) 41 | 42 | state_dict_g = load_checkpoint(h.checkpoint_file_load, device) 43 | generator.load_state_dict(state_dict_g['generator']) 44 | 45 | filelist = sorted(os.listdir(h.test_input_mels_dir if h.test_mel_load else h.test_input_wavs_dir)) 46 | 47 | os.makedirs(h.test_output_dir, exist_ok=True) 48 | 49 | generator.eval() 50 | l=0 51 | with torch.no_grad(): 52 | starttime = time.time() 53 | for i, filename in enumerate(filelist): 54 | 55 | # if h.test_mel_load: 56 | if 1: 57 | mel = np.load(os.path.join(h.test_input_wavs_dir, filename)) 58 | x = torch.FloatTensor(mel).to(device) 59 | x=x.transpose(1,2) 60 | else: 61 | raw_wav, _ = librosa.load(os.path.join(h.test_input_wavs_dir, filename), sr=h.sampling_rate, mono=True) 62 | raw_wav = torch.FloatTensor(raw_wav).to(device) 63 | x = get_mel(raw_wav.unsqueeze(0)) 64 | 65 | logamp_g, pha_g, _, _, y_g = generator(x) 66 | audio = y_g.squeeze() 67 | # logamp = logamp_g.squeeze() 68 | # pha = pha_g.squeeze() 69 | audio = audio.cpu().numpy() 70 | # logamp = logamp.cpu().numpy() 71 | # pha = pha.cpu().numpy() 72 | audiolen=len(audio) 73 | sf.write(os.path.join(h.test_output_dir, filename.split('.')[0]+'.wav'), audio, h.sampling_rate,'PCM_16') 74 | 75 | # print(pp) 76 | l+=audiolen 77 | 78 | # write(output_file, h.sampling_rate, audio) 79 | # print(output_file) 80 | end=time.time() 81 | print(end-starttime) 82 | print(l/22050) 83 | print(l/22050/(end-starttime)) 84 | 85 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_logamp.npy'), logamp) 86 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_pha.npy'), pha) 87 | # if i==9: 88 | # break 89 | 90 | def main(): 91 | print('Initializing Inference Process..') 92 | 93 | config_file = 'config.json' 94 | 95 | with open(config_file) as f: 96 | data = f.read() 97 | 98 | global h 99 | json_config = json.loads(data) 100 | h = AttrDict(json_config) 101 | 102 | torch.manual_seed(h.seed) 103 | global device 104 | if torch.cuda.is_available(): 105 | torch.cuda.manual_seed(h.seed) 106 | device = torch.device('cuda') 107 | else: 108 | device = torch.device('cpu') 109 | device = torch.device('cpu') 110 | inference(h) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | 116 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from librosa.filters import mel as librosa_mel_fn 9 | import librosa 10 | import torchaudio 11 | import torch.nn as nn 12 | 13 | def load_wav(full_path, sample_rate): 14 | data, _ = librosa.load(full_path, sr=sample_rate, mono=True) 15 | return data 16 | 17 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 18 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 19 | 20 | def dynamic_range_decompression(x, C=1): 21 | return np.exp(x) / C 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | def dynamic_range_decompression_torch(x, C=1): 27 | return torch.exp(x) / C 28 | 29 | def spectral_normalize_torch(magnitudes): 30 | output = dynamic_range_compression_torch(magnitudes) 31 | return output 32 | 33 | def spectral_de_normalize_torch(magnitudes): 34 | output = dynamic_range_decompression_torch(magnitudes) 35 | return output 36 | 37 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=True): 38 | 39 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 40 | mel_basis = torch.from_numpy(mel).float().to(y.device) 41 | hann_window = torch.hann_window(win_size).to(y.device) 42 | 43 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=True) 44 | 45 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 46 | 47 | spec = torch.matmul(mel_basis, spec) 48 | spec = spectral_normalize_torch(spec) 49 | 50 | return spec #[batch_size,n_fft/2+1,frames] 51 | 52 | def amp_pha_specturm(y, n_fft, hop_size, win_size): 53 | 54 | hann_window=torch.hann_window(win_size).to(y.device) 55 | 56 | stft_spec=torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,center=True) #[batch_size, n_fft//2+1, frames, 2] 57 | 58 | rea=stft_spec[:,:,:,0] #[batch_size, n_fft//2+1, frames] 59 | imag=stft_spec[:,:,:,1] #[batch_size, n_fft//2+1, frames] 60 | 61 | log_amplitude=torch.log(torch.abs(torch.sqrt(torch.pow(rea,2)+torch.pow(imag,2)))+1e-5) #[batch_size, n_fft//2+1, frames] 62 | phase=torch.atan2(imag,rea) #[batch_size, n_fft//2+1, frames] 63 | 64 | return log_amplitude, phase, rea, imag 65 | 66 | def get_dataset_filelist(input_training_wav_list,input_validation_wav_list): 67 | training_files=[] 68 | filelist=os.listdir(input_training_wav_list) 69 | for files in filelist: 70 | 71 | src=os.path.join(input_training_wav_list,files) 72 | training_files.append(src) 73 | 74 | validation_files=[] 75 | filelist=os.listdir(input_validation_wav_list) 76 | for files in filelist: 77 | src=os.path.join(input_validation_wav_list,files) 78 | validation_files.append(src) 79 | 80 | return training_files, validation_files 81 | 82 | 83 | class Dataset(torch.utils.data.Dataset): 84 | def __init__(self, training_files, segment_size, n_fft, num_mels, 85 | hop_size, win_size, sampling_rate, fmin, fmax,meloss, split=True, shuffle=True, n_cache_reuse=1, 86 | device=None): 87 | self.audio_files = training_files 88 | random.seed(1234) 89 | if shuffle: 90 | random.shuffle(self.audio_files) 91 | self.segment_size = segment_size 92 | self.sampling_rate = sampling_rate 93 | self.split = split 94 | self.n_fft = n_fft 95 | self.num_mels = num_mels 96 | self.hop_size = hop_size 97 | self.win_size = win_size 98 | self.fmin = fmin 99 | self.fmax = fmax 100 | self.cached_wav = None 101 | self.n_cache_reuse = n_cache_reuse 102 | self._cache_ref_count = 0 103 | self.device = device 104 | self.meloss=meloss 105 | 106 | def __getitem__(self, index): 107 | filename = self.audio_files[index] 108 | if self._cache_ref_count == 0: 109 | audio = load_wav(filename, self.sampling_rate) 110 | self.cached_wav = audio 111 | self._cache_ref_count = self.n_cache_reuse 112 | else: 113 | audio = self.cached_wav 114 | self._cache_ref_count -= 1 115 | 116 | audio = torch.FloatTensor(audio) #[T] 117 | audio = audio.unsqueeze(0) #[1,T] 118 | 119 | if self.split: 120 | if audio.size(1) >= self.segment_size: 121 | max_audio_start = audio.size(1) - self.segment_size 122 | audio_start = random.randint(0, max_audio_start) 123 | audio = audio[:, audio_start: audio_start + self.segment_size] #[1,T] 124 | else: 125 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 126 | 127 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 128 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 129 | center=True) 130 | meloss1 = mel_spectrogram(audio, self.n_fft, self.num_mels, 131 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.meloss, 132 | center=True) 133 | log_amplitude, phase, rea, imag = amp_pha_specturm(audio, self.n_fft, self.hop_size, self.win_size) #[1,n_fft/2+1,frames] 134 | 135 | 136 | return (mel.squeeze(), log_amplitude.squeeze(), phase.squeeze(), rea.squeeze(), imag.squeeze(), audio.squeeze(0),meloss1.squeeze()) 137 | 138 | def __len__(self): 139 | return len(self.audio_files) 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import itertools 4 | import os 5 | import time 6 | import argparse 7 | import json 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.utils.data import DistributedSampler, DataLoader 12 | import torch.multiprocessing as mp 13 | from torch.distributed import init_process_group 14 | from torch.nn.parallel import DistributedDataParallel 15 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist 16 | from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\ 17 | discriminator_loss, amplitude_loss, phase_loss, STFT_consistency_loss,MultiResolutionDiscriminator 18 | from utils import AttrDict, build_env, plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 19 | 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def train(h): 24 | 25 | torch.cuda.manual_seed(h.seed) 26 | device = torch.device('cuda:{:d}'.format(0)) 27 | 28 | generator = Generator(h).to(device) 29 | mpd = MultiPeriodDiscriminator().to(device) 30 | mrd = MultiResolutionDiscriminator().to(device) 31 | 32 | print(generator) 33 | os.makedirs(h.checkpoint_path, exist_ok=True) 34 | print("checkpoints directory : ", h.checkpoint_path) 35 | 36 | if os.path.isdir(h.checkpoint_path): 37 | cp_g = scan_checkpoint(h.checkpoint_path, 'g_') 38 | cp_do = scan_checkpoint(h.checkpoint_path, 'do_') 39 | 40 | steps = 0 41 | if cp_g is None or cp_do is None: 42 | state_dict_do = None 43 | last_epoch = -1 44 | else: 45 | state_dict_g = load_checkpoint(cp_g, device) 46 | state_dict_do = load_checkpoint(cp_do, device) 47 | generator.load_state_dict(state_dict_g['generator']) 48 | mpd.load_state_dict(state_dict_do['mpd']) 49 | mrd.load_state_dict(state_dict_do['mrd']) 50 | steps = state_dict_do['steps'] + 1 51 | last_epoch = state_dict_do['epoch'] 52 | 53 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 54 | optim_d = torch.optim.AdamW(itertools.chain(mrd.parameters(), mpd.parameters()), 55 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 56 | 57 | if state_dict_do is not None: 58 | optim_g.load_state_dict(state_dict_do['optim_g']) 59 | optim_d.load_state_dict(state_dict_do['optim_d']) 60 | 61 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 62 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 63 | 64 | training_filelist, validation_filelist = get_dataset_filelist(h.input_training_wav_list, h.input_validation_wav_list) 65 | 66 | trainset = Dataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, 67 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, h.meloss,n_cache_reuse=0, 68 | shuffle=True, device=device) 69 | 70 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 71 | sampler=None, 72 | batch_size=h.batch_size, 73 | pin_memory=True, 74 | drop_last=True) 75 | 76 | validset = Dataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, 77 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax,h.meloss, False, False, n_cache_reuse=0, 78 | device=device) 79 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 80 | sampler=None, 81 | batch_size=1, 82 | pin_memory=True, 83 | drop_last=True) 84 | 85 | sw = SummaryWriter(os.path.join(h.checkpoint_path, 'logs')) 86 | 87 | generator.train() 88 | mpd.train() 89 | mrd.train() 90 | 91 | for epoch in range(max(0, last_epoch), h.training_epochs): 92 | 93 | start = time.time() 94 | print("Epoch: {}".format(epoch+1)) 95 | 96 | for i, batch in enumerate(train_loader): 97 | start_b = time.time() 98 | x, logamp, pha, rea, imag, y,meloss = batch 99 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 100 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 101 | logamp = torch.autograd.Variable(logamp.to(device, non_blocking=True)) 102 | pha = torch.autograd.Variable(pha.to(device, non_blocking=True)) 103 | rea = torch.autograd.Variable(rea.to(device, non_blocking=True)) 104 | imag = torch.autograd.Variable(imag.to(device, non_blocking=True)) 105 | y = y.unsqueeze(1) 106 | meloss = torch.autograd.Variable(meloss.to(device, non_blocking=True)) 107 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x) 108 | y_g_mel = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, 109 | h.fmin, h.meloss) 110 | 111 | optim_d.zero_grad() 112 | 113 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach()) 114 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 115 | 116 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach()) 117 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 118 | 119 | L_D = loss_disc_s*0.1 + loss_disc_f 120 | 121 | L_D.backward() 122 | optim_d.step() 123 | 124 | # Generator 125 | optim_g.zero_grad() 126 | 127 | # Losses defined on log amplitude spectra 128 | L_A = amplitude_loss(logamp, logamp_g) 129 | 130 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 131 | # Losses defined on phase spectra 132 | L_P = L_IP + L_GD + L_PTD 133 | 134 | _, _, rea_g_final, imag_g_final = amp_pha_specturm(y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size) 135 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final) 136 | L_R = F.l1_loss(rea, rea_g) 137 | L_I = F.l1_loss(imag, imag_g) 138 | # Losses defined on reconstructed STFT spectra 139 | L_S = L_C + 2.25 * (L_R + L_I) 140 | 141 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g) 142 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g) 143 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 144 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 145 | loss_gen_f, losses_gen_f = generator_loss(y_df_g) 146 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g) 147 | L_GAN_G = loss_gen_s *0.1+ loss_gen_f 148 | L_FM = loss_fm_s *0.1+ loss_fm_f 149 | L_Mel = F.l1_loss(meloss, y_g_mel) 150 | # Losses defined on final waveforms 151 | L_W = L_GAN_G + L_FM + 45 * L_Mel 152 | 153 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W 154 | 155 | L_G.backward() 156 | optim_g.step() 157 | 158 | # STDOUT logging 159 | if steps % h.stdout_interval == 0: 160 | with torch.no_grad(): 161 | A_error = amplitude_loss(logamp, logamp_g).item() 162 | IP_error, GD_error, PTD_error = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 163 | IP_error = IP_error.item() 164 | GD_error = GD_error.item() 165 | PTD_error = PTD_error.item() 166 | C_error = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final).item() 167 | R_error = F.l1_loss(rea, rea_g).item() 168 | I_error = F.l1_loss(imag, imag_g).item() 169 | Mel_error = F.l1_loss(x, y_g_mel).item() 170 | 171 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}'. 172 | format(steps, L_G, A_error, IP_error, GD_error, PTD_error, C_error, R_error, I_error, Mel_error, time.time() - start_b)) 173 | 174 | # checkpointing 175 | if steps % h.checkpoint_interval == 0 and steps != 0: 176 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps) 177 | save_checkpoint(checkpoint_path, 178 | {'generator': generator.state_dict()}) 179 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps) 180 | save_checkpoint(checkpoint_path, 181 | {'mpd': mpd.state_dict(), 182 | 'mrd': mrd.state_dict(), 183 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 184 | 'epoch': epoch}) 185 | 186 | # Tensorboard summary logging 187 | if steps % h.summary_interval == 0: 188 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps) 189 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps) 190 | 191 | # Validation 192 | if steps % h.validation_interval == 0: # and steps != 0: 193 | generator.eval() 194 | torch.cuda.empty_cache() 195 | val_A_err_tot = 0 196 | val_IP_err_tot = 0 197 | val_GD_err_tot = 0 198 | val_PTD_err_tot = 0 199 | val_C_err_tot = 0 200 | val_R_err_tot = 0 201 | val_I_err_tot = 0 202 | val_Mel_err_tot = 0 203 | with torch.no_grad(): 204 | for j, batch in enumerate(validation_loader): 205 | x, logamp, pha, rea, imag, y ,meloss= batch 206 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x.to(device)) 207 | mel = x 208 | mel = torch.autograd.Variable(mel.to(device, non_blocking=True)) 209 | logamp = torch.autograd.Variable(logamp.to(device, non_blocking=True)) 210 | pha = torch.autograd.Variable(pha.to(device, non_blocking=True)) 211 | rea = torch.autograd.Variable(rea.to(device, non_blocking=True)) 212 | imag = torch.autograd.Variable(imag.to(device, non_blocking=True)) 213 | meloss = torch.autograd.Variable(meloss.to(device, non_blocking=True)) 214 | y_g_mel = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,h.hop_size, h.win_size,h.fmin, h.meloss) 215 | 216 | _, _, rea_g_final, imag_g_final = amp_pha_specturm(y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size) 217 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item() 218 | val_IP_err, val_GD_err, val_PTD_err = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1]) 219 | val_IP_err_tot += val_IP_err.item() 220 | val_GD_err_tot += val_GD_err.item() 221 | val_PTD_err_tot += val_PTD_err.item() 222 | val_C_err_tot += STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final).item() 223 | val_R_err_tot += F.l1_loss(rea, rea_g).item() 224 | val_I_err_tot += F.l1_loss(imag, imag_g).item() 225 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item() 226 | 227 | # if j <= 4: 228 | # if steps == 0: 229 | # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 230 | # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) 231 | 232 | # sw.add_audio('generated/y_g_{}'.format(j), y_g[0], steps, h.sampling_rate) 233 | # y_g_spec = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels, 234 | # h.sampling_rate, h.hop_size, h.win_size, 235 | # h.fmin, h.fmax) 236 | # sw.add_figure('generated/y_g_spec_{}'.format(j), 237 | # plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), steps) 238 | 239 | val_A_err = val_A_err_tot / (j+1) 240 | val_IP_err = val_IP_err_tot / (j+1) 241 | val_GD_err = val_GD_err_tot / (j+1) 242 | val_PTD_err = val_PTD_err_tot / (j+1) 243 | val_C_err = val_C_err_tot / (j+1) 244 | val_R_err = val_R_err_tot / (j+1) 245 | val_I_err = val_I_err_tot / (j+1) 246 | val_Mel_err = val_Mel_err_tot / (j+1) 247 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps) 248 | sw.add_scalar("Validation/Instantaneous_Phase_Loss", val_IP_err, steps) 249 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps) 250 | sw.add_scalar("Validation/Phase_Time_Difference_Loss", val_PTD_err, steps) 251 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps) 252 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps) 253 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps) 254 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps) 255 | 256 | generator.train() 257 | 258 | steps += 1 259 | 260 | scheduler_g.step() 261 | scheduler_d.step() 262 | 263 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 264 | 265 | 266 | def main(): 267 | print('Initializing Training Process..') 268 | 269 | config_file = 'config.json' 270 | 271 | with open(config_file) as f: 272 | data = f.read() 273 | 274 | json_config = json.loads(data) 275 | h = AttrDict(json_config) 276 | build_env(config_file, 'config.json', h.checkpoint_path) 277 | 278 | torch.manual_seed(h.seed) 279 | if torch.cuda.is_available(): 280 | torch.cuda.manual_seed(h.seed) 281 | else: 282 | pass 283 | 284 | train(h) 285 | 286 | 287 | if __name__ == '__main__': 288 | main() 289 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | import numpy as np 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class GRN(nn.Module): 12 | """ GRN (Global Response Normalization) layer 13 | """ 14 | def __init__(self, dim): 15 | super().__init__() 16 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 17 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 18 | 19 | def forward(self, x): 20 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 21 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 22 | return self.gamma * (x * Nx) + self.beta + x 23 | 24 | class ConvNeXtBlock(nn.Module): 25 | def __init__( 26 | self, 27 | dim: int, 28 | intermediate_dim: int, 29 | layer_scale_init_value= None, 30 | adanorm_num_embeddings = None, 31 | ): 32 | super().__init__() 33 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 34 | self.adanorm = adanorm_num_embeddings is not None 35 | 36 | self.norm = nn.LayerNorm(dim, eps=1e-6) 37 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 38 | self.act = nn.GELU() 39 | self.grn = GRN(intermediate_dim) 40 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 41 | 42 | def forward(self, x, cond_embedding_id = None) : 43 | residual = x 44 | x = self.dwconv(x) 45 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 46 | if self.adanorm: 47 | assert cond_embedding_id is not None 48 | x = self.norm(x, cond_embedding_id) 49 | else: 50 | x = self.norm(x) 51 | x = self.pwconv1(x) 52 | x = self.act(x) 53 | x = self.grn(x) 54 | x = self.pwconv2(x) 55 | 56 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 57 | 58 | x = residual + x 59 | return x 60 | class Generator(torch.nn.Module): 61 | def __init__(self, h): 62 | super(Generator, self).__init__() 63 | self.h = h 64 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes) 65 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes) 66 | 67 | self.ASP_input_conv = Conv1d(h.num_mels, h.ASP_channel, h.ASP_input_conv_kernel_size, 1, 68 | padding=get_padding(h.ASP_input_conv_kernel_size, 1)) 69 | self.PSP_input_conv = Conv1d(h.num_mels, h.PSP_channel, h.PSP_input_conv_kernel_size, 1, 70 | padding=get_padding(h.PSP_input_conv_kernel_size, 1)) 71 | 72 | self.ASP_output_conv = Conv1d(h.ASP_channel, h.n_fft//2+1, h.ASP_output_conv_kernel_size, 1, 73 | padding=get_padding(h.ASP_output_conv_kernel_size, 1)) 74 | self.PSP_output_R_conv = Conv1d(512, h.n_fft//2+1, h.PSP_output_R_conv_kernel_size, 1, 75 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1)) 76 | self.PSP_output_I_conv = Conv1d(512, h.n_fft//2+1, h.PSP_output_I_conv_kernel_size, 1, 77 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1)) 78 | 79 | self.dim=512 80 | self.num_layers=8 81 | self.adanorm_num_embeddings=None 82 | self.intermediate_dim=1536 83 | self.norm = nn.LayerNorm(self.dim, eps=1e-6) 84 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6) 85 | layer_scale_init_value = 1 / self.num_layers 86 | self.convnext = nn.ModuleList( 87 | [ 88 | ConvNeXtBlock( 89 | dim=self.dim, 90 | intermediate_dim=self.intermediate_dim, 91 | layer_scale_init_value=layer_scale_init_value, 92 | adanorm_num_embeddings=self.adanorm_num_embeddings, 93 | ) 94 | for _ in range(self.num_layers) 95 | ] 96 | ) 97 | self.convnext2 = nn.ModuleList( 98 | [ 99 | ConvNeXtBlock( 100 | dim=self.dim, 101 | intermediate_dim=self.intermediate_dim, 102 | layer_scale_init_value=layer_scale_init_value, 103 | adanorm_num_embeddings=self.adanorm_num_embeddings, 104 | ) 105 | for _ in range(self.num_layers) 106 | ] 107 | ) 108 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6) 109 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6) 110 | self.apply(self._init_weights) 111 | 112 | def _init_weights(self, m): 113 | if isinstance(m, (nn.Conv1d, nn.Linear)): 114 | nn.init.trunc_normal_(m.weight, std=0.02) 115 | nn.init.constant_(m.bias, 0) 116 | 117 | def forward(self, mel): 118 | 119 | logamp = self.ASP_input_conv(mel) 120 | logamp = self.norm2(logamp.transpose(1, 2)) 121 | logamp = logamp.transpose(1, 2) 122 | for conv_block in self.convnext2: 123 | logamp = conv_block(logamp, cond_embedding_id=None) 124 | logamp = self.final_layer_norm2(logamp.transpose(1, 2)) 125 | logamp = logamp.transpose(1, 2) 126 | logamp = self.ASP_output_conv(logamp) 127 | 128 | 129 | pha = self.PSP_input_conv(mel) 130 | pha = self.norm(pha.transpose(1, 2)) 131 | pha = pha.transpose(1, 2) 132 | for conv_block in self.convnext: 133 | pha = conv_block(pha, cond_embedding_id=None) 134 | pha = self.final_layer_norm(pha.transpose(1, 2)) 135 | pha = pha.transpose(1, 2) 136 | R = self.PSP_output_R_conv(pha) 137 | I = self.PSP_output_I_conv(pha) 138 | 139 | pha = torch.atan2(I,R) 140 | 141 | rea = torch.exp(logamp)*torch.cos(pha) 142 | imag = torch.exp(logamp)*torch.sin(pha) 143 | 144 | spec = torch.cat((rea.unsqueeze(-1),imag.unsqueeze(-1)),-1) 145 | 146 | audio = torch.istft(spec, self.h.n_fft, hop_length=self.h.hop_size, win_length=self.h.win_size, window=torch.hann_window(self.h.win_size).to(mel.device), center=True) 147 | 148 | return logamp, pha, rea, imag, audio.unsqueeze(1) 149 | 150 | class DiscriminatorP(torch.nn.Module): 151 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 152 | super(DiscriminatorP, self).__init__() 153 | self.period = period 154 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 155 | self.convs = nn.ModuleList([ 156 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 157 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 158 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 159 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 160 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 161 | ]) 162 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 163 | 164 | def forward(self, x): 165 | fmap = [] 166 | 167 | # 1d to 2d 168 | b, c, t = x.shape 169 | if t % self.period != 0: # pad first 170 | n_pad = self.period - (t % self.period) 171 | x = F.pad(x, (0, n_pad), "reflect") 172 | t = t + n_pad 173 | x = x.view(b, c, t // self.period, self.period) 174 | 175 | for l in self.convs: 176 | x = l(x) 177 | x = F.leaky_relu(x, LRELU_SLOPE) 178 | fmap.append(x) 179 | x = self.conv_post(x) 180 | fmap.append(x) 181 | x = torch.flatten(x, 1, -1) 182 | 183 | return x, fmap 184 | 185 | 186 | class MultiPeriodDiscriminator(torch.nn.Module): 187 | def __init__(self): 188 | super(MultiPeriodDiscriminator, self).__init__() 189 | self.discriminators = nn.ModuleList([ 190 | DiscriminatorP(2), 191 | DiscriminatorP(3), 192 | DiscriminatorP(5), 193 | DiscriminatorP(7), 194 | DiscriminatorP(11), 195 | ]) 196 | 197 | def forward(self, y, y_hat): 198 | y_d_rs = [] 199 | y_d_gs = [] 200 | fmap_rs = [] 201 | fmap_gs = [] 202 | for i, d in enumerate(self.discriminators): 203 | y_d_r, fmap_r = d(y) 204 | y_d_g, fmap_g = d(y_hat) 205 | y_d_rs.append(y_d_r) 206 | fmap_rs.append(fmap_r) 207 | y_d_gs.append(y_d_g) 208 | fmap_gs.append(fmap_g) 209 | 210 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 211 | 212 | def phase_loss(phase_r, phase_g, n_fft, frames): 213 | 214 | MSELoss = torch.nn.MSELoss() 215 | 216 | GD_matrix = torch.triu(torch.ones(n_fft//2+1,n_fft//2+1),diagonal=1)-torch.triu(torch.ones(n_fft//2+1,n_fft//2+1),diagonal=2)-torch.eye(n_fft//2+1) 217 | GD_matrix = GD_matrix.to(phase_g.device) 218 | 219 | GD_r = torch.matmul(phase_r.permute(0,2,1), GD_matrix) 220 | GD_g = torch.matmul(phase_g.permute(0,2,1), GD_matrix) 221 | 222 | PTD_matrix = torch.triu(torch.ones(frames,frames),diagonal=1)-torch.triu(torch.ones(frames,frames),diagonal=2)-torch.eye(frames) 223 | PTD_matrix = PTD_matrix.to(phase_g.device) 224 | 225 | PTD_r = torch.matmul(phase_r, PTD_matrix) 226 | PTD_g = torch.matmul(phase_g, PTD_matrix) 227 | 228 | IP_loss = torch.mean(anti_wrapping_function(phase_r-phase_g)) 229 | GD_loss = torch.mean(anti_wrapping_function(GD_r-GD_g)) 230 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r-PTD_g)) 231 | 232 | 233 | return IP_loss, GD_loss, PTD_loss 234 | class MultiResolutionDiscriminator(nn.Module): 235 | def __init__( 236 | self, 237 | resolutions= ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), 238 | num_embeddings: int = None, 239 | ): 240 | super().__init__() 241 | self.discriminators = nn.ModuleList( 242 | [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] 243 | ) 244 | 245 | def forward( 246 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 247 | ) : 248 | y_d_rs = [] 249 | y_d_gs = [] 250 | fmap_rs = [] 251 | fmap_gs = [] 252 | 253 | for d in self.discriminators: 254 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 255 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 256 | y_d_rs.append(y_d_r) 257 | fmap_rs.append(fmap_r) 258 | y_d_gs.append(y_d_g) 259 | fmap_gs.append(fmap_g) 260 | 261 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 262 | 263 | 264 | class DiscriminatorR(nn.Module): 265 | def __init__( 266 | self, 267 | resolution, 268 | channels: int = 64, 269 | in_channels: int = 1, 270 | num_embeddings: int = None, 271 | lrelu_slope: float = 0.1, 272 | ): 273 | super().__init__() 274 | self.resolution = resolution 275 | self.in_channels = in_channels 276 | self.lrelu_slope = lrelu_slope 277 | self.convs = nn.ModuleList( 278 | [ 279 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), 280 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), 281 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), 282 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), 283 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), 284 | ] 285 | ) 286 | if num_embeddings is not None: 287 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 288 | torch.nn.init.zeros_(self.emb.weight) 289 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 290 | 291 | def forward( 292 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None) : 293 | fmap = [] 294 | x=x.squeeze(1) 295 | 296 | x = self.spectrogram(x) 297 | x = x.unsqueeze(1) 298 | for l in self.convs: 299 | x = l(x) 300 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 301 | fmap.append(x) 302 | if cond_embedding_id is not None: 303 | emb = self.emb(cond_embedding_id) 304 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 305 | else: 306 | h = 0 307 | x = self.conv_post(x) 308 | fmap.append(x) 309 | x += h 310 | x = torch.flatten(x, 1, -1) 311 | 312 | return x, fmap 313 | 314 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 315 | n_fft, hop_length, win_length = self.resolution 316 | magnitude_spectrogram = torch.stft( 317 | x, 318 | n_fft=n_fft, 319 | hop_length=hop_length, 320 | win_length=win_length, 321 | window=None, # interestingly rectangular window kind of works here 322 | center=True, 323 | return_complex=True, 324 | ).abs() 325 | 326 | return magnitude_spectrogram 327 | 328 | def anti_wrapping_function(x): 329 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 330 | 331 | def amplitude_loss(log_amplitude_r, log_amplitude_g): 332 | 333 | MSELoss = torch.nn.MSELoss() 334 | 335 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g) 336 | 337 | return amplitude_loss 338 | 339 | 340 | def feature_loss(fmap_r, fmap_g): 341 | loss = 0 342 | for dr, dg in zip(fmap_r, fmap_g): 343 | for rl, gl in zip(dr, dg): 344 | loss += torch.mean(torch.abs(rl - gl)) 345 | 346 | return loss 347 | 348 | 349 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 350 | loss = 0 351 | r_losses = [] 352 | g_losses = [] 353 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 354 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 355 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 356 | loss += r_loss + g_loss 357 | r_losses.append(r_loss.item()) 358 | g_losses.append(g_loss.item()) 359 | 360 | return loss, r_losses, g_losses 361 | 362 | 363 | def generator_loss(disc_outputs): 364 | loss = 0 365 | gen_losses = [] 366 | for dg in disc_outputs: 367 | l = torch.mean(torch.clamp(1 - dg, min=0)) 368 | gen_losses.append(l) 369 | loss += l 370 | 371 | return loss, gen_losses 372 | 373 | 374 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g): 375 | 376 | C_loss=torch.mean(torch.mean((rea_r-rea_g)**2+(imag_r-imag_g)**2,(1,2))) 377 | 378 | return C_loss 379 | 380 | --------------------------------------------------------------------------------