├── figure
├── 1.png
├── 2.png
└── 3.png
├── requirements.txt
├── env.py
├── LICENSE
├── config.json
├── utils.py
├── README.md
├── inference.py
├── dataset.py
├── train.py
└── models.py
/figure/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/1.png
--------------------------------------------------------------------------------
/figure/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/2.png
--------------------------------------------------------------------------------
/figure/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redmist328/APNet2/HEAD/figure/3.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1+cu111
2 | numpy==1.21.6
3 | librosa==0.9.1
4 | tensorboard==2.8.0
5 | soundfile==0.10.3
6 | matplotlib==3.1.3
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 redmist
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "input_training_wav_list": "../../datasets/LJ_22050/LJ_train",
3 | "input_validation_wav_list": "../../datasets/LJ_22050/LJ_val",
4 | "test_input_wavs_dir":"../../datasets/LJ_22050/LJ_test",
5 | "test_input_mels_dir":"./",
6 | "test_mel_load": 0,
7 | "test_output_dir":"output",
8 |
9 | "batch_size": 16,
10 | "learning_rate": 0.0002,
11 | "adam_b1": 0.8,
12 | "adam_b2": 0.99,
13 | "lr_decay": 0.999,
14 | "seed": 1234,
15 | "training_epochs": 3100,
16 | "stdout_interval":20,
17 | "checkpoint_interval": 1000,
18 | "summary_interval": 100,
19 | "validation_interval": 250,
20 | "checkpoint_path": "cp_APNet",
21 | "checkpoint_file_load": "cp_APNet/g_01000000",
22 |
23 | "ASP_channel": 512,
24 | "ASP_resblock_kernel_sizes": [3,7,11],
25 | "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
26 | "ASP_input_conv_kernel_size": 7,
27 | "ASP_output_conv_kernel_size": 7,
28 |
29 | "PSP_channel": 512,
30 | "PSP_resblock_kernel_sizes": [3,7,11],
31 | "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
32 | "PSP_input_conv_kernel_size": 7,
33 | "PSP_output_R_conv_kernel_size": 7,
34 | "PSP_output_I_conv_kernel_size": 7,
35 |
36 | "segment_size": 8192,
37 | "num_mels": 80,
38 | "n_fft": 1024,
39 | "hop_size": 256,
40 | "win_size": 1024,
41 |
42 | "sampling_rate": 22050,
43 |
44 | "fmin": 0,
45 | "fmax": 8000,
46 | "meloss":null,
47 | "num_workers": 4
48 | }
49 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 | import shutil
9 |
10 | class AttrDict(dict):
11 | def __init__(self, *args, **kwargs):
12 | super(AttrDict, self).__init__(*args, **kwargs)
13 | self.__dict__ = self
14 |
15 |
16 | def build_env(config, config_name, path):
17 | t_path = os.path.join(path, config_name)
18 | if config != t_path:
19 | os.makedirs(path, exist_ok=True)
20 | shutil.copyfile(config, os.path.join(path, config_name))
21 |
22 | def plot_spectrogram(spectrogram):
23 | fig, ax = plt.subplots(figsize=(10, 2))
24 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
25 | interpolation='none')
26 | plt.colorbar(im, ax=ax)
27 |
28 | fig.canvas.draw()
29 | plt.close()
30 |
31 | return fig
32 |
33 |
34 | def init_weights(m, mean=0.0, std=0.01):
35 | classname = m.__class__.__name__
36 | if classname.find("Conv") != -1:
37 | m.weight.data.normal_(mean, std)
38 |
39 |
40 | def apply_weight_norm(m):
41 | classname = m.__class__.__name__
42 | if classname.find("Conv") != -1:
43 | weight_norm(m)
44 |
45 |
46 | def get_padding(kernel_size, dilation=1):
47 | return int((kernel_size*dilation - dilation)/2)
48 |
49 |
50 | def load_checkpoint(filepath, device):
51 | assert os.path.isfile(filepath)
52 | print("Loading '{}'".format(filepath))
53 | checkpoint_dict = torch.load(filepath, map_location=device)
54 | print("Complete.")
55 | return checkpoint_dict
56 |
57 |
58 | def save_checkpoint(filepath, obj):
59 | print("Saving checkpoint to {}".format(filepath))
60 | torch.save(obj, filepath)
61 | print("Complete.")
62 |
63 |
64 | def scan_checkpoint(cp_dir, prefix):
65 | pattern = os.path.join(cp_dir, prefix + '????????')
66 | cp_list = glob.glob(pattern)
67 | if len(cp_list) == 0:
68 | return None
69 | return sorted(cp_list)[-1]
70 |
71 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra
2 | ### Hui-Peng Du, Ye-Xin Lu, Yang Ai, Zhen-Hua Ling
3 | In our [paper](https://arxiv.org/pdf/2311.11545.pdf), we proposed APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra.
4 | We provide our implementation as open source in this repository.
5 |
6 | **Abstract:**
7 | In our previous work, we proposed a neural vocoder called APNet, which directly predicts speech amplitude and phase spectra with a 5 ms frame shift in parallel from the input acoustic features, and then reconstructs the 16 kHz speech waveform using inverse short-time Fourier transform (ISTFT).
8 | APNet demonstrates the capability to generate synthesized speech of comparable quality to the HiFi-GAN vocoder but with a considerably improved inference speed.
9 | However, the performance of the APNet vocoder is constrained by the waveform sampling rate and spectral frame shift, limiting its practicality for high-quality speech synthesis.
10 | Therefore, this paper proposes an improved iteration of APNet, named APNet2.
11 | The proposed APNet2 vocoder adopts ConvNeXt v2 as the backbone network for amplitude and phase predictions, expecting to enhance the modeling capability.
12 | Additionally, we introduce a multi-resolution discriminator (MRD) into the GAN-based losses and optimize the form of certain losses.
13 | At a common configuration with a waveform sampling rate of 22.05 kHz and spectral frame shift of 256 points (i.e., approximately 11.6ms), our proposed APNet2 vocoder outperformed the original APNet and Vocos vocoders in terms of synthesized speech quality.
14 | The synthesized speech quality of APNet2 is also comparable to that of HiFi-GAN and iSTFTNet, while offering a significantly faster inference speed.
15 |
16 | Audio samples can be found [here](https://redmist328.github.io/APNet2_demo/).
17 |
18 | ## Requirements
19 | Follow this [txt](https://github.com/redmist328/APNet2/blob/main/requirements.txt).
20 |
21 | ## Training
22 | ```
23 | python train.py
24 | ```
25 | Checkpoints and copy of the configuration file are saved in the `cp_APNet` directory by default.
26 | You can modify the training and inference configuration by modifying the parameters in the [config.json](https://github.com/redmist328/APNet2/blob/main/config.json).
27 | ## Inference
28 | You can download pretrained model on LJSpeech dataset at [here](http://home.ustc.edu.cn/~redmist/APNet2/).
29 | ```
30 | python inference.py
31 | ```
32 |
33 | ## Model Structure
34 | 
35 |
36 | ## Comparison with other models
37 | 
38 |
39 | ## Acknowledgements
40 | We referred to [HiFiGAN](https://github.com/jik876/hifi-gan), [NSPP](https://github.com/YangAi520/NSPP), [APNet](https://github.com/YangAi520/APNet)
41 | and [Vocos](https://github.com/charactr-platform/vocos) to implement this.
42 |
43 | ## Citation
44 | ```
45 | @article{du2023apnet2,
46 | title={APNet2: High-quality and High-efficiency Neural Vocoder with Direct Prediction of Amplitude and Phase Spectra},
47 | author={Du, Hui-Peng and Lu, Ye-Xin and Ai, Yang and Ling, Zhen-Hua},
48 | journal={arXiv preprint arXiv:2311.11545},
49 | year={2023}
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import argparse
6 | import json
7 | import torch
8 | from utils import AttrDict
9 | from dataset import mel_spectrogram, load_wav
10 | from models import Generator
11 | import soundfile as sf
12 | import librosa
13 | import numpy as np
14 | import time
15 | h = None
16 | device = None
17 |
18 |
19 | def load_checkpoint(filepath, device):
20 | assert os.path.isfile(filepath)
21 | print("Loading '{}'".format(filepath))
22 | checkpoint_dict = torch.load(filepath, map_location=device)
23 | print("Complete.")
24 | return checkpoint_dict
25 |
26 |
27 | def get_mel(x):
28 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
29 |
30 |
31 | def scan_checkpoint(cp_dir, prefix):
32 | pattern = os.path.join(cp_dir, prefix + '*')
33 | cp_list = glob.glob(pattern)
34 | if len(cp_list) == 0:
35 | return ''
36 | return sorted(cp_list)[-1]
37 |
38 |
39 | def inference(h):
40 | generator = Generator(h).to(device)
41 |
42 | state_dict_g = load_checkpoint(h.checkpoint_file_load, device)
43 | generator.load_state_dict(state_dict_g['generator'])
44 |
45 | filelist = sorted(os.listdir(h.test_input_mels_dir if h.test_mel_load else h.test_input_wavs_dir))
46 |
47 | os.makedirs(h.test_output_dir, exist_ok=True)
48 |
49 | generator.eval()
50 | l=0
51 | with torch.no_grad():
52 | starttime = time.time()
53 | for i, filename in enumerate(filelist):
54 |
55 | # if h.test_mel_load:
56 | if 1:
57 | mel = np.load(os.path.join(h.test_input_wavs_dir, filename))
58 | x = torch.FloatTensor(mel).to(device)
59 | x=x.transpose(1,2)
60 | else:
61 | raw_wav, _ = librosa.load(os.path.join(h.test_input_wavs_dir, filename), sr=h.sampling_rate, mono=True)
62 | raw_wav = torch.FloatTensor(raw_wav).to(device)
63 | x = get_mel(raw_wav.unsqueeze(0))
64 |
65 | logamp_g, pha_g, _, _, y_g = generator(x)
66 | audio = y_g.squeeze()
67 | # logamp = logamp_g.squeeze()
68 | # pha = pha_g.squeeze()
69 | audio = audio.cpu().numpy()
70 | # logamp = logamp.cpu().numpy()
71 | # pha = pha.cpu().numpy()
72 | audiolen=len(audio)
73 | sf.write(os.path.join(h.test_output_dir, filename.split('.')[0]+'.wav'), audio, h.sampling_rate,'PCM_16')
74 |
75 | # print(pp)
76 | l+=audiolen
77 |
78 | # write(output_file, h.sampling_rate, audio)
79 | # print(output_file)
80 | end=time.time()
81 | print(end-starttime)
82 | print(l/22050)
83 | print(l/22050/(end-starttime))
84 |
85 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_logamp.npy'), logamp)
86 | # np.save(os.path.join(h.test_output_dir, filename.split('.')[0]+'_pha.npy'), pha)
87 | # if i==9:
88 | # break
89 |
90 | def main():
91 | print('Initializing Inference Process..')
92 |
93 | config_file = 'config.json'
94 |
95 | with open(config_file) as f:
96 | data = f.read()
97 |
98 | global h
99 | json_config = json.loads(data)
100 | h = AttrDict(json_config)
101 |
102 | torch.manual_seed(h.seed)
103 | global device
104 | if torch.cuda.is_available():
105 | torch.cuda.manual_seed(h.seed)
106 | device = torch.device('cuda')
107 | else:
108 | device = torch.device('cpu')
109 | device = torch.device('cpu')
110 | inference(h)
111 |
112 |
113 | if __name__ == '__main__':
114 | main()
115 |
116 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | import torch.utils.data
6 | import numpy as np
7 | from librosa.util import normalize
8 | from librosa.filters import mel as librosa_mel_fn
9 | import librosa
10 | import torchaudio
11 | import torch.nn as nn
12 |
13 | def load_wav(full_path, sample_rate):
14 | data, _ = librosa.load(full_path, sr=sample_rate, mono=True)
15 | return data
16 |
17 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
18 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
19 |
20 | def dynamic_range_decompression(x, C=1):
21 | return np.exp(x) / C
22 |
23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24 | return torch.log(torch.clamp(x, min=clip_val) * C)
25 |
26 | def dynamic_range_decompression_torch(x, C=1):
27 | return torch.exp(x) / C
28 |
29 | def spectral_normalize_torch(magnitudes):
30 | output = dynamic_range_compression_torch(magnitudes)
31 | return output
32 |
33 | def spectral_de_normalize_torch(magnitudes):
34 | output = dynamic_range_decompression_torch(magnitudes)
35 | return output
36 |
37 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=True):
38 |
39 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
40 | mel_basis = torch.from_numpy(mel).float().to(y.device)
41 | hann_window = torch.hann_window(win_size).to(y.device)
42 |
43 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=True)
44 |
45 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
46 |
47 | spec = torch.matmul(mel_basis, spec)
48 | spec = spectral_normalize_torch(spec)
49 |
50 | return spec #[batch_size,n_fft/2+1,frames]
51 |
52 | def amp_pha_specturm(y, n_fft, hop_size, win_size):
53 |
54 | hann_window=torch.hann_window(win_size).to(y.device)
55 |
56 | stft_spec=torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,center=True) #[batch_size, n_fft//2+1, frames, 2]
57 |
58 | rea=stft_spec[:,:,:,0] #[batch_size, n_fft//2+1, frames]
59 | imag=stft_spec[:,:,:,1] #[batch_size, n_fft//2+1, frames]
60 |
61 | log_amplitude=torch.log(torch.abs(torch.sqrt(torch.pow(rea,2)+torch.pow(imag,2)))+1e-5) #[batch_size, n_fft//2+1, frames]
62 | phase=torch.atan2(imag,rea) #[batch_size, n_fft//2+1, frames]
63 |
64 | return log_amplitude, phase, rea, imag
65 |
66 | def get_dataset_filelist(input_training_wav_list,input_validation_wav_list):
67 | training_files=[]
68 | filelist=os.listdir(input_training_wav_list)
69 | for files in filelist:
70 |
71 | src=os.path.join(input_training_wav_list,files)
72 | training_files.append(src)
73 |
74 | validation_files=[]
75 | filelist=os.listdir(input_validation_wav_list)
76 | for files in filelist:
77 | src=os.path.join(input_validation_wav_list,files)
78 | validation_files.append(src)
79 |
80 | return training_files, validation_files
81 |
82 |
83 | class Dataset(torch.utils.data.Dataset):
84 | def __init__(self, training_files, segment_size, n_fft, num_mels,
85 | hop_size, win_size, sampling_rate, fmin, fmax,meloss, split=True, shuffle=True, n_cache_reuse=1,
86 | device=None):
87 | self.audio_files = training_files
88 | random.seed(1234)
89 | if shuffle:
90 | random.shuffle(self.audio_files)
91 | self.segment_size = segment_size
92 | self.sampling_rate = sampling_rate
93 | self.split = split
94 | self.n_fft = n_fft
95 | self.num_mels = num_mels
96 | self.hop_size = hop_size
97 | self.win_size = win_size
98 | self.fmin = fmin
99 | self.fmax = fmax
100 | self.cached_wav = None
101 | self.n_cache_reuse = n_cache_reuse
102 | self._cache_ref_count = 0
103 | self.device = device
104 | self.meloss=meloss
105 |
106 | def __getitem__(self, index):
107 | filename = self.audio_files[index]
108 | if self._cache_ref_count == 0:
109 | audio = load_wav(filename, self.sampling_rate)
110 | self.cached_wav = audio
111 | self._cache_ref_count = self.n_cache_reuse
112 | else:
113 | audio = self.cached_wav
114 | self._cache_ref_count -= 1
115 |
116 | audio = torch.FloatTensor(audio) #[T]
117 | audio = audio.unsqueeze(0) #[1,T]
118 |
119 | if self.split:
120 | if audio.size(1) >= self.segment_size:
121 | max_audio_start = audio.size(1) - self.segment_size
122 | audio_start = random.randint(0, max_audio_start)
123 | audio = audio[:, audio_start: audio_start + self.segment_size] #[1,T]
124 | else:
125 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
126 |
127 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
128 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
129 | center=True)
130 | meloss1 = mel_spectrogram(audio, self.n_fft, self.num_mels,
131 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.meloss,
132 | center=True)
133 | log_amplitude, phase, rea, imag = amp_pha_specturm(audio, self.n_fft, self.hop_size, self.win_size) #[1,n_fft/2+1,frames]
134 |
135 |
136 | return (mel.squeeze(), log_amplitude.squeeze(), phase.squeeze(), rea.squeeze(), imag.squeeze(), audio.squeeze(0),meloss1.squeeze())
137 |
138 | def __len__(self):
139 | return len(self.audio_files)
140 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import itertools
4 | import os
5 | import time
6 | import argparse
7 | import json
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.utils.data import DistributedSampler, DataLoader
12 | import torch.multiprocessing as mp
13 | from torch.distributed import init_process_group
14 | from torch.nn.parallel import DistributedDataParallel
15 | from dataset import Dataset, mel_spectrogram, amp_pha_specturm, get_dataset_filelist
16 | from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
17 | discriminator_loss, amplitude_loss, phase_loss, STFT_consistency_loss,MultiResolutionDiscriminator
18 | from utils import AttrDict, build_env, plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
19 |
20 | torch.backends.cudnn.benchmark = True
21 |
22 |
23 | def train(h):
24 |
25 | torch.cuda.manual_seed(h.seed)
26 | device = torch.device('cuda:{:d}'.format(0))
27 |
28 | generator = Generator(h).to(device)
29 | mpd = MultiPeriodDiscriminator().to(device)
30 | mrd = MultiResolutionDiscriminator().to(device)
31 |
32 | print(generator)
33 | os.makedirs(h.checkpoint_path, exist_ok=True)
34 | print("checkpoints directory : ", h.checkpoint_path)
35 |
36 | if os.path.isdir(h.checkpoint_path):
37 | cp_g = scan_checkpoint(h.checkpoint_path, 'g_')
38 | cp_do = scan_checkpoint(h.checkpoint_path, 'do_')
39 |
40 | steps = 0
41 | if cp_g is None or cp_do is None:
42 | state_dict_do = None
43 | last_epoch = -1
44 | else:
45 | state_dict_g = load_checkpoint(cp_g, device)
46 | state_dict_do = load_checkpoint(cp_do, device)
47 | generator.load_state_dict(state_dict_g['generator'])
48 | mpd.load_state_dict(state_dict_do['mpd'])
49 | mrd.load_state_dict(state_dict_do['mrd'])
50 | steps = state_dict_do['steps'] + 1
51 | last_epoch = state_dict_do['epoch']
52 |
53 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
54 | optim_d = torch.optim.AdamW(itertools.chain(mrd.parameters(), mpd.parameters()),
55 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
56 |
57 | if state_dict_do is not None:
58 | optim_g.load_state_dict(state_dict_do['optim_g'])
59 | optim_d.load_state_dict(state_dict_do['optim_d'])
60 |
61 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
62 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
63 |
64 | training_filelist, validation_filelist = get_dataset_filelist(h.input_training_wav_list, h.input_validation_wav_list)
65 |
66 | trainset = Dataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
67 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, h.meloss,n_cache_reuse=0,
68 | shuffle=True, device=device)
69 |
70 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
71 | sampler=None,
72 | batch_size=h.batch_size,
73 | pin_memory=True,
74 | drop_last=True)
75 |
76 | validset = Dataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
77 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax,h.meloss, False, False, n_cache_reuse=0,
78 | device=device)
79 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
80 | sampler=None,
81 | batch_size=1,
82 | pin_memory=True,
83 | drop_last=True)
84 |
85 | sw = SummaryWriter(os.path.join(h.checkpoint_path, 'logs'))
86 |
87 | generator.train()
88 | mpd.train()
89 | mrd.train()
90 |
91 | for epoch in range(max(0, last_epoch), h.training_epochs):
92 |
93 | start = time.time()
94 | print("Epoch: {}".format(epoch+1))
95 |
96 | for i, batch in enumerate(train_loader):
97 | start_b = time.time()
98 | x, logamp, pha, rea, imag, y,meloss = batch
99 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
100 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
101 | logamp = torch.autograd.Variable(logamp.to(device, non_blocking=True))
102 | pha = torch.autograd.Variable(pha.to(device, non_blocking=True))
103 | rea = torch.autograd.Variable(rea.to(device, non_blocking=True))
104 | imag = torch.autograd.Variable(imag.to(device, non_blocking=True))
105 | y = y.unsqueeze(1)
106 | meloss = torch.autograd.Variable(meloss.to(device, non_blocking=True))
107 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x)
108 | y_g_mel = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
109 | h.fmin, h.meloss)
110 |
111 | optim_d.zero_grad()
112 |
113 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g.detach())
114 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
115 |
116 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g.detach())
117 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
118 |
119 | L_D = loss_disc_s*0.1 + loss_disc_f
120 |
121 | L_D.backward()
122 | optim_d.step()
123 |
124 | # Generator
125 | optim_g.zero_grad()
126 |
127 | # Losses defined on log amplitude spectra
128 | L_A = amplitude_loss(logamp, logamp_g)
129 |
130 | L_IP, L_GD, L_PTD = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1])
131 | # Losses defined on phase spectra
132 | L_P = L_IP + L_GD + L_PTD
133 |
134 | _, _, rea_g_final, imag_g_final = amp_pha_specturm(y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size)
135 | L_C = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final)
136 | L_R = F.l1_loss(rea, rea_g)
137 | L_I = F.l1_loss(imag, imag_g)
138 | # Losses defined on reconstructed STFT spectra
139 | L_S = L_C + 2.25 * (L_R + L_I)
140 |
141 | y_df_r, y_df_g, fmap_f_r, fmap_f_g = mpd(y, y_g)
142 | y_ds_r, y_ds_g, fmap_s_r, fmap_s_g = mrd(y, y_g)
143 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
144 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
145 | loss_gen_f, losses_gen_f = generator_loss(y_df_g)
146 | loss_gen_s, losses_gen_s = generator_loss(y_ds_g)
147 | L_GAN_G = loss_gen_s *0.1+ loss_gen_f
148 | L_FM = loss_fm_s *0.1+ loss_fm_f
149 | L_Mel = F.l1_loss(meloss, y_g_mel)
150 | # Losses defined on final waveforms
151 | L_W = L_GAN_G + L_FM + 45 * L_Mel
152 |
153 | L_G = 45 * L_A + 100 * L_P + 20 * L_S + L_W
154 |
155 | L_G.backward()
156 | optim_g.step()
157 |
158 | # STDOUT logging
159 | if steps % h.stdout_interval == 0:
160 | with torch.no_grad():
161 | A_error = amplitude_loss(logamp, logamp_g).item()
162 | IP_error, GD_error, PTD_error = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1])
163 | IP_error = IP_error.item()
164 | GD_error = GD_error.item()
165 | PTD_error = PTD_error.item()
166 | C_error = STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final).item()
167 | R_error = F.l1_loss(rea, rea_g).item()
168 | I_error = F.l1_loss(imag, imag_g).item()
169 | Mel_error = F.l1_loss(x, y_g_mel).item()
170 |
171 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Amplitude Loss : {:4.3f}, Instantaneous Phase Loss : {:4.3f}, Group Delay Loss : {:4.3f}, Phase Time Difference Loss : {:4.3f}, STFT Consistency Loss : {:4.3f}, Real Part Loss : {:4.3f}, Imaginary Part Loss : {:4.3f}, Mel Spectrogram Loss : {:4.3f}, s/b : {:4.3f}'.
172 | format(steps, L_G, A_error, IP_error, GD_error, PTD_error, C_error, R_error, I_error, Mel_error, time.time() - start_b))
173 |
174 | # checkpointing
175 | if steps % h.checkpoint_interval == 0 and steps != 0:
176 | checkpoint_path = "{}/g_{:08d}".format(h.checkpoint_path, steps)
177 | save_checkpoint(checkpoint_path,
178 | {'generator': generator.state_dict()})
179 | checkpoint_path = "{}/do_{:08d}".format(h.checkpoint_path, steps)
180 | save_checkpoint(checkpoint_path,
181 | {'mpd': mpd.state_dict(),
182 | 'mrd': mrd.state_dict(),
183 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
184 | 'epoch': epoch})
185 |
186 | # Tensorboard summary logging
187 | if steps % h.summary_interval == 0:
188 | sw.add_scalar("Training/Generator_Total_Loss", L_G, steps)
189 | sw.add_scalar("Training/Mel_Spectrogram_Loss", Mel_error, steps)
190 |
191 | # Validation
192 | if steps % h.validation_interval == 0: # and steps != 0:
193 | generator.eval()
194 | torch.cuda.empty_cache()
195 | val_A_err_tot = 0
196 | val_IP_err_tot = 0
197 | val_GD_err_tot = 0
198 | val_PTD_err_tot = 0
199 | val_C_err_tot = 0
200 | val_R_err_tot = 0
201 | val_I_err_tot = 0
202 | val_Mel_err_tot = 0
203 | with torch.no_grad():
204 | for j, batch in enumerate(validation_loader):
205 | x, logamp, pha, rea, imag, y ,meloss= batch
206 | logamp_g, pha_g, rea_g, imag_g, y_g = generator(x.to(device))
207 | mel = x
208 | mel = torch.autograd.Variable(mel.to(device, non_blocking=True))
209 | logamp = torch.autograd.Variable(logamp.to(device, non_blocking=True))
210 | pha = torch.autograd.Variable(pha.to(device, non_blocking=True))
211 | rea = torch.autograd.Variable(rea.to(device, non_blocking=True))
212 | imag = torch.autograd.Variable(imag.to(device, non_blocking=True))
213 | meloss = torch.autograd.Variable(meloss.to(device, non_blocking=True))
214 | y_g_mel = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,h.hop_size, h.win_size,h.fmin, h.meloss)
215 |
216 | _, _, rea_g_final, imag_g_final = amp_pha_specturm(y_g.squeeze(1), h.n_fft, h.hop_size, h.win_size)
217 | val_A_err_tot += amplitude_loss(logamp, logamp_g).item()
218 | val_IP_err, val_GD_err, val_PTD_err = phase_loss(pha, pha_g, h.n_fft, pha.size()[-1])
219 | val_IP_err_tot += val_IP_err.item()
220 | val_GD_err_tot += val_GD_err.item()
221 | val_PTD_err_tot += val_PTD_err.item()
222 | val_C_err_tot += STFT_consistency_loss(rea_g, rea_g_final, imag_g, imag_g_final).item()
223 | val_R_err_tot += F.l1_loss(rea, rea_g).item()
224 | val_I_err_tot += F.l1_loss(imag, imag_g).item()
225 | val_Mel_err_tot += F.l1_loss(meloss, y_g_mel).item()
226 |
227 | # if j <= 4:
228 | # if steps == 0:
229 | # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
230 | # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
231 |
232 | # sw.add_audio('generated/y_g_{}'.format(j), y_g[0], steps, h.sampling_rate)
233 | # y_g_spec = mel_spectrogram(y_g.squeeze(1), h.n_fft, h.num_mels,
234 | # h.sampling_rate, h.hop_size, h.win_size,
235 | # h.fmin, h.fmax)
236 | # sw.add_figure('generated/y_g_spec_{}'.format(j),
237 | # plot_spectrogram(y_g_spec.squeeze(0).cpu().numpy()), steps)
238 |
239 | val_A_err = val_A_err_tot / (j+1)
240 | val_IP_err = val_IP_err_tot / (j+1)
241 | val_GD_err = val_GD_err_tot / (j+1)
242 | val_PTD_err = val_PTD_err_tot / (j+1)
243 | val_C_err = val_C_err_tot / (j+1)
244 | val_R_err = val_R_err_tot / (j+1)
245 | val_I_err = val_I_err_tot / (j+1)
246 | val_Mel_err = val_Mel_err_tot / (j+1)
247 | sw.add_scalar("Validation/Amplitude_Loss", val_A_err, steps)
248 | sw.add_scalar("Validation/Instantaneous_Phase_Loss", val_IP_err, steps)
249 | sw.add_scalar("Validation/Group_Delay_Loss", val_GD_err, steps)
250 | sw.add_scalar("Validation/Phase_Time_Difference_Loss", val_PTD_err, steps)
251 | sw.add_scalar("Validation/STFT_Consistency_Loss", val_C_err, steps)
252 | sw.add_scalar("Validation/Real_Part_Loss", val_R_err, steps)
253 | sw.add_scalar("Validation/Imaginary_Part_Loss", val_I_err, steps)
254 | sw.add_scalar("Validation/Mel_Spectrogram_loss", val_Mel_err, steps)
255 |
256 | generator.train()
257 |
258 | steps += 1
259 |
260 | scheduler_g.step()
261 | scheduler_d.step()
262 |
263 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
264 |
265 |
266 | def main():
267 | print('Initializing Training Process..')
268 |
269 | config_file = 'config.json'
270 |
271 | with open(config_file) as f:
272 | data = f.read()
273 |
274 | json_config = json.loads(data)
275 | h = AttrDict(json_config)
276 | build_env(config_file, 'config.json', h.checkpoint_path)
277 |
278 | torch.manual_seed(h.seed)
279 | if torch.cuda.is_available():
280 | torch.cuda.manual_seed(h.seed)
281 | else:
282 | pass
283 |
284 | train(h)
285 |
286 |
287 | if __name__ == '__main__':
288 | main()
289 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, spectral_norm
6 | from utils import init_weights, get_padding
7 | import numpy as np
8 | LRELU_SLOPE = 0.1
9 |
10 |
11 | class GRN(nn.Module):
12 | """ GRN (Global Response Normalization) layer
13 | """
14 | def __init__(self, dim):
15 | super().__init__()
16 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
17 | self.beta = nn.Parameter(torch.zeros(1, 1, dim))
18 |
19 | def forward(self, x):
20 | Gx = torch.norm(x, p=2, dim=1, keepdim=True)
21 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
22 | return self.gamma * (x * Nx) + self.beta + x
23 |
24 | class ConvNeXtBlock(nn.Module):
25 | def __init__(
26 | self,
27 | dim: int,
28 | intermediate_dim: int,
29 | layer_scale_init_value= None,
30 | adanorm_num_embeddings = None,
31 | ):
32 | super().__init__()
33 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
34 | self.adanorm = adanorm_num_embeddings is not None
35 |
36 | self.norm = nn.LayerNorm(dim, eps=1e-6)
37 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
38 | self.act = nn.GELU()
39 | self.grn = GRN(intermediate_dim)
40 | self.pwconv2 = nn.Linear(intermediate_dim, dim)
41 |
42 | def forward(self, x, cond_embedding_id = None) :
43 | residual = x
44 | x = self.dwconv(x)
45 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
46 | if self.adanorm:
47 | assert cond_embedding_id is not None
48 | x = self.norm(x, cond_embedding_id)
49 | else:
50 | x = self.norm(x)
51 | x = self.pwconv1(x)
52 | x = self.act(x)
53 | x = self.grn(x)
54 | x = self.pwconv2(x)
55 |
56 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
57 |
58 | x = residual + x
59 | return x
60 | class Generator(torch.nn.Module):
61 | def __init__(self, h):
62 | super(Generator, self).__init__()
63 | self.h = h
64 | self.ASP_num_kernels = len(h.ASP_resblock_kernel_sizes)
65 | self.PSP_num_kernels = len(h.PSP_resblock_kernel_sizes)
66 |
67 | self.ASP_input_conv = Conv1d(h.num_mels, h.ASP_channel, h.ASP_input_conv_kernel_size, 1,
68 | padding=get_padding(h.ASP_input_conv_kernel_size, 1))
69 | self.PSP_input_conv = Conv1d(h.num_mels, h.PSP_channel, h.PSP_input_conv_kernel_size, 1,
70 | padding=get_padding(h.PSP_input_conv_kernel_size, 1))
71 |
72 | self.ASP_output_conv = Conv1d(h.ASP_channel, h.n_fft//2+1, h.ASP_output_conv_kernel_size, 1,
73 | padding=get_padding(h.ASP_output_conv_kernel_size, 1))
74 | self.PSP_output_R_conv = Conv1d(512, h.n_fft//2+1, h.PSP_output_R_conv_kernel_size, 1,
75 | padding=get_padding(h.PSP_output_R_conv_kernel_size, 1))
76 | self.PSP_output_I_conv = Conv1d(512, h.n_fft//2+1, h.PSP_output_I_conv_kernel_size, 1,
77 | padding=get_padding(h.PSP_output_I_conv_kernel_size, 1))
78 |
79 | self.dim=512
80 | self.num_layers=8
81 | self.adanorm_num_embeddings=None
82 | self.intermediate_dim=1536
83 | self.norm = nn.LayerNorm(self.dim, eps=1e-6)
84 | self.norm2 = nn.LayerNorm(self.dim, eps=1e-6)
85 | layer_scale_init_value = 1 / self.num_layers
86 | self.convnext = nn.ModuleList(
87 | [
88 | ConvNeXtBlock(
89 | dim=self.dim,
90 | intermediate_dim=self.intermediate_dim,
91 | layer_scale_init_value=layer_scale_init_value,
92 | adanorm_num_embeddings=self.adanorm_num_embeddings,
93 | )
94 | for _ in range(self.num_layers)
95 | ]
96 | )
97 | self.convnext2 = nn.ModuleList(
98 | [
99 | ConvNeXtBlock(
100 | dim=self.dim,
101 | intermediate_dim=self.intermediate_dim,
102 | layer_scale_init_value=layer_scale_init_value,
103 | adanorm_num_embeddings=self.adanorm_num_embeddings,
104 | )
105 | for _ in range(self.num_layers)
106 | ]
107 | )
108 | self.final_layer_norm = nn.LayerNorm(self.dim, eps=1e-6)
109 | self.final_layer_norm2 = nn.LayerNorm(self.dim, eps=1e-6)
110 | self.apply(self._init_weights)
111 |
112 | def _init_weights(self, m):
113 | if isinstance(m, (nn.Conv1d, nn.Linear)):
114 | nn.init.trunc_normal_(m.weight, std=0.02)
115 | nn.init.constant_(m.bias, 0)
116 |
117 | def forward(self, mel):
118 |
119 | logamp = self.ASP_input_conv(mel)
120 | logamp = self.norm2(logamp.transpose(1, 2))
121 | logamp = logamp.transpose(1, 2)
122 | for conv_block in self.convnext2:
123 | logamp = conv_block(logamp, cond_embedding_id=None)
124 | logamp = self.final_layer_norm2(logamp.transpose(1, 2))
125 | logamp = logamp.transpose(1, 2)
126 | logamp = self.ASP_output_conv(logamp)
127 |
128 |
129 | pha = self.PSP_input_conv(mel)
130 | pha = self.norm(pha.transpose(1, 2))
131 | pha = pha.transpose(1, 2)
132 | for conv_block in self.convnext:
133 | pha = conv_block(pha, cond_embedding_id=None)
134 | pha = self.final_layer_norm(pha.transpose(1, 2))
135 | pha = pha.transpose(1, 2)
136 | R = self.PSP_output_R_conv(pha)
137 | I = self.PSP_output_I_conv(pha)
138 |
139 | pha = torch.atan2(I,R)
140 |
141 | rea = torch.exp(logamp)*torch.cos(pha)
142 | imag = torch.exp(logamp)*torch.sin(pha)
143 |
144 | spec = torch.cat((rea.unsqueeze(-1),imag.unsqueeze(-1)),-1)
145 |
146 | audio = torch.istft(spec, self.h.n_fft, hop_length=self.h.hop_size, win_length=self.h.win_size, window=torch.hann_window(self.h.win_size).to(mel.device), center=True)
147 |
148 | return logamp, pha, rea, imag, audio.unsqueeze(1)
149 |
150 | class DiscriminatorP(torch.nn.Module):
151 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
152 | super(DiscriminatorP, self).__init__()
153 | self.period = period
154 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
155 | self.convs = nn.ModuleList([
156 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
157 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
158 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
159 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
160 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
161 | ])
162 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
163 |
164 | def forward(self, x):
165 | fmap = []
166 |
167 | # 1d to 2d
168 | b, c, t = x.shape
169 | if t % self.period != 0: # pad first
170 | n_pad = self.period - (t % self.period)
171 | x = F.pad(x, (0, n_pad), "reflect")
172 | t = t + n_pad
173 | x = x.view(b, c, t // self.period, self.period)
174 |
175 | for l in self.convs:
176 | x = l(x)
177 | x = F.leaky_relu(x, LRELU_SLOPE)
178 | fmap.append(x)
179 | x = self.conv_post(x)
180 | fmap.append(x)
181 | x = torch.flatten(x, 1, -1)
182 |
183 | return x, fmap
184 |
185 |
186 | class MultiPeriodDiscriminator(torch.nn.Module):
187 | def __init__(self):
188 | super(MultiPeriodDiscriminator, self).__init__()
189 | self.discriminators = nn.ModuleList([
190 | DiscriminatorP(2),
191 | DiscriminatorP(3),
192 | DiscriminatorP(5),
193 | DiscriminatorP(7),
194 | DiscriminatorP(11),
195 | ])
196 |
197 | def forward(self, y, y_hat):
198 | y_d_rs = []
199 | y_d_gs = []
200 | fmap_rs = []
201 | fmap_gs = []
202 | for i, d in enumerate(self.discriminators):
203 | y_d_r, fmap_r = d(y)
204 | y_d_g, fmap_g = d(y_hat)
205 | y_d_rs.append(y_d_r)
206 | fmap_rs.append(fmap_r)
207 | y_d_gs.append(y_d_g)
208 | fmap_gs.append(fmap_g)
209 |
210 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
211 |
212 | def phase_loss(phase_r, phase_g, n_fft, frames):
213 |
214 | MSELoss = torch.nn.MSELoss()
215 |
216 | GD_matrix = torch.triu(torch.ones(n_fft//2+1,n_fft//2+1),diagonal=1)-torch.triu(torch.ones(n_fft//2+1,n_fft//2+1),diagonal=2)-torch.eye(n_fft//2+1)
217 | GD_matrix = GD_matrix.to(phase_g.device)
218 |
219 | GD_r = torch.matmul(phase_r.permute(0,2,1), GD_matrix)
220 | GD_g = torch.matmul(phase_g.permute(0,2,1), GD_matrix)
221 |
222 | PTD_matrix = torch.triu(torch.ones(frames,frames),diagonal=1)-torch.triu(torch.ones(frames,frames),diagonal=2)-torch.eye(frames)
223 | PTD_matrix = PTD_matrix.to(phase_g.device)
224 |
225 | PTD_r = torch.matmul(phase_r, PTD_matrix)
226 | PTD_g = torch.matmul(phase_g, PTD_matrix)
227 |
228 | IP_loss = torch.mean(anti_wrapping_function(phase_r-phase_g))
229 | GD_loss = torch.mean(anti_wrapping_function(GD_r-GD_g))
230 | PTD_loss = torch.mean(anti_wrapping_function(PTD_r-PTD_g))
231 |
232 |
233 | return IP_loss, GD_loss, PTD_loss
234 | class MultiResolutionDiscriminator(nn.Module):
235 | def __init__(
236 | self,
237 | resolutions= ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
238 | num_embeddings: int = None,
239 | ):
240 | super().__init__()
241 | self.discriminators = nn.ModuleList(
242 | [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
243 | )
244 |
245 | def forward(
246 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
247 | ) :
248 | y_d_rs = []
249 | y_d_gs = []
250 | fmap_rs = []
251 | fmap_gs = []
252 |
253 | for d in self.discriminators:
254 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
255 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
256 | y_d_rs.append(y_d_r)
257 | fmap_rs.append(fmap_r)
258 | y_d_gs.append(y_d_g)
259 | fmap_gs.append(fmap_g)
260 |
261 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
262 |
263 |
264 | class DiscriminatorR(nn.Module):
265 | def __init__(
266 | self,
267 | resolution,
268 | channels: int = 64,
269 | in_channels: int = 1,
270 | num_embeddings: int = None,
271 | lrelu_slope: float = 0.1,
272 | ):
273 | super().__init__()
274 | self.resolution = resolution
275 | self.in_channels = in_channels
276 | self.lrelu_slope = lrelu_slope
277 | self.convs = nn.ModuleList(
278 | [
279 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
280 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
281 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
282 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
283 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
284 | ]
285 | )
286 | if num_embeddings is not None:
287 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
288 | torch.nn.init.zeros_(self.emb.weight)
289 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
290 |
291 | def forward(
292 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None) :
293 | fmap = []
294 | x=x.squeeze(1)
295 |
296 | x = self.spectrogram(x)
297 | x = x.unsqueeze(1)
298 | for l in self.convs:
299 | x = l(x)
300 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
301 | fmap.append(x)
302 | if cond_embedding_id is not None:
303 | emb = self.emb(cond_embedding_id)
304 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
305 | else:
306 | h = 0
307 | x = self.conv_post(x)
308 | fmap.append(x)
309 | x += h
310 | x = torch.flatten(x, 1, -1)
311 |
312 | return x, fmap
313 |
314 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
315 | n_fft, hop_length, win_length = self.resolution
316 | magnitude_spectrogram = torch.stft(
317 | x,
318 | n_fft=n_fft,
319 | hop_length=hop_length,
320 | win_length=win_length,
321 | window=None, # interestingly rectangular window kind of works here
322 | center=True,
323 | return_complex=True,
324 | ).abs()
325 |
326 | return magnitude_spectrogram
327 |
328 | def anti_wrapping_function(x):
329 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
330 |
331 | def amplitude_loss(log_amplitude_r, log_amplitude_g):
332 |
333 | MSELoss = torch.nn.MSELoss()
334 |
335 | amplitude_loss = MSELoss(log_amplitude_r, log_amplitude_g)
336 |
337 | return amplitude_loss
338 |
339 |
340 | def feature_loss(fmap_r, fmap_g):
341 | loss = 0
342 | for dr, dg in zip(fmap_r, fmap_g):
343 | for rl, gl in zip(dr, dg):
344 | loss += torch.mean(torch.abs(rl - gl))
345 |
346 | return loss
347 |
348 |
349 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
350 | loss = 0
351 | r_losses = []
352 | g_losses = []
353 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
354 | r_loss = torch.mean(torch.clamp(1 - dr, min=0))
355 | g_loss = torch.mean(torch.clamp(1 + dg, min=0))
356 | loss += r_loss + g_loss
357 | r_losses.append(r_loss.item())
358 | g_losses.append(g_loss.item())
359 |
360 | return loss, r_losses, g_losses
361 |
362 |
363 | def generator_loss(disc_outputs):
364 | loss = 0
365 | gen_losses = []
366 | for dg in disc_outputs:
367 | l = torch.mean(torch.clamp(1 - dg, min=0))
368 | gen_losses.append(l)
369 | loss += l
370 |
371 | return loss, gen_losses
372 |
373 |
374 | def STFT_consistency_loss(rea_r, rea_g, imag_r, imag_g):
375 |
376 | C_loss=torch.mean(torch.mean((rea_r-rea_g)**2+(imag_r-imag_g)**2,(1,2)))
377 |
378 | return C_loss
379 |
380 |
--------------------------------------------------------------------------------