├── Eval ├── __pycache__ │ └── pitch_periodicity.cpython-39.pyc └── pitch_periodicity.py ├── LICENSE ├── README.md ├── commons.py ├── configs ├── periodwave_22050hz.json ├── periodwave_24000hz.json ├── periodwave_encodec.json ├── periodwave_encodec_turbo.json ├── periodwave_turbo_22050hz.json └── periodwave_turbo_24000hz.json ├── dataset_codec.py ├── encodec_feature_extractor.py ├── extract_energy.py ├── filelist_gen.py ├── inference.py ├── inference_large.py ├── inference_large_with_evaluation.py ├── inference_periodwave_encodec_universal_test_sound.py ├── inference_periodwave_encodec_universal_test_sound_step2.py ├── inference_periodwave_encodec_universal_test_speech.py ├── inference_periodwave_encodec_universal_test_speech_step2.py ├── inference_periodwave_encodec_universal_test_vocal.py ├── inference_periodwave_encodec_universal_test_vocal_step2.py ├── inference_with_FreeU.py ├── inference_with_TTS.py ├── inference_with_evaluation.py ├── meldataset_prior_length.py ├── model ├── base.py ├── bigvganv2_discriminator.py ├── commons.py ├── convnext.py ├── diffusion_module.py ├── ms_mel_loss.py ├── ms_stftd.py ├── periodwave.py ├── periodwave_encodec.py ├── periodwave_encodec_freeu.py ├── periodwave_encodec_freeu_utils.py ├── periodwave_encodec_turbo.py ├── periodwave_encodec_utils.py ├── periodwave_freeu.py ├── periodwave_large.py ├── periodwave_large_utils.py ├── periodwave_turbo.py ├── periodwave_utils.py ├── periodwave_utils_freeu.py └── utils.py ├── periodwave.png ├── requirements.txt ├── stats_libritts_24000hz ├── energy_max_train.npy └── energy_min_train.npy ├── stats_lj_22050hz ├── energy_max_train.npy └── energy_min_train.npy ├── test └── Triviul_feat._The_Fiend_-_Widow.stem.vocals_part180.wav ├── train_periodwave.py ├── train_periodwave_encodec.py ├── train_periodwave_turbo.py ├── train_periodwave_turbo_encodec.py └── utils.py /Eval/__pycache__/pitch_periodicity.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/Eval/__pycache__/pitch_periodicity.cpython-39.pyc -------------------------------------------------------------------------------- /Eval/pitch_periodicity.py: -------------------------------------------------------------------------------- 1 | import torchcrepe 2 | import torch 3 | import functools 4 | 5 | def from_audio(audio, target_length, hopsize): 6 | """Preprocess pitch from audio""" 7 | 8 | # Resample hopsize 9 | # Estimate pitch 10 | 11 | audio = audio.unsqueeze(0) 12 | pitch, periodicity = torchcrepe.predict( 13 | audio, 14 | sample_rate=torchcrepe.SAMPLE_RATE, 15 | hop_length=hopsize, 16 | fmin=50, 17 | fmax=550, 18 | model='full', 19 | return_periodicity=True, 20 | batch_size=1024, 21 | device=audio.device, 22 | pad=False) 23 | 24 | # Set low energy frames to unvoiced 25 | periodicity = torchcrepe.threshold.Silence()( 26 | periodicity, 27 | audio, 28 | torchcrepe.SAMPLE_RATE, 29 | hop_length=hopsize, 30 | pad=False) 31 | 32 | # Potentially resize due to resampled integer hopsize 33 | if pitch.shape[1] != target_length: 34 | interp_fn = functools.partial( 35 | torch.nn.functional.interpolate, 36 | size=target_length, 37 | mode='linear', 38 | align_corners=False) 39 | pitch = 2 ** interp_fn(torch.log2(pitch)[None]).squeeze(0) 40 | periodicity = interp_fn(periodicity[None]).squeeze(0) 41 | 42 | return pitch, periodicity 43 | 44 | def p_p_F(threshold, true_pitch, true_periodicity, pred_pitch, pred_periodicity): 45 | true_threshold = threshold(true_pitch, true_periodicity) 46 | pred_threshold = threshold(pred_pitch, pred_periodicity) 47 | true_voiced = ~torch.isnan(true_threshold) 48 | pred_voiced = ~torch.isnan(pred_threshold) 49 | 50 | # Update periodicity rmse 51 | count = true_pitch.shape[1] 52 | periodicity_total = (true_periodicity - pred_periodicity).pow(2).sum() 53 | 54 | # Update pitch rmse 55 | voiced = true_voiced & pred_voiced 56 | voiced_sum = voiced.sum() 57 | 58 | difference_cents = 1200 * (torch.log2(true_pitch[voiced]) - 59 | torch.log2(pred_pitch[voiced])) 60 | pitch_total = difference_cents.pow(2).sum() 61 | 62 | # Update voiced/unvoiced precision and recall 63 | true_positives = (true_voiced & pred_voiced).sum() 64 | false_positives = (~true_voiced & pred_voiced).sum() 65 | false_negatives = (true_voiced & ~pred_voiced).sum() 66 | 67 | pitch_rmse = torch.sqrt(pitch_total / voiced_sum) 68 | periodicity_rmse = torch.sqrt(periodicity_total / count) 69 | precision = true_positives / (true_positives + false_positives) 70 | recall = true_positives / (true_positives + false_negatives) 71 | f1 = 2 * precision * recall / (precision + recall) 72 | 73 | return pitch_rmse.nan_to_num().item(), periodicity_rmse.item(), f1.nan_to_num().item() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sang-Hoon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | def sequence_mask(length, max_length=None): 18 | if max_length is None: 19 | max_length = length.max() 20 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 21 | return x.unsqueeze(0) < length.unsqueeze(1) 22 | 23 | 24 | def clip_grad_value_(parameters, clip_value, norm_type=2): 25 | if isinstance(parameters, torch.Tensor): 26 | parameters = [parameters] 27 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 28 | norm_type = float(norm_type) 29 | if clip_value is not None: 30 | clip_value = float(clip_value) 31 | 32 | total_norm = 0 33 | for p in parameters: 34 | param_norm = p.grad.data.norm(norm_type) 35 | total_norm += param_norm.item() ** norm_type 36 | if clip_value is not None: 37 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 38 | total_norm = total_norm ** (1. / norm_type) 39 | return total_norm 40 | -------------------------------------------------------------------------------- /configs/periodwave_22050hz.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 50000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-4, 9 | "batch_size": 32, 10 | "fp16_run": false, 11 | "segment_size": 32768 12 | }, 13 | 14 | "data": { 15 | "train_filelist_path": "filelist_lj/train_wav.txt", 16 | "test_filelist_path": "filelist_lj/val_wav.txt", 17 | "max_wav_value": 32768.0, 18 | "sampling_rate": 22050, 19 | "filter_length": 1024, 20 | "hop_length": 256, 21 | "win_length": 1024, 22 | "n_mel_channels": 80, 23 | "mel_fmin": 0, 24 | "mel_fmax": 8000, 25 | "energy_max": "stats_lj_22050hz/energy_max_train.npy", 26 | "energy_min": "stats_lj_22050hz/energy_min_train.npy" 27 | }, 28 | 29 | "model": { 30 | "periods": [1,2,3,5,7], 31 | "noise_scale": 0.25, 32 | "final_dim": 32, 33 | "hidden_dim":512 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /configs/periodwave_24000hz.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 50000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-4, 9 | "batch_size": 32, 10 | "fp16_run": false, 11 | "segment_size": 32768 12 | }, 13 | 14 | "data": { 15 | "train_filelist_path": "filelists_24k/train_wav.txt", 16 | "test_filelist_path": "filelists_24k/val_wav.txt", 17 | "max_wav_value": 32768.0, 18 | "sampling_rate": 24000, 19 | "filter_length": 1024, 20 | "hop_length": 256, 21 | "win_length": 1024, 22 | "n_mel_channels": 100, 23 | "mel_fmin": 0, 24 | "mel_fmax": 12000, 25 | "energy_max": "stats_libritts_24000hz/energy_max_train.npy", 26 | "energy_min": "stats_libritts_24000hz/energy_min_train.npy" 27 | }, 28 | 29 | "model": { 30 | "periods": [1,2,3,5,7], 31 | "noise_scale": 0.5, 32 | "final_dim": 32, 33 | "hidden_dim":512 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /configs/periodwave_encodec.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000000, 5 | "save_interval": 50000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-4, 9 | "batch_size": 32, 10 | "fp16_run": false, 11 | "segment_size": 48000 12 | }, 13 | "data": { 14 | "train_filelist_path": "filelists_24k/train_wav.txt", 15 | "test_filelist_path": "filelists_24k/val_wav.txt", 16 | "text_cleaners":["english_cleaners2"], 17 | "max_wav_value": 32768.0, 18 | "sampling_rate": 24000, 19 | "filter_length": 1280, 20 | "hop_length": 320, 21 | "win_length": 1280, 22 | "n_mel_channels": 100, 23 | "mel_fmin": 0, 24 | "mel_fmax": 12000 25 | }, 26 | "model": { 27 | "periods": [1,2,3,5,7], 28 | "noise_scale": 0.25, 29 | "final_dim": 32, 30 | "hidden_dim": 512 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /configs/periodwave_encodec_turbo.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-5, 9 | "batch_size": 8, 10 | "fp16_run": false, 11 | "segment_size": 48000, 12 | "tuning_steps": 4, 13 | "finetuning_temperature": 1, 14 | "w_stft":45, 15 | "pretrain_path": "./logs/periodwave_encodec/G_950000.pth" 16 | }, 17 | "data": { 18 | "train_filelist_path": "filelists_24k/train_wav.txt", 19 | "test_filelist_path": "filelists_24k/val_wav.txt", 20 | "text_cleaners":["english_cleaners2"], 21 | "max_wav_value": 32768.0, 22 | "sampling_rate": 24000, 23 | "filter_length": 1280, 24 | "hop_length": 320, 25 | "win_length": 1280, 26 | "n_mel_channels": 100, 27 | "mel_fmin": 0, 28 | "mel_fmax": 12000 29 | }, 30 | "model": { 31 | "periods": [1,2,3,5,7], 32 | "noise_scale": 0.25, 33 | "final_dim": 32, 34 | "hidden_dim": 512 35 | }, 36 | "discriminator": { 37 | "cqtd_filters": 128, 38 | "cqtd_max_filters": 1024, 39 | "cqtd_filters_scale": 1, 40 | "cqtd_dilations": [1, 2, 4], 41 | "cqtd_hop_lengths": [512, 256, 256], 42 | "cqtd_n_octaves": [9, 9, 9], 43 | "cqtd_bins_per_octaves": [24, 36, 48], 44 | "sampling_rate":24000, 45 | "mpd_reshapes": [2, 3, 5, 7, 11], 46 | "use_spectral_norm": false, 47 | "discriminator_channel_mult": 1 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /configs/periodwave_turbo_22050hz.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-5, 9 | "batch_size": 8, 10 | "fp16_run": false, 11 | "segment_size": 32768, 12 | "tuning_steps": 4, 13 | "finetuning_temperature": 1, 14 | "w_stft":45, 15 | "pretrain_path": "./logs/periodwave_lj_22050hz/G_1000000.pth" 16 | }, 17 | 18 | "data": { 19 | "train_filelist_path": "filelist_lj/train_wav.txt", 20 | "test_filelist_path": "filelist_lj/val_wav.txt", 21 | "max_wav_value": 32768.0, 22 | "sampling_rate": 22050, 23 | "filter_length": 1024, 24 | "hop_length": 256, 25 | "win_length": 1024, 26 | "n_mel_channels": 80, 27 | "mel_fmin": 0, 28 | "mel_fmax": 8000, 29 | "energy_max": "stats_lj_22050hz/energy_max_train.npy", 30 | "energy_min": "stats_lj_22050hz/energy_min_train.npy" 31 | }, 32 | 33 | "model": { 34 | "periods": [1,2,3,5,7], 35 | "noise_scale": 0.25, 36 | "final_dim": 32, 37 | "hidden_dim":512 38 | }, 39 | 40 | "discriminator": { 41 | "cqtd_filters": 128, 42 | "cqtd_max_filters": 1024, 43 | "cqtd_filters_scale": 1, 44 | "cqtd_dilations": [1, 2, 4], 45 | "cqtd_hop_lengths": [512, 256, 256], 46 | "cqtd_n_octaves": [9, 9, 9], 47 | "cqtd_bins_per_octaves": [24, 36, 48], 48 | "sampling_rate":22050, 49 | "mpd_reshapes": [2, 3, 5, 7, 11], 50 | "use_spectral_norm": false, 51 | "discriminator_channel_mult": 1 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /configs/periodwave_turbo_24000hz.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 20000, 8 | "learning_rate": 2e-5, 9 | "batch_size": 8, 10 | "fp16_run": false, 11 | "segment_size": 32768, 12 | "tuning_steps": 4, 13 | "finetuning_temperature": 1, 14 | "w_stft":45, 15 | "pretrain_path": "./logs/periodwave_libritts_24000hz/G_1000000.pth" 16 | }, 17 | 18 | "data": { 19 | "train_filelist_path": "filelists_24k/train_wav.txt", 20 | "test_filelist_path": "filelists_24k/val_wav.txt", 21 | "max_wav_value": 32768.0, 22 | "sampling_rate": 24000, 23 | "filter_length": 1024, 24 | "hop_length": 256, 25 | "win_length": 1024, 26 | "n_mel_channels": 100, 27 | "mel_fmin": 0, 28 | "mel_fmax": 12000, 29 | "energy_max": "stats_libritts_24000hz/energy_max_train.npy", 30 | "energy_min": "stats_libritts_24000hz/energy_min_train.npy" 31 | }, 32 | 33 | "model": { 34 | "periods": [1,2,3,5,7], 35 | "noise_scale": 0.5, 36 | "final_dim": 32, 37 | "hidden_dim":512 38 | }, 39 | 40 | "discriminator": { 41 | "cqtd_filters": 128, 42 | "cqtd_max_filters": 1024, 43 | "cqtd_filters_scale": 1, 44 | "cqtd_dilations": [1, 2, 4], 45 | "cqtd_hop_lengths": [512, 256, 256], 46 | "cqtd_n_octaves": [9, 9, 9], 47 | "cqtd_bins_per_octaves": [24, 36, 48], 48 | "sampling_rate":24000, 49 | "mpd_reshapes": [2, 3, 5, 7, 11], 50 | "use_spectral_norm": false, 51 | "discriminator_channel_mult": 1 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /dataset_codec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 5 | # LICENSE is in incl_licenses directory. 6 | 7 | import math 8 | import os 9 | import random 10 | import torch 11 | import torch.utils.data 12 | import numpy as np 13 | from librosa.util import normalize 14 | from scipy.io.wavfile import read 15 | from librosa.filters import mel as librosa_mel_fn 16 | import pathlib 17 | from tqdm import tqdm 18 | 19 | MAX_WAV_VALUE = 32768.0 20 | 21 | 22 | def load_wav(full_path, sr_target): 23 | sampling_rate, data = read(full_path) 24 | if sampling_rate != sr_target: 25 | raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz". 26 | format(full_path, sampling_rate, sr_target)) 27 | return data, sampling_rate 28 | 29 | 30 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 31 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 32 | 33 | 34 | def dynamic_range_decompression(x, C=1): 35 | return np.exp(x) / C 36 | 37 | 38 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 39 | return torch.log(torch.clamp(x, min=clip_val) * C) 40 | 41 | 42 | def dynamic_range_decompression_torch(x, C=1): 43 | return torch.exp(x) / C 44 | 45 | 46 | def spectral_normalize_torch(magnitudes): 47 | output = dynamic_range_compression_torch(magnitudes) 48 | return output 49 | 50 | 51 | def spectral_de_normalize_torch(magnitudes): 52 | output = dynamic_range_decompression_torch(magnitudes) 53 | return output 54 | 55 | 56 | mel_basis = {} 57 | hann_window = {} 58 | 59 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 60 | if torch.min(y) < -1.: 61 | print('min value is ', torch.min(y)) 62 | if torch.max(y) > 1.: 63 | print('max value is ', torch.max(y)) 64 | 65 | global mel_basis, hann_window 66 | if fmax not in mel_basis: 67 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 68 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 69 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 70 | 71 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 72 | y = y.squeeze(1) 73 | 74 | # complex tensor as default, then use view_as_real for future pytorch compatibility 75 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 76 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 77 | spec = torch.view_as_real(spec) 78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 79 | 80 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 81 | spec = spectral_normalize_torch(spec) 82 | 83 | return spec 84 | 85 | def parse_filelist(filelist_path): 86 | with open(filelist_path, 'r') as f: 87 | filelist = [line.strip() for line in f.readlines()] 88 | return filelist 89 | 90 | class MelDataset(torch.utils.data.Dataset): 91 | def __init__(self, training_files, hparams, segment_size, n_fft, num_mels, 92 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, 93 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): 94 | self.audio_files = parse_filelist(training_files) 95 | random.seed(1234) 96 | if shuffle: 97 | random.shuffle(self.audio_files) 98 | self.hparams = hparams 99 | self.segment_size = segment_size 100 | self.sampling_rate = sampling_rate 101 | self.split = split 102 | self.n_fft = n_fft 103 | self.num_mels = num_mels 104 | self.hop_size = hop_size 105 | self.win_size = win_size 106 | self.fmin = fmin 107 | self.fmax = fmax 108 | self.fmax_loss = fmax_loss 109 | self.cached_wav = None 110 | self.n_cache_reuse = n_cache_reuse 111 | self._cache_ref_count = 0 112 | self.device = device 113 | self.fine_tuning = fine_tuning 114 | self.base_mels_path = base_mels_path 115 | 116 | def __getitem__(self, index): 117 | 118 | filename = self.audio_files[index] 119 | if self._cache_ref_count == 0: 120 | audio, sampling_rate = load_wav(filename, self.sampling_rate) 121 | audio = audio / MAX_WAV_VALUE 122 | if not self.fine_tuning: 123 | audio = normalize(audio) * 0.95 124 | self.cached_wav = audio 125 | if sampling_rate != self.sampling_rate: 126 | raise ValueError("{} SR doesn't match target {} SR".format( 127 | sampling_rate, self.sampling_rate)) 128 | self._cache_ref_count = self.n_cache_reuse 129 | else: 130 | audio = self.cached_wav 131 | self._cache_ref_count -= 1 132 | 133 | audio = torch.FloatTensor(audio) 134 | audio = audio.unsqueeze(0) 135 | 136 | if self.split: 137 | if audio.size(1) >= self.segment_size: 138 | max_audio_start = audio.size(1) - self.segment_size 139 | audio_start = random.randint(0, max_audio_start) 140 | audio = audio[:, audio_start:audio_start+self.segment_size] 141 | audio_length = torch.LongTensor([self.segment_size]) 142 | else: 143 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 144 | audio_length = torch.LongTensor([audio.size(1)]) 145 | 146 | else: # validation step 147 | # match audio length to self.hop_size * n for evaluation 148 | if (audio.size(1) % self.hop_size) != 0: 149 | audio = audio[:, :-(audio.size(1) % self.hop_size)] 150 | audio_length = torch.LongTensor([audio.size(1)]) 151 | 152 | 153 | return (audio.squeeze(0), audio_length) 154 | 155 | def __len__(self): 156 | return len(self.audio_files) -------------------------------------------------------------------------------- /encodec_feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/gemelo-ai/vocos/blob/main/vocos/feature_extractors.py 2 | ## Licensed under the MIT license. 3 | from typing import List 4 | 5 | import torch 6 | import torchaudio 7 | from encodec import EncodecModel 8 | from torch import nn 9 | 10 | class EncodecFeatures(nn.Module): 11 | def __init__( 12 | self, 13 | encodec_model: str = "encodec_24khz", 14 | bandwidth: float = 6.0, 15 | train_codebooks: bool = False, 16 | ): 17 | super().__init__() 18 | if encodec_model == "encodec_24khz": 19 | encodec = EncodecModel.encodec_model_24khz 20 | else: 21 | raise ValueError( 22 | f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." 23 | ) 24 | self.encodec = encodec(pretrained=True) 25 | for param in self.encodec.parameters(): 26 | param.requires_grad = False 27 | self.encodec.eval() 28 | 29 | self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( 30 | self.encodec.frame_rate, bandwidth=bandwidth 31 | ) 32 | codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) 33 | self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) 34 | self.bandwidth = bandwidth 35 | self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode 36 | self.encodec.set_target_bandwidth(self.bandwidth) 37 | 38 | @torch.no_grad() 39 | def get_encodec_codes(self, audio): 40 | audio = audio.unsqueeze(1) 41 | emb = self.encodec.encoder(audio) 42 | codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) 43 | return codes 44 | 45 | def forward(self, audio: torch.Tensor): 46 | 47 | codes = self.get_encodec_codes(audio) 48 | # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` 49 | # with offsets given by the number of bins, and finally summed in a vectorized operation. 50 | offsets = torch.arange( 51 | 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device 52 | ) 53 | embeddings_idxs = codes + offsets.view(-1, 1, 1) 54 | features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) 55 | return features.transpose(1, 2) 56 | -------------------------------------------------------------------------------- /extract_energy.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from scipy.io.wavfile import read 9 | from librosa.filters import mel as librosa_mel_fn 10 | import pathlib 11 | from tqdm import tqdm 12 | 13 | MAX_WAV_VALUE = 32768.0 14 | def load_wav(full_path, sr_target): 15 | sampling_rate, data = read(full_path) 16 | if sampling_rate != sr_target: 17 | raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz". 18 | format(full_path, sampling_rate, sr_target)) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global mel_basis, hann_window 58 | if fmax not in mel_basis: 59 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 60 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 61 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | # complex tensor as default, then use view_as_real for future pytorch compatibility 67 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 68 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 69 | spec = torch.view_as_real(spec) 70 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 71 | 72 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 73 | spec = spectral_normalize_torch(spec) 74 | 75 | return spec 76 | 77 | def parse_filelist(filelist_path): 78 | with open(filelist_path, 'r') as f: 79 | filelist = [line.strip() for line in f.readlines()] 80 | return filelist 81 | 82 | 83 | audio_files = parse_filelist("filelists_24k/train_wav.txt") 84 | 85 | energy_list = [] 86 | 87 | print("INFO: computing training set waveform statistics for PriorGrad training...") 88 | 89 | for i in tqdm(range(len(audio_files))): 90 | filename = audio_files[i] 91 | audio, sr = load_wav(filename, 24000) 92 | if 24000 != sr: 93 | raise ValueError(f'Invalid sample rate {sr}.') 94 | audio = audio / MAX_WAV_VALUE 95 | audio = normalize(audio) * 0.95 96 | 97 | audio = torch.FloatTensor(audio).cuda() 98 | audio = audio.unsqueeze(0) 99 | # match audio length to self.hop_size * n for evaluation 100 | if (audio.size(1) % 256) != 0: 101 | audio = audio[:, :-(audio.size(1) % 256)] 102 | 103 | mel = mel_spectrogram(audio, 1024, 100, 104 | 24000, 256, 1024, 0, 12000, 105 | center=False) 106 | assert audio.shape[1] == mel.shape[2] * 256, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 107 | 108 | energy = (mel.exp()).sum(1).sqrt() 109 | 110 | 111 | energy_list.append(energy.squeeze(0)) 112 | 113 | 114 | 115 | energy_list = torch.cat(energy_list) 116 | 117 | 118 | energy_max = energy_list.max().cpu().numpy() 119 | energy_min = energy_list.min().cpu().numpy() 120 | os.makedirs("stats_libritts", exist_ok=True) 121 | print("INFO: stats computed: max energy {} min energy {}".format(energy_max, energy_min)) 122 | np.save("stats_libritts/energy_max_train.npy", energy_max) 123 | np.save("stats_libritts/energy_min_train.npy", energy_min) 124 | 125 | -------------------------------------------------------------------------------- /filelist_gen.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import torchaudio 4 | import torch 5 | import tqdm 6 | import os 7 | 8 | def main(args): 9 | 10 | ############################## 11 | 12 | 13 | path = args.input_dir 14 | wavs_train = [] 15 | wavs_train += sorted(glob.glob(path+'/train-clean-100/**/*.wav', recursive=True)) 16 | wavs_train += sorted(glob.glob(path+'/train-clean-360/**/*.wav', recursive=True)) 17 | wavs_train += sorted(glob.glob(path+'/train-other-500/**/*.wav', recursive=True)) 18 | 19 | with open('filelists_24k/train_wav.txt', 'w') as f: 20 | for wav in wavs_train: 21 | f.write(wav+'\n') 22 | 23 | 24 | with open(args.input_validation_file, 'r', encoding='utf-8') as fi: 25 | validation_files = [os.path.join(path, x.split('|')[0] + '.wav') 26 | for x in fi.read().split('\n') if len(x) > 0] 27 | with open(args.input_validation_file2, 'r', encoding='utf-8') as fi: 28 | validation_files += [os.path.join(path, x.split('|')[0] + '.wav') 29 | for x in fi.read().split('\n') if len(x) > 0] 30 | print("first validation file: {}".format(validation_files[0])) 31 | 32 | with open('filelists_24k/val_wav.txt', 'w') as f: 33 | for wav in validation_files: 34 | f.write(wav + '\n') 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-i', '--input_dir', default='./LibriTTS') 40 | parser.add_argument('--input_validation_file', default='./filelists_24k/dev-clean.txt') 41 | parser.add_argument('--input_validation_file2', default='./filelists_24k/dev-other.txt') 42 | a = parser.parse_args() 43 | 44 | main(a) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import tqdm 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from scipy.io.wavfile import write 8 | import utils 9 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 10 | from librosa.util import normalize 11 | from model.periodwave import FlowMatch 12 | 13 | h = None 14 | device = None 15 | 16 | def get_param_num(model): 17 | num_param = sum(param.numel() for param in model.parameters()) 18 | return num_param 19 | 20 | def inference(a): 21 | torch.manual_seed(1234) 22 | np.random.seed(1234) 23 | 24 | os.makedirs(a.output_dir, exist_ok=True) 25 | model = FlowMatch(hps.data.n_mel_channels, 26 | hps.model.periods, 27 | hps.model.noise_scale).cuda() 28 | 29 | num_param = get_param_num(model) 30 | print('[Model] number of Parameters:', num_param) 31 | 32 | _ = model.eval() 33 | _ = utils.load_checkpoint(a.ckpt, model, None) 34 | 35 | model.estimator.remove_weight_norm() 36 | 37 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 38 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 39 | std_min = 0.1 40 | 41 | 42 | 43 | wavs_test = parse_filelist(hps.data.test_filelist_path) 44 | 45 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 46 | 47 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 48 | audio = audio / MAX_WAV_VALUE 49 | audio = normalize(audio) * 0.95 50 | 51 | audio = torch.FloatTensor(audio) 52 | audio = audio.unsqueeze(0) 53 | if (audio.size(1) % hps.data.hop_length) != 0: 54 | audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] 55 | 56 | file_name = os.path.splitext(os.path.basename(source_path))[0] 57 | audio = audio.cuda() 58 | 59 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 60 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 61 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 62 | energy = (mel.exp()).sum(1).sqrt() 63 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 64 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 65 | 66 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 67 | 68 | with torch.no_grad(): 69 | 70 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver) 71 | 72 | if torch.abs(resynthesis_audio).max() >= 0.95: 73 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 74 | 75 | 76 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 77 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 78 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 79 | 80 | file_name = os.path.splitext(os.path.basename(source_path))[0] 81 | file_name = "{}.wav".format(file_name) 82 | 83 | output_file = os.path.join('periodwave'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), file_name) 84 | 85 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 86 | write(output_file, 24000, resynthesis_audio) 87 | 88 | def main(): 89 | print('Initializing Inference Process..') 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--input_dir', default='gt') 93 | parser.add_argument('--output_dir', default='test') 94 | parser.add_argument('--ckpt', default='logs/periodwave_turbo_4_msmel_45_mel_gan_2e5/G_274000.pth') 95 | parser.add_argument('--iter', default=4, type=int) 96 | parser.add_argument('--noise_scale', default=1, type=float) 97 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 98 | a = parser.parse_args() 99 | 100 | global hps, device 101 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 103 | 104 | inference(a) 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /inference_large.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import tqdm 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from scipy.io.wavfile import write 8 | import utils 9 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 10 | from librosa.util import normalize 11 | from model.periodwave_large import FlowMatch 12 | 13 | h = None 14 | device = None 15 | 16 | def get_param_num(model): 17 | num_param = sum(param.numel() for param in model.parameters()) 18 | return num_param 19 | 20 | def inference(a): 21 | torch.manual_seed(1234) 22 | np.random.seed(1234) 23 | 24 | os.makedirs(a.output_dir, exist_ok=True) 25 | model = FlowMatch(hps.data.n_mel_channels, 26 | hps.model.periods, 27 | hps.model.noise_scale).cuda() 28 | 29 | num_param = get_param_num(model) 30 | print('[Model] number of Parameters:', num_param) 31 | 32 | _ = model.eval() 33 | _ = utils.load_checkpoint(a.ckpt, model, None) 34 | 35 | model.estimator.remove_weight_norm() 36 | 37 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 38 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 39 | std_min = 0.1 40 | 41 | 42 | wavs_test = parse_filelist(hps.data.test_filelist_path) 43 | 44 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 45 | 46 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 47 | audio = audio / MAX_WAV_VALUE 48 | audio = normalize(audio) * 0.95 49 | 50 | audio = torch.FloatTensor(audio) 51 | audio = audio.unsqueeze(0) 52 | if (audio.size(1) % hps.data.hop_length) != 0: 53 | audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] 54 | 55 | file_name = os.path.splitext(os.path.basename(source_path))[0] 56 | audio = audio.cuda() 57 | 58 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 59 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 60 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 61 | energy = (mel.exp()).sum(1).sqrt() 62 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 63 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 64 | 65 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 66 | 67 | with torch.no_grad(): 68 | 69 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver) 70 | 71 | if torch.abs(resynthesis_audio).max() >= 0.95: 72 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 73 | 74 | 75 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 76 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 77 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 78 | 79 | file_name = os.path.splitext(os.path.basename(source_path))[0] 80 | file_name = "{}.wav".format(file_name) 81 | 82 | output_file = os.path.join('periodwave_turbo_large'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), file_name) 83 | 84 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 85 | write(output_file, 24000, resynthesis_audio) 86 | 87 | def main(): 88 | print('Initializing Inference Process..') 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--input_dir', default='gt') 92 | parser.add_argument('--output_dir', default='test') 93 | parser.add_argument('--ckpt', default='logs/periodwave_turbo_8_large/G_379000.pth') 94 | parser.add_argument('--iter', default=4, type=int) 95 | parser.add_argument('--noise_scale', default=1, type=float) 96 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 97 | a = parser.parse_args() 98 | 99 | global hps, device 100 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | 103 | inference(a) 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /inference_large_with_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_large import FlowMatch 22 | 23 | h = None 24 | device = None 25 | 26 | def get_param_num(model): 27 | num_param = sum(param.numel() for param in model.parameters()) 28 | return num_param 29 | 30 | def inference(a): 31 | torch.manual_seed(1234) 32 | np.random.seed(1234) 33 | 34 | os.makedirs(a.output_dir, exist_ok=True) 35 | model = FlowMatch(hps.data.n_mel_channels, 36 | hps.model.periods, 37 | hps.model.noise_scale).cuda() 38 | 39 | num_param = get_param_num(model) 40 | print('[Model] number of Parameters:', num_param) 41 | 42 | _ = model.eval() 43 | _ = utils.load_checkpoint(a.ckpt, model, None) 44 | 45 | model.estimator.remove_weight_norm() 46 | 47 | threshold = torchcrepe.threshold.Hysteresis() 48 | 49 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 50 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 51 | std_min = 0.1 52 | 53 | predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).cuda() 54 | 55 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 56 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 57 | 58 | 59 | wavs_test = parse_filelist(hps.data.test_filelist_path) 60 | 61 | 62 | i = 0 63 | 64 | mel_error = 0 65 | pesq_wb = 0 66 | pesq_nb = 0 67 | 68 | pitch_total = 0 69 | periodicity_total = 0 70 | f1_total = 0 71 | utmos = 0 72 | val_mrstft_tot = 0 73 | 74 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 75 | 76 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 77 | audio = audio / MAX_WAV_VALUE 78 | audio = normalize(audio) * 0.95 79 | 80 | audio = torch.FloatTensor(audio) 81 | audio = audio.unsqueeze(0) 82 | if (audio.size(1) % hps.data.hop_length) != 0: 83 | audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] 84 | 85 | file_name = os.path.splitext(os.path.basename(source_path))[0] 86 | audio = audio.cuda() 87 | 88 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 89 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 90 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 91 | energy = (mel.exp()).sum(1).sqrt() 92 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 93 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 94 | 95 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 96 | 97 | with torch.no_grad(): 98 | 99 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver) 100 | 101 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 102 | 103 | if torch.abs(resynthesis_audio).max() >= 0.95: 104 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 105 | 106 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 107 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 108 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 109 | 110 | mel_error += F.l1_loss(mel, mel_hat).item() 111 | 112 | y_16k = pesq_resampler(audio) 113 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 114 | 115 | hopsize = int(256 * (torchcrepe.SAMPLE_RATE / 24000)) 116 | padding = int((1024 - hopsize) // 2) 117 | 118 | audio_for_pitch = torch.nn.functional.pad( 119 | y_16k[None], 120 | (padding, padding), 121 | mode='reflect').squeeze(0) 122 | 123 | gen_audio_for_pitch = torch.nn.functional.pad( 124 | y_g_hat_16k[None], 125 | (padding, padding), 126 | mode='reflect').squeeze(0) 127 | 128 | ori_audio_len = audio.shape[-1]//256 129 | true_pitch, true_periodicity = from_audio(audio_for_pitch.squeeze(), ori_audio_len, hopsize) 130 | fake_pitch, fake_periodicity = from_audio(gen_audio_for_pitch.squeeze(), ori_audio_len, hopsize) 131 | 132 | pitch, periodicity, f1 = p_p_F(threshold, true_pitch, true_periodicity, fake_pitch, fake_periodicity) 133 | 134 | pitch_total += pitch 135 | f1_total += f1 136 | 137 | periodicity_total += periodicity 138 | 139 | utmos += predictor(y_g_hat_16k, 16000) 140 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 141 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 142 | 143 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 144 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 145 | 146 | # MRSTFT calculation 147 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 148 | 149 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 150 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 151 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 152 | 153 | file_name = os.path.splitext(os.path.basename(source_path))[0] 154 | file_name = "{}.wav".format(file_name) 155 | 156 | output_file = os.path.join('periodwave_turbo_large_libritts_dev'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), file_name) 157 | 158 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 159 | write(output_file, 24000, resynthesis_audio) 160 | 161 | i +=1 162 | 163 | mel_error = mel_error/i 164 | pesq_wb = pesq_wb/i 165 | pitch_total = pitch_total/i 166 | periodicity_total = periodicity_total/i 167 | f1_total = f1_total/i 168 | utmos = utmos/i 169 | val_mrstft_tot = val_mrstft_tot/i 170 | 171 | with open(os.path.join('periodwave_turbo_large_libritts_dev'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), 'score_list.txt'), 'w') as f: 172 | f.write('periodwave_turbo_large_libritts_dev Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 173 | f.write('UTMOS: {}\n'.format(utmos)) 174 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 175 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 176 | f.write('Pitch: {}\nPeriodicity: {}\nV/UV F1: {}\n'.format(pitch_total, periodicity_total, f1_total)) 177 | 178 | 179 | def main(): 180 | print('Initializing Inference Process..') 181 | 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--input_dir', default='gt') 184 | parser.add_argument('--output_dir', default='test') 185 | parser.add_argument('--ckpt', default='logs/periodwave_turbo_8_large/G_379000.pth') 186 | parser.add_argument('--iter', default=4, type=int) 187 | parser.add_argument('--noise_scale', default=1, type=float) 188 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 189 | a = parser.parse_args() 190 | 191 | global hps, device 192 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 193 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 194 | 195 | inference(a) 196 | 197 | if __name__ == '__main__': 198 | main() -------------------------------------------------------------------------------- /inference_periodwave_encodec_universal_test_sound.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_encodec_freeu import FlowMatch 22 | from encodec_feature_extractor import EncodecFeatures 23 | 24 | h = None 25 | device = None 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def inference(a): 32 | torch.manual_seed(1234) 33 | np.random.seed(1234) 34 | 35 | os.makedirs(a.output_dir, exist_ok=True) 36 | model = FlowMatch(hps.data.n_mel_channels, 37 | hps.model.periods, 38 | hps.model.noise_scale, 39 | hps.model.final_dim, 40 | hps.model.hidden_dim,).cuda() 41 | 42 | num_param = get_param_num(model) 43 | print('[Model] number of Parameters:', num_param) 44 | 45 | _ = model.eval() 46 | _ = utils.load_checkpoint(a.ckpt, model, None) 47 | 48 | model.estimator.remove_weight_norm() 49 | 50 | 51 | Encodec = EncodecFeatures(bandwidth=a.bw).cuda() 52 | # 6.0 (Default, we trained the model with the feature of 6.0) 53 | # 1.5, 3.0, 6.0, 12.0 54 | # 12.0 (Not used during training but our model can generate higher quality audio with 12.0) 55 | 56 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 57 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 58 | 59 | wavs_test = [] 60 | wavs_test += sorted(glob('audio_reconstruct_universal_testset_v2/sound_effect/**/*.wav', recursive=True)) 61 | 62 | 63 | i = 0 64 | pitch_eval = False 65 | mel_error = 0 66 | pesq_wb = 0 67 | pesq_nb = 0 68 | val_mrstft_tot = 0 69 | mel_L = 0 70 | mel_M = 0 71 | mel_H = 0 72 | pitch_count= 0 73 | pseq_error = 0 74 | 75 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 76 | 77 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 78 | audio = audio / MAX_WAV_VALUE 79 | audio = normalize(audio) * 0.95 80 | 81 | audio = torch.FloatTensor(audio) 82 | audio = audio.unsqueeze(0) 83 | 84 | audio = F.pad(audio, (0, ((audio.size(1) // 3840)+1)*3840 - audio.size(1)), 'constant') 85 | 86 | file_name = os.path.splitext(os.path.basename(source_path))[0] 87 | audio = audio.cuda() 88 | 89 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 90 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 91 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 92 | 93 | 94 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 95 | 96 | with torch.no_grad(): 97 | embs = Encodec(audio) 98 | resynthesis_audio = model(audio, embs, n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, sway=a.sway, sway_coef=a.sway_coef, s_w=a.s_w, b_w=a.b_w) 99 | 100 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 101 | 102 | if torch.abs(resynthesis_audio).max() >= 0.95: 103 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 104 | 105 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 106 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 107 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 108 | 109 | mel_error += F.l1_loss(mel, mel_hat).item() 110 | 111 | mel_L +=F.l1_loss(mel[:,:61,:], mel_hat[:,:61,:]).item() 112 | mel_M +=F.l1_loss(mel[:,60:81,:], mel_hat[:,60:81,:]).item() 113 | mel_H +=F.l1_loss(mel[:,80:100,:], mel_hat[:,80:100,:]).item() 114 | 115 | y_16k = pesq_resampler(audio) 116 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 117 | 118 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 119 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 120 | 121 | 122 | try: 123 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 124 | except: 125 | pseq_error +=1 126 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 127 | 128 | # MRSTFT calculation 129 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 130 | 131 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 132 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 133 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 134 | 135 | file_name = os.path.splitext(os.path.basename(source_path))[0] 136 | file_name = "{}.wav".format(file_name) 137 | 138 | output_file = os.path.join('periodwave_encodec_base_turbo_final_590k_rfwave_sound'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), file_name) 139 | 140 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 141 | write(output_file, 24000, resynthesis_audio) 142 | 143 | i +=1 144 | 145 | mel_error = mel_error/i 146 | pesq_wb = pesq_wb/(i-pseq_error) 147 | 148 | val_mrstft_tot = val_mrstft_tot/i 149 | mel_L = mel_L/i 150 | mel_M = mel_M/i 151 | mel_H = mel_H/i 152 | 153 | with open(os.path.join('periodwave_encodec_base_turbo_final_590k_rfwave_sound'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), 'score_list.txt'), 'w') as f: 154 | f.write('periodwave_encodec_base_turbo_final_590k_rfwave_sound Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 155 | f.write('bw: {}\n'.format(a.bw)) 156 | f.write('s_w: {}\nb_w: {}\n'.format(a.s_w, a.b_w)) 157 | f.write('Sway: {}\n'.format(a.sway)) 158 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 159 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 160 | f.write('mel_L: {}\nmel_M: {}\nmel_H: {}\n'.format(mel_L, mel_M, mel_H)) 161 | 162 | 163 | def main(): 164 | print('Initializing Inference Process..') 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('--input_dir', default='gt') 168 | parser.add_argument('--output_dir', default='test') 169 | parser.add_argument('--ckpt', default='logs/periodwave_encodec_turbo_universe_mel45_from_speechonly470k/G_590000.pth') 170 | parser.add_argument('--bw', default=6.0, type=float) 171 | parser.add_argument('--iter', default=4, type=int) 172 | parser.add_argument('--noise_scale', default=1, type=float) 173 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 174 | parser.add_argument('--s_w', default=1, type=float) 175 | parser.add_argument('--b_w', default=1, type=float) 176 | parser.add_argument('--sway', default=False, type=bool) 177 | parser.add_argument('--sway_coef', default=-1.0, type=float) 178 | a = parser.parse_args() 179 | 180 | global hps, device 181 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 182 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 183 | 184 | inference(a) 185 | 186 | if __name__ == '__main__': 187 | main() -------------------------------------------------------------------------------- /inference_periodwave_encodec_universal_test_sound_step2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_encodec_freeu import FlowMatch 22 | from encodec_feature_extractor import EncodecFeatures 23 | 24 | h = None 25 | device = None 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def inference(a): 32 | torch.manual_seed(1234) 33 | np.random.seed(1234) 34 | 35 | os.makedirs(a.output_dir, exist_ok=True) 36 | model = FlowMatch(hps.data.n_mel_channels, 37 | hps.model.periods, 38 | hps.model.noise_scale, 39 | hps.model.final_dim, 40 | hps.model.hidden_dim,).cuda() 41 | 42 | num_param = get_param_num(model) 43 | print('[Model] number of Parameters:', num_param) 44 | 45 | _ = model.eval() 46 | _ = utils.load_checkpoint(a.ckpt, model, None) 47 | 48 | model.estimator.remove_weight_norm() 49 | 50 | Encodec = EncodecFeatures(bandwidth=a.bw).cuda() 51 | # 6.0 (Default, we trained the model with the feature of 6.0) 52 | # 1.5, 3.0, 6.0, 12.0 53 | # 12.0 (Not used during training but our model can generate higher quality audio with 12.0) 54 | 55 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 56 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 57 | 58 | wavs_test = [] 59 | wavs_test += sorted(glob('audio_reconstruct_universal_testset_v2/sound_effect/**/*.wav', recursive=True)) 60 | 61 | i = 0 62 | pitch_eval = False 63 | mel_error = 0 64 | pesq_wb = 0 65 | pesq_nb = 0 66 | 67 | val_mrstft_tot = 0 68 | mel_L = 0 69 | mel_M = 0 70 | mel_H = 0 71 | pitch_count= 0 72 | pseq_error = 0 73 | 74 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 75 | 76 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 77 | audio = audio / MAX_WAV_VALUE 78 | audio = normalize(audio) * 0.95 79 | 80 | audio = torch.FloatTensor(audio) 81 | audio = audio.unsqueeze(0) 82 | 83 | audio = F.pad(audio, (0, ((audio.size(1) // 3840)+1)*3840 - audio.size(1)), 'constant') 84 | 85 | file_name = os.path.splitext(os.path.basename(source_path))[0] 86 | audio = audio.cuda() 87 | 88 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 89 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 90 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 91 | 92 | 93 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 94 | 95 | with torch.no_grad(): 96 | embs = Encodec(audio) 97 | resynthesis_audio = model(audio, embs, n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, sway=a.sway, sway_coef=a.sway_coef, s_w=a.s_w, b_w=a.b_w) 98 | 99 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 100 | 101 | if torch.abs(resynthesis_audio).max() >= 0.95: 102 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 103 | 104 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 105 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 106 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 107 | 108 | mel_error += F.l1_loss(mel, mel_hat).item() 109 | 110 | mel_L +=F.l1_loss(mel[:,:61,:], mel_hat[:,:61,:]).item() 111 | mel_M +=F.l1_loss(mel[:,60:81,:], mel_hat[:,60:81,:]).item() 112 | mel_H +=F.l1_loss(mel[:,80:100,:], mel_hat[:,80:100,:]).item() 113 | 114 | y_16k = pesq_resampler(audio) 115 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 116 | 117 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 118 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 119 | 120 | try: 121 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 122 | except: 123 | pseq_error +=1 124 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 125 | 126 | # MRSTFT calculation 127 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 128 | 129 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 130 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 131 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 132 | 133 | file_name = os.path.splitext(os.path.basename(source_path))[0] 134 | file_name = "{}.wav".format(file_name) 135 | 136 | output_file = os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_sound'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), file_name) 137 | 138 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 139 | write(output_file, 24000, resynthesis_audio) 140 | 141 | i +=1 142 | 143 | mel_error = mel_error/i 144 | pesq_wb = pesq_wb/(i-pseq_error) 145 | val_mrstft_tot = val_mrstft_tot/i 146 | mel_L = mel_L/i 147 | mel_M = mel_M/i 148 | mel_H = mel_H/i 149 | 150 | with open(os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_sound'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), 'score_list.txt'), 'w') as f: 151 | f.write('periodwave_encodec_base_turbo_step2_400k_rfwave_sound Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 152 | f.write('bw: {}\n'.format(a.bw)) 153 | f.write('s_w: {}\nb_w: {}\n'.format(a.s_w, a.b_w)) 154 | f.write('Sway: {}\n'.format(a.sway)) 155 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 156 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 157 | f.write('mel_L: {}\nmel_M: {}\nmel_H: {}\n'.format(mel_L, mel_M, mel_H)) 158 | 159 | 160 | def main(): 161 | print('Initializing Inference Process..') 162 | 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument('--input_dir', default='gt') 165 | parser.add_argument('--output_dir', default='test') 166 | parser.add_argument('--ckpt', default='logs/periodwave_encodec_turbo_universe_cont_step2/G_400000.pth') 167 | parser.add_argument('--bw', default=6.0, type=float) 168 | parser.add_argument('--iter', default=2, type=int) 169 | parser.add_argument('--noise_scale', default=1, type=float) 170 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 171 | parser.add_argument('--s_w', default=1, type=float) 172 | parser.add_argument('--b_w', default=1, type=float) 173 | parser.add_argument('--sway', default=False, type=bool) 174 | parser.add_argument('--sway_coef', default=-1.0, type=float) 175 | a = parser.parse_args() 176 | 177 | global hps, device 178 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 179 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 180 | 181 | inference(a) 182 | 183 | if __name__ == '__main__': 184 | main() -------------------------------------------------------------------------------- /inference_periodwave_encodec_universal_test_speech_step2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_encodec_freeu import FlowMatch 22 | from encodec_feature_extractor import EncodecFeatures 23 | 24 | h = None 25 | device = None 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def inference(a): 32 | torch.manual_seed(1234) 33 | np.random.seed(1234) 34 | 35 | os.makedirs(a.output_dir, exist_ok=True) 36 | model = FlowMatch(hps.data.n_mel_channels, 37 | hps.model.periods, 38 | hps.model.noise_scale, 39 | hps.model.final_dim, 40 | hps.model.hidden_dim,).cuda() 41 | 42 | num_param = get_param_num(model) 43 | print('[Model] number of Parameters:', num_param) 44 | 45 | _ = model.eval() 46 | _ = utils.load_checkpoint(a.ckpt, model, None) 47 | 48 | model.estimator.remove_weight_norm() 49 | 50 | threshold = torchcrepe.threshold.Hysteresis() 51 | 52 | Encodec = EncodecFeatures(bandwidth=a.bw).cuda() 53 | # 6.0 (Default, we trained the model with the feature of 6.0) 54 | # 1.5, 3.0, 6.0, 12.0 55 | # 12.0 (Not used during training but our model can generate higher quality audio with 12.0) 56 | 57 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 58 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 59 | utmos_predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).cuda() 60 | ssl_mos_predictor = torch.hub.load("unilight/sheet:v0.1.0", "default", trust_repo=True, force_reload=True) 61 | ssl_mos_predictor.model.cuda() 62 | 63 | wavs_test = [] 64 | wavs_test += sorted(glob('audio_reconstruct_universal_testset_v2/speech/**/*.wav', recursive=True)) 65 | 66 | i = 0 67 | pitch_eval = True 68 | mel_error = 0 69 | pesq_wb = 0 70 | pesq_nb = 0 71 | 72 | pitch_total = 0 73 | periodicity_total = 0 74 | f1_total = 0 75 | utmos = 0 76 | ssl_mos = 0 77 | val_mrstft_tot = 0 78 | mel_L = 0 79 | mel_M = 0 80 | mel_H = 0 81 | pitch_count= 0 82 | pseq_error = 0 83 | 84 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 85 | 86 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 87 | audio = audio / MAX_WAV_VALUE 88 | audio = normalize(audio) * 0.95 89 | 90 | audio = torch.FloatTensor(audio) 91 | audio = audio.unsqueeze(0) 92 | 93 | audio = F.pad(audio, (0, ((audio.size(1) // 3840)+1)*3840 - audio.size(1)), 'constant') 94 | 95 | file_name = os.path.splitext(os.path.basename(source_path))[0] 96 | audio = audio.cuda() 97 | 98 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 99 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 100 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 101 | 102 | 103 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 104 | 105 | with torch.no_grad(): 106 | embs = Encodec(audio) 107 | resynthesis_audio = model(audio, embs, n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, sway=a.sway, sway_coef=a.sway_coef, s_w=a.s_w, b_w=a.b_w) 108 | 109 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 110 | 111 | if torch.abs(resynthesis_audio).max() >= 0.95: 112 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 113 | 114 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 115 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 116 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 117 | 118 | mel_error += F.l1_loss(mel, mel_hat).item() 119 | 120 | mel_L +=F.l1_loss(mel[:,:61,:], mel_hat[:,:61,:]).item() 121 | mel_M +=F.l1_loss(mel[:,60:81,:], mel_hat[:,60:81,:]).item() 122 | mel_H +=F.l1_loss(mel[:,80:100,:], mel_hat[:,80:100,:]).item() 123 | 124 | y_16k = pesq_resampler(audio) 125 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 126 | 127 | utmos += utmos_predictor(y_g_hat_16k, 16000) 128 | ssl_mos += ssl_mos_predictor.predict(wav=y_g_hat_16k.squeeze()) 129 | 130 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 131 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 132 | 133 | 134 | if pitch_eval == True: 135 | hopsize = int(256 * (torchcrepe.SAMPLE_RATE / 24000)) 136 | padding = int((1024 - hopsize) // 2) 137 | 138 | audio_for_pitch = torch.nn.functional.pad( 139 | y_16k[None], 140 | (padding, padding), 141 | mode='reflect').squeeze(0) 142 | 143 | gen_audio_for_pitch = torch.nn.functional.pad( 144 | y_g_hat_16k[None], 145 | (padding, padding), 146 | mode='reflect').squeeze(0) 147 | 148 | ori_audio_len = audio.shape[-1]//256 149 | true_pitch, true_periodicity = from_audio(audio_for_pitch.squeeze(), ori_audio_len, hopsize) 150 | fake_pitch, fake_periodicity = from_audio(gen_audio_for_pitch.squeeze(), ori_audio_len, hopsize) 151 | 152 | pitch, periodicity, f1 = p_p_F(threshold, true_pitch, true_periodicity, fake_pitch, fake_periodicity) 153 | 154 | pitch_total += pitch 155 | f1_total += f1 156 | 157 | periodicity_total += periodicity 158 | pitch_count +=1 159 | try: 160 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 161 | except: 162 | pseq_error +=1 163 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 164 | 165 | # MRSTFT calculation 166 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 167 | 168 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 169 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 170 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 171 | 172 | file_name = os.path.splitext(os.path.basename(source_path))[0] 173 | file_name = "{}.wav".format(file_name) 174 | 175 | output_file = os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_speech'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), file_name) 176 | 177 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 178 | write(output_file, 24000, resynthesis_audio) 179 | 180 | i +=1 181 | 182 | mel_error = mel_error/i 183 | pesq_wb = pesq_wb/(i-pseq_error) 184 | pitch_total = pitch_total/pitch_count 185 | periodicity_total = periodicity_total/pitch_count 186 | f1_total = f1_total/pitch_count 187 | val_mrstft_tot = val_mrstft_tot/i 188 | mel_L = mel_L/i 189 | mel_M = mel_M/i 190 | mel_H = mel_H/i 191 | utmos = utmos/i 192 | ssl_mos = ssl_mos/i 193 | with open(os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_speech'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), 'score_list.txt'), 'w') as f: 194 | f.write('periodwave_encodec_base_turbo_step2_400k_rfwave_speech Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 195 | f.write('bw: {}\n'.format(a.bw)) 196 | f.write('s_w: {}\nb_w: {}\n'.format(a.s_w, a.b_w)) 197 | f.write('Sway: {}\n'.format(a.sway)) 198 | f.write('UTMOS: {}\n'.format(utmos.item())) 199 | f.write('SSL-MOS: {}\n'.format(ssl_mos)) 200 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 201 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 202 | f.write('Pitch: {}\nPeriodicity: {}\nV/UV F1: {}\n'.format(pitch_total, periodicity_total, f1_total)) 203 | f.write('mel_L: {}\nmel_M: {}\nmel_H: {}\n'.format(mel_L, mel_M, mel_H)) 204 | 205 | 206 | def main(): 207 | print('Initializing Inference Process..') 208 | 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument('--input_dir', default='gt') 211 | parser.add_argument('--output_dir', default='test') 212 | parser.add_argument('--ckpt', default='logs/periodwave_encodec_turbo_universe_cont_step2/G_400000.pth') 213 | parser.add_argument('--bw', default=6.0, type=float) 214 | parser.add_argument('--iter', default=2, type=int) 215 | parser.add_argument('--noise_scale', default=1, type=float) 216 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 217 | parser.add_argument('--s_w', default=1, type=float) 218 | parser.add_argument('--b_w', default=1, type=float) 219 | parser.add_argument('--sway', default=False, type=bool) 220 | parser.add_argument('--sway_coef', default=-1.0, type=float) 221 | a = parser.parse_args() 222 | 223 | global hps, device 224 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 225 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 226 | 227 | inference(a) 228 | 229 | if __name__ == '__main__': 230 | main() -------------------------------------------------------------------------------- /inference_periodwave_encodec_universal_test_vocal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_encodec_freeu import FlowMatch 22 | from encodec_feature_extractor import EncodecFeatures 23 | 24 | h = None 25 | device = None 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def inference(a): 32 | torch.manual_seed(1234) 33 | np.random.seed(1234) 34 | 35 | os.makedirs(a.output_dir, exist_ok=True) 36 | model = FlowMatch(hps.data.n_mel_channels, 37 | hps.model.periods, 38 | hps.model.noise_scale, 39 | hps.model.final_dim, 40 | hps.model.hidden_dim,).cuda() 41 | 42 | num_param = get_param_num(model) 43 | print('[Model] number of Parameters:', num_param) 44 | 45 | _ = model.eval() 46 | _ = utils.load_checkpoint(a.ckpt, model, None) 47 | 48 | model.estimator.remove_weight_norm() 49 | 50 | threshold = torchcrepe.threshold.Hysteresis() 51 | 52 | Encodec = EncodecFeatures(bandwidth=a.bw).cuda() 53 | # 6.0 (Default, we trained the model with the feature of 6.0) 54 | # 1.5, 3.0, 6.0, 12.0 55 | # 12.0 (Not used during training but our model can generate higher quality audio with 12.0) 56 | 57 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 58 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 59 | 60 | 61 | wavs_test = [] 62 | wavs_test += sorted(glob('audio_reconstruct_universal_testset_v2/vocals/**/*.wav', recursive=True)) 63 | 64 | 65 | i = 0 66 | pitch_eval = True 67 | mel_error = 0 68 | pesq_wb = 0 69 | pesq_nb = 0 70 | 71 | pitch_total = 0 72 | periodicity_total = 0 73 | f1_total = 0 74 | 75 | val_mrstft_tot = 0 76 | mel_L = 0 77 | mel_M = 0 78 | mel_H = 0 79 | pitch_count= 0 80 | pseq_error = 0 81 | 82 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 83 | 84 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 85 | audio = audio / MAX_WAV_VALUE 86 | audio = normalize(audio) * 0.95 87 | 88 | audio = torch.FloatTensor(audio) 89 | audio = audio.unsqueeze(0) 90 | 91 | audio = F.pad(audio, (0, ((audio.size(1) // 3840)+1)*3840 - audio.size(1)), 'constant') 92 | 93 | file_name = os.path.splitext(os.path.basename(source_path))[0] 94 | audio = audio.cuda() 95 | 96 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 97 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 98 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 99 | 100 | 101 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 102 | 103 | with torch.no_grad(): 104 | embs = Encodec(audio) 105 | resynthesis_audio = model(audio, embs, n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, sway=a.sway, sway_coef=a.sway_coef, s_w=a.s_w, b_w=a.b_w) 106 | 107 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 108 | 109 | if torch.abs(resynthesis_audio).max() >= 0.95: 110 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 111 | 112 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 113 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 114 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 115 | 116 | mel_error += F.l1_loss(mel, mel_hat).item() 117 | 118 | mel_L +=F.l1_loss(mel[:,:61,:], mel_hat[:,:61,:]).item() 119 | mel_M +=F.l1_loss(mel[:,60:81,:], mel_hat[:,60:81,:]).item() 120 | mel_H +=F.l1_loss(mel[:,80:100,:], mel_hat[:,80:100,:]).item() 121 | 122 | y_16k = pesq_resampler(audio) 123 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 124 | 125 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 126 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 127 | 128 | 129 | if pitch_eval == True: 130 | hopsize = int(256 * (torchcrepe.SAMPLE_RATE / 24000)) 131 | padding = int((1024 - hopsize) // 2) 132 | 133 | audio_for_pitch = torch.nn.functional.pad( 134 | y_16k[None], 135 | (padding, padding), 136 | mode='reflect').squeeze(0) 137 | 138 | gen_audio_for_pitch = torch.nn.functional.pad( 139 | y_g_hat_16k[None], 140 | (padding, padding), 141 | mode='reflect').squeeze(0) 142 | 143 | ori_audio_len = audio.shape[-1]//256 144 | true_pitch, true_periodicity = from_audio(audio_for_pitch.squeeze(), ori_audio_len, hopsize) 145 | fake_pitch, fake_periodicity = from_audio(gen_audio_for_pitch.squeeze(), ori_audio_len, hopsize) 146 | 147 | pitch, periodicity, f1 = p_p_F(threshold, true_pitch, true_periodicity, fake_pitch, fake_periodicity) 148 | 149 | pitch_total += pitch 150 | f1_total += f1 151 | 152 | periodicity_total += periodicity 153 | pitch_count +=1 154 | try: 155 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 156 | except: 157 | pseq_error +=1 158 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 159 | 160 | # MRSTFT calculation 161 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 162 | 163 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 164 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 165 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 166 | 167 | file_name = os.path.splitext(os.path.basename(source_path))[0] 168 | file_name = "{}.wav".format(file_name) 169 | 170 | output_file = os.path.join('periodwave_encodec_base_turbo_final_590k_rfwave_vocal'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), file_name) 171 | 172 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 173 | write(output_file, 24000, resynthesis_audio) 174 | 175 | i +=1 176 | 177 | mel_error = mel_error/i 178 | pesq_wb = pesq_wb/(i-pseq_error) 179 | pitch_total = pitch_total/pitch_count 180 | periodicity_total = periodicity_total/pitch_count 181 | f1_total = f1_total/pitch_count 182 | val_mrstft_tot = val_mrstft_tot/i 183 | mel_L = mel_L/i 184 | mel_M = mel_M/i 185 | mel_H = mel_H/i 186 | 187 | with open(os.path.join('periodwave_encodec_base_turbo_final_590k_rfwave_vocal'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), 'score_list.txt'), 'w') as f: 188 | f.write('periodwave_encodec_base_turbo_final_590k_rfwave_vocal Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 189 | f.write('bw: {}\n'.format(a.bw)) 190 | f.write('s_w: {}\nb_w: {}\n'.format(a.s_w, a.b_w)) 191 | f.write('Sway: {}\n'.format(a.sway)) 192 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 193 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 194 | f.write('Pitch: {}\nPeriodicity: {}\nV/UV F1: {}\n'.format(pitch_total, periodicity_total, f1_total)) 195 | f.write('mel_L: {}\nmel_M: {}\nmel_H: {}\n'.format(mel_L, mel_M, mel_H)) 196 | 197 | def main(): 198 | print('Initializing Inference Process..') 199 | 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('--input_dir', default='gt') 202 | parser.add_argument('--output_dir', default='test') 203 | parser.add_argument('--ckpt', default='logs/periodwave_encodec_turbo_universe_mel45_from_speechonly470k/G_590000.pth') 204 | parser.add_argument('--iter', default=4, type=int) 205 | parser.add_argument('--noise_scale', default=1, type=float) 206 | parser.add_argument('--bw', default=6.0, type=float) 207 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 208 | parser.add_argument('--s_w', default=1, type=float) 209 | parser.add_argument('--b_w', default=1, type=float) 210 | parser.add_argument('--sway', default=False, type=bool) 211 | parser.add_argument('--sway_coef', default=-1.0, type=float) 212 | a = parser.parse_args() 213 | 214 | global hps, device 215 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 216 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 217 | 218 | inference(a) 219 | 220 | if __name__ == '__main__': 221 | main() -------------------------------------------------------------------------------- /inference_periodwave_encodec_universal_test_vocal_step2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave_encodec_freeu import FlowMatch 22 | from encodec_feature_extractor import EncodecFeatures 23 | 24 | h = None 25 | device = None 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def inference(a): 32 | torch.manual_seed(1234) 33 | np.random.seed(1234) 34 | 35 | os.makedirs(a.output_dir, exist_ok=True) 36 | model = FlowMatch(hps.data.n_mel_channels, 37 | hps.model.periods, 38 | hps.model.noise_scale, 39 | hps.model.final_dim, 40 | hps.model.hidden_dim,).cuda() 41 | 42 | num_param = get_param_num(model) 43 | print('[Model] number of Parameters:', num_param) 44 | 45 | _ = model.eval() 46 | _ = utils.load_checkpoint(a.ckpt, model, None) 47 | 48 | model.estimator.remove_weight_norm() 49 | 50 | threshold = torchcrepe.threshold.Hysteresis() 51 | 52 | Encodec = EncodecFeatures(bandwidth=a.bw).cuda() 53 | # 6.0 (Default, we trained the model with the feature of 6.0) 54 | # 1.5, 3.0, 6.0, 12.0 55 | # 12.0 (Not used during training but our model can generate higher quality audio with 12.0) 56 | 57 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 58 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 59 | 60 | 61 | wavs_test = [] 62 | wavs_test += sorted(glob('audio_reconstruct_universal_testset_v2/vocals/**/*.wav', recursive=True)) 63 | 64 | 65 | i = 0 66 | pitch_eval = True 67 | mel_error = 0 68 | pesq_wb = 0 69 | pesq_nb = 0 70 | 71 | pitch_total = 0 72 | periodicity_total = 0 73 | f1_total = 0 74 | utmos = 0 75 | ssl_mos = 0 76 | val_mrstft_tot = 0 77 | mel_L = 0 78 | mel_M = 0 79 | mel_H = 0 80 | pitch_count= 0 81 | pseq_error = 0 82 | 83 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 84 | 85 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 86 | audio = audio / MAX_WAV_VALUE 87 | audio = normalize(audio) * 0.95 88 | 89 | audio = torch.FloatTensor(audio) 90 | audio = audio.unsqueeze(0) 91 | 92 | audio = F.pad(audio, (0, ((audio.size(1) // 3840)+1)*3840 - audio.size(1)), 'constant') 93 | 94 | file_name = os.path.splitext(os.path.basename(source_path))[0] 95 | audio = audio.cuda() 96 | 97 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 98 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 99 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 100 | 101 | 102 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 103 | 104 | with torch.no_grad(): 105 | embs = Encodec(audio) 106 | resynthesis_audio = model(audio, embs, n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, sway=a.sway, sway_coef=a.sway_coef, s_w=a.s_w, b_w=a.b_w) 107 | 108 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 109 | 110 | if torch.abs(resynthesis_audio).max() >= 0.95: 111 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 112 | 113 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 114 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 115 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 116 | 117 | mel_error += F.l1_loss(mel, mel_hat).item() 118 | 119 | mel_L +=F.l1_loss(mel[:,:61,:], mel_hat[:,:61,:]).item() 120 | mel_M +=F.l1_loss(mel[:,60:81,:], mel_hat[:,60:81,:]).item() 121 | mel_H +=F.l1_loss(mel[:,80:100,:], mel_hat[:,80:100,:]).item() 122 | 123 | y_16k = pesq_resampler(audio) 124 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 125 | 126 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 127 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 128 | 129 | 130 | if pitch_eval == True: 131 | hopsize = int(256 * (torchcrepe.SAMPLE_RATE / 24000)) 132 | padding = int((1024 - hopsize) // 2) 133 | 134 | audio_for_pitch = torch.nn.functional.pad( 135 | y_16k[None], 136 | (padding, padding), 137 | mode='reflect').squeeze(0) 138 | 139 | gen_audio_for_pitch = torch.nn.functional.pad( 140 | y_g_hat_16k[None], 141 | (padding, padding), 142 | mode='reflect').squeeze(0) 143 | 144 | ori_audio_len = audio.shape[-1]//256 145 | true_pitch, true_periodicity = from_audio(audio_for_pitch.squeeze(), ori_audio_len, hopsize) 146 | fake_pitch, fake_periodicity = from_audio(gen_audio_for_pitch.squeeze(), ori_audio_len, hopsize) 147 | 148 | pitch, periodicity, f1 = p_p_F(threshold, true_pitch, true_periodicity, fake_pitch, fake_periodicity) 149 | 150 | pitch_total += pitch 151 | f1_total += f1 152 | 153 | periodicity_total += periodicity 154 | pitch_count +=1 155 | try: 156 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 157 | except: 158 | pseq_error +=1 159 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 160 | 161 | # MRSTFT calculation 162 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 163 | 164 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 165 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 166 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 167 | 168 | file_name = os.path.splitext(os.path.basename(source_path))[0] 169 | file_name = "{}.wav".format(file_name) 170 | 171 | output_file = os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_vocal'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), file_name) 172 | 173 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 174 | write(output_file, 24000, resynthesis_audio) 175 | 176 | i +=1 177 | 178 | mel_error = mel_error/i 179 | pesq_wb = pesq_wb/(i-pseq_error) 180 | pitch_total = pitch_total/pitch_count 181 | periodicity_total = periodicity_total/pitch_count 182 | f1_total = f1_total/pitch_count 183 | val_mrstft_tot = val_mrstft_tot/i 184 | mel_L = mel_L/i 185 | mel_M = mel_M/i 186 | mel_H = mel_H/i 187 | 188 | with open(os.path.join('periodwave_encodec_base_turbo_step2_400k_rfwave_vocal'+'_'+str(a.solver)+'_'+str(a.bw)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w)+'_sway_'+str(a.sway), 'score_list.txt'), 'w') as f: 189 | f.write('periodwave_encodec_base_turbo_step2_400k_rfwave_vocal Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 190 | f.write('bw: {}\n'.format(a.bw)) 191 | f.write('s_w: {}\nb_w: {}\n'.format(a.s_w, a.b_w)) 192 | f.write('Sway: {}\n'.format(a.sway)) 193 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 194 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 195 | f.write('Pitch: {}\nPeriodicity: {}\nV/UV F1: {}\n'.format(pitch_total, periodicity_total, f1_total)) 196 | f.write('mel_L: {}\nmel_M: {}\nmel_H: {}\n'.format(mel_L, mel_M, mel_H)) 197 | 198 | 199 | def main(): 200 | print('Initializing Inference Process..') 201 | 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('--input_dir', default='gt') 204 | parser.add_argument('--output_dir', default='test') 205 | parser.add_argument('--ckpt', default='logs/periodwave_encodec_turbo_universe_cont_step2/G_400000.pth') 206 | parser.add_argument('--bw', default=6.0, type=float) 207 | parser.add_argument('--iter', default=2, type=int) 208 | parser.add_argument('--noise_scale', default=1, type=float) 209 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 210 | parser.add_argument('--s_w', default=1, type=float) 211 | parser.add_argument('--b_w', default=1, type=float) 212 | parser.add_argument('--sway', default=False, type=bool) 213 | parser.add_argument('--sway_coef', default=-1.0, type=float) 214 | a = parser.parse_args() 215 | 216 | global hps, device 217 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 218 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 219 | 220 | inference(a) 221 | 222 | if __name__ == '__main__': 223 | main() -------------------------------------------------------------------------------- /inference_with_FreeU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import tqdm 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from scipy.io.wavfile import write 8 | import utils 9 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 10 | from librosa.util import normalize 11 | from model.periodwave_freeu import FlowMatch 12 | 13 | h = None 14 | device = None 15 | 16 | def get_param_num(model): 17 | num_param = sum(param.numel() for param in model.parameters()) 18 | return num_param 19 | 20 | def inference(a): 21 | torch.manual_seed(1234) 22 | np.random.seed(1234) 23 | 24 | os.makedirs(a.output_dir, exist_ok=True) 25 | model = FlowMatch(hps.data.n_mel_channels, 26 | hps.model.periods, 27 | hps.model.noise_scale).cuda() 28 | 29 | num_param = get_param_num(model) 30 | print('[Model] number of Parameters:', num_param) 31 | 32 | _ = model.eval() 33 | _ = utils.load_checkpoint(a.ckpt, model, None) 34 | 35 | model.estimator.remove_weight_norm() 36 | 37 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 38 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 39 | std_min = 0.1 40 | 41 | 42 | 43 | wavs_test = parse_filelist(hps.data.test_filelist_path) 44 | 45 | # wavs_test = [] 46 | # random.seed(1234) 47 | 48 | # temp1 = glob('/workspace/sb/dataset/LibriTTS_24k/LibriTTS/test-clean/**/*.wav', recursive=True) 49 | # temp2 = glob('/workspace/sb/dataset/LibriTTS_24k/LibriTTS/test-other/**/*.wav', recursive=True) 50 | 51 | # random.shuffle(temp1) 52 | # random.shuffle(temp2) 53 | # wavs_test += temp1[:50] 54 | # wavs_test += temp2[:50] 55 | 56 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 57 | 58 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 59 | audio = audio / MAX_WAV_VALUE 60 | audio = normalize(audio) * 0.95 61 | 62 | audio = torch.FloatTensor(audio) 63 | audio = audio.unsqueeze(0) 64 | if (audio.size(1) % hps.data.hop_length) != 0: 65 | audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] 66 | 67 | file_name = os.path.splitext(os.path.basename(source_path))[0] 68 | audio = audio.cuda() 69 | 70 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 71 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 72 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 73 | energy = (mel.exp()).sum(1).sqrt() 74 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 75 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 76 | 77 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 78 | 79 | with torch.no_grad(): 80 | 81 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver, s_w=a.s_w, b_w=a.b_w) 82 | 83 | if torch.abs(resynthesis_audio).max() >= 0.95: 84 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 85 | 86 | 87 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 88 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 89 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 90 | 91 | file_name = os.path.splitext(os.path.basename(source_path))[0] 92 | file_name = "{}.wav".format(file_name) 93 | 94 | output_file = os.path.join('periodwave'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale)+'_'+str(a.s_w)+'_'+str(a.b_w), file_name) 95 | 96 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 97 | write(output_file, 24000, resynthesis_audio) 98 | 99 | def main(): 100 | print('Initializing Inference Process..') 101 | 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--input_dir', default='gt') 104 | parser.add_argument('--output_dir', default='test') 105 | parser.add_argument('--ckpt', default='logs/periodwave_libritts_24000hz/G_1000000.pth') 106 | parser.add_argument('--iter', default=16, type=int) 107 | parser.add_argument('--noise_scale', default=0.667, type=float) 108 | parser.add_argument('--solver', default='midpoint', help="euler midpoint heun rk4") 109 | parser.add_argument('--s_w', default=0.9, type=float) 110 | parser.add_argument('--b_w', default=1.1, type=float) 111 | a = parser.parse_args() 112 | 113 | global hps, device 114 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 115 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 116 | 117 | inference(a) 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /inference_with_TTS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from glob import glob 5 | import tqdm 6 | import numpy as np 7 | from scipy.io.wavfile import write 8 | import torchaudio 9 | import utils 10 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 11 | 12 | 13 | from model.periodwave import FlowMatch 14 | 15 | h = None 16 | device = None 17 | 18 | def get_param_num(model): 19 | num_param = sum(param.numel() for param in model.parameters()) 20 | return num_param 21 | 22 | def inference(a): 23 | torch.manual_seed(1234) 24 | np.random.seed(1234) 25 | 26 | os.makedirs(a.output_dir, exist_ok=True) 27 | model = FlowMatch(hps.data.n_mel_channels, 28 | hps.model.periods, 29 | hps.model.noise_scale).cuda() 30 | 31 | num_param = get_param_num(model) 32 | print('[Model] number of Parameters:', num_param) 33 | 34 | _ = model.eval() 35 | _ = utils.load_checkpoint(a.ckpt, model, None) 36 | 37 | model.estimator.remove_weight_norm() 38 | 39 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 40 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 41 | std_min = 0.1 42 | 43 | predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).cuda() 44 | 45 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 46 | 47 | mels = glob('ardit_dmd_b1/*.mel', recursive=True) 48 | 49 | i = 0 50 | utmos = 0 51 | 52 | for mel_path in tqdm.tqdm(mels, desc="synthesizing each utterance"): 53 | 54 | file_name = os.path.splitext(os.path.basename(mel_path))[0] 55 | mel = torch.load(mel_path) 56 | mel = torch.FloatTensor(mel).cuda() 57 | mel = mel.transpose(1,2) 58 | audio = torch.zeros((1, mel.shape[-1]*256)).cuda() 59 | 60 | energy = (mel.exp()).sum(1).sqrt() 61 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 62 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 63 | 64 | with torch.no_grad(): 65 | 66 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver) 67 | 68 | if torch.abs(resynthesis_audio).max() >= 0.95: 69 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 70 | 71 | 72 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 73 | utmos += predictor(y_g_hat_16k, 16000) 74 | 75 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 76 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 77 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 78 | 79 | file_name = os.path.splitext(os.path.basename(mel_path))[0] 80 | file_name = "{}.wav".format(file_name) 81 | 82 | output_file = os.path.join('periodwave_ardit'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), file_name) 83 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 84 | write(output_file, 24000, resynthesis_audio) 85 | 86 | i +=1 87 | utmos = utmos/i 88 | 89 | with open(os.path.join('periodwave_ardit'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), 'score_list.txt'), 'w') as f: 90 | f.write('periodwave_ardit Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 91 | f.write('UTMOS: {}\n'.format(utmos)) 92 | 93 | 94 | def main(): 95 | print('Initializing Inference Process..') 96 | 97 | parser = argparse.ArgumentParser() 98 | 99 | parser.add_argument('--input_dir', default='gt') 100 | parser.add_argument('--output_dir', default='test') 101 | parser.add_argument('--ckpt', default='logs/periodwave_turbo_4_msmel_45_mel_gan_2e5/G_274000.pth') 102 | parser.add_argument('--iter', default=4, type=int) 103 | parser.add_argument('--noise_scale', default=1, type=float) 104 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 105 | a = parser.parse_args() 106 | 107 | global hps, device 108 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 109 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 110 | 111 | inference(a) 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /inference_with_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from glob import glob 6 | import tqdm 7 | import numpy as np 8 | from torch.nn import functional as F 9 | import commons 10 | from scipy.io.wavfile import write 11 | import torchaudio 12 | import utils 13 | from meldataset_prior_length import mel_spectrogram, load_wav, MAX_WAV_VALUE, parse_filelist 14 | from librosa.util import normalize 15 | 16 | import auraloss 17 | from pesq import pesq 18 | import torchcrepe 19 | from Eval.pitch_periodicity import from_audio, p_p_F 20 | 21 | from model.periodwave import FlowMatch 22 | 23 | h = None 24 | device = None 25 | 26 | def get_param_num(model): 27 | num_param = sum(param.numel() for param in model.parameters()) 28 | return num_param 29 | 30 | def inference(a): 31 | torch.manual_seed(1234) 32 | np.random.seed(1234) 33 | 34 | os.makedirs(a.output_dir, exist_ok=True) 35 | model = FlowMatch(hps.data.n_mel_channels, 36 | hps.model.periods, 37 | hps.model.noise_scale).cuda() 38 | 39 | num_param = get_param_num(model) 40 | print('[Model] number of Parameters:', num_param) 41 | 42 | _ = model.eval() 43 | _ = utils.load_checkpoint(a.ckpt, model, None) 44 | 45 | model.estimator.remove_weight_norm() 46 | 47 | threshold = torchcrepe.threshold.Hysteresis() 48 | 49 | energy_max = float(np.load(hps.data.energy_max, allow_pickle=True)) 50 | energy_min = float(np.load(hps.data.energy_min, allow_pickle=True)) 51 | std_min = 0.1 52 | 53 | predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).cuda() 54 | 55 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 56 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 57 | 58 | 59 | wavs_test = parse_filelist(hps.data.test_filelist_path) 60 | 61 | i = 0 62 | 63 | mel_error = 0 64 | pesq_wb = 0 65 | pesq_nb = 0 66 | 67 | pitch_total = 0 68 | periodicity_total = 0 69 | f1_total = 0 70 | utmos = 0 71 | val_mrstft_tot = 0 72 | 73 | for source_path in tqdm.tqdm(wavs_test, desc="synthesizing each utterance"): 74 | 75 | audio, _ = load_wav(source_path, hps.data.sampling_rate) 76 | audio = audio / MAX_WAV_VALUE 77 | audio = normalize(audio) * 0.95 78 | 79 | audio = torch.FloatTensor(audio) 80 | audio = audio.unsqueeze(0) 81 | if (audio.size(1) % hps.data.hop_length) != 0: 82 | audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] 83 | 84 | file_name = os.path.splitext(os.path.basename(source_path))[0] 85 | audio = audio.cuda() 86 | 87 | mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.n_mel_channels, 88 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 89 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 90 | energy = (mel.exp()).sum(1).sqrt() 91 | target_std = torch.clamp((energy - energy_min) / (energy_max - energy_min), std_min, None) 92 | target_std = torch.repeat_interleave(target_std, 256, dim=1) 93 | 94 | assert audio.shape[1] == mel.shape[2] * hps.data.hop_length, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 95 | 96 | with torch.no_grad(): 97 | 98 | resynthesis_audio = model(audio, mel, target_std.unsqueeze(0), n_timesteps=a.iter, temperature=a.noise_scale, solver=a.solver) 99 | 100 | # resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 101 | 102 | if torch.abs(resynthesis_audio).max() >= 0.95: 103 | resynthesis_audio = (resynthesis_audio / (torch.abs(resynthesis_audio).max())) * 0.95 104 | 105 | mel_hat = mel_spectrogram(resynthesis_audio.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 106 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 107 | hps.data.mel_fmin, hps.data.mel_fmax, center=False) 108 | 109 | mel_error += F.l1_loss(mel, mel_hat).item() 110 | 111 | y_16k = pesq_resampler(audio) 112 | y_g_hat_16k = pesq_resampler(resynthesis_audio.squeeze(1)) 113 | 114 | hopsize = int(256 * (torchcrepe.SAMPLE_RATE / 24000)) 115 | padding = int((1024 - hopsize) // 2) 116 | 117 | audio_for_pitch = torch.nn.functional.pad( 118 | y_16k[None], 119 | (padding, padding), 120 | mode='reflect').squeeze(0) 121 | 122 | gen_audio_for_pitch = torch.nn.functional.pad( 123 | y_g_hat_16k[None], 124 | (padding, padding), 125 | mode='reflect').squeeze(0) 126 | 127 | ori_audio_len = audio.shape[-1]//256 128 | true_pitch, true_periodicity = from_audio(audio_for_pitch.squeeze(), ori_audio_len, hopsize) 129 | fake_pitch, fake_periodicity = from_audio(gen_audio_for_pitch.squeeze(), ori_audio_len, hopsize) 130 | 131 | pitch, periodicity, f1 = p_p_F(threshold, true_pitch, true_periodicity, fake_pitch, fake_periodicity) 132 | 133 | pitch_total += pitch 134 | f1_total += f1 135 | 136 | periodicity_total += periodicity 137 | 138 | utmos += predictor(y_g_hat_16k, 16000) 139 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 140 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 141 | 142 | pesq_wb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 143 | # pesq_nb += pesq(16000, y_int_16k, y_g_hat_int_16k, 'nb') 144 | 145 | # MRSTFT calculation 146 | val_mrstft_tot += loss_mrstft(resynthesis_audio, audio).item() 147 | 148 | resynthesis_audio = resynthesis_audio.squeeze()[:audio.shape[-1]] 149 | resynthesis_audio = resynthesis_audio * MAX_WAV_VALUE 150 | resynthesis_audio = resynthesis_audio.cpu().numpy().astype('int16') 151 | 152 | file_name = os.path.splitext(os.path.basename(source_path))[0] 153 | file_name = "{}.wav".format(file_name) 154 | 155 | output_file = os.path.join('periodwave_turbo_libritts_dev'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), file_name) 156 | 157 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 158 | write(output_file, 24000, resynthesis_audio) 159 | 160 | i +=1 161 | 162 | mel_error = mel_error/i 163 | pesq_wb = pesq_wb/i 164 | pitch_total = pitch_total/i 165 | periodicity_total = periodicity_total/i 166 | f1_total = f1_total/i 167 | utmos = utmos/i 168 | val_mrstft_tot = val_mrstft_tot/i 169 | 170 | with open(os.path.join('periodwave_turbo_libritts_dev'+'_'+str(a.solver)+'_'+str(a.iter)+'_'+str(a.noise_scale), 'score_list.txt'), 'w') as f: 171 | f.write('periodwave_turbo_libritts_dev Solver:{}\nIter: {}\nNoise_scale: {}\n'.format(a.solver, a.iter, a.noise_scale)) 172 | f.write('UTMOS: {}\n'.format(utmos)) 173 | f.write('Mel L1 distance: {}\nMR-STFT: {}\n'.format(mel_error, val_mrstft_tot)) 174 | f.write('PESQ Wide Band: {}\nPESQ Narrow Band {}\n'.format(pesq_wb, pesq_nb)) 175 | f.write('Pitch: {}\nPeriodicity: {}\nV/UV F1: {}\n'.format(pitch_total, periodicity_total, f1_total)) 176 | 177 | 178 | def main(): 179 | print('Initializing Inference Process..') 180 | 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--input_dir', default='gt') 183 | parser.add_argument('--output_dir', default='test') 184 | parser.add_argument('--ckpt', default='logs/periodwave_turbo_4_msmel_45_mel_gan_2e5/G_274000.pth') 185 | parser.add_argument('--iter', default=4, type=int) 186 | parser.add_argument('--noise_scale', default=1, type=float) 187 | parser.add_argument('--solver', default='euler', help="euler midpoint heun rk4") 188 | a = parser.parse_args() 189 | 190 | global hps, device 191 | hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) 192 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 193 | 194 | inference(a) 195 | 196 | if __name__ == '__main__': 197 | main() -------------------------------------------------------------------------------- /meldataset_prior_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 5 | # LICENSE is in incl_licenses directory. 6 | 7 | import math 8 | import os 9 | import random 10 | import torch 11 | import torch.utils.data 12 | import numpy as np 13 | from librosa.util import normalize 14 | from scipy.io.wavfile import read 15 | from librosa.filters import mel as librosa_mel_fn 16 | import pathlib 17 | from tqdm import tqdm 18 | 19 | MAX_WAV_VALUE = 32768.0 20 | 21 | 22 | def load_wav(full_path, sr_target): 23 | sampling_rate, data = read(full_path) 24 | if sampling_rate != sr_target: 25 | raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz". 26 | format(full_path, sampling_rate, sr_target)) 27 | return data, sampling_rate 28 | 29 | 30 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 31 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 32 | 33 | 34 | def dynamic_range_decompression(x, C=1): 35 | return np.exp(x) / C 36 | 37 | 38 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 39 | return torch.log(torch.clamp(x, min=clip_val) * C) 40 | 41 | 42 | def dynamic_range_decompression_torch(x, C=1): 43 | return torch.exp(x) / C 44 | 45 | 46 | def spectral_normalize_torch(magnitudes): 47 | output = dynamic_range_compression_torch(magnitudes) 48 | return output 49 | 50 | 51 | def spectral_de_normalize_torch(magnitudes): 52 | output = dynamic_range_decompression_torch(magnitudes) 53 | return output 54 | 55 | 56 | mel_basis = {} 57 | hann_window = {} 58 | 59 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 60 | if torch.min(y) < -1.: 61 | print('min value is ', torch.min(y)) 62 | if torch.max(y) > 1.: 63 | print('max value is ', torch.max(y)) 64 | 65 | global mel_basis, hann_window 66 | if fmax not in mel_basis: 67 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 68 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 69 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 70 | 71 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 72 | y = y.squeeze(1) 73 | 74 | # complex tensor as default, then use view_as_real for future pytorch compatibility 75 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 76 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 77 | spec = torch.view_as_real(spec) 78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 79 | 80 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 81 | spec = spectral_normalize_torch(spec) 82 | 83 | return spec 84 | 85 | def parse_filelist(filelist_path): 86 | with open(filelist_path, 'r') as f: 87 | filelist = [line.strip() for line in f.readlines()] 88 | return filelist 89 | 90 | class MelDataset(torch.utils.data.Dataset): 91 | def __init__(self, training_files, hparams, segment_size, n_fft, num_mels, 92 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, 93 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): 94 | self.audio_files = parse_filelist(training_files) 95 | random.seed(1234) 96 | if shuffle: 97 | random.shuffle(self.audio_files) 98 | self.hparams = hparams 99 | self.segment_size = segment_size 100 | self.sampling_rate = sampling_rate 101 | self.split = split 102 | self.n_fft = n_fft 103 | self.num_mels = num_mels 104 | self.hop_size = hop_size 105 | self.win_size = win_size 106 | self.fmin = fmin 107 | self.fmax = fmax 108 | self.fmax_loss = fmax_loss 109 | self.cached_wav = None 110 | self.n_cache_reuse = n_cache_reuse 111 | self._cache_ref_count = 0 112 | self.device = device 113 | self.fine_tuning = fine_tuning 114 | self.base_mels_path = base_mels_path 115 | 116 | # self.energy_max = float(np.load('stats_priorgrad/energy_max_train.npy', allow_pickle=True)) 117 | # self.energy_min = float(np.load('stats_priorgrad/energy_min_train.npy', allow_pickle=True)) 118 | self.energy_max = float(np.load(hparams.data.energy_max, allow_pickle=True)) 119 | self.energy_min = float(np.load(hparams.data.energy_min, allow_pickle=True)) 120 | 121 | self.std_min = 0.1 122 | print("INFO: loaded frame-level waveform stats : max {} min {}".format(self.energy_max, self.energy_min)) 123 | # print("INFO: checking dataset integrity...") 124 | # for i in tqdm(range(len(self.audio_files))): 125 | # assert os.path.exists(self.audio_files[i]), "{} not found".format(self.audio_files[i]) 126 | 127 | def __getitem__(self, index): 128 | 129 | filename = self.audio_files[index] 130 | if self._cache_ref_count == 0: 131 | audio, sampling_rate = load_wav(filename, self.sampling_rate) 132 | audio = audio / MAX_WAV_VALUE 133 | if not self.fine_tuning: 134 | audio = normalize(audio) * 0.95 135 | self.cached_wav = audio 136 | if sampling_rate != self.sampling_rate: 137 | raise ValueError("{} SR doesn't match target {} SR".format( 138 | sampling_rate, self.sampling_rate)) 139 | self._cache_ref_count = self.n_cache_reuse 140 | else: 141 | audio = self.cached_wav 142 | self._cache_ref_count -= 1 143 | 144 | audio = torch.FloatTensor(audio) 145 | audio = audio.unsqueeze(0) 146 | 147 | if self.split: 148 | if audio.size(1) >= self.segment_size: 149 | max_audio_start = audio.size(1) - self.segment_size 150 | audio_start = random.randint(0, max_audio_start) 151 | audio = audio[:, audio_start:audio_start+self.segment_size] 152 | audio_length = torch.LongTensor([self.segment_size]) 153 | else: 154 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 155 | audio_length = torch.LongTensor([audio.size(1)]) 156 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 157 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 158 | center=False) 159 | else: # validation step 160 | # match audio length to self.hop_size * n for evaluation 161 | if (audio.size(1) % self.hop_size) != 0: 162 | audio = audio[:, :-(audio.size(1) % self.hop_size)] 163 | audio_length = torch.LongTensor([audio.size(1)]) 164 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 165 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 166 | center=False) 167 | assert audio.shape[1] == mel.shape[2] * self.hop_size, "audio shape {} mel shape {}".format(audio.shape, mel.shape) 168 | 169 | energy = (mel.exp()).sum(1).sqrt() 170 | target_std = torch.clamp((energy - self.energy_min) / (self.energy_max - self.energy_min), self.std_min, None) 171 | target_std = torch.repeat_interleave(target_std, self.hop_size, dim=1) 172 | return (mel.squeeze(), audio.squeeze(0), target_std, audio_length) 173 | 174 | def __len__(self): 175 | return len(self.audio_files) -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BaseModule(torch.nn.Module): 6 | def __init__(self): 7 | super(BaseModule, self).__init__() 8 | 9 | @property 10 | def nparams(self): 11 | num_params = 0 12 | for name, param in self.named_parameters(): 13 | if param.requires_grad: 14 | num_params += np.prod(param.detach().cpu().numpy().shape) 15 | return num_params 16 | 17 | 18 | def relocate_input(self, x: list): 19 | device = next(self.parameters()).device 20 | for i in range(len(x)): 21 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 22 | x[i] = x[i].to(device) 23 | return x 24 | -------------------------------------------------------------------------------- /model/commons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def init_weights(m, mean=0.0, std=0.01): 4 | classname = m.__class__.__name__ 5 | if classname.find("Conv") != -1: 6 | m.weight.data.normal_(mean, std) 7 | 8 | 9 | def get_padding(kernel_size, dilation=1): 10 | return int((kernel_size*dilation - dilation)/2) 11 | 12 | def sequence_mask(length, max_length=None): 13 | if max_length is None: 14 | max_length = length.max() 15 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 16 | 17 | return x.unsqueeze(0) < length.unsqueeze(1) 18 | -------------------------------------------------------------------------------- /model/convnext.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 6 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 7 | 8 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 9 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 10 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 11 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 12 | 'survival rate' as the argument. 13 | 14 | """ 15 | if drop_prob == 0. or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0 and scale_by_keep: 21 | random_tensor.div_(keep_prob) 22 | return x * random_tensor 23 | 24 | 25 | class DropPath(nn.Module): 26 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 27 | """ 28 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 29 | super(DropPath, self).__init__() 30 | self.drop_prob = drop_prob 31 | self.scale_by_keep = scale_by_keep 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 35 | 36 | def extra_repr(self): 37 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 38 | def get_padding(kernel_size, dilation=1): 39 | return int((kernel_size*dilation - dilation)/2) 40 | 41 | class GRN(nn.Module): 42 | """ GRN (Global Response Normalization) layer 43 | """ 44 | def __init__(self, dim): 45 | super().__init__() 46 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 47 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 48 | 49 | def forward(self, x): 50 | Gx = torch.norm(x, p=2, dim=(1), keepdim=True) 51 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 52 | return self.gamma * (x * Nx) + self.beta + x 53 | 54 | class GRN_2D(nn.Module): 55 | """ GRN (Global Response Normalization) layer 56 | """ 57 | def __init__(self, dim): 58 | super().__init__() 59 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 60 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 61 | 62 | def forward(self, x): 63 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 64 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 65 | return self.gamma * (x * Nx) + self.beta + x 66 | 67 | class ConvNeXtV2Block(nn.Module): 68 | """ 69 | We change this module from https://github.com/gemelo-ai/vocos/blob/main/vocos/modules.py for ConvNext-v2 Block 70 | ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 71 | 72 | Args: 73 | dim (int): Number of input channels. 74 | intermediate_dim (int): Dimensionality of the intermediate layer. 75 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 76 | Defaults to None. 77 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 78 | None means non-conditional LayerNorm. Defaults to None. 79 | """ 80 | def __init__( 81 | self, 82 | dim: int, 83 | intermediate_dim: int, 84 | drop_path: Optional[float] = 0 85 | ): 86 | super().__init__() 87 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 88 | self.norm = nn.LayerNorm(dim, eps=1e-6) 89 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 90 | self.act = nn.GELU() 91 | self.grn = GRN(intermediate_dim) 92 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | 95 | def forward(self, x: torch.Tensor) -> torch.Tensor: 96 | residual = x 97 | x = self.dwconv(x) 98 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 99 | 100 | x = self.norm(x) 101 | x = self.pwconv1(x) 102 | x = self.act(x) 103 | x = self.grn(x) 104 | x = self.pwconv2(x) 105 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 106 | 107 | x = residual + self.drop_path(x) 108 | return x 109 | 110 | def modulate(x, shift, scale): 111 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 112 | def modulate2(x, shift, scale): 113 | return x * (1 + scale.unsqueeze(1).unsqueeze(2)) + shift.unsqueeze(1).unsqueeze(2) 114 | 115 | class ConvNeXtV2Block2D(nn.Module): 116 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 117 | 118 | Args: 119 | dim (int): Number of input channels. 120 | intermediate_dim (int): Dimensionality of the intermediate layer. 121 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 122 | Defaults to None. 123 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 124 | None means non-conditional LayerNorm. Defaults to None. 125 | """ 126 | def __init__( 127 | self, 128 | dim: int, 129 | intermediate_dim: int, 130 | drop_path: Optional[float] = 0 131 | ): 132 | super().__init__() 133 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=(7,1), padding=(3,0), groups=dim) # depthwise conv 134 | self.norm = nn.LayerNorm(dim, eps=1e-6) 135 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 136 | self.act = nn.GELU() 137 | self.grn = GRN(intermediate_dim) 138 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 139 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | b,c,t,p = x.shape 143 | residual = x 144 | 145 | x = self.dwconv(x) 146 | x = x.reshape(b,c, t*p) 147 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 148 | 149 | x = self.norm(x) 150 | x = self.pwconv1(x) 151 | x = self.act(x) 152 | x = self.grn(x) 153 | x = self.pwconv2(x) 154 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 155 | x = x.reshape(b,c,t,p) 156 | x = residual + self.drop_path(x) 157 | return x -------------------------------------------------------------------------------- /model/diffusion_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | 5 | class Mish(BaseModule): 6 | def forward(self, x): 7 | return x * torch.tanh(torch.nn.functional.softplus(x)) 8 | 9 | class SinusoidalPosEmb(BaseModule): 10 | def __init__(self, dim): 11 | super(SinusoidalPosEmb, self).__init__() 12 | self.dim = dim 13 | 14 | def forward(self, x): 15 | device = x.device 16 | half_dim = self.dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 19 | emb = 1000.0 * x.unsqueeze(1) * emb.unsqueeze(0) 20 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 21 | return emb 22 | -------------------------------------------------------------------------------- /model/ms_mel_loss.py: -------------------------------------------------------------------------------- 1 | # From BigVGAN-v2 2 | # Copyright (c) 2024 NVIDIA CORPORATION. 3 | # Licensed under the MIT license. 4 | 5 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 6 | # LICENSE is in incl_licenses directory. 7 | 8 | 9 | import torch 10 | 11 | import torch.nn as nn 12 | from librosa.filters import mel as librosa_mel_fn 13 | from scipy import signal 14 | import typing 15 | from collections import namedtuple 16 | import math 17 | import functools 18 | from typing import List 19 | 20 | # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. 21 | # LICENSE is in incl_licenses directory. 22 | class MultiScaleMelSpectrogramLoss(nn.Module): 23 | """Compute distance between mel spectrograms. Can be used 24 | in a multi-scale way. 25 | 26 | Parameters 27 | ---------- 28 | n_mels : List[int] 29 | Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320], 30 | window_lengths : List[int], optional 31 | Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048] 32 | loss_fn : typing.Callable, optional 33 | How to compare each loss, by default nn.L1Loss() 34 | clamp_eps : float, optional 35 | Clamp on the log magnitude, below, by default 1e-5 36 | mag_weight : float, optional 37 | Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part) 38 | log_weight : float, optional 39 | Weight of log magnitude portion of loss, by default 1.0 40 | pow : float, optional 41 | Power to raise magnitude to before taking log, by default 1.0 42 | weight : float, optional 43 | Weight of this loss, by default 1.0 44 | match_stride : bool, optional 45 | Whether to match the stride of convolutional layers, by default False 46 | 47 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 48 | Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py 49 | """ 50 | 51 | def __init__( 52 | self, 53 | sampling_rate: int, 54 | n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], 55 | window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], 56 | loss_fn: typing.Callable = nn.L1Loss(), 57 | clamp_eps: float = 1e-5, 58 | mag_weight: float = 0.0, 59 | log_weight: float = 1.0, 60 | pow: float = 1.0, 61 | weight: float = 1.0, 62 | match_stride: bool = False, 63 | mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], 64 | mel_fmax: List[float] = [None, None, None, None, None, None, None], 65 | window_type: str = 'hann', 66 | ): 67 | super().__init__() 68 | self.sampling_rate = sampling_rate 69 | 70 | STFTParams = namedtuple( 71 | "STFTParams", 72 | ["window_length", "hop_length", "window_type", "match_stride"], 73 | ) 74 | 75 | self.stft_params = [ 76 | STFTParams( 77 | window_length=w, 78 | hop_length=w // 4, 79 | match_stride=match_stride, 80 | window_type=window_type, 81 | ) 82 | for w in window_lengths 83 | ] 84 | self.n_mels = n_mels 85 | self.loss_fn = loss_fn 86 | self.clamp_eps = clamp_eps 87 | self.log_weight = log_weight 88 | self.mag_weight = mag_weight 89 | self.weight = weight 90 | self.mel_fmin = mel_fmin 91 | self.mel_fmax = mel_fmax 92 | self.pow = pow 93 | 94 | @staticmethod 95 | @functools.lru_cache(None) 96 | def get_window( 97 | window_type,window_length, 98 | ): 99 | return signal.get_window(window_type, window_length) 100 | 101 | @staticmethod 102 | @functools.lru_cache(None) 103 | def get_mel_filters( 104 | sr, n_fft, n_mels, fmin, fmax 105 | ): 106 | return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 107 | 108 | def mel_spectrogram( 109 | self, wav, n_mels, fmin, fmax, window_length, hop_length, match_stride, window_type 110 | ): 111 | # mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: 112 | # https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py 113 | B, C, T = wav.shape 114 | 115 | if match_stride: 116 | assert ( 117 | hop_length == window_length // 4 118 | ), "For match_stride, hop must equal n_fft // 4" 119 | right_pad = math.ceil(T / hop_length) * hop_length - T 120 | pad = (window_length - hop_length) // 2 121 | else: 122 | right_pad = 0 123 | pad = 0 124 | 125 | wav = torch.nn.functional.pad( 126 | wav, (pad, pad + right_pad), mode='reflect' 127 | ) 128 | 129 | window = self.get_window(window_type, window_length) 130 | window = torch.from_numpy(window).to(wav.device).float() 131 | 132 | stft = torch.stft( 133 | wav.reshape(-1, T), 134 | n_fft=window_length, 135 | hop_length=hop_length, 136 | window=window, 137 | return_complex=True, 138 | center=True, 139 | ) 140 | _, nf, nt = stft.shape 141 | stft = stft.reshape(B, C, nf, nt) 142 | if match_stride: 143 | # Drop first two and last two frames, which are added 144 | # because of padding. Now num_frames * hop_length = num_samples. 145 | stft = stft[..., 2:-2] 146 | magnitude = torch.abs(stft) 147 | 148 | nf = magnitude.shape[2] 149 | mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax) 150 | mel_basis = torch.from_numpy(mel_basis).to(wav.device) 151 | mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T 152 | mel_spectrogram = mel_spectrogram.transpose(-1, 2) 153 | 154 | return mel_spectrogram 155 | 156 | def forward( 157 | self, 158 | x: torch.Tensor, 159 | y: torch.Tensor 160 | ) -> torch.Tensor: 161 | """Computes mel loss between an estimate and a reference 162 | signal. 163 | 164 | Parameters 165 | ---------- 166 | x : torch.Tensor 167 | Estimate signal 168 | y : torch.Tensor 169 | Reference signal 170 | 171 | Returns 172 | ------- 173 | torch.Tensor 174 | Mel loss. 175 | """ 176 | 177 | loss = 0.0 178 | for n_mels, fmin, fmax, s in zip( 179 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 180 | ): 181 | kwargs = { 182 | "n_mels": n_mels, 183 | "fmin": fmin, 184 | "fmax": fmax, 185 | "window_length": s.window_length, 186 | "hop_length": s.hop_length, 187 | "match_stride": s.match_stride, 188 | "window_type": s.window_type, 189 | } 190 | 191 | x_mels = self.mel_spectrogram(x, **kwargs) 192 | y_mels = self.mel_spectrogram(y, **kwargs) 193 | x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) 194 | y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) 195 | 196 | loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) 197 | loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) 198 | 199 | return loss -------------------------------------------------------------------------------- /model/ms_stftd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | import torchaudio 6 | from einops import rearrange 7 | 8 | 9 | class DiscriminatorR(torch.nn.Module): 10 | def __init__(self, resolution, use_spectral_norm=False): 11 | super(DiscriminatorR, self).__init__() 12 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 13 | 14 | n_fft, hop_length, win_length = resolution 15 | self.spec_transform = torchaudio.transforms.Spectrogram( 16 | n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, 17 | normalized=True, center=False, pad_mode=None, power=None) 18 | 19 | self.convs = nn.ModuleList([ 20 | norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), 21 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 22 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), 23 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), 24 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 25 | ]) 26 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 27 | 28 | def forward(self, y): 29 | fmap = [] 30 | 31 | x = self.spec_transform(y) # [B, 2, Freq, Frames, 2] 32 | x = torch.cat([x.real, x.imag], dim=1) 33 | x = rearrange(x, 'b c w t -> b c t w') 34 | 35 | for l in self.convs: 36 | x = l(x) 37 | x = F.leaky_relu(x, 0.1) 38 | fmap.append(x) 39 | x = self.conv_post(x) 40 | fmap.append(x) 41 | x = torch.flatten(x, 1, -1) 42 | 43 | return x, fmap 44 | 45 | 46 | class MultiScaleSTFTDiscriminator(torch.nn.Module): 47 | def __init__(self, use_spectral_norm=False): 48 | super(MultiScaleSTFTDiscriminator, self).__init__() 49 | 50 | resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] 51 | 52 | discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] 53 | 54 | self.discriminators = nn.ModuleList(discs) 55 | 56 | def forward(self, y, y_hat): 57 | y_d_rs = [] 58 | y_d_gs = [] 59 | fmap_rs = [] 60 | fmap_gs = [] 61 | for i, d in enumerate(self.discriminators): 62 | y_d_r, fmap_r = d(y) 63 | y_d_g, fmap_g = d(y_hat) 64 | y_d_rs.append(y_d_r) 65 | y_d_gs.append(y_d_g) 66 | fmap_rs.append(fmap_r) 67 | fmap_gs.append(fmap_g) 68 | 69 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 70 | -------------------------------------------------------------------------------- /model/periodwave.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | 6 | import torch.nn as nn 7 | from model.periodwave_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 8 | 9 | LRELU_SLOPE = 0.1 10 | from math import sqrt 11 | 12 | class VectorFieldEstimator(BaseModule): 13 | def __init__(self, n_mel, periods, final_dim=32, hidden_dim=512): 14 | super(VectorFieldEstimator, self).__init__() 15 | 16 | self.len_periods = len(periods) 17 | self.hidden_dim = hidden_dim 18 | ### Mel condition 19 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 20 | ### Time Condition 21 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim//2) # We recommend using hidden_dim instead of hidden_dim//2. Please refer the code of encodec version 22 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim//2) # We recommend using hidden_dim instead of hidden_dim//2. Please refer the code of encodec version 23 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim//2) ** -0.5) # We recommend using hidden_dim instead of hidden_dim//2. Please refer the code of encodec version 24 | 25 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 26 | 27 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim*4), 28 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 29 | 30 | ### Multi-period Audio U-net 31 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 32 | 33 | self.proj_layer = FinalBlock(final_dim) 34 | def remove_weight_norm(self): 35 | print('Removing weight norm...') 36 | self.MelCond.remove_weight_norm() 37 | self.mpg.remove_weight_norm() 38 | self.proj_layer.remove_weight_norm() 39 | 40 | def forward(self, x, mel, t): 41 | 42 | # Mel condition 43 | cond = self.MelCond(mel) 44 | # Time Condition 45 | t = self.time_pos_emb(t) 46 | 47 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) # We accidentally used hidden_dim instead of hidden_dim//2. Although there is no issue in our model, we recommend changing #21, #22, #23 lines following encodec version. 48 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 49 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 50 | 51 | t = self.mlp(t) 52 | 53 | xs = self.mpg(x, t, cond) 54 | 55 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 56 | x = self.proj_layer(x) 57 | 58 | return x 59 | def mel_encoder(self, mel): 60 | 61 | # Mel condition 62 | mel = self.MelCond(mel) 63 | 64 | return mel 65 | 66 | def decoder(self, x, cond, t): 67 | 68 | # Time Condition 69 | t = self.time_pos_emb(t) 70 | 71 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 72 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 73 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 74 | 75 | t = self.mlp(t) 76 | 77 | xs = self.mpg(x, t, cond) 78 | 79 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 80 | x = self.proj_layer(x) 81 | 82 | return x 83 | class FlowMatch(BaseModule): 84 | def __init__(self, n_mel=100, periods=[1,2,3,5,7], noise_scale=0.25, final_dim=32, hidden_dim=512): 85 | super().__init__() 86 | self.sigma_min = 1e-4 87 | self.noise_scale = noise_scale 88 | self.estimator = VectorFieldEstimator(n_mel, periods, final_dim, hidden_dim) 89 | 90 | @torch.no_grad() 91 | def forward(self, y, mel, target_std, n_timesteps, temperature=1.0, solver='euler'): 92 | y = y.unsqueeze(1) 93 | z = torch.randn_like(y)*self.noise_scale*target_std *temperature 94 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mel.device) 95 | mel = torch.cat([mel, target_std[:, :, ::256]], dim=1) # Concat target std 96 | mel = self.estimator.mel_encoder(mel) 97 | if solver=='euler': 98 | return self.solve_euler(z, t_span=t_span, mel=mel) 99 | elif solver=='midpoint': 100 | return self.solve_midpoint(z, t_span=t_span, mel=mel) 101 | 102 | def solve_euler(self, x, t_span, mel): 103 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 104 | t = t.reshape(1) 105 | 106 | sol = [] 107 | steps = 1 108 | while steps <= len(t_span) - 1: 109 | dphi_dt = self.estimator.decoder(x, mel, t) 110 | 111 | x = x + dt * dphi_dt 112 | t = t + dt 113 | sol.append(x) 114 | if steps < len(t_span) - 1: 115 | dt = t_span[steps + 1] - t 116 | steps += 1 117 | 118 | return sol[-1] 119 | 120 | def solve_midpoint(self, x, t_span, mel): 121 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 122 | t = t.reshape(1) 123 | 124 | sol = [] 125 | steps = 1 126 | while steps <= len(t_span) - 1: 127 | 128 | dphi_dt = self.estimator.decoder(x, mel, t) 129 | half_dt = 0.5 * dt 130 | x_mid = x + half_dt * dphi_dt 131 | 132 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt) 133 | t = t + dt 134 | sol.append(x) 135 | if steps < len(t_span) - 1: 136 | dt = t_span[steps + 1] - t 137 | steps += 1 138 | 139 | return sol[-1] 140 | 141 | def compute_loss(self, x1, mel, target_std, length): 142 | b, _, t = mel.shape 143 | mel = torch.cat([mel, target_std[:, :, ::256]], dim=1) # Concat target std 144 | 145 | x_mask = self.sequence_mask(length, x1.size(1)).to(x1.dtype) 146 | 147 | t = torch.rand([b, 1, 1], device=mel.device, dtype=mel.dtype) 148 | 149 | x1 = x1.unsqueeze(1)*x_mask 150 | x0 = torch.randn_like(x1)*self.noise_scale*target_std*x_mask # B, 1, T 151 | 152 | # Flow from x0 to x1 153 | y = (1 - (1 - self.sigma_min) * t) * x0 + t * x1 154 | # Gradient vector field 155 | u = x1 - (1 - self.sigma_min) * x0 156 | 157 | pred = self.estimator(y, mel, t.squeeze()) 158 | loss = self.scaled_mse_loss(pred, u, target_std, x_mask) 159 | 160 | return loss 161 | 162 | def scaled_mse_loss(self, decoder_output, target, target_std, x_mask): 163 | # inverse of diagonal matrix is 1/x for each element 164 | # from PriorGrad 165 | sigma_inv = torch.reciprocal(target_std) 166 | mse_loss = (((decoder_output - target) * sigma_inv) ** 2) 167 | mse_loss = (mse_loss*x_mask).sum() / x_mask.sum() 168 | return mse_loss 169 | 170 | def sequence_mask(self, length, max_length=None): 171 | if max_length is None: 172 | max_length = length.max() 173 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 174 | return x.unsqueeze(0) < length.unsqueeze(1) 175 | -------------------------------------------------------------------------------- /model/periodwave_encodec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | import torch.nn.functional as F 6 | 7 | import torch.nn as nn 8 | from model.periodwave_encodec_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 9 | 10 | LRELU_SLOPE = 0.1 11 | from math import sqrt 12 | 13 | class VectorFieldEstimator(BaseModule): 14 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], final_dim=32, hidden_dim=512): 15 | super(VectorFieldEstimator, self).__init__() 16 | 17 | self.len_periods = len(periods) 18 | self.hidden_dim = hidden_dim 19 | ### Mel condition 20 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 21 | ### Time Condition 22 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim) 23 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim) 24 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim) ** -0.5) 25 | 26 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 27 | 28 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim*2, hidden_dim*4), 29 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 30 | 31 | ### Multi-period Audio U-net 32 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 33 | 34 | self.proj_layer = FinalBlock(final_dim) 35 | def remove_weight_norm(self): 36 | print('Removing weight norm...') 37 | self.MelCond.remove_weight_norm() 38 | self.mpg.remove_weight_norm() 39 | self.proj_layer.remove_weight_norm() 40 | 41 | def forward(self, x, mel, t): 42 | 43 | # Mel condition 44 | cond = self.MelCond(mel) 45 | # Time Condition 46 | t = self.time_pos_emb(t) 47 | 48 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 49 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 50 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 51 | 52 | t = self.mlp(t) 53 | 54 | xs = self.mpg(x, t, cond) 55 | 56 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 57 | x = self.proj_layer(x) 58 | 59 | return x 60 | def mel_encoder(self, mel): 61 | 62 | # Mel condition 63 | mel = self.MelCond(mel) 64 | 65 | return mel 66 | 67 | def decoder(self, x, cond, t): 68 | 69 | # Time Condition 70 | t = self.time_pos_emb(t) 71 | 72 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 73 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 74 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 75 | 76 | t = self.mlp(t) 77 | 78 | xs = self.mpg(x, t, cond) 79 | 80 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 81 | x = self.proj_layer(x) 82 | 83 | return x 84 | class FlowMatch(BaseModule): 85 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], noise_scale=0.25, final_dim=32, hidden_dim=512): 86 | super().__init__() 87 | self.sigma_min = 1e-4 88 | self.noise_scale = noise_scale 89 | self.estimator = VectorFieldEstimator(512, periods, final_dim, hidden_dim) 90 | 91 | @torch.no_grad() 92 | def forward(self, y, embs, n_timesteps, temperature=1.0, solver='euler', sway=False, sway_coef=-1.0): 93 | y = y.unsqueeze(1) 94 | z = torch.randn_like(y)*self.noise_scale*temperature 95 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=embs.device) 96 | if sway == True: 97 | t_span = t_span + sway_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) 98 | 99 | mel = self.estimator.mel_encoder(embs) 100 | if solver=='euler': 101 | return self.solve_euler(z, t_span=t_span, mel=mel) 102 | elif solver=='midpoint': 103 | return self.solve_midpoint(z, t_span=t_span, mel=mel) 104 | 105 | def solve_euler(self, x, t_span, mel): 106 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 107 | t = t.reshape(1) 108 | 109 | sol = [] 110 | steps = 1 111 | while steps <= len(t_span) - 1: 112 | dphi_dt = self.estimator.decoder(x, mel, t) 113 | 114 | x = x + dt * dphi_dt 115 | t = t + dt 116 | sol.append(x) 117 | if steps < len(t_span) - 1: 118 | dt = t_span[steps + 1] - t 119 | steps += 1 120 | 121 | return sol[-1] 122 | 123 | def solve_midpoint(self, x, t_span, mel): 124 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 125 | t = t.reshape(1) 126 | 127 | sol = [] 128 | steps = 1 129 | while steps <= len(t_span) - 1: 130 | 131 | dphi_dt = self.estimator.decoder(x, mel, t) 132 | half_dt = 0.5 * dt 133 | x_mid = x + half_dt * dphi_dt 134 | 135 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt) 136 | t = t + dt 137 | sol.append(x) 138 | if steps < len(t_span) - 1: 139 | dt = t_span[steps + 1] - t 140 | steps += 1 141 | 142 | return sol[-1] 143 | 144 | def compute_loss(self, x1, emb, length): 145 | b, _, t = emb.shape 146 | x_mask = self.sequence_mask(length, x1.size(1)).to(x1.dtype) 147 | 148 | t = torch.rand([b, 1, 1], device=emb.device, dtype=emb.dtype) 149 | 150 | x1 = x1.unsqueeze(1)*x_mask 151 | x0 = torch.randn_like(x1)*self.noise_scale*x_mask # B, 1, T 152 | 153 | # Flow from x0 to x1 154 | y = (1 - (1 - self.sigma_min) * t) * x0 + t * x1 155 | # Gradient vector field 156 | u = x1 - (1 - self.sigma_min) * x0 157 | 158 | pred = self.estimator(y, emb, t.squeeze()) 159 | 160 | loss = F.mse_loss(pred, u) 161 | 162 | return loss 163 | 164 | def sequence_mask(self, length, max_length=None): 165 | if max_length is None: 166 | max_length = length.max() 167 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 168 | return x.unsqueeze(0) < length.unsqueeze(1) 169 | -------------------------------------------------------------------------------- /model/periodwave_encodec_freeu.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | import torch.nn.functional as F 6 | 7 | import torch.nn as nn 8 | from model.periodwave_encodec_freeu_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 9 | 10 | LRELU_SLOPE = 0.1 11 | from math import sqrt 12 | 13 | class VectorFieldEstimator(BaseModule): 14 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], final_dim=32, hidden_dim=512): 15 | super(VectorFieldEstimator, self).__init__() 16 | 17 | self.len_periods = len(periods) 18 | self.hidden_dim = hidden_dim 19 | ### Mel condition 20 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 21 | ### Time Condition 22 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim) 23 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim) 24 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim) ** -0.5) 25 | 26 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 27 | 28 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim*2, hidden_dim*4), 29 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 30 | 31 | ### Multi-period Audio U-net 32 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 33 | 34 | self.proj_layer = FinalBlock(final_dim) 35 | def remove_weight_norm(self): 36 | print('Removing weight norm...') 37 | self.MelCond.remove_weight_norm() 38 | self.mpg.remove_weight_norm() 39 | self.proj_layer.remove_weight_norm() 40 | 41 | def forward(self, x, mel, t): 42 | 43 | # Mel condition 44 | cond = self.MelCond(mel) 45 | # Time Condition 46 | t = self.time_pos_emb(t) 47 | 48 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 49 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 50 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 51 | 52 | t = self.mlp(t) 53 | 54 | xs = self.mpg(x, t, cond) 55 | 56 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 57 | x = self.proj_layer(x) 58 | 59 | return x 60 | def mel_encoder(self, mel): 61 | 62 | # Mel condition 63 | mel = self.MelCond(mel) 64 | 65 | return mel 66 | 67 | def decoder(self, x, cond, t, s_w=1, b_w=1): 68 | 69 | # Time Condition 70 | t = self.time_pos_emb(t) 71 | 72 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 73 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 74 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 75 | 76 | t = self.mlp(t) 77 | 78 | xs = self.mpg(x, t, cond, s_w=s_w, b_w=b_w) 79 | 80 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 81 | x = self.proj_layer(x) 82 | 83 | return x 84 | class FlowMatch(BaseModule): 85 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], noise_scale=0.25, final_dim=32, hidden_dim=512): 86 | super().__init__() 87 | self.sigma_min = 1e-4 88 | self.noise_scale = noise_scale 89 | self.estimator = VectorFieldEstimator(512, periods, final_dim, hidden_dim) 90 | 91 | @torch.no_grad() 92 | def forward(self, y, embs, n_timesteps, temperature=1.0, solver='euler', s_w=1, b_w=1, sway=False, sway_coef=-1.0): 93 | y = y.unsqueeze(1) 94 | z = torch.randn_like(y)*self.noise_scale*temperature 95 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=embs.device) 96 | if sway == True: 97 | t_span = t_span + sway_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) 98 | 99 | mel = self.estimator.mel_encoder(embs) 100 | if solver=='euler': 101 | return self.solve_euler(z, t_span=t_span, mel=mel, s_w=s_w, b_w=b_w) 102 | elif solver=='midpoint': 103 | return self.solve_midpoint(z, t_span=t_span, mel=mel, s_w=s_w, b_w=b_w) 104 | 105 | def solve_euler(self, x, t_span, mel, s_w=1, b_w=1): 106 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 107 | t = t.reshape(1) 108 | 109 | sol = [] 110 | steps = 1 111 | while steps <= len(t_span) - 1: 112 | dphi_dt = self.estimator.decoder(x, mel, t, s_w=s_w, b_w=b_w) 113 | 114 | x = x + dt * dphi_dt 115 | t = t + dt 116 | sol.append(x) 117 | if steps < len(t_span) - 1: 118 | dt = t_span[steps + 1] - t 119 | steps += 1 120 | 121 | return sol[-1] 122 | 123 | def solve_midpoint(self, x, t_span, mel, s_w=1, b_w=1): 124 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 125 | t = t.reshape(1) 126 | 127 | sol = [] 128 | steps = 1 129 | while steps <= len(t_span) - 1: 130 | 131 | dphi_dt = self.estimator.decoder(x, mel, t, s_w=s_w, b_w=b_w) 132 | half_dt = 0.5 * dt 133 | x_mid = x + half_dt * dphi_dt 134 | 135 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt, s_w=s_w, b_w=b_w) 136 | t = t + dt 137 | sol.append(x) 138 | if steps < len(t_span) - 1: 139 | dt = t_span[steps + 1] - t 140 | steps += 1 141 | 142 | return sol[-1] 143 | 144 | def compute_loss(self, x1, emb, length): 145 | b, _, t = emb.shape 146 | x_mask = self.sequence_mask(length, x1.size(1)).to(x1.dtype) 147 | 148 | t = torch.rand([b, 1, 1], device=emb.device, dtype=emb.dtype) 149 | 150 | x1 = x1.unsqueeze(1)*x_mask 151 | x0 = torch.randn_like(x1)*self.noise_scale*x_mask # B, 1, T 152 | 153 | # Flow from x0 to x1 154 | y = (1 - (1 - self.sigma_min) * t) * x0 + t * x1 155 | # Gradient vector field 156 | u = x1 - (1 - self.sigma_min) * x0 157 | 158 | pred = self.estimator(y, emb, t.squeeze()) 159 | 160 | loss = F.mse_loss(pred, u) 161 | 162 | return loss 163 | 164 | def sequence_mask(self, length, max_length=None): 165 | if max_length is None: 166 | max_length = length.max() 167 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 168 | return x.unsqueeze(0) < length.unsqueeze(1) 169 | -------------------------------------------------------------------------------- /model/periodwave_encodec_turbo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | import torch.nn.functional as F 6 | 7 | import torch.nn as nn 8 | from model.periodwave_encodec_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 9 | 10 | LRELU_SLOPE = 0.1 11 | from math import sqrt 12 | from model.ms_mel_loss import MultiScaleMelSpectrogramLoss 13 | 14 | class VectorFieldEstimator(BaseModule): 15 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], final_dim=32, hidden_dim=512): 16 | super(VectorFieldEstimator, self).__init__() 17 | 18 | self.len_periods = len(periods) 19 | self.hidden_dim = hidden_dim 20 | ### Mel condition 21 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 22 | ### Time Condition 23 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim) 24 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim) 25 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim) ** -0.5) 26 | 27 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 28 | 29 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim*2, hidden_dim*4), 30 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 31 | 32 | ### Multi-period Audio U-net 33 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 34 | 35 | self.proj_layer = FinalBlock(final_dim) 36 | def remove_weight_norm(self): 37 | print('Removing weight norm...') 38 | self.MelCond.remove_weight_norm() 39 | self.mpg.remove_weight_norm() 40 | self.proj_layer.remove_weight_norm() 41 | 42 | def forward(self, x, mel, t): 43 | 44 | # Mel condition 45 | cond = self.MelCond(mel) 46 | # Time Condition 47 | t = self.time_pos_emb(t) 48 | 49 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 50 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 51 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 52 | 53 | t = self.mlp(t) 54 | 55 | xs = self.mpg(x, t, cond) 56 | 57 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 58 | x = self.proj_layer(x) 59 | 60 | return x 61 | def mel_encoder(self, mel): 62 | 63 | # Mel condition 64 | mel = self.MelCond(mel) 65 | 66 | return mel 67 | 68 | def decoder(self, x, cond, t): 69 | 70 | # Time Condition 71 | t = self.time_pos_emb(t) 72 | 73 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 74 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 75 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 76 | 77 | t = self.mlp(t) 78 | 79 | xs = self.mpg(x, t, cond) 80 | 81 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 82 | x = self.proj_layer(x) 83 | 84 | return x 85 | class FlowMatch(BaseModule): 86 | def __init__(self, n_mel=512, periods=[1,2,3,5,7], noise_scale=0.25, final_dim=32, hidden_dim=512): 87 | super().__init__() 88 | self.sigma_min = 1e-4 89 | self.noise_scale = noise_scale 90 | self.estimator = VectorFieldEstimator(512, periods, final_dim, hidden_dim) 91 | self.msmel = MultiScaleMelSpectrogramLoss(sampling_rate=24000) 92 | 93 | def forward(self, y, embs, n_timesteps, temperature=1.0, solver='euler'): 94 | 95 | if y.shape[1] != 1: 96 | y = y.unsqueeze(1) 97 | z = torch.randn_like(y)*self.noise_scale*temperature 98 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=embs.device) 99 | 100 | mel = self.estimator.mel_encoder(embs) 101 | if solver=='euler': 102 | return self.solve_euler(z, t_span=t_span, mel=mel) 103 | elif solver=='midpoint': 104 | return self.solve_midpoint(z, t_span=t_span, mel=mel) 105 | 106 | def solve_euler(self, x, t_span, mel): 107 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 108 | t = t.reshape(1) 109 | 110 | sol = [] 111 | steps = 1 112 | while steps <= len(t_span) - 1: 113 | dphi_dt = self.estimator.decoder(x, mel, t) 114 | 115 | x = x + dt * dphi_dt 116 | t = t + dt 117 | sol.append(x) 118 | if steps < len(t_span) - 1: 119 | dt = t_span[steps + 1] - t 120 | steps += 1 121 | 122 | return sol[-1] 123 | 124 | def solve_midpoint(self, x, t_span, mel): 125 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 126 | t = t.reshape(1) 127 | 128 | sol = [] 129 | steps = 1 130 | while steps <= len(t_span) - 1: 131 | 132 | dphi_dt = self.estimator.decoder(x, mel, t) 133 | half_dt = 0.5 * dt 134 | x_mid = x + half_dt * dphi_dt 135 | 136 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt) 137 | t = t + dt 138 | sol.append(x) 139 | if steps < len(t_span) - 1: 140 | dt = t_span[steps + 1] - t 141 | steps += 1 142 | 143 | return sol[-1] 144 | 145 | def compute_loss(self, x1, emb, length, tuning_steps, temperature): 146 | b, _, t = emb.shape 147 | x_mask = self.sequence_mask(length, x1.size(1)).to(x1.dtype) 148 | 149 | x1 = x1.unsqueeze(1)*x_mask 150 | 151 | pred_x1 = self.forward(x1, emb, n_timesteps=tuning_steps, temperature=temperature) 152 | 153 | loss = self.msmel(pred_x1*x_mask, x1) 154 | 155 | return loss, pred_x1 156 | 157 | def sequence_mask(self, length, max_length=None): 158 | if max_length is None: 159 | max_length = length.max() 160 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 161 | return x.unsqueeze(0) < length.unsqueeze(1) 162 | -------------------------------------------------------------------------------- /model/periodwave_freeu.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | import torch.nn as nn 6 | from model.periodwave_utils_freeu import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 7 | 8 | LRELU_SLOPE = 0.1 9 | from math import sqrt 10 | 11 | class GradLogPEstimator(BaseModule): 12 | def __init__(self, n_mel, periods, final_dim=32, hidden_dim=512): 13 | super(GradLogPEstimator, self).__init__() 14 | 15 | self.len_periods = len(periods) 16 | self.hidden_dim = hidden_dim 17 | ### Mel condition 18 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 19 | ### Time Condition 20 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim//2) 21 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim//2) 22 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim//2) ** -0.5) 23 | 24 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 25 | 26 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim*4), 27 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 28 | 29 | ### Multi-period Audio U-net 30 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 31 | 32 | self.proj_layer = FinalBlock(final_dim) 33 | 34 | def remove_weight_norm(self): 35 | print('Removing weight norm...') 36 | self.MelCond.remove_weight_norm() 37 | self.mpg.remove_weight_norm() 38 | self.proj_layer.remove_weight_norm() 39 | 40 | def forward(self, x, mel, t): 41 | 42 | # Mel condition 43 | cond = self.MelCond(mel) 44 | # Time Condition 45 | t = self.time_pos_emb(t) 46 | 47 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 48 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 49 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 50 | 51 | t = self.mlp(t) 52 | 53 | xs = self.mpg(x, t, cond) 54 | 55 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 56 | x = self.proj_layer(x) 57 | 58 | return x 59 | def mel_encoder(self, mel): 60 | 61 | # Mel condition 62 | mel = self.MelCond(mel) 63 | 64 | return mel 65 | 66 | def decoder(self, x, cond, t, s_w=1, b_w=1): 67 | 68 | # Time Condition 69 | t = self.time_pos_emb(t) 70 | 71 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 72 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 73 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 74 | 75 | t = self.mlp(t) 76 | 77 | xs = self.mpg(x, t, cond, s_w=s_w, b_w=b_w) 78 | 79 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 80 | x = self.proj_layer(x) 81 | 82 | return x 83 | class FlowMatch(BaseModule): 84 | def __init__(self, n_mel=100, periods=[2,3,5,7,11], noise_scale=0.25): 85 | super().__init__() 86 | self.sigma_min = 1e-4 87 | self.noise_scale = noise_scale 88 | self.estimator = GradLogPEstimator(n_mel, periods) 89 | 90 | def forward(self, y, mel, target_std, n_timesteps, temperature=1.0, solver='euler', s_w=1, b_w=1): 91 | 92 | if y.shape[1] != 1: 93 | y = y.unsqueeze(1) 94 | z = torch.randn_like(y)*self.noise_scale*target_std *temperature 95 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mel.device) 96 | 97 | mel = torch.cat([mel, target_std[:, :, ::256]], dim=1) # Concat target std 98 | mel = self.estimator.mel_encoder(mel) 99 | if solver=='euler': 100 | return self.solve_euler(z, t_span=t_span, mel=mel, s_w=s_w, b_w=b_w) 101 | elif solver=='midpoint': 102 | return self.solve_midpoint(z, t_span=t_span, mel=mel, s_w=s_w, b_w=b_w) 103 | 104 | def solve_euler(self, x, t_span, mel, s_w=1, b_w=1): 105 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 106 | t = t.reshape(1) 107 | 108 | sol = [] 109 | steps = 1 110 | while steps <= len(t_span) - 1: 111 | dphi_dt = self.estimator.decoder(x, mel, t, s_w=s_w, b_w=b_w) 112 | 113 | x = x + dt * dphi_dt 114 | t = t + dt 115 | sol.append(x) 116 | if steps < len(t_span) - 1: 117 | dt = t_span[steps + 1] - t 118 | steps += 1 119 | 120 | return sol[-1] 121 | 122 | def solve_midpoint(self, x, t_span, mel, s_w=1, b_w=1): 123 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 124 | t = t.reshape(1) 125 | 126 | sol = [] 127 | steps = 1 128 | while steps <= len(t_span) - 1: 129 | 130 | dphi_dt = self.estimator.decoder(x, mel, t, s_w=s_w, b_w=b_w) 131 | half_dt = 0.5 * dt 132 | x_mid = x + half_dt * dphi_dt 133 | 134 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt, s_w=s_w, b_w=b_w) 135 | t = t + dt 136 | sol.append(x) 137 | if steps < len(t_span) - 1: 138 | dt = t_span[steps + 1] - t 139 | steps += 1 140 | 141 | return sol[-1] 142 | 143 | -------------------------------------------------------------------------------- /model/periodwave_large.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | 6 | import torch.nn as nn 7 | from model.periodwave_large_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 8 | 9 | LRELU_SLOPE = 0.1 10 | from math import sqrt 11 | 12 | class GradLogPEstimator(BaseModule): 13 | def __init__(self, n_mel, periods, final_dim=48, hidden_dim=768): 14 | super(GradLogPEstimator, self).__init__() 15 | 16 | self.len_periods = len(periods) 17 | self.hidden_dim = hidden_dim 18 | ### Mel condition 19 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 20 | ### Time Condition 21 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim//2) 22 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim//2) 23 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim//2) ** -0.5) 24 | 25 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 26 | 27 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim*4), 28 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 29 | 30 | ### Multi-period Audio U-net 31 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 32 | 33 | self.proj_layer = FinalBlock(final_dim) 34 | 35 | def remove_weight_norm(self): 36 | print('Removing weight norm...') 37 | self.MelCond.remove_weight_norm() 38 | self.mpg.remove_weight_norm() 39 | self.proj_layer.remove_weight_norm() 40 | 41 | def forward(self, x, mel, t): 42 | 43 | # Mel condition 44 | cond = self.MelCond(mel) 45 | # Time Condition 46 | t = self.time_pos_emb(t) 47 | 48 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 49 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 50 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 51 | 52 | t = self.mlp(t) 53 | 54 | xs = self.mpg(x, t, cond) 55 | 56 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 57 | x = self.proj_layer(x) 58 | 59 | return x 60 | def mel_encoder(self, mel): 61 | 62 | # Mel condition 63 | mel = self.MelCond(mel) 64 | 65 | return mel 66 | 67 | def decoder(self, x, cond, t): 68 | 69 | # Time Condition 70 | t = self.time_pos_emb(t) 71 | 72 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 73 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 74 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 75 | 76 | t = self.mlp(t) 77 | 78 | xs = self.mpg(x, t, cond) 79 | 80 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 81 | x = self.proj_layer(x) 82 | 83 | return x 84 | class FlowMatch(BaseModule): 85 | def __init__(self, n_mel=100, periods=[2,3,5,7,11], noise_scale=0.25): 86 | super().__init__() 87 | self.sigma_min = 1e-4 88 | self.noise_scale = noise_scale 89 | self.estimator = GradLogPEstimator(n_mel, periods) 90 | 91 | def forward(self, y, mel, target_std, n_timesteps, temperature=1.0, solver='euler'): 92 | 93 | if y.shape[1] != 1: 94 | y = y.unsqueeze(1) 95 | z = torch.randn_like(y)*self.noise_scale*target_std *temperature 96 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mel.device) 97 | 98 | mel = torch.cat([mel, target_std[:, :, ::256]], dim=1) # Concat target std 99 | mel = self.estimator.mel_encoder(mel) 100 | if solver=='euler': 101 | return self.solve_euler(z, t_span=t_span, mel=mel) 102 | elif solver=='midpoint': 103 | return self.solve_midpoint(z, t_span=t_span, mel=mel) 104 | 105 | def solve_euler(self, x, t_span, mel): 106 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 107 | t = t.reshape(1) 108 | 109 | sol = [] 110 | steps = 1 111 | while steps <= len(t_span) - 1: 112 | dphi_dt = self.estimator.decoder(x, mel, t) 113 | 114 | x = x + dt * dphi_dt 115 | t = t + dt 116 | sol.append(x) 117 | if steps < len(t_span) - 1: 118 | dt = t_span[steps + 1] - t 119 | steps += 1 120 | 121 | return sol[-1] 122 | 123 | def solve_midpoint(self, x, t_span, mel): 124 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 125 | t = t.reshape(1) 126 | 127 | sol = [] 128 | steps = 1 129 | while steps <= len(t_span) - 1: 130 | 131 | dphi_dt = self.estimator.decoder(x, mel, t) 132 | half_dt = 0.5 * dt 133 | x_mid = x + half_dt * dphi_dt 134 | 135 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt) 136 | t = t + dt 137 | sol.append(x) 138 | if steps < len(t_span) - 1: 139 | dt = t_span[steps + 1] - t 140 | steps += 1 141 | 142 | return sol[-1] 143 | -------------------------------------------------------------------------------- /model/periodwave_turbo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from model.base import BaseModule 4 | from model.diffusion_module import SinusoidalPosEmb 5 | 6 | import torch.nn as nn 7 | from model.periodwave_utils import MultiPeriodGenerator, MelSpectrogramUpsampler, FinalBlock 8 | LRELU_SLOPE = 0.1 9 | from math import sqrt 10 | from model.ms_mel_loss import MultiScaleMelSpectrogramLoss 11 | 12 | class VectorFieldEstimator(BaseModule): 13 | def __init__(self, n_mel, periods, final_dim=32, hidden_dim=512): 14 | super(VectorFieldEstimator, self).__init__() 15 | 16 | self.len_periods = len(periods) 17 | self.hidden_dim = hidden_dim 18 | ### Mel condition 19 | self.MelCond = MelSpectrogramUpsampler(n_mel, periods, hidden_dim) 20 | ### Time Condition 21 | self.time_pos_emb = SinusoidalPosEmb(hidden_dim//2) 22 | self.period_emb = nn.Embedding(self.len_periods, hidden_dim//2) 23 | torch.nn.init.normal_(self.period_emb.weight, 0.0, (hidden_dim//2) ** -0.5) 24 | 25 | self.period_token = torch.LongTensor([i for i in range(self.len_periods)]).cuda() 26 | 27 | self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim*4), 28 | nn.SiLU(), torch.nn.Linear(hidden_dim*4, hidden_dim)) 29 | 30 | ### Multi-period Audio U-net 31 | self.mpg = MultiPeriodGenerator(periods=periods, final_dim=final_dim, hidden_dim=hidden_dim) 32 | 33 | self.proj_layer = FinalBlock(final_dim) 34 | 35 | def remove_weight_norm(self): 36 | print('Removing weight norm...') 37 | self.MelCond.remove_weight_norm() 38 | self.mpg.remove_weight_norm() 39 | self.proj_layer.remove_weight_norm() 40 | 41 | def forward(self, x, mel, t): 42 | 43 | # Mel condition 44 | cond = self.MelCond(mel) 45 | # Time Condition 46 | t = self.time_pos_emb(t) 47 | 48 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 49 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 50 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 51 | 52 | t = self.mlp(t) 53 | 54 | xs = self.mpg(x, t, cond) 55 | 56 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 57 | x = self.proj_layer(x) 58 | 59 | return x 60 | def mel_encoder(self, mel): 61 | 62 | # Mel condition 63 | mel = self.MelCond(mel) 64 | 65 | return mel 66 | 67 | def decoder(self, x, cond, t): 68 | 69 | # Time Condition 70 | t = self.time_pos_emb(t) 71 | 72 | p = self.period_emb(self.period_token) * math.sqrt(self.hidden_dim) 73 | p = p.unsqueeze(0).expand(t.shape[0], self.len_periods, -1) 74 | t = torch.concat([t.unsqueeze(1).expand(-1, p.shape[1], -1), p], dim=2) 75 | 76 | t = self.mlp(t) 77 | 78 | xs = self.mpg(x, t, cond) 79 | 80 | x = torch.sum(torch.stack(xs), dim=0) / sqrt(self.len_periods) 81 | x = self.proj_layer(x) 82 | 83 | return x 84 | class FlowMatch(BaseModule): 85 | def __init__(self, n_mel=100, periods=[1,2,3,5,7], noise_scale=0.25, final_dim=32, hidden_dim=512, sampling_rate=24000): 86 | super().__init__() 87 | self.sigma_min = 1e-4 88 | self.noise_scale = noise_scale 89 | self.estimator = VectorFieldEstimator(n_mel, periods, final_dim, hidden_dim) 90 | self.msmel = MultiScaleMelSpectrogramLoss(sampling_rate=sampling_rate) 91 | 92 | def forward(self, y, mel, target_std, n_timesteps, temperature=1.0, solver='euler'): 93 | 94 | if y.shape[1] != 1: 95 | y = y.unsqueeze(1) 96 | z = torch.randn_like(y)*self.noise_scale*target_std *temperature 97 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mel.device) 98 | 99 | mel = torch.cat([mel, target_std[:, :, ::256]], dim=1) # Concat target std 100 | mel = self.estimator.mel_encoder(mel) 101 | if solver=='euler': 102 | return self.solve_euler(z, t_span=t_span, mel=mel) 103 | elif solver=='midpoint': 104 | return self.solve_midpoint(z, t_span=t_span, mel=mel) 105 | 106 | def solve_euler(self, x, t_span, mel): 107 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 108 | t = t.reshape(1) 109 | 110 | sol = [] 111 | steps = 1 112 | while steps <= len(t_span) - 1: 113 | dphi_dt = self.estimator.decoder(x, mel, t) 114 | 115 | x = x + dt * dphi_dt 116 | t = t + dt 117 | sol.append(x) 118 | if steps < len(t_span) - 1: 119 | dt = t_span[steps + 1] - t 120 | steps += 1 121 | 122 | return sol[-1] 123 | 124 | def solve_midpoint(self, x, t_span, mel): 125 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 126 | t = t.reshape(1) 127 | 128 | sol = [] 129 | steps = 1 130 | while steps <= len(t_span) - 1: 131 | 132 | dphi_dt = self.estimator.decoder(x, mel, t) 133 | half_dt = 0.5 * dt 134 | x_mid = x + half_dt * dphi_dt 135 | 136 | x = x + dt * self.estimator.decoder(x_mid, mel, t+half_dt) 137 | t = t + dt 138 | sol.append(x) 139 | if steps < len(t_span) - 1: 140 | dt = t_span[steps + 1] - t 141 | steps += 1 142 | 143 | return sol[-1] 144 | 145 | def compute_loss(self, x1, mel, target_std, length, tuning_steps, temperature): 146 | b, _, t = mel.shape 147 | 148 | x_mask = self.sequence_mask(length, x1.size(1)).to(x1.dtype) 149 | 150 | x1 = x1.unsqueeze(1)*x_mask 151 | 152 | pred_x1 = self.forward(x1, mel, target_std, n_timesteps=tuning_steps, temperature=temperature) 153 | 154 | loss = self.msmel(pred_x1*x_mask, x1) 155 | 156 | return loss, pred_x1 157 | 158 | def scaled_mse_loss(self, decoder_output, target, target_std, x_mask): 159 | # inverse of diagonal matrix is 1/x for each element 160 | sigma_inv = torch.reciprocal(target_std) 161 | mse_loss = (((decoder_output - target) * sigma_inv) ** 2) 162 | mse_loss = (mse_loss*x_mask).sum() / x_mask.sum() 163 | return mse_loss 164 | 165 | def sequence_mask(self, length, max_length=None): 166 | if max_length is None: 167 | max_length = length.max() 168 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 169 | return x.unsqueeze(0) < length.unsqueeze(1) 170 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import logging 5 | import torch 6 | import torchaudio 7 | import numpy as np 8 | 9 | from librosa.filters import mel as librosa_mel_fn 10 | from model.base import BaseModule 11 | 12 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 13 | logger = logging 14 | 15 | def mse_loss(x, y, mask, n_feats): 16 | loss = torch.sum(((x - y)**2) * mask) 17 | return loss / (torch.sum(mask) * n_feats) 18 | 19 | def sequence_mask(length, max_length=None): 20 | if max_length is None: 21 | max_length = length.max() 22 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 23 | return x.unsqueeze(0) < length.unsqueeze(1) 24 | 25 | def convert_pad_shape(pad_shape): 26 | l = pad_shape[::-1] 27 | pad_shape = [item for sublist in l for item in sublist] 28 | return pad_shape 29 | 30 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 31 | while True: 32 | if length % (2**num_downsamplings_in_unet) == 0: 33 | return length 34 | length += 1 35 | 36 | class PseudoInversion(BaseModule): 37 | def __init__(self, n_mels, sampling_rate, n_fft): 38 | super(PseudoInversion, self).__init__() 39 | self.n_mels = n_mels 40 | self.sampling_rate = sampling_rate 41 | self.n_fft = n_fft 42 | mel_basis = librosa_mel_fn(sampling_rate, n_fft, n_mels, 0, 8000) 43 | mel_basis_inverse = np.linalg.pinv(mel_basis) 44 | mel_basis_inverse = torch.from_numpy(mel_basis_inverse).float() 45 | self.register_buffer("mel_basis_inverse", mel_basis_inverse) 46 | 47 | def forward(self, log_mel_spectrogram): 48 | mel_spectrogram = torch.exp(log_mel_spectrogram) 49 | stftm = torch.matmul(self.mel_basis_inverse, mel_spectrogram) 50 | return stftm 51 | 52 | class InitialReconstruction(BaseModule): 53 | def __init__(self, n_fft, hop_size): 54 | super(InitialReconstruction, self).__init__() 55 | self.n_fft = n_fft 56 | self.hop_size = hop_size 57 | window = torch.hann_window(n_fft).float() 58 | self.register_buffer("window", window) 59 | 60 | def forward(self, stftm): 61 | real_part = torch.ones_like(stftm, device=stftm.device) 62 | imag_part = torch.zeros_like(stftm, device=stftm.device) 63 | stft = torch.stack([real_part, imag_part], -1)*stftm.unsqueeze(-1) 64 | istft = torchaudio.functional.istft(stft, n_fft=self.n_fft, 65 | hop_length=self.hop_size, win_length=self.n_fft, 66 | window=self.window, center=True) 67 | return istft.unsqueeze(1) 68 | 69 | def load_checkpoint(checkpoint_path, model, optimizer=None): 70 | assert os.path.isfile(checkpoint_path) 71 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 72 | iteration = checkpoint_dict['iteration'] 73 | learning_rate = checkpoint_dict['learning_rate'] 74 | if optimizer is not None: 75 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 76 | saved_state_dict = checkpoint_dict['model'] 77 | if hasattr(model, 'module'): 78 | state_dict = model.module.state_dict() 79 | else: 80 | state_dict = model.state_dict() 81 | new_state_dict = {} 82 | for k, v in state_dict.items(): 83 | try: 84 | new_state_dict[k] = saved_state_dict[k] 85 | except: 86 | logger.info("%s is not in the checkpoint" % k) 87 | new_state_dict[k] = v 88 | if hasattr(model, 'module'): 89 | model.module.load_state_dict(new_state_dict) 90 | else: 91 | model.load_state_dict(new_state_dict) 92 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 93 | checkpoint_path, iteration)) 94 | return model, optimizer, learning_rate, iteration -------------------------------------------------------------------------------- /periodwave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/periodwave.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | auraloss==0.4.0 2 | einops==0.8.1 3 | encodec==0.1.1 4 | huggingface_hub==0.26.2 5 | librosa==0.9.2 6 | matplotlib==3.10.0 7 | nnAudio==0.3.3 8 | numpy==2.2.2 9 | pesq==0.0.4 10 | scipy==1.15.1 11 | torch==2.4.0 12 | torchaudio==2.4.0 13 | torchcrepe==0.0.23 14 | tqdm==4.66.4 15 | -------------------------------------------------------------------------------- /stats_libritts_24000hz/energy_max_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/stats_libritts_24000hz/energy_max_train.npy -------------------------------------------------------------------------------- /stats_libritts_24000hz/energy_min_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/stats_libritts_24000hz/energy_min_train.npy -------------------------------------------------------------------------------- /stats_lj_22050hz/energy_max_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/stats_lj_22050hz/energy_max_train.npy -------------------------------------------------------------------------------- /stats_lj_22050hz/energy_min_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/stats_lj_22050hz/energy_min_train.npy -------------------------------------------------------------------------------- /test/Triviul_feat._The_Fiend_-_Widow.stem.vocals_part180.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh-lee-prml/PeriodWave/c5c40736daacc3ccc831f4d22d246a832886425a/test/Triviul_feat._The_Fiend_-_Widow.stem.vocals_part180.wav -------------------------------------------------------------------------------- /train_periodwave.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | from torch.cuda.amp import autocast, GradScaler 9 | 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import torchaudio 14 | import random 15 | import commons 16 | import utils 17 | from meldataset_prior_length import MelDataset, mel_spectrogram, MAX_WAV_VALUE 18 | from torch.utils.data.distributed import DistributedSampler 19 | from torch.utils.data import DataLoader 20 | import auraloss 21 | from pesq import pesq 22 | from model.periodwave import FlowMatch 23 | 24 | torch.backends.cudnn.benchmark = True 25 | global_step = 0 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def main(): 32 | """Assume Single Node Multi GPUs Training Only""" 33 | assert torch.cuda.is_available(), "CPU training is not allowed." 34 | 35 | n_gpus = torch.cuda.device_count() 36 | port = 50000 + random.randint(0, 100) 37 | os.environ['MASTER_ADDR'] = 'localhost' 38 | os.environ['MASTER_PORT'] = str(port) 39 | 40 | hps = utils.get_hparams() 41 | if n_gpus > 1: 42 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 43 | else: 44 | run(0, n_gpus, hps) 45 | 46 | 47 | def run(rank, n_gpus, hps): 48 | global global_step 49 | if rank == 0: 50 | logger = utils.get_logger(hps.model_dir) 51 | logger.info(hps) 52 | utils.check_git_hash(hps.model_dir) 53 | writer = SummaryWriter(log_dir=hps.model_dir) 54 | if n_gpus > 1: 55 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 56 | 57 | torch.manual_seed(hps.train.seed) 58 | torch.cuda.set_device(rank) 59 | device = torch.device('cuda:{:d}'.format(rank)) 60 | 61 | train_dataset = MelDataset(hps.data.train_filelist_path, hps, hps.train.segment_size, hps.data.filter_length, hps.data.n_mel_channels, 62 | hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax, n_cache_reuse=0, 63 | shuffle=False if n_gpus > 1 else True, device=device) 64 | 65 | train_sampler = DistributedSampler(train_dataset) if n_gpus > 1 else None 66 | train_loader = DataLoader( 67 | train_dataset, batch_size=hps.train.batch_size, num_workers=16, shuffle=False, 68 | sampler=train_sampler, drop_last=True, pin_memory=True 69 | ) 70 | 71 | if rank == 0: 72 | test_dataset = MelDataset(hps.data.test_filelist_path, hps, hps.train.segment_size, hps.data.filter_length, hps.data.n_mel_channels, 73 | hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax, n_cache_reuse=0, split=False, shuffle=False, 74 | device=device) 75 | eval_loader = DataLoader(test_dataset, batch_size=1) 76 | 77 | model = FlowMatch(hps.data.n_mel_channels, 78 | hps.model.periods, 79 | hps.model.noise_scale, 80 | hps.model.final_dim, 81 | hps.model.hidden_dim).cuda() 82 | 83 | if rank == 0: 84 | num_param = get_param_num(model) 85 | print('number of Parameters:', num_param) 86 | 87 | optimizer = torch.optim.AdamW( 88 | model.parameters(), 89 | hps.train.learning_rate) 90 | 91 | if n_gpus > 1: 92 | model = DDP(model, device_ids=[rank]) 93 | 94 | try: 95 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), model, 96 | optimizer) 97 | 98 | global_step = (epoch_str - 1) * len(train_loader) 99 | except: 100 | epoch_str = 1 101 | global_step = 0 102 | 103 | scaler = GradScaler(enabled=hps.train.fp16_run) 104 | 105 | for epoch in range(epoch_str, hps.train.epochs + 1): 106 | if rank == 0: 107 | train_and_evaluate(rank, epoch, hps, model, optimizer, 108 | scaler, 109 | [train_loader, eval_loader], logger, writer, n_gpus) 110 | else: 111 | train_and_evaluate(rank, epoch, hps, model, optimizer, 112 | scaler, 113 | [train_loader, None], None, None, n_gpus) 114 | 115 | def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, loaders, logger, writers, n_gpus): 116 | model = nets 117 | optimizer = optims 118 | train_loader, eval_loader = loaders 119 | 120 | if writers is not None: 121 | writer = writers 122 | 123 | global global_step 124 | 125 | if n_gpus > 1: 126 | train_loader.sampler.set_epoch(epoch) 127 | 128 | model.train() 129 | for batch_idx, (mel, y, target_std, length) in enumerate(train_loader): 130 | y = y.cuda(rank, non_blocking=True) 131 | mel = mel.cuda(rank, non_blocking=True) 132 | target_std = target_std.cuda(rank, non_blocking=True) 133 | length = length.cuda(rank, non_blocking=True) 134 | 135 | optimizer.zero_grad() 136 | if n_gpus > 1: 137 | loss_fm = model.module.compute_loss(y, mel, target_std, length) 138 | else: 139 | loss_fm = model.compute_loss(y, mel, target_std, length) 140 | 141 | loss_gen_all = loss_fm 142 | 143 | loss_gen_all.backward() 144 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 145 | optimizer.step() 146 | 147 | if rank == 0: 148 | if global_step % hps.train.log_interval == 0: 149 | lr = optimizer.param_groups[0]['lr'] 150 | losses = [loss_fm] 151 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 152 | epoch, 153 | 100. * batch_idx / len(train_loader))) 154 | logger.info([x.item() for x in losses] + [global_step, lr]) 155 | 156 | scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} 157 | scalar_dict.update({"loss/g/fm": loss_fm}) 158 | 159 | utils.summarize( 160 | writer=writer, 161 | global_step=global_step, 162 | scalars=scalar_dict) 163 | 164 | if global_step % hps.train.eval_interval == 0: 165 | torch.cuda.empty_cache() 166 | evaluate(hps, model, eval_loader, writer) 167 | 168 | if global_step % hps.train.save_interval == 0: 169 | utils.save_checkpoint(model, optimizer, hps.train.learning_rate, epoch, 170 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 171 | 172 | global_step += 1 173 | 174 | if rank == 0: 175 | logger.info('====> Epoch: {}'.format(epoch)) 176 | 177 | def evaluate(hps, model, eval_loader, writer_eval): 178 | model.eval() 179 | image_dict = {} 180 | audio_dict = {} 181 | 182 | # modules for evaluation metrics 183 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 184 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 185 | 186 | val_err_tot = 0 187 | val_pesq_tot = 0 188 | val_mrstft_tot = 0 189 | 190 | with torch.no_grad(): 191 | for batch_idx, (mel, y, target_std,_) in enumerate(eval_loader): 192 | 193 | y = y.cuda(0) 194 | mel = mel.cuda(0) 195 | target_std = target_std.cuda(0) 196 | 197 | y_gen = model(y, mel, target_std, n_timesteps=16, temperature=1.0) 198 | 199 | if torch.abs(y_gen).max() >= 0.95: 200 | y_gen = (y_gen / (torch.abs(y_gen).max())) * 0.95 201 | 202 | y_gen_mel = mel_spectrogram(y_gen.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 203 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 204 | hps.data.mel_fmin, hps.data.mel_fmax) 205 | 206 | val_err_tot += F.l1_loss(mel, y_gen_mel).item() 207 | 208 | y_16k = pesq_resampler(y) 209 | y_g_hat_16k = pesq_resampler(y_gen.squeeze(1)) 210 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 211 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 212 | val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 213 | 214 | # MRSTFT calculation 215 | val_mrstft_tot += loss_mrstft(y_gen, y).item() 216 | 217 | if batch_idx <= 4: 218 | 219 | plot_mel = torch.cat([mel, y_gen_mel], dim=1) 220 | plot_mel = plot_mel.clip(min=-10, max=10) 221 | 222 | image_dict.update({ 223 | "gen/mel_{}".format(batch_idx): utils.plot_spectrogram_to_numpy(plot_mel.squeeze().cpu().numpy()), 224 | }) 225 | audio_dict.update({ 226 | "gen/audio_{}_gen".format(batch_idx): y_gen.squeeze(), 227 | 228 | }) 229 | if global_step == 0: 230 | audio_dict.update({"gt/audio_{}".format(batch_idx): y.squeeze()}) 231 | 232 | val_err_tot /= (batch_idx+1) 233 | val_pesq_tot /= (batch_idx+1) 234 | val_mrstft_tot /= (batch_idx+1) 235 | 236 | scalar_dict = {"val/mel": val_err_tot, "val/pesq": val_pesq_tot, "val/mrstft": val_mrstft_tot} 237 | utils.summarize( 238 | writer=writer_eval, 239 | global_step=global_step, 240 | images=image_dict, 241 | audios=audio_dict, 242 | audio_sampling_rate=hps.data.sampling_rate, 243 | scalars=scalar_dict 244 | ) 245 | model.train() 246 | 247 | 248 | if __name__ == "__main__": 249 | main() -------------------------------------------------------------------------------- /train_periodwave_encodec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | from torch.cuda.amp import autocast, GradScaler 9 | 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import torchaudio 14 | import random 15 | import commons 16 | import utils 17 | from dataset_codec import MelDataset, mel_spectrogram, MAX_WAV_VALUE 18 | from torch.utils.data.distributed import DistributedSampler 19 | from torch.utils.data import DataLoader 20 | import auraloss 21 | from pesq import pesq 22 | 23 | from model.periodwave_encodec import FlowMatch 24 | from encodec_feature_extractor import EncodecFeatures 25 | 26 | torch.backends.cudnn.benchmark = True 27 | global_step = 0 28 | 29 | def get_param_num(model): 30 | num_param = sum(param.numel() for param in model.parameters()) 31 | return num_param 32 | 33 | def main(): 34 | """Assume Single Node Multi GPUs Training Only""" 35 | assert torch.cuda.is_available(), "CPU training is not allowed." 36 | 37 | n_gpus = torch.cuda.device_count() 38 | port = 50000 + random.randint(0, 100) 39 | os.environ['MASTER_ADDR'] = 'localhost' 40 | os.environ['MASTER_PORT'] = str(port) 41 | 42 | hps = utils.get_hparams() 43 | if n_gpus > 1: 44 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 45 | else: 46 | run(0, n_gpus, hps) 47 | 48 | 49 | def run(rank, n_gpus, hps): 50 | global global_step 51 | if rank == 0: 52 | logger = utils.get_logger(hps.model_dir) 53 | logger.info(hps) 54 | utils.check_git_hash(hps.model_dir) 55 | writer = SummaryWriter(log_dir=hps.model_dir) 56 | if n_gpus > 1: 57 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 58 | 59 | torch.manual_seed(hps.train.seed) 60 | torch.cuda.set_device(rank) 61 | device = torch.device('cuda:{:d}'.format(rank)) 62 | 63 | train_dataset = MelDataset(hps.data.train_filelist_path, hps, hps.train.segment_size, hps.data.filter_length, hps.data.n_mel_channels, 64 | hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax, n_cache_reuse=0, 65 | shuffle=False if n_gpus > 1 else True, device=device) 66 | 67 | train_sampler = DistributedSampler(train_dataset) if n_gpus > 1 else None 68 | train_loader = DataLoader( 69 | train_dataset, batch_size=hps.train.batch_size, num_workers=4, shuffle=False, 70 | sampler=train_sampler, drop_last=True, pin_memory=True, persistent_workers=True 71 | ) 72 | 73 | if rank == 0: 74 | test_dataset = MelDataset(hps.data.test_filelist_path, hps, hps.train.segment_size, hps.data.filter_length, hps.data.n_mel_channels, 75 | hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax, n_cache_reuse=0, split=False, shuffle=False, 76 | device=device) 77 | eval_loader = DataLoader(test_dataset, num_workers=1, batch_size=1, pin_memory=True, persistent_workers=True) 78 | 79 | Encodec = EncodecFeatures().cuda() 80 | 81 | model = FlowMatch(hps.data.n_mel_channels, 82 | hps.model.periods, 83 | hps.model.noise_scale, 84 | hps.model.final_dim, 85 | hps.model.hidden_dim 86 | ).cuda() 87 | 88 | if rank == 0: 89 | num_param = get_param_num(model) 90 | print('number of Parameters:', num_param) 91 | 92 | optimizer = torch.optim.AdamW( 93 | model.parameters(), 94 | hps.train.learning_rate) 95 | 96 | if n_gpus > 1: 97 | model = DDP(model, device_ids=[rank]) 98 | 99 | try: 100 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), model, 101 | optimizer) 102 | 103 | global_step = (epoch_str - 1) * len(train_loader) 104 | except: 105 | epoch_str = 1 106 | global_step = 0 107 | 108 | scaler = GradScaler(enabled=hps.train.fp16_run) 109 | 110 | for epoch in range(epoch_str, hps.train.epochs + 1): 111 | if rank == 0: 112 | train_and_evaluate(rank, epoch, hps, model, optimizer, 113 | scaler, 114 | [train_loader, eval_loader], logger, writer, n_gpus, Encodec) 115 | else: 116 | train_and_evaluate(rank, epoch, hps, model, optimizer, 117 | scaler, 118 | [train_loader, None], None, None, n_gpus, Encodec) 119 | 120 | def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, loaders, logger, writers, n_gpus, Encodec): 121 | model = nets 122 | optimizer = optims 123 | train_loader, eval_loader = loaders 124 | 125 | if writers is not None: 126 | writer = writers 127 | 128 | global global_step 129 | 130 | if n_gpus > 1: 131 | train_loader.sampler.set_epoch(epoch) 132 | 133 | model.train() 134 | for batch_idx, (y, length) in enumerate(train_loader): 135 | y = y.cuda(rank, non_blocking=True) 136 | length = length.cuda(rank, non_blocking=True) 137 | 138 | optimizer.zero_grad() 139 | 140 | with torch.no_grad(): 141 | embs = Encodec(y) # [B, K = 8, T] 142 | 143 | if n_gpus > 1: 144 | loss_fm = model.module.compute_loss(y, embs, length) 145 | else: 146 | loss_fm = model.compute_loss(y, embs, length) 147 | 148 | loss_gen_all = loss_fm 149 | 150 | loss_gen_all.backward() 151 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 152 | optimizer.step() 153 | 154 | if rank == 0: 155 | if global_step % hps.train.log_interval == 0: 156 | lr = optimizer.param_groups[0]['lr'] 157 | losses = [loss_fm] 158 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 159 | epoch, 160 | 100. * batch_idx / len(train_loader))) 161 | logger.info([x.item() for x in losses] + [global_step, lr]) 162 | 163 | scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} 164 | scalar_dict.update({"loss/g/fm": loss_fm}) 165 | 166 | utils.summarize( 167 | writer=writer, 168 | global_step=global_step, 169 | scalars=scalar_dict) 170 | 171 | # if global_step % hps.train.eval_interval == 0: 172 | # torch.cuda.empty_cache() 173 | # evaluate(hps, model, eval_loader, writer) 174 | 175 | if global_step % hps.train.save_interval == 0: 176 | utils.save_checkpoint(model, optimizer, hps.train.learning_rate, epoch, 177 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 178 | 179 | global_step += 1 180 | 181 | if rank == 0: 182 | logger.info('====> Epoch: {}'.format(epoch)) 183 | 184 | def evaluate(hps, model, eval_loader, writer_eval): 185 | model.eval() 186 | image_dict = {} 187 | audio_dict = {} 188 | 189 | # modules for evaluation metrics 190 | pesq_resampler = torchaudio.transforms.Resample(hps.data.sampling_rate, 16000).cuda() 191 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") 192 | 193 | val_err_tot = 0 194 | val_pesq_tot = 0 195 | val_mrstft_tot = 0 196 | 197 | with torch.no_grad(): 198 | for batch_idx, (mel, y, _) in enumerate(eval_loader): 199 | 200 | y = y.cuda(0) 201 | mel = mel.cuda(0) 202 | 203 | 204 | y_gen = model(y, mel, n_timesteps=16, temperature=1.0) 205 | 206 | if torch.abs(y_gen).max() >= 0.95: 207 | y_gen = (y_gen / (torch.abs(y_gen).max())) * 0.95 208 | 209 | y_gen_mel = mel_spectrogram(y_gen.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, 210 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 211 | hps.data.mel_fmin, hps.data.mel_fmax) 212 | 213 | val_err_tot += F.l1_loss(mel, y_gen_mel).item() 214 | 215 | y_16k = pesq_resampler(y) 216 | y_g_hat_16k = pesq_resampler(y_gen.squeeze(1)) 217 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 218 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 219 | val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 220 | 221 | # MRSTFT calculation 222 | val_mrstft_tot += loss_mrstft(y_gen, y).item() 223 | 224 | if batch_idx <= 4: 225 | 226 | plot_mel = torch.cat([mel, y_gen_mel], dim=1) 227 | plot_mel = plot_mel.clip(min=-10, max=10) 228 | 229 | image_dict.update({ 230 | "gen/mel_{}".format(batch_idx): utils.plot_spectrogram_to_numpy(plot_mel.squeeze().cpu().numpy()), 231 | }) 232 | audio_dict.update({ 233 | "gen/audio_{}_gen".format(batch_idx): y_gen.squeeze(), 234 | 235 | }) 236 | if global_step == 0: 237 | audio_dict.update({"gt/audio_{}".format(batch_idx): y.squeeze()}) 238 | 239 | val_err_tot /= (batch_idx+1) 240 | val_pesq_tot /= (batch_idx+1) 241 | val_mrstft_tot /= (batch_idx+1) 242 | 243 | scalar_dict = {"val/mel": val_err_tot, "val/pesq": val_pesq_tot, "val/mrstft": val_mrstft_tot} 244 | utils.summarize( 245 | writer=writer_eval, 246 | global_step=global_step, 247 | images=image_dict, 248 | audios=audio_dict, 249 | audio_sampling_rate=hps.data.sampling_rate, 250 | scalars=scalar_dict 251 | ) 252 | model.train() 253 | 254 | 255 | if __name__ == "__main__": 256 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def plot_spectrogram_to_numpy(spectrogram): 79 | global MATPLOTLIB_FLAG 80 | if not MATPLOTLIB_FLAG: 81 | import matplotlib 82 | matplotlib.use("Agg") 83 | MATPLOTLIB_FLAG = True 84 | mpl_logger = logging.getLogger('matplotlib') 85 | mpl_logger.setLevel(logging.WARNING) 86 | import matplotlib.pylab as plt 87 | import numpy as np 88 | 89 | fig, ax = plt.subplots(figsize=(10, 2)) 90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 91 | interpolation='none') 92 | plt.colorbar(im, ax=ax) 93 | plt.xlabel("Frames") 94 | plt.ylabel("Channels") 95 | plt.tight_layout() 96 | 97 | fig.canvas.draw() 98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 100 | plt.close() 101 | return data 102 | 103 | 104 | def plot_alignment_to_numpy(alignment, info=None): 105 | global MATPLOTLIB_FLAG 106 | if not MATPLOTLIB_FLAG: 107 | import matplotlib 108 | matplotlib.use("Agg") 109 | MATPLOTLIB_FLAG = True 110 | mpl_logger = logging.getLogger('matplotlib') 111 | mpl_logger.setLevel(logging.WARNING) 112 | import matplotlib.pylab as plt 113 | import numpy as np 114 | 115 | fig, ax = plt.subplots(figsize=(6, 4)) 116 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 117 | interpolation='none') 118 | fig.colorbar(im, ax=ax) 119 | xlabel = 'Decoder timestep' 120 | if info is not None: 121 | xlabel += '\n\n' + info 122 | plt.xlabel(xlabel) 123 | plt.ylabel('Encoder timestep') 124 | plt.tight_layout() 125 | 126 | fig.canvas.draw() 127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close() 130 | return data 131 | 132 | 133 | def load_wav_to_torch(full_path): 134 | sampling_rate, data = read(full_path) 135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 136 | 137 | 138 | def load_filepaths_and_text(filename, split="|"): 139 | with open(filename, encoding='utf-8') as f: 140 | filepaths_and_text = [line.strip().split(split) for line in f] 141 | return filepaths_and_text 142 | 143 | 144 | def get_hparams(init=True): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('-c', '--config', type=str, default="./configs/vctk.json", 147 | help='JSON file for configuration') 148 | parser.add_argument('-m', '--model', type=str, required=True, 149 | help='Model name') 150 | 151 | args = parser.parse_args() 152 | model_dir = os.path.join("./logs", args.model) 153 | 154 | if not os.path.exists(model_dir): 155 | os.makedirs(model_dir) 156 | 157 | config_path = args.config 158 | config_save_path = os.path.join(model_dir, "config.json") 159 | if init: 160 | with open(config_path, "r") as f: 161 | data = f.read() 162 | with open(config_save_path, "w") as f: 163 | f.write(data) 164 | else: 165 | with open(config_save_path, "r") as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | 169 | hparams = HParams(**config) 170 | hparams.model_dir = model_dir 171 | return hparams 172 | 173 | 174 | def get_hparams_from_dir(model_dir): 175 | config_save_path = os.path.join(model_dir, "config.json") 176 | with open(config_save_path, "r") as f: 177 | data = f.read() 178 | config = json.loads(data) 179 | 180 | hparams = HParams(**config) 181 | hparams.model_dir = model_dir 182 | return hparams 183 | 184 | 185 | def get_hparams_from_file(config_path): 186 | with open(config_path, "r") as f: 187 | data = f.read() 188 | config = json.loads(data) 189 | 190 | hparams = HParams(**config) 191 | return hparams 192 | 193 | 194 | def check_git_hash(model_dir): 195 | source_dir = os.path.dirname(os.path.realpath(__file__)) 196 | if not os.path.exists(os.path.join(source_dir, ".git")): 197 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 198 | source_dir 199 | )) 200 | return 201 | 202 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 203 | 204 | path = os.path.join(model_dir, "githash") 205 | if os.path.exists(path): 206 | saved_hash = open(path).read() 207 | if saved_hash != cur_hash: 208 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 209 | saved_hash[:8], cur_hash[:8])) 210 | else: 211 | open(path, "w").write(cur_hash) 212 | 213 | 214 | def get_logger(model_dir, filename="train.log"): 215 | global logger 216 | logger = logging.getLogger(os.path.basename(model_dir)) 217 | logger.setLevel(logging.DEBUG) 218 | 219 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 220 | if not os.path.exists(model_dir): 221 | os.makedirs(model_dir) 222 | h = logging.FileHandler(os.path.join(model_dir, filename)) 223 | h.setLevel(logging.DEBUG) 224 | h.setFormatter(formatter) 225 | logger.addHandler(h) 226 | return logger 227 | 228 | 229 | def parse_filelist(filelist_path): 230 | with open(filelist_path, 'r') as f: 231 | filelist = [line.strip() for line in f.readlines()] 232 | return filelist 233 | 234 | 235 | def parse_filelist_and_spk_id(filelist_path, split="|"): 236 | with open(filelist_path, encoding='utf-8') as f: 237 | filepaths_and_spkid = [line.strip().split(split) for line in f] 238 | return filepaths_and_spkid 239 | 240 | 241 | class HParams(): 242 | def __init__(self, **kwargs): 243 | for k, v in kwargs.items(): 244 | if type(v) == dict: 245 | v = HParams(**v) 246 | self[k] = v 247 | 248 | def keys(self): 249 | return self.__dict__.keys() 250 | 251 | def items(self): 252 | return self.__dict__.items() 253 | 254 | def values(self): 255 | return self.__dict__.values() 256 | 257 | def __len__(self): 258 | return len(self.__dict__) 259 | 260 | def __getitem__(self, key): 261 | return getattr(self, key) 262 | 263 | def __setitem__(self, key, value): 264 | return setattr(self, key, value) 265 | 266 | def __contains__(self, key): 267 | return key in self.__dict__ 268 | 269 | def __repr__(self): 270 | return self.__dict__.__repr__() 271 | --------------------------------------------------------------------------------