├── DNSMOS ├── ONNX_models │ ├── bak_ovr.onnx │ ├── model_v8.onnx │ ├── sig.onnx │ └── sig_bak_ovr.onnx ├── README.md └── dnsmos_local.py ├── IUB_ind2.pickle ├── Librispeech_clean.csv ├── README.md ├── Tencent_ind2.pickle ├── VCTK_clean_test.csv ├── VCTK_clean_train.csv ├── VCTK_noisy_testSet_with_scores.pickle ├── VCTK_noisy_validationSet.pickle ├── VQScore.png ├── adv_wav.png ├── bin └── train.py ├── clean_p232_005.wav ├── config ├── QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml └── SE_cbook_4096_1_128_lr_1m5_1m5_github.yaml ├── dataloader └── dataset.py ├── exp ├── QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github │ └── checkpoint-dnsmos_ovr_CC=0.835.pkl └── SE_cbook_4096_1_128_lr_1m5_1m5_github │ ├── checkpoint-dnsmos_ovr=2.654.pkl │ └── checkpoint-dnsmos_ovr=2.761_AT.pkl ├── inference.py ├── inference_folder.py ├── models ├── VQVAE_models.py └── vector_quantize_pytorch.py ├── noisy_p232_005.wav ├── requirements.txt ├── trainVQVAE.py └── trainer ├── autoencoder.py ├── eval_dataset.py └── trainerAE.py /DNSMOS/ONNX_models/bak_ovr.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/DNSMOS/ONNX_models/bak_ovr.onnx -------------------------------------------------------------------------------- /DNSMOS/ONNX_models/model_v8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/DNSMOS/ONNX_models/model_v8.onnx -------------------------------------------------------------------------------- /DNSMOS/ONNX_models/sig.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/DNSMOS/ONNX_models/sig.onnx -------------------------------------------------------------------------------- /DNSMOS/ONNX_models/sig_bak_ovr.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/DNSMOS/ONNX_models/sig_bak_ovr.onnx -------------------------------------------------------------------------------- /DNSMOS/README.md: -------------------------------------------------------------------------------- 1 | # DNSMOS: A non-intrusive perceptual objective speech quality metric to evaluate noise suppressors 2 | 3 | Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception. Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression methods. More details can be found in [DNSMOS paper](https://arxiv.org/pdf/2010.15258.pdf). 4 | 5 | ## Evaluation methodology: 6 | There are two ways to use DNSMOS: 7 | 1. Using the Web-API. The benefit here is that computation happens on the cloud and will always have the latest models. 8 | 2. Local evaluation using the models uploaded locally to this GitHub repo. We will try to keep this model in sync with the cloud but there are no guarantees. 9 | 10 | ### To use the Web-API: 11 | Please complete the following form: https://forms.office.com/r/pRhyZ0mQy3 12 | We will send you the **AUTH_KEY** that you can insert in the **dnsmos.py** script. 13 | Example command for P.835 evaluation of test clips: python dnsmos --testset_dir --method p835 14 | 15 | ### To use the local evaluation method: 16 | Use the **dnsmos_local.py** script. 17 | 1. To compute a personalized MOS score (where interfering speaker is penalized) provide the '-p' argument 18 | Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv -p 19 | 2. To compute a regular MOS score omit the '-p' argument. 20 | Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv 21 | 22 | ## Citation: 23 | If you have used the API for your research and development purpose, please cite the [DNSMOS paper](https://arxiv.org/pdf/2010.15258.pdf): 24 | ```BibTex 25 | @inproceedings{reddy2021dnsmos, 26 | title={Dnsmos: A non-intrusive perceptual objective speech quality metric to evaluate noise suppressors}, 27 | author={Reddy, Chandan KA and Gopal, Vishak and Cutler, Ross}, 28 | booktitle={ICASSP 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 29 | pages={6493--6497}, 30 | year={2021}, 31 | organization={IEEE} 32 | } 33 | ``` 34 | 35 | If you used DNSMOS P.835 please cite the [DNSMOS P.835](https://arxiv.org/pdf/2110.01763.pdf) paper: 36 | 37 | ```BibTex 38 | @inproceedings{reddy2022dnsmos, 39 | title={DNSMOS P.835: A non-intrusive perceptual objective speech quality metric to evaluate noise suppressors}, 40 | author={Reddy, Chandan KA and Gopal, Vishak and Cutler, Ross}, 41 | booktitle={ICASSP 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 42 | year={2022}, 43 | organization={IEEE} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /DNSMOS/dnsmos_local.py: -------------------------------------------------------------------------------- 1 | # Usage: 2 | # python dnsmos_local.py -t c:\temp\DNSChallenge4_Blindset -o DNSCh4_Blind.csv -p 3 | # 4 | 5 | import argparse 6 | import concurrent.futures 7 | import glob 8 | import os 9 | 10 | import librosa 11 | import numpy as np 12 | import numpy.polynomial.polynomial as poly 13 | import onnxruntime as ort 14 | import pandas as pd 15 | import soundfile as sf 16 | from requests import session 17 | from tqdm import tqdm 18 | 19 | SAMPLING_RATE = 16000 20 | INPUT_LENGTH = 9.01 21 | 22 | class ComputeScore: 23 | def __init__(self, primary_model_path, p808_model_path) -> None: 24 | self.onnx_sess = ort.InferenceSession(primary_model_path, providers=['CUDAExecutionProvider']) 25 | self.p808_onnx_sess = ort.InferenceSession(p808_model_path, providers=['CUDAExecutionProvider']) 26 | 27 | def audio_melspec(self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True): 28 | mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=frame_size+1, hop_length=hop_length, n_mels=n_mels) 29 | if to_db: 30 | mel_spec = (librosa.power_to_db(mel_spec, ref=np.max)+40)/40 31 | return mel_spec.T 32 | 33 | def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): 34 | if is_personalized_MOS: 35 | p_ovr = np.poly1d([-0.00533021, 0.005101 , 1.18058466, -0.11236046]) 36 | p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) 37 | p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611 , 0.96883132]) 38 | else: 39 | p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) 40 | p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439 ]) 41 | p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) 42 | 43 | sig_poly = p_sig(sig) 44 | bak_poly = p_bak(bak) 45 | ovr_poly = p_ovr(ovr) 46 | 47 | return sig_poly, bak_poly, ovr_poly 48 | 49 | def __call__(self, audio, sampling_rate, is_personalized_MOS, is_normalized=False, is_p808=False): 50 | fs = sampling_rate 51 | ''' 52 | aud, input_fs = sf.read(fpath) 53 | fs = sampling_rate 54 | if input_fs != fs: 55 | audio = librosa.resample(aud, input_fs, fs) 56 | else: 57 | audio = aud 58 | ''' 59 | if is_normalized: 60 | audio = audio/abs(audio).max() 61 | 62 | actual_audio_len = len(audio) 63 | len_samples = int(INPUT_LENGTH*fs) 64 | while len(audio) < len_samples: 65 | audio = np.append(audio, audio) 66 | 67 | num_hops = int(np.floor(len(audio)/fs) - INPUT_LENGTH)+1 68 | hop_len_samples = fs 69 | predicted_mos_sig_seg_raw = [] 70 | predicted_mos_bak_seg_raw = [] 71 | predicted_mos_ovr_seg_raw = [] 72 | predicted_mos_sig_seg = [] 73 | predicted_mos_bak_seg = [] 74 | predicted_mos_ovr_seg = [] 75 | predicted_p808_mos = [] 76 | 77 | for idx in range(num_hops): 78 | audio_seg = audio[int(idx*hop_len_samples) : int((idx+INPUT_LENGTH)*hop_len_samples)] 79 | if len(audio_seg) < len_samples: 80 | continue 81 | 82 | if is_p808: 83 | p808_input_features = np.array(self.audio_melspec(audio=audio_seg[:-160])).astype('float32')[np.newaxis, :, :] 84 | p808_oi = {'input_1': p808_input_features} 85 | p808_mos = self.p808_onnx_sess.run(None, p808_oi)[0][0][0] 86 | 87 | input_features = np.array(audio_seg).astype('float32')[np.newaxis,:] 88 | oi = {'input_1': input_features} 89 | mos_sig_raw,mos_bak_raw,mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] 90 | mos_sig,mos_bak,mos_ovr = self.get_polyfit_val(mos_sig_raw,mos_bak_raw,mos_ovr_raw,is_personalized_MOS) 91 | predicted_mos_sig_seg_raw.append(mos_sig_raw) 92 | predicted_mos_bak_seg_raw.append(mos_bak_raw) 93 | predicted_mos_ovr_seg_raw.append(mos_ovr_raw) 94 | predicted_mos_sig_seg.append(mos_sig) 95 | predicted_mos_bak_seg.append(mos_bak) 96 | predicted_mos_ovr_seg.append(mos_ovr) 97 | if is_p808: 98 | predicted_p808_mos.append(p808_mos) 99 | 100 | clip_dict = {'len_in_sec': actual_audio_len/fs, 'sr':fs} 101 | clip_dict['num_hops'] = num_hops 102 | clip_dict['OVRL_raw'] = np.mean(predicted_mos_ovr_seg_raw) 103 | clip_dict['SIG_raw'] = np.mean(predicted_mos_sig_seg_raw) 104 | clip_dict['BAK_raw'] = np.mean(predicted_mos_bak_seg_raw) 105 | clip_dict['OVRL'] = np.mean(predicted_mos_ovr_seg) 106 | clip_dict['SIG'] = np.mean(predicted_mos_sig_seg) 107 | clip_dict['BAK'] = np.mean(predicted_mos_bak_seg) 108 | if is_p808: 109 | clip_dict['P808_MOS'] = np.mean(predicted_p808_mos) 110 | return clip_dict 111 | 112 | def main(args): 113 | models = glob.glob(os.path.join(args.testset_dir, "*")) 114 | audio_clips_list = [] 115 | p808_model_path = os.path.join('DNSMOS', 'model_v8.onnx') 116 | 117 | if args.personalized_MOS: 118 | primary_model_path = os.path.join('pDNSMOS', 'sig_bak_ovr.onnx') 119 | else: 120 | primary_model_path = os.path.join('DNSMOS', 'sig_bak_ovr.onnx') 121 | 122 | compute_score = ComputeScore(primary_model_path, p808_model_path) 123 | 124 | rows = [] 125 | clips = [] 126 | clips = glob.glob(os.path.join(args.testset_dir, "*.wav")) 127 | is_personalized_eval = args.personalized_MOS 128 | desired_fs = SAMPLING_RATE 129 | for m in tqdm(models): 130 | max_recursion_depth = 10 131 | audio_path = os.path.join(args.testset_dir, m) 132 | audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav")) 133 | while len(audio_clips_list) == 0 and max_recursion_depth > 0: 134 | audio_path = os.path.join(audio_path, "**") 135 | audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav")) 136 | max_recursion_depth -= 1 137 | clips.extend(audio_clips_list) 138 | 139 | with concurrent.futures.ThreadPoolExecutor() as executor: 140 | future_to_url = {executor.submit(compute_score, clip, desired_fs, is_personalized_eval): clip for clip in clips} 141 | for future in tqdm(concurrent.futures.as_completed(future_to_url)): 142 | clip = future_to_url[future] 143 | try: 144 | data = future.result() 145 | except Exception as exc: 146 | print('%r generated an exception: %s' % (clip, exc)) 147 | else: 148 | rows.append(data) 149 | 150 | df = pd.DataFrame(rows) 151 | if args.csv_path: 152 | csv_path = args.csv_path 153 | df.to_csv(csv_path) 154 | else: 155 | print(df.describe()) 156 | 157 | if __name__=="__main__": 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument('-t', "--testset_dir", default='.', 160 | help='Path to the dir containing audio clips in .wav to be evaluated') 161 | parser.add_argument('-o', "--csv_path", default=None, help='Dir to the csv that saves the results') 162 | parser.add_argument('-p', "--personalized_MOS", action='store_true', 163 | help='Flag to indicate if personalized MOS score is needed or regular') 164 | 165 | args = parser.parse_args() 166 | 167 | main(args) 168 | -------------------------------------------------------------------------------- /IUB_ind2.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/IUB_ind2.pickle -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Self-Supervised Speech Quality Estimation and Enhancement Using Only Clean Speech (ICLR 2024) 2 | #### Szu-Wei Fu, Kuo-Hsuan Hung, Yu Tsao, Yu-Chiang Frank Wang 3 | 4 | ## Update 2024/12/2 5 | Provide the Quality_Estimation "inference_folder" code to evaluate the utterances in a folder. We found that the inference speed of VQScore is quite fast (15.5 hours of speech takes less than 2 minutes on a single A100 GPU), making it suitable for filtering out noisy training data when training speech enhancement or TTS models. 6 | 7 | 8 | ### Introduction 9 | This work is about training a speech quality estimator and enhancement model WITHOUT any labeled (paired) data. Specifically, during training, we only need CLEAN speech for model training. 10 | 11 |
12 | 13 | ## Environment 14 | CUDA Version: 12.2 15 | 16 | python: 3.8 17 | 18 | * Note: To use 'CUDAExecutionProvider' for accelerated DNSMOS ONNX model inference, please check CUDA and ONNX Runtime version compatibility, [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html). 19 | 20 | ## Dataset used in the paper/code 21 | If you want to train from scratch, please download the dataset to the corresponding path depicted in the .csv and .pickle files. 22 | 23 | Speech enhancement: 24 | 25 | => Training: [clean speech of VoiceBank-DEMAND trainset](https://datashare.ed.ac.uk/handle/10283/2791) (Its original sampling rate is 48kHz, you have to down-sample it to 16kHz) 26 | 27 | => validation: As in MetricGAN-U, [noisy speech (speakers p226 and p287) of VoiceBank-DEMAND trainset](https://datashare.ed.ac.uk/handle/10283/2791) 28 | 29 | => Evaluation: [noisy speech of VoiceBank-DEMAND testset](https://datashare.ed.ac.uk/handle/10283/2791) and [DNS1 and DNS3](https://github.com/microsoft/DNS-Challenge) 30 | 31 | Quality estimation (VQScore): 32 | 33 | => Training: [LibriSpeech clean-460 hours](https://www.openslr.org/12) 34 | 35 | => validation: [noisy speech of VoiceBank-DEMAND testset](https://datashare.ed.ac.uk/handle/10283/2791) 36 | 37 | => Evaluation: [Tencent and IUB](https://github.com/ConferencingSpeech/ConferencingSpeech2022/tree/main/Training/Dev%20datasets) 38 | 39 | ## Training 40 | To Train our speech enhancement model (using only Clean Speech). Below is an example command. 41 | ```shell 42 | python trainVQVAE.py \ 43 | -c config/SE_cbook_4096_1_128_lr_1m5_1m5_github.yaml \ 44 | --tag SE_cbook_4096_1_128_lr_1m5_1m5_github 45 | ``` 46 | To Train our speech quality estimator, VQScore. Below is an example command. 47 | ```shell 48 | python trainVQVAE.py \ 49 | -c config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml \ 50 | --tag QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github 51 | ``` 52 | 53 | ## Inference 54 | Below is an example command for generating enhanced speech/ estimated quality scores from the model. 55 | Where '-c' is the path of the config file, '-m' is the path of the pre-trained model, and '-i' is the path of the input wav file. 56 | 57 | * Note: Because our training data is 16kHz clean speech, only 16kHz speech input is supported. 58 | 59 | ```shell 60 | python inference.py \ 61 | -c ./config/SE_cbook_4096_1_128_lr_1m5_1m5_github.yaml \ 62 | -m ./exp/SE_cbook_4096_1_128_lr_1m5_1m5_github/checkpoint-dnsmos_ovr=2.761_AT.pkl \ 63 | -i ./noisy_p232_005.wav 64 | ``` 65 | ```shell 66 | python inference.py \ 67 | -c ./config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml \ 68 | -m ./exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl \ 69 | -i ./noisy_p232_005.wav 70 | ``` 71 | ```shell 72 | python inference_folder.py \ 73 | -c ./config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml \ 74 | -m ./exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl \ 75 | -i path to the folder you want to evaluate 76 | ``` 77 | 78 | 79 | ## Pretrained Models 80 | We provide the checkpoints of trained models in the corresponding ./exp/config_name folder. 81 | 82 | * Note that the provided checkpoints are the models after we reorganize the code, so the results are slightly different from those shown in the paper. 83 | * However, the overall trend should be similar. 84 | 85 | ## Adversarial noise 86 | As shown in the following spectrogram, the applied adversarial noise doesn't have a fixed pattern as Gaussian noise. So it may be a good one to train a robust speech enhancement model. 87 |
88 | 89 | ## Collaboration 90 | I'm open to collaboration! If you find this Self-Supervised SE/QE topic interesting, please let me know (e-mail: szuweif@nvidia.com). 91 | 92 | ### Citation 93 | If you find the code useful in your research, please cite our ICLR paper :) 94 | 95 | ## References 96 | * [vector-quantize](https://github.com/lucidrains/vector-quantize-pytorch) (for VQ-VAE) 97 | * [DNSMOS](https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS) 98 | -------------------------------------------------------------------------------- /Tencent_ind2.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/Tencent_ind2.pickle -------------------------------------------------------------------------------- /VCTK_clean_test.csv: -------------------------------------------------------------------------------- 1 | ,Filename 2 | 0,/vctk_data/clean_testset_wav_16k/p232_001.wav 3 | 1,/vctk_data/clean_testset_wav_16k/p232_002.wav 4 | 2,/vctk_data/clean_testset_wav_16k/p232_003.wav 5 | 3,/vctk_data/clean_testset_wav_16k/p232_005.wav 6 | 4,/vctk_data/clean_testset_wav_16k/p232_006.wav 7 | 5,/vctk_data/clean_testset_wav_16k/p232_007.wav 8 | 6,/vctk_data/clean_testset_wav_16k/p232_009.wav 9 | 7,/vctk_data/clean_testset_wav_16k/p232_010.wav 10 | 8,/vctk_data/clean_testset_wav_16k/p232_011.wav 11 | 9,/vctk_data/clean_testset_wav_16k/p232_012.wav 12 | 10,/vctk_data/clean_testset_wav_16k/p232_013.wav 13 | 11,/vctk_data/clean_testset_wav_16k/p232_014.wav 14 | 12,/vctk_data/clean_testset_wav_16k/p232_015.wav 15 | 13,/vctk_data/clean_testset_wav_16k/p232_016.wav 16 | 14,/vctk_data/clean_testset_wav_16k/p232_017.wav 17 | 15,/vctk_data/clean_testset_wav_16k/p232_019.wav 18 | 16,/vctk_data/clean_testset_wav_16k/p232_020.wav 19 | 17,/vctk_data/clean_testset_wav_16k/p232_021.wav 20 | 18,/vctk_data/clean_testset_wav_16k/p232_022.wav 21 | 19,/vctk_data/clean_testset_wav_16k/p232_023.wav 22 | 20,/vctk_data/clean_testset_wav_16k/p232_024.wav 23 | 21,/vctk_data/clean_testset_wav_16k/p232_025.wav 24 | 22,/vctk_data/clean_testset_wav_16k/p232_027.wav 25 | 23,/vctk_data/clean_testset_wav_16k/p232_028.wav 26 | 24,/vctk_data/clean_testset_wav_16k/p232_029.wav 27 | 25,/vctk_data/clean_testset_wav_16k/p232_030.wav 28 | 26,/vctk_data/clean_testset_wav_16k/p232_031.wav 29 | 27,/vctk_data/clean_testset_wav_16k/p232_032.wav 30 | 28,/vctk_data/clean_testset_wav_16k/p232_033.wav 31 | 29,/vctk_data/clean_testset_wav_16k/p232_034.wav 32 | 30,/vctk_data/clean_testset_wav_16k/p232_035.wav 33 | 31,/vctk_data/clean_testset_wav_16k/p232_036.wav 34 | 32,/vctk_data/clean_testset_wav_16k/p232_037.wav 35 | 33,/vctk_data/clean_testset_wav_16k/p232_038.wav 36 | 34,/vctk_data/clean_testset_wav_16k/p232_039.wav 37 | 35,/vctk_data/clean_testset_wav_16k/p232_040.wav 38 | 36,/vctk_data/clean_testset_wav_16k/p232_041.wav 39 | 37,/vctk_data/clean_testset_wav_16k/p232_042.wav 40 | 38,/vctk_data/clean_testset_wav_16k/p232_043.wav 41 | 39,/vctk_data/clean_testset_wav_16k/p232_044.wav 42 | 40,/vctk_data/clean_testset_wav_16k/p232_045.wav 43 | 41,/vctk_data/clean_testset_wav_16k/p232_046.wav 44 | 42,/vctk_data/clean_testset_wav_16k/p232_047.wav 45 | 43,/vctk_data/clean_testset_wav_16k/p232_048.wav 46 | 44,/vctk_data/clean_testset_wav_16k/p232_049.wav 47 | 45,/vctk_data/clean_testset_wav_16k/p232_050.wav 48 | 46,/vctk_data/clean_testset_wav_16k/p232_051.wav 49 | 47,/vctk_data/clean_testset_wav_16k/p232_052.wav 50 | 48,/vctk_data/clean_testset_wav_16k/p232_053.wav 51 | 49,/vctk_data/clean_testset_wav_16k/p232_054.wav 52 | 50,/vctk_data/clean_testset_wav_16k/p232_055.wav 53 | 51,/vctk_data/clean_testset_wav_16k/p232_056.wav 54 | 52,/vctk_data/clean_testset_wav_16k/p232_057.wav 55 | 53,/vctk_data/clean_testset_wav_16k/p232_058.wav 56 | 54,/vctk_data/clean_testset_wav_16k/p232_059.wav 57 | 55,/vctk_data/clean_testset_wav_16k/p232_060.wav 58 | 56,/vctk_data/clean_testset_wav_16k/p232_061.wav 59 | 57,/vctk_data/clean_testset_wav_16k/p232_062.wav 60 | 58,/vctk_data/clean_testset_wav_16k/p232_063.wav 61 | 59,/vctk_data/clean_testset_wav_16k/p232_064.wav 62 | 60,/vctk_data/clean_testset_wav_16k/p232_065.wav 63 | 61,/vctk_data/clean_testset_wav_16k/p232_066.wav 64 | 62,/vctk_data/clean_testset_wav_16k/p232_067.wav 65 | 63,/vctk_data/clean_testset_wav_16k/p232_068.wav 66 | 64,/vctk_data/clean_testset_wav_16k/p232_069.wav 67 | 65,/vctk_data/clean_testset_wav_16k/p232_070.wav 68 | 66,/vctk_data/clean_testset_wav_16k/p232_071.wav 69 | 67,/vctk_data/clean_testset_wav_16k/p232_072.wav 70 | 68,/vctk_data/clean_testset_wav_16k/p232_073.wav 71 | 69,/vctk_data/clean_testset_wav_16k/p232_074.wav 72 | 70,/vctk_data/clean_testset_wav_16k/p232_075.wav 73 | 71,/vctk_data/clean_testset_wav_16k/p232_076.wav 74 | 72,/vctk_data/clean_testset_wav_16k/p232_077.wav 75 | 73,/vctk_data/clean_testset_wav_16k/p232_078.wav 76 | 74,/vctk_data/clean_testset_wav_16k/p232_079.wav 77 | 75,/vctk_data/clean_testset_wav_16k/p232_080.wav 78 | 76,/vctk_data/clean_testset_wav_16k/p232_081.wav 79 | 77,/vctk_data/clean_testset_wav_16k/p232_082.wav 80 | 78,/vctk_data/clean_testset_wav_16k/p232_083.wav 81 | 79,/vctk_data/clean_testset_wav_16k/p232_084.wav 82 | 80,/vctk_data/clean_testset_wav_16k/p232_085.wav 83 | 81,/vctk_data/clean_testset_wav_16k/p232_086.wav 84 | 82,/vctk_data/clean_testset_wav_16k/p232_087.wav 85 | 83,/vctk_data/clean_testset_wav_16k/p232_088.wav 86 | 84,/vctk_data/clean_testset_wav_16k/p232_089.wav 87 | 85,/vctk_data/clean_testset_wav_16k/p232_090.wav 88 | 86,/vctk_data/clean_testset_wav_16k/p232_091.wav 89 | 87,/vctk_data/clean_testset_wav_16k/p232_092.wav 90 | 88,/vctk_data/clean_testset_wav_16k/p232_093.wav 91 | 89,/vctk_data/clean_testset_wav_16k/p232_094.wav 92 | 90,/vctk_data/clean_testset_wav_16k/p232_095.wav 93 | 91,/vctk_data/clean_testset_wav_16k/p232_096.wav 94 | 92,/vctk_data/clean_testset_wav_16k/p232_097.wav 95 | 93,/vctk_data/clean_testset_wav_16k/p232_098.wav 96 | 94,/vctk_data/clean_testset_wav_16k/p232_099.wav 97 | 95,/vctk_data/clean_testset_wav_16k/p232_100.wav 98 | 96,/vctk_data/clean_testset_wav_16k/p232_101.wav 99 | 97,/vctk_data/clean_testset_wav_16k/p232_102.wav 100 | 98,/vctk_data/clean_testset_wav_16k/p232_103.wav 101 | 99,/vctk_data/clean_testset_wav_16k/p232_104.wav 102 | 100,/vctk_data/clean_testset_wav_16k/p232_105.wav 103 | 101,/vctk_data/clean_testset_wav_16k/p232_106.wav 104 | 102,/vctk_data/clean_testset_wav_16k/p232_107.wav 105 | 103,/vctk_data/clean_testset_wav_16k/p232_108.wav 106 | 104,/vctk_data/clean_testset_wav_16k/p232_109.wav 107 | 105,/vctk_data/clean_testset_wav_16k/p232_110.wav 108 | 106,/vctk_data/clean_testset_wav_16k/p232_112.wav 109 | 107,/vctk_data/clean_testset_wav_16k/p232_113.wav 110 | 108,/vctk_data/clean_testset_wav_16k/p232_114.wav 111 | 109,/vctk_data/clean_testset_wav_16k/p232_115.wav 112 | 110,/vctk_data/clean_testset_wav_16k/p232_116.wav 113 | 111,/vctk_data/clean_testset_wav_16k/p232_117.wav 114 | 112,/vctk_data/clean_testset_wav_16k/p232_118.wav 115 | 113,/vctk_data/clean_testset_wav_16k/p232_119.wav 116 | 114,/vctk_data/clean_testset_wav_16k/p232_120.wav 117 | 115,/vctk_data/clean_testset_wav_16k/p232_121.wav 118 | 116,/vctk_data/clean_testset_wav_16k/p232_123.wav 119 | 117,/vctk_data/clean_testset_wav_16k/p232_124.wav 120 | 118,/vctk_data/clean_testset_wav_16k/p232_125.wav 121 | 119,/vctk_data/clean_testset_wav_16k/p232_126.wav 122 | 120,/vctk_data/clean_testset_wav_16k/p232_127.wav 123 | 121,/vctk_data/clean_testset_wav_16k/p232_128.wav 124 | 122,/vctk_data/clean_testset_wav_16k/p232_129.wav 125 | 123,/vctk_data/clean_testset_wav_16k/p232_130.wav 126 | 124,/vctk_data/clean_testset_wav_16k/p232_131.wav 127 | 125,/vctk_data/clean_testset_wav_16k/p232_132.wav 128 | 126,/vctk_data/clean_testset_wav_16k/p232_133.wav 129 | 127,/vctk_data/clean_testset_wav_16k/p232_134.wav 130 | 128,/vctk_data/clean_testset_wav_16k/p232_135.wav 131 | 129,/vctk_data/clean_testset_wav_16k/p232_136.wav 132 | 130,/vctk_data/clean_testset_wav_16k/p232_137.wav 133 | 131,/vctk_data/clean_testset_wav_16k/p232_138.wav 134 | 132,/vctk_data/clean_testset_wav_16k/p232_139.wav 135 | 133,/vctk_data/clean_testset_wav_16k/p232_140.wav 136 | 134,/vctk_data/clean_testset_wav_16k/p232_141.wav 137 | 135,/vctk_data/clean_testset_wav_16k/p232_142.wav 138 | 136,/vctk_data/clean_testset_wav_16k/p232_143.wav 139 | 137,/vctk_data/clean_testset_wav_16k/p232_144.wav 140 | 138,/vctk_data/clean_testset_wav_16k/p232_145.wav 141 | 139,/vctk_data/clean_testset_wav_16k/p232_146.wav 142 | 140,/vctk_data/clean_testset_wav_16k/p232_147.wav 143 | 141,/vctk_data/clean_testset_wav_16k/p232_148.wav 144 | 142,/vctk_data/clean_testset_wav_16k/p232_150.wav 145 | 143,/vctk_data/clean_testset_wav_16k/p232_151.wav 146 | 144,/vctk_data/clean_testset_wav_16k/p232_152.wav 147 | 145,/vctk_data/clean_testset_wav_16k/p232_153.wav 148 | 146,/vctk_data/clean_testset_wav_16k/p232_154.wav 149 | 147,/vctk_data/clean_testset_wav_16k/p232_155.wav 150 | 148,/vctk_data/clean_testset_wav_16k/p232_156.wav 151 | 149,/vctk_data/clean_testset_wav_16k/p232_158.wav 152 | 150,/vctk_data/clean_testset_wav_16k/p232_159.wav 153 | 151,/vctk_data/clean_testset_wav_16k/p232_160.wav 154 | 152,/vctk_data/clean_testset_wav_16k/p232_161.wav 155 | 153,/vctk_data/clean_testset_wav_16k/p232_162.wav 156 | 154,/vctk_data/clean_testset_wav_16k/p232_163.wav 157 | 155,/vctk_data/clean_testset_wav_16k/p232_164.wav 158 | 156,/vctk_data/clean_testset_wav_16k/p232_165.wav 159 | 157,/vctk_data/clean_testset_wav_16k/p232_167.wav 160 | 158,/vctk_data/clean_testset_wav_16k/p232_169.wav 161 | 159,/vctk_data/clean_testset_wav_16k/p232_170.wav 162 | 160,/vctk_data/clean_testset_wav_16k/p232_171.wav 163 | 161,/vctk_data/clean_testset_wav_16k/p232_172.wav 164 | 162,/vctk_data/clean_testset_wav_16k/p232_173.wav 165 | 163,/vctk_data/clean_testset_wav_16k/p232_174.wav 166 | 164,/vctk_data/clean_testset_wav_16k/p232_175.wav 167 | 165,/vctk_data/clean_testset_wav_16k/p232_176.wav 168 | 166,/vctk_data/clean_testset_wav_16k/p232_177.wav 169 | 167,/vctk_data/clean_testset_wav_16k/p232_178.wav 170 | 168,/vctk_data/clean_testset_wav_16k/p232_179.wav 171 | 169,/vctk_data/clean_testset_wav_16k/p232_180.wav 172 | 170,/vctk_data/clean_testset_wav_16k/p232_181.wav 173 | 171,/vctk_data/clean_testset_wav_16k/p232_182.wav 174 | 172,/vctk_data/clean_testset_wav_16k/p232_183.wav 175 | 173,/vctk_data/clean_testset_wav_16k/p232_184.wav 176 | 174,/vctk_data/clean_testset_wav_16k/p232_185.wav 177 | 175,/vctk_data/clean_testset_wav_16k/p232_186.wav 178 | 176,/vctk_data/clean_testset_wav_16k/p232_187.wav 179 | 177,/vctk_data/clean_testset_wav_16k/p232_188.wav 180 | 178,/vctk_data/clean_testset_wav_16k/p232_189.wav 181 | 179,/vctk_data/clean_testset_wav_16k/p232_190.wav 182 | 180,/vctk_data/clean_testset_wav_16k/p232_191.wav 183 | 181,/vctk_data/clean_testset_wav_16k/p232_193.wav 184 | 182,/vctk_data/clean_testset_wav_16k/p232_194.wav 185 | 183,/vctk_data/clean_testset_wav_16k/p232_195.wav 186 | 184,/vctk_data/clean_testset_wav_16k/p232_196.wav 187 | 185,/vctk_data/clean_testset_wav_16k/p232_197.wav 188 | 186,/vctk_data/clean_testset_wav_16k/p232_198.wav 189 | 187,/vctk_data/clean_testset_wav_16k/p232_199.wav 190 | 188,/vctk_data/clean_testset_wav_16k/p232_200.wav 191 | 189,/vctk_data/clean_testset_wav_16k/p232_201.wav 192 | 190,/vctk_data/clean_testset_wav_16k/p232_202.wav 193 | 191,/vctk_data/clean_testset_wav_16k/p232_203.wav 194 | 192,/vctk_data/clean_testset_wav_16k/p232_204.wav 195 | 193,/vctk_data/clean_testset_wav_16k/p232_205.wav 196 | 194,/vctk_data/clean_testset_wav_16k/p232_206.wav 197 | 195,/vctk_data/clean_testset_wav_16k/p232_207.wav 198 | 196,/vctk_data/clean_testset_wav_16k/p232_208.wav 199 | 197,/vctk_data/clean_testset_wav_16k/p232_209.wav 200 | 198,/vctk_data/clean_testset_wav_16k/p232_210.wav 201 | 199,/vctk_data/clean_testset_wav_16k/p232_211.wav 202 | 200,/vctk_data/clean_testset_wav_16k/p232_213.wav 203 | 201,/vctk_data/clean_testset_wav_16k/p232_214.wav 204 | 202,/vctk_data/clean_testset_wav_16k/p232_215.wav 205 | 203,/vctk_data/clean_testset_wav_16k/p232_216.wav 206 | 204,/vctk_data/clean_testset_wav_16k/p232_217.wav 207 | 205,/vctk_data/clean_testset_wav_16k/p232_218.wav 208 | 206,/vctk_data/clean_testset_wav_16k/p232_219.wav 209 | 207,/vctk_data/clean_testset_wav_16k/p232_220.wav 210 | 208,/vctk_data/clean_testset_wav_16k/p232_221.wav 211 | 209,/vctk_data/clean_testset_wav_16k/p232_223.wav 212 | 210,/vctk_data/clean_testset_wav_16k/p232_224.wav 213 | 211,/vctk_data/clean_testset_wav_16k/p232_225.wav 214 | 212,/vctk_data/clean_testset_wav_16k/p232_226.wav 215 | 213,/vctk_data/clean_testset_wav_16k/p232_227.wav 216 | 214,/vctk_data/clean_testset_wav_16k/p232_228.wav 217 | 215,/vctk_data/clean_testset_wav_16k/p232_229.wav 218 | 216,/vctk_data/clean_testset_wav_16k/p232_230.wav 219 | 217,/vctk_data/clean_testset_wav_16k/p232_231.wav 220 | 218,/vctk_data/clean_testset_wav_16k/p232_232.wav 221 | 219,/vctk_data/clean_testset_wav_16k/p232_234.wav 222 | 220,/vctk_data/clean_testset_wav_16k/p232_235.wav 223 | 221,/vctk_data/clean_testset_wav_16k/p232_236.wav 224 | 222,/vctk_data/clean_testset_wav_16k/p232_237.wav 225 | 223,/vctk_data/clean_testset_wav_16k/p232_238.wav 226 | 224,/vctk_data/clean_testset_wav_16k/p232_239.wav 227 | 225,/vctk_data/clean_testset_wav_16k/p232_240.wav 228 | 226,/vctk_data/clean_testset_wav_16k/p232_241.wav 229 | 227,/vctk_data/clean_testset_wav_16k/p232_242.wav 230 | 228,/vctk_data/clean_testset_wav_16k/p232_243.wav 231 | 229,/vctk_data/clean_testset_wav_16k/p232_244.wav 232 | 230,/vctk_data/clean_testset_wav_16k/p232_245.wav 233 | 231,/vctk_data/clean_testset_wav_16k/p232_246.wav 234 | 232,/vctk_data/clean_testset_wav_16k/p232_247.wav 235 | 233,/vctk_data/clean_testset_wav_16k/p232_248.wav 236 | 234,/vctk_data/clean_testset_wav_16k/p232_249.wav 237 | 235,/vctk_data/clean_testset_wav_16k/p232_250.wav 238 | 236,/vctk_data/clean_testset_wav_16k/p232_251.wav 239 | 237,/vctk_data/clean_testset_wav_16k/p232_252.wav 240 | 238,/vctk_data/clean_testset_wav_16k/p232_253.wav 241 | 239,/vctk_data/clean_testset_wav_16k/p232_254.wav 242 | 240,/vctk_data/clean_testset_wav_16k/p232_255.wav 243 | 241,/vctk_data/clean_testset_wav_16k/p232_256.wav 244 | 242,/vctk_data/clean_testset_wav_16k/p232_257.wav 245 | 243,/vctk_data/clean_testset_wav_16k/p232_258.wav 246 | 244,/vctk_data/clean_testset_wav_16k/p232_259.wav 247 | 245,/vctk_data/clean_testset_wav_16k/p232_260.wav 248 | 246,/vctk_data/clean_testset_wav_16k/p232_261.wav 249 | 247,/vctk_data/clean_testset_wav_16k/p232_263.wav 250 | 248,/vctk_data/clean_testset_wav_16k/p232_264.wav 251 | 249,/vctk_data/clean_testset_wav_16k/p232_265.wav 252 | 250,/vctk_data/clean_testset_wav_16k/p232_266.wav 253 | 251,/vctk_data/clean_testset_wav_16k/p232_267.wav 254 | 252,/vctk_data/clean_testset_wav_16k/p232_268.wav 255 | 253,/vctk_data/clean_testset_wav_16k/p232_269.wav 256 | 254,/vctk_data/clean_testset_wav_16k/p232_270.wav 257 | 255,/vctk_data/clean_testset_wav_16k/p232_271.wav 258 | 256,/vctk_data/clean_testset_wav_16k/p232_272.wav 259 | 257,/vctk_data/clean_testset_wav_16k/p232_273.wav 260 | 258,/vctk_data/clean_testset_wav_16k/p232_274.wav 261 | 259,/vctk_data/clean_testset_wav_16k/p232_275.wav 262 | 260,/vctk_data/clean_testset_wav_16k/p232_276.wav 263 | 261,/vctk_data/clean_testset_wav_16k/p232_277.wav 264 | 262,/vctk_data/clean_testset_wav_16k/p232_278.wav 265 | 263,/vctk_data/clean_testset_wav_16k/p232_279.wav 266 | 264,/vctk_data/clean_testset_wav_16k/p232_280.wav 267 | 265,/vctk_data/clean_testset_wav_16k/p232_281.wav 268 | 266,/vctk_data/clean_testset_wav_16k/p232_282.wav 269 | 267,/vctk_data/clean_testset_wav_16k/p232_283.wav 270 | 268,/vctk_data/clean_testset_wav_16k/p232_284.wav 271 | 269,/vctk_data/clean_testset_wav_16k/p232_285.wav 272 | 270,/vctk_data/clean_testset_wav_16k/p232_286.wav 273 | 271,/vctk_data/clean_testset_wav_16k/p232_287.wav 274 | 272,/vctk_data/clean_testset_wav_16k/p232_288.wav 275 | 273,/vctk_data/clean_testset_wav_16k/p232_289.wav 276 | 274,/vctk_data/clean_testset_wav_16k/p232_290.wav 277 | 275,/vctk_data/clean_testset_wav_16k/p232_291.wav 278 | 276,/vctk_data/clean_testset_wav_16k/p232_292.wav 279 | 277,/vctk_data/clean_testset_wav_16k/p232_293.wav 280 | 278,/vctk_data/clean_testset_wav_16k/p232_294.wav 281 | 279,/vctk_data/clean_testset_wav_16k/p232_295.wav 282 | 280,/vctk_data/clean_testset_wav_16k/p232_296.wav 283 | 281,/vctk_data/clean_testset_wav_16k/p232_297.wav 284 | 282,/vctk_data/clean_testset_wav_16k/p232_298.wav 285 | 283,/vctk_data/clean_testset_wav_16k/p232_299.wav 286 | 284,/vctk_data/clean_testset_wav_16k/p232_300.wav 287 | 285,/vctk_data/clean_testset_wav_16k/p232_301.wav 288 | 286,/vctk_data/clean_testset_wav_16k/p232_302.wav 289 | 287,/vctk_data/clean_testset_wav_16k/p232_303.wav 290 | 288,/vctk_data/clean_testset_wav_16k/p232_305.wav 291 | 289,/vctk_data/clean_testset_wav_16k/p232_306.wav 292 | 290,/vctk_data/clean_testset_wav_16k/p232_307.wav 293 | 291,/vctk_data/clean_testset_wav_16k/p232_308.wav 294 | 292,/vctk_data/clean_testset_wav_16k/p232_309.wav 295 | 293,/vctk_data/clean_testset_wav_16k/p232_310.wav 296 | 294,/vctk_data/clean_testset_wav_16k/p232_311.wav 297 | 295,/vctk_data/clean_testset_wav_16k/p232_312.wav 298 | 296,/vctk_data/clean_testset_wav_16k/p232_313.wav 299 | 297,/vctk_data/clean_testset_wav_16k/p232_314.wav 300 | 298,/vctk_data/clean_testset_wav_16k/p232_315.wav 301 | 299,/vctk_data/clean_testset_wav_16k/p232_316.wav 302 | 300,/vctk_data/clean_testset_wav_16k/p232_317.wav 303 | 301,/vctk_data/clean_testset_wav_16k/p232_318.wav 304 | 302,/vctk_data/clean_testset_wav_16k/p232_319.wav 305 | 303,/vctk_data/clean_testset_wav_16k/p232_320.wav 306 | 304,/vctk_data/clean_testset_wav_16k/p232_321.wav 307 | 305,/vctk_data/clean_testset_wav_16k/p232_322.wav 308 | 306,/vctk_data/clean_testset_wav_16k/p232_323.wav 309 | 307,/vctk_data/clean_testset_wav_16k/p232_324.wav 310 | 308,/vctk_data/clean_testset_wav_16k/p232_325.wav 311 | 309,/vctk_data/clean_testset_wav_16k/p232_326.wav 312 | 310,/vctk_data/clean_testset_wav_16k/p232_327.wav 313 | 311,/vctk_data/clean_testset_wav_16k/p232_328.wav 314 | 312,/vctk_data/clean_testset_wav_16k/p232_329.wav 315 | 313,/vctk_data/clean_testset_wav_16k/p232_330.wav 316 | 314,/vctk_data/clean_testset_wav_16k/p232_331.wav 317 | 315,/vctk_data/clean_testset_wav_16k/p232_332.wav 318 | 316,/vctk_data/clean_testset_wav_16k/p232_333.wav 319 | 317,/vctk_data/clean_testset_wav_16k/p232_334.wav 320 | 318,/vctk_data/clean_testset_wav_16k/p232_335.wav 321 | 319,/vctk_data/clean_testset_wav_16k/p232_336.wav 322 | 320,/vctk_data/clean_testset_wav_16k/p232_337.wav 323 | 321,/vctk_data/clean_testset_wav_16k/p232_338.wav 324 | 322,/vctk_data/clean_testset_wav_16k/p232_339.wav 325 | 323,/vctk_data/clean_testset_wav_16k/p232_340.wav 326 | 324,/vctk_data/clean_testset_wav_16k/p232_341.wav 327 | 325,/vctk_data/clean_testset_wav_16k/p232_342.wav 328 | 326,/vctk_data/clean_testset_wav_16k/p232_343.wav 329 | 327,/vctk_data/clean_testset_wav_16k/p232_344.wav 330 | 328,/vctk_data/clean_testset_wav_16k/p232_346.wav 331 | 329,/vctk_data/clean_testset_wav_16k/p232_347.wav 332 | 330,/vctk_data/clean_testset_wav_16k/p232_348.wav 333 | 331,/vctk_data/clean_testset_wav_16k/p232_349.wav 334 | 332,/vctk_data/clean_testset_wav_16k/p232_350.wav 335 | 333,/vctk_data/clean_testset_wav_16k/p232_351.wav 336 | 334,/vctk_data/clean_testset_wav_16k/p232_352.wav 337 | 335,/vctk_data/clean_testset_wav_16k/p232_353.wav 338 | 336,/vctk_data/clean_testset_wav_16k/p232_354.wav 339 | 337,/vctk_data/clean_testset_wav_16k/p232_355.wav 340 | 338,/vctk_data/clean_testset_wav_16k/p232_356.wav 341 | 339,/vctk_data/clean_testset_wav_16k/p232_357.wav 342 | 340,/vctk_data/clean_testset_wav_16k/p232_358.wav 343 | 341,/vctk_data/clean_testset_wav_16k/p232_359.wav 344 | 342,/vctk_data/clean_testset_wav_16k/p232_360.wav 345 | 343,/vctk_data/clean_testset_wav_16k/p232_361.wav 346 | 344,/vctk_data/clean_testset_wav_16k/p232_362.wav 347 | 345,/vctk_data/clean_testset_wav_16k/p232_363.wav 348 | 346,/vctk_data/clean_testset_wav_16k/p232_364.wav 349 | 347,/vctk_data/clean_testset_wav_16k/p232_365.wav 350 | 348,/vctk_data/clean_testset_wav_16k/p232_366.wav 351 | 349,/vctk_data/clean_testset_wav_16k/p232_367.wav 352 | 350,/vctk_data/clean_testset_wav_16k/p232_368.wav 353 | 351,/vctk_data/clean_testset_wav_16k/p232_369.wav 354 | 352,/vctk_data/clean_testset_wav_16k/p232_370.wav 355 | 353,/vctk_data/clean_testset_wav_16k/p232_371.wav 356 | 354,/vctk_data/clean_testset_wav_16k/p232_372.wav 357 | 355,/vctk_data/clean_testset_wav_16k/p232_373.wav 358 | 356,/vctk_data/clean_testset_wav_16k/p232_374.wav 359 | 357,/vctk_data/clean_testset_wav_16k/p232_375.wav 360 | 358,/vctk_data/clean_testset_wav_16k/p232_377.wav 361 | 359,/vctk_data/clean_testset_wav_16k/p232_378.wav 362 | 360,/vctk_data/clean_testset_wav_16k/p232_379.wav 363 | 361,/vctk_data/clean_testset_wav_16k/p232_380.wav 364 | 362,/vctk_data/clean_testset_wav_16k/p232_381.wav 365 | 363,/vctk_data/clean_testset_wav_16k/p232_382.wav 366 | 364,/vctk_data/clean_testset_wav_16k/p232_383.wav 367 | 365,/vctk_data/clean_testset_wav_16k/p232_384.wav 368 | 366,/vctk_data/clean_testset_wav_16k/p232_385.wav 369 | 367,/vctk_data/clean_testset_wav_16k/p232_386.wav 370 | 368,/vctk_data/clean_testset_wav_16k/p232_387.wav 371 | 369,/vctk_data/clean_testset_wav_16k/p232_388.wav 372 | 370,/vctk_data/clean_testset_wav_16k/p232_389.wav 373 | 371,/vctk_data/clean_testset_wav_16k/p232_390.wav 374 | 372,/vctk_data/clean_testset_wav_16k/p232_391.wav 375 | 373,/vctk_data/clean_testset_wav_16k/p232_392.wav 376 | 374,/vctk_data/clean_testset_wav_16k/p232_393.wav 377 | 375,/vctk_data/clean_testset_wav_16k/p232_394.wav 378 | 376,/vctk_data/clean_testset_wav_16k/p232_396.wav 379 | 377,/vctk_data/clean_testset_wav_16k/p232_397.wav 380 | 378,/vctk_data/clean_testset_wav_16k/p232_398.wav 381 | 379,/vctk_data/clean_testset_wav_16k/p232_399.wav 382 | 380,/vctk_data/clean_testset_wav_16k/p232_400.wav 383 | 381,/vctk_data/clean_testset_wav_16k/p232_402.wav 384 | 382,/vctk_data/clean_testset_wav_16k/p232_403.wav 385 | 383,/vctk_data/clean_testset_wav_16k/p232_404.wav 386 | 384,/vctk_data/clean_testset_wav_16k/p232_405.wav 387 | 385,/vctk_data/clean_testset_wav_16k/p232_407.wav 388 | 386,/vctk_data/clean_testset_wav_16k/p232_409.wav 389 | 387,/vctk_data/clean_testset_wav_16k/p232_410.wav 390 | 388,/vctk_data/clean_testset_wav_16k/p232_411.wav 391 | 389,/vctk_data/clean_testset_wav_16k/p232_412.wav 392 | 390,/vctk_data/clean_testset_wav_16k/p232_413.wav 393 | 391,/vctk_data/clean_testset_wav_16k/p232_414.wav 394 | 392,/vctk_data/clean_testset_wav_16k/p232_415.wav 395 | 393,/vctk_data/clean_testset_wav_16k/p257_001.wav 396 | 394,/vctk_data/clean_testset_wav_16k/p257_002.wav 397 | 395,/vctk_data/clean_testset_wav_16k/p257_003.wav 398 | 396,/vctk_data/clean_testset_wav_16k/p257_004.wav 399 | 397,/vctk_data/clean_testset_wav_16k/p257_006.wav 400 | 398,/vctk_data/clean_testset_wav_16k/p257_007.wav 401 | 399,/vctk_data/clean_testset_wav_16k/p257_008.wav 402 | 400,/vctk_data/clean_testset_wav_16k/p257_009.wav 403 | 401,/vctk_data/clean_testset_wav_16k/p257_010.wav 404 | 402,/vctk_data/clean_testset_wav_16k/p257_011.wav 405 | 403,/vctk_data/clean_testset_wav_16k/p257_012.wav 406 | 404,/vctk_data/clean_testset_wav_16k/p257_013.wav 407 | 405,/vctk_data/clean_testset_wav_16k/p257_014.wav 408 | 406,/vctk_data/clean_testset_wav_16k/p257_015.wav 409 | 407,/vctk_data/clean_testset_wav_16k/p257_016.wav 410 | 408,/vctk_data/clean_testset_wav_16k/p257_017.wav 411 | 409,/vctk_data/clean_testset_wav_16k/p257_018.wav 412 | 410,/vctk_data/clean_testset_wav_16k/p257_019.wav 413 | 411,/vctk_data/clean_testset_wav_16k/p257_020.wav 414 | 412,/vctk_data/clean_testset_wav_16k/p257_022.wav 415 | 413,/vctk_data/clean_testset_wav_16k/p257_023.wav 416 | 414,/vctk_data/clean_testset_wav_16k/p257_024.wav 417 | 415,/vctk_data/clean_testset_wav_16k/p257_025.wav 418 | 416,/vctk_data/clean_testset_wav_16k/p257_026.wav 419 | 417,/vctk_data/clean_testset_wav_16k/p257_027.wav 420 | 418,/vctk_data/clean_testset_wav_16k/p257_028.wav 421 | 419,/vctk_data/clean_testset_wav_16k/p257_029.wav 422 | 420,/vctk_data/clean_testset_wav_16k/p257_030.wav 423 | 421,/vctk_data/clean_testset_wav_16k/p257_031.wav 424 | 422,/vctk_data/clean_testset_wav_16k/p257_032.wav 425 | 423,/vctk_data/clean_testset_wav_16k/p257_033.wav 426 | 424,/vctk_data/clean_testset_wav_16k/p257_034.wav 427 | 425,/vctk_data/clean_testset_wav_16k/p257_035.wav 428 | 426,/vctk_data/clean_testset_wav_16k/p257_036.wav 429 | 427,/vctk_data/clean_testset_wav_16k/p257_037.wav 430 | 428,/vctk_data/clean_testset_wav_16k/p257_038.wav 431 | 429,/vctk_data/clean_testset_wav_16k/p257_039.wav 432 | 430,/vctk_data/clean_testset_wav_16k/p257_040.wav 433 | 431,/vctk_data/clean_testset_wav_16k/p257_041.wav 434 | 432,/vctk_data/clean_testset_wav_16k/p257_042.wav 435 | 433,/vctk_data/clean_testset_wav_16k/p257_043.wav 436 | 434,/vctk_data/clean_testset_wav_16k/p257_044.wav 437 | 435,/vctk_data/clean_testset_wav_16k/p257_045.wav 438 | 436,/vctk_data/clean_testset_wav_16k/p257_046.wav 439 | 437,/vctk_data/clean_testset_wav_16k/p257_047.wav 440 | 438,/vctk_data/clean_testset_wav_16k/p257_048.wav 441 | 439,/vctk_data/clean_testset_wav_16k/p257_049.wav 442 | 440,/vctk_data/clean_testset_wav_16k/p257_050.wav 443 | 441,/vctk_data/clean_testset_wav_16k/p257_051.wav 444 | 442,/vctk_data/clean_testset_wav_16k/p257_052.wav 445 | 443,/vctk_data/clean_testset_wav_16k/p257_053.wav 446 | 444,/vctk_data/clean_testset_wav_16k/p257_054.wav 447 | 445,/vctk_data/clean_testset_wav_16k/p257_055.wav 448 | 446,/vctk_data/clean_testset_wav_16k/p257_056.wav 449 | 447,/vctk_data/clean_testset_wav_16k/p257_057.wav 450 | 448,/vctk_data/clean_testset_wav_16k/p257_058.wav 451 | 449,/vctk_data/clean_testset_wav_16k/p257_059.wav 452 | 450,/vctk_data/clean_testset_wav_16k/p257_060.wav 453 | 451,/vctk_data/clean_testset_wav_16k/p257_061.wav 454 | 452,/vctk_data/clean_testset_wav_16k/p257_062.wav 455 | 453,/vctk_data/clean_testset_wav_16k/p257_063.wav 456 | 454,/vctk_data/clean_testset_wav_16k/p257_064.wav 457 | 455,/vctk_data/clean_testset_wav_16k/p257_065.wav 458 | 456,/vctk_data/clean_testset_wav_16k/p257_066.wav 459 | 457,/vctk_data/clean_testset_wav_16k/p257_067.wav 460 | 458,/vctk_data/clean_testset_wav_16k/p257_068.wav 461 | 459,/vctk_data/clean_testset_wav_16k/p257_069.wav 462 | 460,/vctk_data/clean_testset_wav_16k/p257_070.wav 463 | 461,/vctk_data/clean_testset_wav_16k/p257_071.wav 464 | 462,/vctk_data/clean_testset_wav_16k/p257_072.wav 465 | 463,/vctk_data/clean_testset_wav_16k/p257_073.wav 466 | 464,/vctk_data/clean_testset_wav_16k/p257_074.wav 467 | 465,/vctk_data/clean_testset_wav_16k/p257_075.wav 468 | 466,/vctk_data/clean_testset_wav_16k/p257_076.wav 469 | 467,/vctk_data/clean_testset_wav_16k/p257_077.wav 470 | 468,/vctk_data/clean_testset_wav_16k/p257_078.wav 471 | 469,/vctk_data/clean_testset_wav_16k/p257_079.wav 472 | 470,/vctk_data/clean_testset_wav_16k/p257_080.wav 473 | 471,/vctk_data/clean_testset_wav_16k/p257_081.wav 474 | 472,/vctk_data/clean_testset_wav_16k/p257_082.wav 475 | 473,/vctk_data/clean_testset_wav_16k/p257_083.wav 476 | 474,/vctk_data/clean_testset_wav_16k/p257_084.wav 477 | 475,/vctk_data/clean_testset_wav_16k/p257_085.wav 478 | 476,/vctk_data/clean_testset_wav_16k/p257_086.wav 479 | 477,/vctk_data/clean_testset_wav_16k/p257_087.wav 480 | 478,/vctk_data/clean_testset_wav_16k/p257_088.wav 481 | 479,/vctk_data/clean_testset_wav_16k/p257_089.wav 482 | 480,/vctk_data/clean_testset_wav_16k/p257_090.wav 483 | 481,/vctk_data/clean_testset_wav_16k/p257_091.wav 484 | 482,/vctk_data/clean_testset_wav_16k/p257_092.wav 485 | 483,/vctk_data/clean_testset_wav_16k/p257_093.wav 486 | 484,/vctk_data/clean_testset_wav_16k/p257_094.wav 487 | 485,/vctk_data/clean_testset_wav_16k/p257_095.wav 488 | 486,/vctk_data/clean_testset_wav_16k/p257_096.wav 489 | 487,/vctk_data/clean_testset_wav_16k/p257_097.wav 490 | 488,/vctk_data/clean_testset_wav_16k/p257_098.wav 491 | 489,/vctk_data/clean_testset_wav_16k/p257_099.wav 492 | 490,/vctk_data/clean_testset_wav_16k/p257_100.wav 493 | 491,/vctk_data/clean_testset_wav_16k/p257_101.wav 494 | 492,/vctk_data/clean_testset_wav_16k/p257_102.wav 495 | 493,/vctk_data/clean_testset_wav_16k/p257_103.wav 496 | 494,/vctk_data/clean_testset_wav_16k/p257_104.wav 497 | 495,/vctk_data/clean_testset_wav_16k/p257_105.wav 498 | 496,/vctk_data/clean_testset_wav_16k/p257_106.wav 499 | 497,/vctk_data/clean_testset_wav_16k/p257_107.wav 500 | 498,/vctk_data/clean_testset_wav_16k/p257_108.wav 501 | 499,/vctk_data/clean_testset_wav_16k/p257_109.wav 502 | 500,/vctk_data/clean_testset_wav_16k/p257_110.wav 503 | 501,/vctk_data/clean_testset_wav_16k/p257_111.wav 504 | 502,/vctk_data/clean_testset_wav_16k/p257_112.wav 505 | 503,/vctk_data/clean_testset_wav_16k/p257_113.wav 506 | 504,/vctk_data/clean_testset_wav_16k/p257_114.wav 507 | 505,/vctk_data/clean_testset_wav_16k/p257_115.wav 508 | 506,/vctk_data/clean_testset_wav_16k/p257_116.wav 509 | 507,/vctk_data/clean_testset_wav_16k/p257_117.wav 510 | 508,/vctk_data/clean_testset_wav_16k/p257_118.wav 511 | 509,/vctk_data/clean_testset_wav_16k/p257_119.wav 512 | 510,/vctk_data/clean_testset_wav_16k/p257_120.wav 513 | 511,/vctk_data/clean_testset_wav_16k/p257_121.wav 514 | 512,/vctk_data/clean_testset_wav_16k/p257_122.wav 515 | 513,/vctk_data/clean_testset_wav_16k/p257_123.wav 516 | 514,/vctk_data/clean_testset_wav_16k/p257_124.wav 517 | 515,/vctk_data/clean_testset_wav_16k/p257_125.wav 518 | 516,/vctk_data/clean_testset_wav_16k/p257_126.wav 519 | 517,/vctk_data/clean_testset_wav_16k/p257_127.wav 520 | 518,/vctk_data/clean_testset_wav_16k/p257_128.wav 521 | 519,/vctk_data/clean_testset_wav_16k/p257_129.wav 522 | 520,/vctk_data/clean_testset_wav_16k/p257_130.wav 523 | 521,/vctk_data/clean_testset_wav_16k/p257_131.wav 524 | 522,/vctk_data/clean_testset_wav_16k/p257_132.wav 525 | 523,/vctk_data/clean_testset_wav_16k/p257_133.wav 526 | 524,/vctk_data/clean_testset_wav_16k/p257_135.wav 527 | 525,/vctk_data/clean_testset_wav_16k/p257_136.wav 528 | 526,/vctk_data/clean_testset_wav_16k/p257_137.wav 529 | 527,/vctk_data/clean_testset_wav_16k/p257_138.wav 530 | 528,/vctk_data/clean_testset_wav_16k/p257_139.wav 531 | 529,/vctk_data/clean_testset_wav_16k/p257_140.wav 532 | 530,/vctk_data/clean_testset_wav_16k/p257_141.wav 533 | 531,/vctk_data/clean_testset_wav_16k/p257_142.wav 534 | 532,/vctk_data/clean_testset_wav_16k/p257_143.wav 535 | 533,/vctk_data/clean_testset_wav_16k/p257_144.wav 536 | 534,/vctk_data/clean_testset_wav_16k/p257_145.wav 537 | 535,/vctk_data/clean_testset_wav_16k/p257_146.wav 538 | 536,/vctk_data/clean_testset_wav_16k/p257_147.wav 539 | 537,/vctk_data/clean_testset_wav_16k/p257_148.wav 540 | 538,/vctk_data/clean_testset_wav_16k/p257_149.wav 541 | 539,/vctk_data/clean_testset_wav_16k/p257_150.wav 542 | 540,/vctk_data/clean_testset_wav_16k/p257_151.wav 543 | 541,/vctk_data/clean_testset_wav_16k/p257_152.wav 544 | 542,/vctk_data/clean_testset_wav_16k/p257_153.wav 545 | 543,/vctk_data/clean_testset_wav_16k/p257_154.wav 546 | 544,/vctk_data/clean_testset_wav_16k/p257_155.wav 547 | 545,/vctk_data/clean_testset_wav_16k/p257_156.wav 548 | 546,/vctk_data/clean_testset_wav_16k/p257_157.wav 549 | 547,/vctk_data/clean_testset_wav_16k/p257_158.wav 550 | 548,/vctk_data/clean_testset_wav_16k/p257_159.wav 551 | 549,/vctk_data/clean_testset_wav_16k/p257_160.wav 552 | 550,/vctk_data/clean_testset_wav_16k/p257_161.wav 553 | 551,/vctk_data/clean_testset_wav_16k/p257_162.wav 554 | 552,/vctk_data/clean_testset_wav_16k/p257_163.wav 555 | 553,/vctk_data/clean_testset_wav_16k/p257_164.wav 556 | 554,/vctk_data/clean_testset_wav_16k/p257_165.wav 557 | 555,/vctk_data/clean_testset_wav_16k/p257_166.wav 558 | 556,/vctk_data/clean_testset_wav_16k/p257_167.wav 559 | 557,/vctk_data/clean_testset_wav_16k/p257_168.wav 560 | 558,/vctk_data/clean_testset_wav_16k/p257_169.wav 561 | 559,/vctk_data/clean_testset_wav_16k/p257_170.wav 562 | 560,/vctk_data/clean_testset_wav_16k/p257_171.wav 563 | 561,/vctk_data/clean_testset_wav_16k/p257_172.wav 564 | 562,/vctk_data/clean_testset_wav_16k/p257_173.wav 565 | 563,/vctk_data/clean_testset_wav_16k/p257_174.wav 566 | 564,/vctk_data/clean_testset_wav_16k/p257_175.wav 567 | 565,/vctk_data/clean_testset_wav_16k/p257_176.wav 568 | 566,/vctk_data/clean_testset_wav_16k/p257_177.wav 569 | 567,/vctk_data/clean_testset_wav_16k/p257_178.wav 570 | 568,/vctk_data/clean_testset_wav_16k/p257_179.wav 571 | 569,/vctk_data/clean_testset_wav_16k/p257_180.wav 572 | 570,/vctk_data/clean_testset_wav_16k/p257_181.wav 573 | 571,/vctk_data/clean_testset_wav_16k/p257_182.wav 574 | 572,/vctk_data/clean_testset_wav_16k/p257_183.wav 575 | 573,/vctk_data/clean_testset_wav_16k/p257_184.wav 576 | 574,/vctk_data/clean_testset_wav_16k/p257_185.wav 577 | 575,/vctk_data/clean_testset_wav_16k/p257_186.wav 578 | 576,/vctk_data/clean_testset_wav_16k/p257_187.wav 579 | 577,/vctk_data/clean_testset_wav_16k/p257_188.wav 580 | 578,/vctk_data/clean_testset_wav_16k/p257_189.wav 581 | 579,/vctk_data/clean_testset_wav_16k/p257_190.wav 582 | 580,/vctk_data/clean_testset_wav_16k/p257_191.wav 583 | 581,/vctk_data/clean_testset_wav_16k/p257_192.wav 584 | 582,/vctk_data/clean_testset_wav_16k/p257_193.wav 585 | 583,/vctk_data/clean_testset_wav_16k/p257_194.wav 586 | 584,/vctk_data/clean_testset_wav_16k/p257_195.wav 587 | 585,/vctk_data/clean_testset_wav_16k/p257_196.wav 588 | 586,/vctk_data/clean_testset_wav_16k/p257_197.wav 589 | 587,/vctk_data/clean_testset_wav_16k/p257_198.wav 590 | 588,/vctk_data/clean_testset_wav_16k/p257_199.wav 591 | 589,/vctk_data/clean_testset_wav_16k/p257_200.wav 592 | 590,/vctk_data/clean_testset_wav_16k/p257_201.wav 593 | 591,/vctk_data/clean_testset_wav_16k/p257_202.wav 594 | 592,/vctk_data/clean_testset_wav_16k/p257_203.wav 595 | 593,/vctk_data/clean_testset_wav_16k/p257_204.wav 596 | 594,/vctk_data/clean_testset_wav_16k/p257_205.wav 597 | 595,/vctk_data/clean_testset_wav_16k/p257_206.wav 598 | 596,/vctk_data/clean_testset_wav_16k/p257_207.wav 599 | 597,/vctk_data/clean_testset_wav_16k/p257_208.wav 600 | 598,/vctk_data/clean_testset_wav_16k/p257_209.wav 601 | 599,/vctk_data/clean_testset_wav_16k/p257_210.wav 602 | 600,/vctk_data/clean_testset_wav_16k/p257_211.wav 603 | 601,/vctk_data/clean_testset_wav_16k/p257_212.wav 604 | 602,/vctk_data/clean_testset_wav_16k/p257_213.wav 605 | 603,/vctk_data/clean_testset_wav_16k/p257_214.wav 606 | 604,/vctk_data/clean_testset_wav_16k/p257_215.wav 607 | 605,/vctk_data/clean_testset_wav_16k/p257_216.wav 608 | 606,/vctk_data/clean_testset_wav_16k/p257_217.wav 609 | 607,/vctk_data/clean_testset_wav_16k/p257_218.wav 610 | 608,/vctk_data/clean_testset_wav_16k/p257_219.wav 611 | 609,/vctk_data/clean_testset_wav_16k/p257_220.wav 612 | 610,/vctk_data/clean_testset_wav_16k/p257_221.wav 613 | 611,/vctk_data/clean_testset_wav_16k/p257_222.wav 614 | 612,/vctk_data/clean_testset_wav_16k/p257_223.wav 615 | 613,/vctk_data/clean_testset_wav_16k/p257_224.wav 616 | 614,/vctk_data/clean_testset_wav_16k/p257_225.wav 617 | 615,/vctk_data/clean_testset_wav_16k/p257_226.wav 618 | 616,/vctk_data/clean_testset_wav_16k/p257_227.wav 619 | 617,/vctk_data/clean_testset_wav_16k/p257_228.wav 620 | 618,/vctk_data/clean_testset_wav_16k/p257_229.wav 621 | 619,/vctk_data/clean_testset_wav_16k/p257_230.wav 622 | 620,/vctk_data/clean_testset_wav_16k/p257_231.wav 623 | 621,/vctk_data/clean_testset_wav_16k/p257_232.wav 624 | 622,/vctk_data/clean_testset_wav_16k/p257_233.wav 625 | 623,/vctk_data/clean_testset_wav_16k/p257_234.wav 626 | 624,/vctk_data/clean_testset_wav_16k/p257_235.wav 627 | 625,/vctk_data/clean_testset_wav_16k/p257_236.wav 628 | 626,/vctk_data/clean_testset_wav_16k/p257_237.wav 629 | 627,/vctk_data/clean_testset_wav_16k/p257_238.wav 630 | 628,/vctk_data/clean_testset_wav_16k/p257_239.wav 631 | 629,/vctk_data/clean_testset_wav_16k/p257_240.wav 632 | 630,/vctk_data/clean_testset_wav_16k/p257_241.wav 633 | 631,/vctk_data/clean_testset_wav_16k/p257_242.wav 634 | 632,/vctk_data/clean_testset_wav_16k/p257_243.wav 635 | 633,/vctk_data/clean_testset_wav_16k/p257_244.wav 636 | 634,/vctk_data/clean_testset_wav_16k/p257_245.wav 637 | 635,/vctk_data/clean_testset_wav_16k/p257_246.wav 638 | 636,/vctk_data/clean_testset_wav_16k/p257_247.wav 639 | 637,/vctk_data/clean_testset_wav_16k/p257_248.wav 640 | 638,/vctk_data/clean_testset_wav_16k/p257_249.wav 641 | 639,/vctk_data/clean_testset_wav_16k/p257_250.wav 642 | 640,/vctk_data/clean_testset_wav_16k/p257_251.wav 643 | 641,/vctk_data/clean_testset_wav_16k/p257_252.wav 644 | 642,/vctk_data/clean_testset_wav_16k/p257_253.wav 645 | 643,/vctk_data/clean_testset_wav_16k/p257_254.wav 646 | 644,/vctk_data/clean_testset_wav_16k/p257_255.wav 647 | 645,/vctk_data/clean_testset_wav_16k/p257_256.wav 648 | 646,/vctk_data/clean_testset_wav_16k/p257_257.wav 649 | 647,/vctk_data/clean_testset_wav_16k/p257_258.wav 650 | 648,/vctk_data/clean_testset_wav_16k/p257_259.wav 651 | 649,/vctk_data/clean_testset_wav_16k/p257_260.wav 652 | 650,/vctk_data/clean_testset_wav_16k/p257_261.wav 653 | 651,/vctk_data/clean_testset_wav_16k/p257_262.wav 654 | 652,/vctk_data/clean_testset_wav_16k/p257_263.wav 655 | 653,/vctk_data/clean_testset_wav_16k/p257_264.wav 656 | 654,/vctk_data/clean_testset_wav_16k/p257_265.wav 657 | 655,/vctk_data/clean_testset_wav_16k/p257_266.wav 658 | 656,/vctk_data/clean_testset_wav_16k/p257_267.wav 659 | 657,/vctk_data/clean_testset_wav_16k/p257_268.wav 660 | 658,/vctk_data/clean_testset_wav_16k/p257_269.wav 661 | 659,/vctk_data/clean_testset_wav_16k/p257_270.wav 662 | 660,/vctk_data/clean_testset_wav_16k/p257_271.wav 663 | 661,/vctk_data/clean_testset_wav_16k/p257_272.wav 664 | 662,/vctk_data/clean_testset_wav_16k/p257_273.wav 665 | 663,/vctk_data/clean_testset_wav_16k/p257_274.wav 666 | 664,/vctk_data/clean_testset_wav_16k/p257_275.wav 667 | 665,/vctk_data/clean_testset_wav_16k/p257_276.wav 668 | 666,/vctk_data/clean_testset_wav_16k/p257_277.wav 669 | 667,/vctk_data/clean_testset_wav_16k/p257_278.wav 670 | 668,/vctk_data/clean_testset_wav_16k/p257_279.wav 671 | 669,/vctk_data/clean_testset_wav_16k/p257_280.wav 672 | 670,/vctk_data/clean_testset_wav_16k/p257_281.wav 673 | 671,/vctk_data/clean_testset_wav_16k/p257_282.wav 674 | 672,/vctk_data/clean_testset_wav_16k/p257_283.wav 675 | 673,/vctk_data/clean_testset_wav_16k/p257_284.wav 676 | 674,/vctk_data/clean_testset_wav_16k/p257_285.wav 677 | 675,/vctk_data/clean_testset_wav_16k/p257_286.wav 678 | 676,/vctk_data/clean_testset_wav_16k/p257_287.wav 679 | 677,/vctk_data/clean_testset_wav_16k/p257_288.wav 680 | 678,/vctk_data/clean_testset_wav_16k/p257_289.wav 681 | 679,/vctk_data/clean_testset_wav_16k/p257_290.wav 682 | 680,/vctk_data/clean_testset_wav_16k/p257_291.wav 683 | 681,/vctk_data/clean_testset_wav_16k/p257_292.wav 684 | 682,/vctk_data/clean_testset_wav_16k/p257_293.wav 685 | 683,/vctk_data/clean_testset_wav_16k/p257_294.wav 686 | 684,/vctk_data/clean_testset_wav_16k/p257_295.wav 687 | 685,/vctk_data/clean_testset_wav_16k/p257_296.wav 688 | 686,/vctk_data/clean_testset_wav_16k/p257_297.wav 689 | 687,/vctk_data/clean_testset_wav_16k/p257_298.wav 690 | 688,/vctk_data/clean_testset_wav_16k/p257_299.wav 691 | 689,/vctk_data/clean_testset_wav_16k/p257_300.wav 692 | 690,/vctk_data/clean_testset_wav_16k/p257_301.wav 693 | 691,/vctk_data/clean_testset_wav_16k/p257_302.wav 694 | 692,/vctk_data/clean_testset_wav_16k/p257_303.wav 695 | 693,/vctk_data/clean_testset_wav_16k/p257_304.wav 696 | 694,/vctk_data/clean_testset_wav_16k/p257_305.wav 697 | 695,/vctk_data/clean_testset_wav_16k/p257_306.wav 698 | 696,/vctk_data/clean_testset_wav_16k/p257_307.wav 699 | 697,/vctk_data/clean_testset_wav_16k/p257_308.wav 700 | 698,/vctk_data/clean_testset_wav_16k/p257_309.wav 701 | 699,/vctk_data/clean_testset_wav_16k/p257_310.wav 702 | 700,/vctk_data/clean_testset_wav_16k/p257_311.wav 703 | 701,/vctk_data/clean_testset_wav_16k/p257_312.wav 704 | 702,/vctk_data/clean_testset_wav_16k/p257_313.wav 705 | 703,/vctk_data/clean_testset_wav_16k/p257_314.wav 706 | 704,/vctk_data/clean_testset_wav_16k/p257_315.wav 707 | 705,/vctk_data/clean_testset_wav_16k/p257_316.wav 708 | 706,/vctk_data/clean_testset_wav_16k/p257_317.wav 709 | 707,/vctk_data/clean_testset_wav_16k/p257_318.wav 710 | 708,/vctk_data/clean_testset_wav_16k/p257_319.wav 711 | 709,/vctk_data/clean_testset_wav_16k/p257_320.wav 712 | 710,/vctk_data/clean_testset_wav_16k/p257_321.wav 713 | 711,/vctk_data/clean_testset_wav_16k/p257_322.wav 714 | 712,/vctk_data/clean_testset_wav_16k/p257_323.wav 715 | 713,/vctk_data/clean_testset_wav_16k/p257_324.wav 716 | 714,/vctk_data/clean_testset_wav_16k/p257_325.wav 717 | 715,/vctk_data/clean_testset_wav_16k/p257_326.wav 718 | 716,/vctk_data/clean_testset_wav_16k/p257_327.wav 719 | 717,/vctk_data/clean_testset_wav_16k/p257_328.wav 720 | 718,/vctk_data/clean_testset_wav_16k/p257_329.wav 721 | 719,/vctk_data/clean_testset_wav_16k/p257_330.wav 722 | 720,/vctk_data/clean_testset_wav_16k/p257_331.wav 723 | 721,/vctk_data/clean_testset_wav_16k/p257_332.wav 724 | 722,/vctk_data/clean_testset_wav_16k/p257_333.wav 725 | 723,/vctk_data/clean_testset_wav_16k/p257_334.wav 726 | 724,/vctk_data/clean_testset_wav_16k/p257_335.wav 727 | 725,/vctk_data/clean_testset_wav_16k/p257_336.wav 728 | 726,/vctk_data/clean_testset_wav_16k/p257_337.wav 729 | 727,/vctk_data/clean_testset_wav_16k/p257_338.wav 730 | 728,/vctk_data/clean_testset_wav_16k/p257_339.wav 731 | 729,/vctk_data/clean_testset_wav_16k/p257_340.wav 732 | 730,/vctk_data/clean_testset_wav_16k/p257_341.wav 733 | 731,/vctk_data/clean_testset_wav_16k/p257_342.wav 734 | 732,/vctk_data/clean_testset_wav_16k/p257_343.wav 735 | 733,/vctk_data/clean_testset_wav_16k/p257_344.wav 736 | 734,/vctk_data/clean_testset_wav_16k/p257_345.wav 737 | 735,/vctk_data/clean_testset_wav_16k/p257_346.wav 738 | 736,/vctk_data/clean_testset_wav_16k/p257_347.wav 739 | 737,/vctk_data/clean_testset_wav_16k/p257_348.wav 740 | 738,/vctk_data/clean_testset_wav_16k/p257_349.wav 741 | 739,/vctk_data/clean_testset_wav_16k/p257_350.wav 742 | 740,/vctk_data/clean_testset_wav_16k/p257_351.wav 743 | 741,/vctk_data/clean_testset_wav_16k/p257_352.wav 744 | 742,/vctk_data/clean_testset_wav_16k/p257_353.wav 745 | 743,/vctk_data/clean_testset_wav_16k/p257_354.wav 746 | 744,/vctk_data/clean_testset_wav_16k/p257_355.wav 747 | 745,/vctk_data/clean_testset_wav_16k/p257_356.wav 748 | 746,/vctk_data/clean_testset_wav_16k/p257_357.wav 749 | 747,/vctk_data/clean_testset_wav_16k/p257_358.wav 750 | 748,/vctk_data/clean_testset_wav_16k/p257_359.wav 751 | 749,/vctk_data/clean_testset_wav_16k/p257_360.wav 752 | 750,/vctk_data/clean_testset_wav_16k/p257_361.wav 753 | 751,/vctk_data/clean_testset_wav_16k/p257_362.wav 754 | 752,/vctk_data/clean_testset_wav_16k/p257_363.wav 755 | 753,/vctk_data/clean_testset_wav_16k/p257_364.wav 756 | 754,/vctk_data/clean_testset_wav_16k/p257_365.wav 757 | 755,/vctk_data/clean_testset_wav_16k/p257_366.wav 758 | 756,/vctk_data/clean_testset_wav_16k/p257_367.wav 759 | 757,/vctk_data/clean_testset_wav_16k/p257_368.wav 760 | 758,/vctk_data/clean_testset_wav_16k/p257_369.wav 761 | 759,/vctk_data/clean_testset_wav_16k/p257_370.wav 762 | 760,/vctk_data/clean_testset_wav_16k/p257_371.wav 763 | 761,/vctk_data/clean_testset_wav_16k/p257_372.wav 764 | 762,/vctk_data/clean_testset_wav_16k/p257_373.wav 765 | 763,/vctk_data/clean_testset_wav_16k/p257_374.wav 766 | 764,/vctk_data/clean_testset_wav_16k/p257_375.wav 767 | 765,/vctk_data/clean_testset_wav_16k/p257_376.wav 768 | 766,/vctk_data/clean_testset_wav_16k/p257_377.wav 769 | 767,/vctk_data/clean_testset_wav_16k/p257_378.wav 770 | 768,/vctk_data/clean_testset_wav_16k/p257_379.wav 771 | 769,/vctk_data/clean_testset_wav_16k/p257_380.wav 772 | 770,/vctk_data/clean_testset_wav_16k/p257_381.wav 773 | 771,/vctk_data/clean_testset_wav_16k/p257_382.wav 774 | 772,/vctk_data/clean_testset_wav_16k/p257_383.wav 775 | 773,/vctk_data/clean_testset_wav_16k/p257_384.wav 776 | 774,/vctk_data/clean_testset_wav_16k/p257_385.wav 777 | 775,/vctk_data/clean_testset_wav_16k/p257_386.wav 778 | 776,/vctk_data/clean_testset_wav_16k/p257_387.wav 779 | 777,/vctk_data/clean_testset_wav_16k/p257_388.wav 780 | 778,/vctk_data/clean_testset_wav_16k/p257_389.wav 781 | 779,/vctk_data/clean_testset_wav_16k/p257_390.wav 782 | 780,/vctk_data/clean_testset_wav_16k/p257_391.wav 783 | 781,/vctk_data/clean_testset_wav_16k/p257_392.wav 784 | 782,/vctk_data/clean_testset_wav_16k/p257_393.wav 785 | 783,/vctk_data/clean_testset_wav_16k/p257_394.wav 786 | 784,/vctk_data/clean_testset_wav_16k/p257_395.wav 787 | 785,/vctk_data/clean_testset_wav_16k/p257_396.wav 788 | 786,/vctk_data/clean_testset_wav_16k/p257_397.wav 789 | 787,/vctk_data/clean_testset_wav_16k/p257_398.wav 790 | 788,/vctk_data/clean_testset_wav_16k/p257_399.wav 791 | 789,/vctk_data/clean_testset_wav_16k/p257_400.wav 792 | 790,/vctk_data/clean_testset_wav_16k/p257_401.wav 793 | 791,/vctk_data/clean_testset_wav_16k/p257_402.wav 794 | 792,/vctk_data/clean_testset_wav_16k/p257_403.wav 795 | 793,/vctk_data/clean_testset_wav_16k/p257_404.wav 796 | 794,/vctk_data/clean_testset_wav_16k/p257_405.wav 797 | 795,/vctk_data/clean_testset_wav_16k/p257_406.wav 798 | 796,/vctk_data/clean_testset_wav_16k/p257_407.wav 799 | 797,/vctk_data/clean_testset_wav_16k/p257_408.wav 800 | 798,/vctk_data/clean_testset_wav_16k/p257_409.wav 801 | 799,/vctk_data/clean_testset_wav_16k/p257_410.wav 802 | 800,/vctk_data/clean_testset_wav_16k/p257_411.wav 803 | 801,/vctk_data/clean_testset_wav_16k/p257_412.wav 804 | 802,/vctk_data/clean_testset_wav_16k/p257_413.wav 805 | 803,/vctk_data/clean_testset_wav_16k/p257_414.wav 806 | 804,/vctk_data/clean_testset_wav_16k/p257_415.wav 807 | 805,/vctk_data/clean_testset_wav_16k/p257_416.wav 808 | 806,/vctk_data/clean_testset_wav_16k/p257_417.wav 809 | 807,/vctk_data/clean_testset_wav_16k/p257_418.wav 810 | 808,/vctk_data/clean_testset_wav_16k/p257_419.wav 811 | 809,/vctk_data/clean_testset_wav_16k/p257_420.wav 812 | 810,/vctk_data/clean_testset_wav_16k/p257_421.wav 813 | 811,/vctk_data/clean_testset_wav_16k/p257_422.wav 814 | 812,/vctk_data/clean_testset_wav_16k/p257_423.wav 815 | 813,/vctk_data/clean_testset_wav_16k/p257_424.wav 816 | 814,/vctk_data/clean_testset_wav_16k/p257_425.wav 817 | 815,/vctk_data/clean_testset_wav_16k/p257_426.wav 818 | 816,/vctk_data/clean_testset_wav_16k/p257_427.wav 819 | 817,/vctk_data/clean_testset_wav_16k/p257_428.wav 820 | 818,/vctk_data/clean_testset_wav_16k/p257_429.wav 821 | 819,/vctk_data/clean_testset_wav_16k/p257_430.wav 822 | 820,/vctk_data/clean_testset_wav_16k/p257_431.wav 823 | 821,/vctk_data/clean_testset_wav_16k/p257_432.wav 824 | 822,/vctk_data/clean_testset_wav_16k/p257_433.wav 825 | 823,/vctk_data/clean_testset_wav_16k/p257_434.wav 826 | -------------------------------------------------------------------------------- /VCTK_noisy_testSet_with_scores.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/VCTK_noisy_testSet_with_scores.pickle -------------------------------------------------------------------------------- /VCTK_noisy_validationSet.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/VCTK_noisy_validationSet.pickle -------------------------------------------------------------------------------- /VQScore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/VQScore.png -------------------------------------------------------------------------------- /adv_wav.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/adv_wav.png -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import yaml 7 | import random 8 | import logging 9 | import torch 10 | import numpy as np 11 | 12 | class Train(object): 13 | def __init__( 14 | self, 15 | args, 16 | ): 17 | # set logger 18 | logging.basicConfig( 19 | level=logging.INFO, 20 | stream=sys.stdout, 21 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 22 | ) 23 | 24 | # Fix seed and make backends deterministic 25 | random.seed(args.seed) 26 | np.random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | if not torch.cuda.is_available(): 29 | self.device = torch.device('cpu') 30 | logging.info(f"device: cpu") 31 | else: 32 | self.device = torch.device('cuda') 33 | logging.info(f"device: gpu") 34 | torch.cuda.manual_seed_all(args.seed) 35 | if args.disable_cudnn == "False": 36 | torch.backends.cudnn.benchmark = True 37 | 38 | # initialize config 39 | with open(args.config, 'r') as f: 40 | self.config = yaml.load(f, Loader=yaml.FullLoader) 41 | self.config.update(vars(args)) 42 | 43 | # initialize model folder 44 | expdir = os.path.join(args.exp_root, args.tag) 45 | os.makedirs(expdir, exist_ok=True) 46 | self.config["outdir"] = expdir 47 | 48 | # save config 49 | with open(os.path.join(expdir, "config.yml"), "w") as f: 50 | yaml.dump(self.config, f, Dumper=yaml.Dumper) 51 | for key, value in self.config.items(): 52 | logging.info(f"[TrainGAN] {key} = {value}") 53 | 54 | # initialize attribute 55 | self.resume = args.resume 56 | self.data_loader = None 57 | self.model = None 58 | self.criterion = None 59 | self.optimizer = None 60 | self.scheduler = None 61 | self.trainer = None 62 | 63 | 64 | def initialize_data_loader(self): 65 | pass 66 | 67 | 68 | def define_model(self): 69 | pass 70 | 71 | 72 | def define_trainer(self): 73 | pass 74 | 75 | 76 | def initialize_model(self): 77 | initial = self.config.get("initial", "") 78 | if len(self.resume) != 0: 79 | self.trainer.load_checkpoint(self.resume) 80 | logging.info(f"Successfully resumed from {self.resume}.") 81 | elif len(initial) != 0: 82 | self.trainer.load_checkpoint(initial, load_only_params=True) 83 | logging.info(f"Successfully initialize parameters from {initial}.") 84 | else: 85 | logging.info("Train from scrach") 86 | 87 | 88 | def run(self): 89 | try: 90 | self.trainer.run() 91 | finally: 92 | self.trainer.save_checkpoint( 93 | os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl") 94 | ) 95 | logging.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.") 96 | 97 | 98 | def _define_optimizer_scheduler(self): 99 | VQVAE_optimizer_class = getattr( 100 | torch.optim, 101 | self.config['VQVAE_optimizer_type']) 102 | 103 | self.optimizer = { 104 | 'VQVAE': VQVAE_optimizer_class( 105 | self.model['VQVAE'].parameters(), 106 | **self.config['VQVAE_optimizer_params'])} 107 | 108 | VQVAE_scheduler_class = getattr( 109 | torch.optim.lr_scheduler, 110 | self.config.get('VQVAE_scheduler_type', "StepLR")) 111 | 112 | self.scheduler = { 113 | 'VQVAE': VQVAE_scheduler_class( 114 | optimizer=self.optimizer['VQVAE'], 115 | **self.config['VQVAE_scheduler_params'])} 116 | 117 | def _show_setting(self): 118 | logging.info(self.model['VQVAE']) 119 | logging.info(self.optimizer['VQVAE']) 120 | logging.info(self.scheduler['VQVAE']) 121 | 122 | -------------------------------------------------------------------------------- /clean_p232_005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/clean_p232_005.wav -------------------------------------------------------------------------------- /config/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github.yaml: -------------------------------------------------------------------------------- 1 | name: QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github 2 | ########################################################### 3 | # DATA SETTING # 4 | ########################################################### 5 | # user defined data path 6 | sampling_rate: 16000 7 | data: 8 | path: "" 9 | subset: 10 | clean_train: "./Librispeech_clean.csv" 11 | clean_valid: "./VCTK_clean_test.csv" # not actually used 12 | 13 | ########################################################### 14 | # MODEL SETTING # 15 | ########################################################### 16 | task: Quality_Estimation # Speech_Enhancement or Quality_Estimation 17 | train_mode: autoencoder 18 | cos_loss: True 19 | input_transform: None 20 | 21 | 22 | VQVAE_params: 23 | codebook_size: 2048 24 | codebook_num: 1 25 | codebook_dim: 32 26 | orthogonal_reg_weight: 0 27 | use_cosine_sim: True 28 | ema_update: True 29 | learnable_codebook: False 30 | stochastic_sample_codes: False 31 | sample_codebook_temp: 6 32 | straight_through: False 33 | reinmax: False 34 | kmeans_init: True 35 | threshold_ema_dead_code: -1000 36 | 37 | ########################################################### 38 | # LOSS WEIGHT SETTING # 39 | ########################################################### 40 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 41 | lambda_stft_loss: 45.0 # Loss weight of stft loss. 42 | 43 | ########################################################### 44 | # DATA LOADER SETTING # 45 | ########################################################### 46 | batch_size: 64 # Batch size. 47 | batch_length: 48000 # Length of each audio in batch. 48 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 49 | num_workers: 6 # Number of workers in Pytorch DataLoader. 50 | 51 | ########################################################### 52 | # OPTIMIZER & SCHEDULER SETTING # 53 | ########################################################### 54 | VQVAE_optimizer_type: Adam 55 | VQVAE_optimizer_params: 56 | lr: 1.0e-5 57 | betas: [0.5, 0.9] 58 | weight_decay: 0.0 59 | VQVAE_scheduler_type: StepLR 60 | VQVAE_scheduler_params: 61 | step_size: 200000 # Generator's scheduler step size. 62 | gamma: 1.0 63 | VQVAE_grad_norm: -1 64 | 65 | ########################################################### 66 | # INTERVAL SETTING # 67 | ########################################################### 68 | start_steps: # Number of steps to start training 69 | VQVAE: 0 70 | AT_training_start_steps: 60000000 71 | train_max_steps: 800000 # Number of training steps. 72 | save_interval_steps: 100000 # Interval steps to save checkpoint. 73 | eval_interval_steps: 20000 # Interval steps to evaluate the network. 74 | log_interval_steps: 20000 # Interval steps to record the training log. 75 | -------------------------------------------------------------------------------- /config/SE_cbook_4096_1_128_lr_1m5_1m5_github.yaml: -------------------------------------------------------------------------------- 1 | name: SE_cbook_4096_1_128_lr_1m5_1m5_github 2 | ########################################################### 3 | # DATA SETTING # 4 | ########################################################### 5 | # user defined data path 6 | vctk_Clean_path: '/vctk_data/clean_testset_wav_16k/' 7 | DNS1_test: '/DNS1_test' 8 | DNS3_test: '/DNS3_test' 9 | 10 | sampling_rate: 16000 11 | data: 12 | path: "" 13 | subset: 14 | clean_train: "./VCTK_clean_train.csv" 15 | clean_valid: "./VCTK_clean_test.csv" # not actually used 16 | ########################################################### 17 | # MODEL SETTING # 18 | ########################################################### 19 | task: Speech_Enhancement # Speech_Enhancement or Quality_Estimation 20 | train_mode: autoencoder 21 | cos_loss: False 22 | input_transform: log1p 23 | 24 | adv_min_epsilon: 0.03 25 | adv_max_epsilon: 0.55 26 | 27 | VQVAE_params: 28 | codebook_size: 4096 29 | codebook_num: 1 30 | codebook_dim: 128 31 | orthogonal_reg_weight: 0 32 | use_cosine_sim: False 33 | ema_update: True 34 | learnable_codebook: False 35 | stochastic_sample_codes: False 36 | sample_codebook_temp: 6 37 | straight_through: False 38 | reinmax: False 39 | kmeans_init: True 40 | threshold_ema_dead_code: -1000 41 | 42 | ########################################################### 43 | # LOSS WEIGHT SETTING # 44 | ########################################################### 45 | lambda_vq_loss: 3.0 # Loss weight of vector quantize loss. 46 | lambda_ce_loss: 1.0 # Loss weight of cross-entropy loss. 47 | lambda_stft_loss: 45.0 # Loss weight of stft loss. 48 | 49 | ########################################################### 50 | # DATA LOADER SETTING # 51 | ########################################################### 52 | batch_size: 64 # Batch size. 53 | batch_length: 48000 # Length of each audio in batch. 54 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 55 | num_workers: 6 # Number of workers in Pytorch DataLoader. 56 | 57 | ########################################################### 58 | # OPTIMIZER & SCHEDULER SETTING # 59 | ########################################################### 60 | VQVAE_optimizer_type: Adam 61 | VQVAE_optimizer_params: 62 | lr: 1.0e-5 63 | betas: [0.5, 0.9] 64 | weight_decay: 0.0 65 | VQVAE_scheduler_type: StepLR 66 | VQVAE_scheduler_params: 67 | step_size: 200000 # Generator's scheduler step size. 68 | gamma: 1.0 69 | VQVAE_grad_norm: -1 70 | 71 | VQVAE_AT_optimizer_params: 72 | lr: 1.0e-5 73 | betas: [0.5, 0.9] 74 | weight_decay: 0.0 75 | 76 | ########################################################### 77 | # INTERVAL SETTING # 78 | ########################################################### 79 | start_steps: # Number of steps to start training 80 | VQVAE: 0 81 | AT_training_start_steps: 400000 82 | train_max_steps: 500000 # Number of training steps. 83 | save_interval_steps: 100000 # Interval steps to save checkpoint. 84 | eval_interval_steps: 20000 # Interval steps to evaluate the network. 85 | log_interval_steps: 20000 # Interval steps to record the training log. 86 | -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | import soundfile as sf 8 | from torch.utils.data import Dataset 9 | 10 | class SingleDataset(Dataset): 11 | def __init__( 12 | self, 13 | data_path, 14 | files, 15 | query="*.wav", 16 | load_fn=sf.read, 17 | return_utt_id=False, 18 | subset_num=-1, 19 | batch_length=9600 20 | ): 21 | self.return_utt_id = return_utt_id 22 | self.load_fn = load_fn 23 | self.subset_num = subset_num 24 | self.data_path = data_path 25 | self.batch_length = batch_length 26 | self.filenames = self._load_list(files, query) 27 | # self.filenames = pd.read_csv(files)['Filaname'] 28 | self.utt_ids = self._load_ids(self.filenames) 29 | 30 | 31 | def __getitem__(self, idx): 32 | utt_id = self.utt_ids[idx] 33 | data = self._data(idx) 34 | 35 | if self.return_utt_id: 36 | items = utt_id, data 37 | else: 38 | items = data 39 | 40 | 41 | return items 42 | 43 | 44 | def __len__(self): 45 | return len(self.filenames) 46 | 47 | 48 | def _load_list(self, files, query): 49 | # if isinstance(files, list): 50 | # filenames = files 51 | # else: 52 | # if os.path.exists(files): 53 | # filenames = sorted(find_files(files, query)) 54 | # else: 55 | # raise ValueError(f"{files} is not a list or a existing folder!") 56 | 57 | # if self.subset_num > 0: 58 | # filenames = filenames[:self.subset_num] 59 | # assert len(filenames) != 0, f"File list in empty!" 60 | # return filenames 61 | filenames = pd.read_csv(files)['Filename'].to_list() 62 | filenames = [os.path.join(self.data_path,filename) for filename in filenames] 63 | return filenames 64 | 65 | 66 | def _load_ids(self, filenames): 67 | utt_ids = [ 68 | os.path.splitext(os.path.basename(f))[0] for f in filenames 69 | ] 70 | return utt_ids 71 | 72 | 73 | def _data(self, idx): 74 | return self._load_data(self.filenames[idx], self.load_fn) 75 | 76 | 77 | def _load_data(self, filename, load_fn): 78 | if load_fn == sf.read: 79 | data = load_fn(filename, always_2d=True)[0][:,0] # T x C, 1 80 | data_shape = data.shape[0] 81 | if data.shape[0]<=self.batch_length: 82 | data = np.concatenate((data, np.zeros(self.batch_length-data.shape[0])))[None,:].astype(np.float32) 83 | else: 84 | start = np.random.randint(0,(data.shape[0]-self.batch_length+1)) 85 | data = data[None, start : start + self.batch_length].astype(np.float32) 86 | else: 87 | data = load_fn(filename) 88 | return data, data_shape 89 | 90 | 91 | -------------------------------------------------------------------------------- /exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/exp/QE_cbook_size_2048_1_32_IN_input_encoder_z_Librispeech_clean_github/checkpoint-dnsmos_ovr_CC=0.835.pkl -------------------------------------------------------------------------------- /exp/SE_cbook_4096_1_128_lr_1m5_1m5_github/checkpoint-dnsmos_ovr=2.654.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/exp/SE_cbook_4096_1_128_lr_1m5_1m5_github/checkpoint-dnsmos_ovr=2.654.pkl -------------------------------------------------------------------------------- /exp/SE_cbook_4096_1_128_lr_1m5_1m5_github/checkpoint-dnsmos_ovr=2.761_AT.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/exp/SE_cbook_4096_1_128_lr_1m5_1m5_github/checkpoint-dnsmos_ovr=2.761_AT.pkl -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: szuweif 4 | """ 5 | import argparse 6 | import yaml 7 | import torch 8 | import torchaudio 9 | from models.VQVAE_models import VQVAE_SE, VQVAE_QE 10 | 11 | 12 | def resynthesize(enhanced_mag, noisy_inputs, hop_size): 13 | """Function for resynthesizing waveforms from enhanced mags. 14 | Arguments 15 | --------- 16 | enhanced_mag : torch.Tensor 17 | Predicted spectral magnitude, should be three dimensional. 18 | noisy_inputs : torch.Tensor 19 | The noisy waveforms before any processing, to extract phase. 20 | Returns 21 | ------- 22 | enhanced_wav : torch.Tensor 23 | The resynthesized waveforms of the enhanced magnitudes with noisy phase. 24 | """ 25 | 26 | # Extract noisy phase from inputs 27 | 28 | noisy_feats = torch.stft(noisy_inputs, n_fft=512, hop_length=hop_size, win_length=512, 29 | window=torch.hamming_window(512).to('cuda'), 30 | center=True, 31 | pad_mode="constant", 32 | onesided=True, 33 | return_complex=False).transpose(2, 1) 34 | 35 | noisy_phase = torch.atan2(noisy_feats[:, :, :, 1], noisy_feats[:, :, :, 0])[:,0:enhanced_mag.shape[1],:] 36 | 37 | # Combine with enhanced magnitude 38 | predictions = torch.mul( 39 | torch.unsqueeze(enhanced_mag, -1), 40 | torch.cat( 41 | ( 42 | torch.unsqueeze(torch.cos(noisy_phase), -1), 43 | torch.unsqueeze(torch.sin(noisy_phase), -1), 44 | ), 45 | -1, 46 | ), 47 | ).permute(0, 2, 1, 3) 48 | 49 | # isft ask complex input 50 | complex_predictions = torch.complex(predictions[..., 0], predictions[..., 1]) 51 | pred_wavs = torch.istft(input=complex_predictions, n_fft=512, hop_length=hop_size, win_length=512, 52 | window=torch.hamming_window(512).to('cuda'), 53 | center=True, 54 | onesided=True, 55 | length=noisy_inputs.shape[1]) 56 | 57 | return pred_wavs 58 | 59 | def stft_magnitude(x, hop_size, fft_size=512, win_length=512): 60 | if x.is_cuda: 61 | x_stft = torch.stft( 62 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length).to('cuda'), return_complex=False 63 | ) 64 | else: 65 | x_stft = torch.stft( 66 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length), return_complex=False 67 | ) 68 | real = x_stft[..., 0] 69 | imag = x_stft[..., 1] 70 | 71 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 72 | 73 | 74 | def cos_loss(SP_noisy, SP_y_noisy): 75 | eps=1e-5 76 | SP_noisy_norm = torch.norm(SP_noisy, p=2, dim=-1, keepdim=True)+eps 77 | SP_y_noisy_norm = torch.norm(SP_y_noisy, p=2, dim=-1, keepdim=True)+eps 78 | Cos_frame = torch.sum(SP_noisy/SP_noisy_norm * SP_y_noisy/SP_y_noisy_norm, dim=-1) # torch.Size([B, T, 1]) 79 | 80 | return -torch.mean(Cos_frame) 81 | 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('-c', '--config', type=str, required=True) 84 | parser.add_argument('-m','--path_of_model_weights', type=str, required=True) 85 | parser.add_argument('-i', '--input_wav_path', type=str, required=True) 86 | parser.add_argument('-e', '--enhanced_wav_path', type=str, default="./enhanced.wav") 87 | args = parser.parse_args() 88 | 89 | # initialize config 90 | with open(args.config, 'r') as f: 91 | config = yaml.load(f, Loader=yaml.FullLoader) 92 | 93 | if not torch.cuda.is_available(): 94 | device = torch.device('cpu') 95 | print("device: cpu") 96 | else: 97 | device = torch.device('cuda') 98 | print("device: gpu") 99 | torch.backends.cudnn.benchmark = True 100 | 101 | with torch.no_grad(): 102 | if config['task'] == "Speech_Enhancement": 103 | VQVAE = VQVAE_SE(**config['VQVAE_params']).to(device).eval() 104 | hop_size = 128 105 | 106 | VQVAE.load_state_dict(torch.load(args.path_of_model_weights)['model']['VQVAE']) 107 | wav_input, fs = torchaudio.load(args.input_wav_path) 108 | 109 | wav_input = wav_input.to(device) 110 | SP_input = stft_magnitude(wav_input, hop_size=hop_size) 111 | if config['input_transform'] == 'log1p': 112 | SP_input = torch.log1p(SP_input) 113 | 114 | z = VQVAE.CNN_1D_encoder(SP_input) 115 | zq, indices, vqloss, distance = VQVAE.quantizer(z, stochastic=False, update=False) 116 | SP_output = VQVAE.CNN_1D_decoder(zq) 117 | 118 | if config['input_transform'] == 'log1p': 119 | wav_output = resynthesize(torch.expm1(SP_output), wav_input, hop_size).cpu() 120 | else: 121 | wav_output = resynthesize(SP_output, wav_input, hop_size).cpu() 122 | 123 | torchaudio.save(args.enhanced_wav_path, wav_output, 16000) 124 | print('=================================================') 125 | print('enhanced wav is saved at:' + args.enhanced_wav_path) 126 | 127 | elif config['task'] == "Quality_Estimation": 128 | VQVAE = VQVAE_QE(**config['VQVAE_params']).to(device).eval() 129 | hop_size = 256 130 | 131 | VQVAE.load_state_dict(torch.load(args.path_of_model_weights)['model']['VQVAE']) 132 | wav_input, fs = torchaudio.load(args.input_wav_path) 133 | 134 | wav_input = wav_input.to(device) 135 | SP_input = stft_magnitude(wav_input, hop_size=hop_size) 136 | if config['input_transform'] == 'log1p': 137 | SP_input = torch.log1p(SP_input) 138 | 139 | z = VQVAE.CNN_1D_encoder(SP_input) 140 | zq, indices, vqloss, distance = VQVAE.quantizer(z, stochastic=False, update=False) 141 | SP_output = VQVAE.CNN_1D_decoder(zq) 142 | 143 | VQScore_cos_z = -cos_loss(z.transpose(2, 1).cpu(), zq.cpu()).numpy() 144 | print('=================================================') 145 | print('VQScore_cos_z = ' + str(VQScore_cos_z)) -------------------------------------------------------------------------------- /inference_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: szuweif 4 | """ 5 | import os 6 | import argparse 7 | import yaml 8 | import torch 9 | import torchaudio 10 | import numpy as np 11 | import pandas as pd 12 | import time 13 | from models.VQVAE_models import VQVAE_SE, VQVAE_QE 14 | 15 | def get_filepaths(directory): 16 | """ 17 | This function will generate the file names in a directory 18 | tree by walking the tree either top-down or bottom-up. For each 19 | directory in the tree rooted at directory top (including top itself), 20 | it yields a 3-tuple (dirpath, dirnames, filenames). 21 | """ 22 | file_paths = [] # List which will store all of the full filepaths. 23 | # Walk the tree. 24 | for root, directories, files in os.walk(directory): 25 | for filename in files: 26 | # Join the two strings in order to form the full filepath. 27 | filepath = os.path.join(root, filename) 28 | file_paths.append(filepath) # Add it to the list. 29 | return file_paths # Self-explanatory. 30 | 31 | def resynthesize(enhanced_mag, noisy_inputs, hop_size): 32 | """Function for resynthesizing waveforms from enhanced mags. 33 | Arguments 34 | --------- 35 | enhanced_mag : torch.Tensor 36 | Predicted spectral magnitude, should be three dimensional. 37 | noisy_inputs : torch.Tensor 38 | The noisy waveforms before any processing, to extract phase. 39 | Returns 40 | ------- 41 | enhanced_wav : torch.Tensor 42 | The resynthesized waveforms of the enhanced magnitudes with noisy phase. 43 | """ 44 | 45 | # Extract noisy phase from inputs 46 | 47 | noisy_feats = torch.stft(noisy_inputs, n_fft=512, hop_length=hop_size, win_length=512, 48 | window=torch.hamming_window(512).to('cuda'), 49 | center=True, 50 | pad_mode="constant", 51 | onesided=True, 52 | return_complex=False).transpose(2, 1) 53 | 54 | noisy_phase = torch.atan2(noisy_feats[:, :, :, 1], noisy_feats[:, :, :, 0])[:,0:enhanced_mag.shape[1],:] 55 | 56 | # Combine with enhanced magnitude 57 | predictions = torch.mul( 58 | torch.unsqueeze(enhanced_mag, -1), 59 | torch.cat( 60 | ( 61 | torch.unsqueeze(torch.cos(noisy_phase), -1), 62 | torch.unsqueeze(torch.sin(noisy_phase), -1), 63 | ), 64 | -1, 65 | ), 66 | ).permute(0, 2, 1, 3) 67 | 68 | # isft ask complex input 69 | complex_predictions = torch.complex(predictions[..., 0], predictions[..., 1]) 70 | pred_wavs = torch.istft(input=complex_predictions, n_fft=512, hop_length=hop_size, win_length=512, 71 | window=torch.hamming_window(512).to('cuda'), 72 | center=True, 73 | onesided=True, 74 | length=noisy_inputs.shape[1]) 75 | 76 | return pred_wavs 77 | 78 | def stft_magnitude(x, hop_size, fft_size=512, win_length=512): 79 | if x.is_cuda: 80 | x_stft = torch.stft( 81 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length).to('cuda'), return_complex=False 82 | ) 83 | else: 84 | x_stft = torch.stft( 85 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length), return_complex=False 86 | ) 87 | real = x_stft[..., 0] 88 | imag = x_stft[..., 1] 89 | 90 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 91 | 92 | def cos_loss(SP_noisy, SP_y_noisy): 93 | eps=1e-5 94 | SP_noisy_norm = torch.norm(SP_noisy, p=2, dim=-1, keepdim=True)+eps 95 | SP_y_noisy_norm = torch.norm(SP_y_noisy, p=2, dim=-1, keepdim=True)+eps 96 | Cos_frame = torch.sum(SP_noisy/SP_noisy_norm * SP_y_noisy/SP_y_noisy_norm, dim=-1) # torch.Size([B, T, 1]) 97 | 98 | return -torch.mean(Cos_frame) 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-c', '--config', type=str, required=True) 102 | parser.add_argument('-m','--path_of_model_weights', type=str, required=True) 103 | parser.add_argument('-i', '--path_of_input_audio_folder', type=str, required=True) 104 | parser.add_argument('-o', '--path_of_output_audio_folder', type=str, default="./enhanced/") 105 | args = parser.parse_args() 106 | 107 | # initialize config 108 | with open(args.config, 'r') as f: 109 | config = yaml.load(f, Loader=yaml.FullLoader) 110 | 111 | if not torch.cuda.is_available(): 112 | device = torch.device('cpu') 113 | print("device: cpu") 114 | else: 115 | device = torch.device('cuda') 116 | print("device: gpu") 117 | torch.backends.cudnn.benchmark = True 118 | 119 | with torch.no_grad(): 120 | if config['task'] == "Speech_Enhancement": 121 | VQVAE = VQVAE_SE(**config['VQVAE_params']).to(device).eval() 122 | hop_size = 128 123 | 124 | VQVAE.load_state_dict(torch.load(args.path_of_model_weights)['model']['VQVAE']) 125 | if not os.path.exists(args.path_of_output_audio_folder): 126 | os.mkdir(args.path_of_output_audio_folder) 127 | 128 | threshold = 1 129 | cluster_size = VQVAE.quantizer.quantizer._codebook.cluster_size[0].cpu().numpy() 130 | preserved_num = np.sum(cluster_size>threshold) 131 | temp = torch.zeros([1, preserved_num, 200]) 132 | j=0 133 | for i in range(cluster_size.shape[0]): 134 | if cluster_size[i] > threshold: 135 | temp [:,j,:] = VQVAE.quantizer.quantizer._codebook.embed[:,i,:] 136 | j = j+1 137 | VQVAE.quantizer.quantizer._codebook.embed = temp.to(device) 138 | 139 | file_list = get_filepaths(args.path_of_input_audio_folder) 140 | estimated_pesq = [] 141 | for file in file_list: 142 | clean, fs = torchaudio.load('/vctk_data/clean_testset_wav_16k/' + file.split('/')[-1]) 143 | 144 | wav_input, fs = torchaudio.load(file) 145 | wav_input = wav_input.to(device) 146 | SP_input = stft_magnitude(wav_input, hop_size=hop_size) 147 | if config['input_transform'] == 'log1p': 148 | SP_input = torch.log1p(SP_input) 149 | 150 | z = VQVAE.CNN_1D_encoder(SP_input) 151 | zq, indices, vqloss, distance = VQVAE.quantizer(z, stochastic=False, update=False) 152 | SP_output = VQVAE.CNN_1D_decoder(zq) 153 | 154 | if config['input_transform'] == 'log1p': 155 | wav_output = resynthesize(torch.expm1(SP_output), wav_input, hop_size).cpu() 156 | else: 157 | wav_output = resynthesize(SP_output, wav_input, hop_size).cpu() 158 | 159 | torchaudio.save(args.path_of_output_audio_folder + file.split('/')[-1], wav_output, 16000) 160 | #print('==============================================================================') 161 | #print('enhanced wav is saved at:' + args.enhanced_wav_path) 162 | #print('==============================================================================') 163 | 164 | estimated_pesq.append(pesq(fs=16000, ref=clean[0].numpy(), deg=wav_output[0].numpy(), mode="wb")) 165 | print(np.mean(estimated_pesq)) 166 | 167 | elif config['task'] == "Quality_Estimation": 168 | hop_size = 256 169 | VQVAE = VQVAE_QE(**config['VQVAE_params']).to(device).eval() 170 | VQVAE.load_state_dict(torch.load(args.path_of_model_weights)['model']['VQVAE']) 171 | 172 | file_list = get_filepaths(args.path_of_input_audio_folder) 173 | num = 0 174 | original_VQ = [] 175 | start_time = time.time() 176 | for file in file_list: 177 | speech, fs = torchaudio.load(file) 178 | if fs != 16000: 179 | speech = torchaudio.functional.resample(speech, fs, 16000).to(device) 180 | 181 | 182 | SP_original = stft_magnitude(speech, hop_size=hop_size) 183 | if config['input_transform'] == 'log1p': 184 | SP_original = torch.log1p(SP_original) 185 | 186 | z = VQVAE.CNN_1D_encoder(SP_original.cuda()) 187 | zq, indices, vqloss, distance = VQVAE.quantizer(z, stochastic=False, update=False) 188 | #SP_output = VQVAE.CNN_1D_decoder(zq) 189 | VQScore_cos_z_original = -cos_loss(z.transpose(2, 1).cpu(), zq.cpu()).numpy() 190 | 191 | original_VQ.append(VQScore_cos_z_original) 192 | 193 | sort_index = np.argsort(original_VQ) # in ascending order 194 | 195 | sorted_VQ = [original_VQ[i] for i in sort_index] 196 | sorted_file_list = [file_list[i] for i in sort_index] 197 | score_dict = {'filename':sorted_file_list, 'VQScore':sorted_VQ} 198 | 199 | df = pd.DataFrame.from_dict(score_dict) 200 | df.to_csv('VQscore.csv', index=False) 201 | end_time = time.time() 202 | 203 | print('Total number of files evaluated:', len(original_VQ)) 204 | print('Average VQScore:', np.mean(original_VQ)) 205 | print('VQScore list (in ascending order) has been saved in the ./VQscore.csv') 206 | print ('The evaluation takes around %.2fmin' % ((end_time - start_time) / 60.)) 207 | #import pdb;pdb.set_trace() 208 | #print('end') 209 | 210 | 211 | -------------------------------------------------------------------------------- /models/VQVAE_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Szu-Wei Fu 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from .vector_quantize_pytorch import VectorQuantize 9 | 10 | ######## Model for speech enhancement 11 | class CNN_1D_encoder_SE(torch.nn.Module): 12 | def __init__(self, codebook_dim): 13 | super().__init__() 14 | self.activation = torch.nn.LeakyReLU(negative_slope=0.3) 15 | 16 | self.conv_enc1 = torch.nn.Conv1d(in_channels=257, out_channels=200, kernel_size=7, stride=1, padding=3) 17 | self.conv_enc2 = torch.nn.Conv1d(in_channels=200, out_channels=200, kernel_size=7, stride=1, padding=3) 18 | self.conv_enc3 = torch.nn.Conv1d(in_channels=200, out_channels=150, kernel_size=7, stride=1, padding=3) 19 | self.conv_enc4 = torch.nn.Conv1d(in_channels=150, out_channels=150, kernel_size=7, stride=1, padding=3) 20 | self.conv_enc5 = torch.nn.Conv1d(in_channels=150, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 21 | self.conv_enc6 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 22 | 23 | encoder_layer = torch.nn.TransformerEncoderLayer(d_model=codebook_dim, nhead=8, dim_feedforward=codebook_dim, dropout=0.4, 24 | activation='gelu', batch_first=True) # batch, seq, feature 25 | self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=2) 26 | 27 | def mean_removal(self, x): 28 | channel_mean = torch.mean(x, dim=-1, keepdim=True) 29 | return x-channel_mean 30 | 31 | def forward(self, x): # x.shape = torch.Size([B, T, 257]) 32 | x = (x.transpose(2, 1)) # x.shape = torch.Size([B, 257, T]) 33 | 34 | x = self.mean_removal(x) 35 | enc1 = self.mean_removal(self.activation(self.conv_enc1(x))) 36 | enc2 = self.mean_removal(self.activation(self.conv_enc2(enc1))) 37 | enc3 = self.mean_removal(self.activation(self.conv_enc3(enc1+enc2))) 38 | enc4 = self.mean_removal(self.activation(self.conv_enc4(enc3))) 39 | enc5 = self.mean_removal(self.activation(self.conv_enc5(enc3+enc4))) 40 | enc6 = self.mean_removal(self.activation(self.conv_enc6(enc5))) 41 | 42 | z = self.transformer_encoder((enc5+enc6).transpose(2, 1)) 43 | z = self.mean_removal(z.transpose(2, 1)) 44 | return z 45 | 46 | class CNN_1D_decoder_SE(torch.nn.Module): 47 | def __init__(self, codebook_dim): 48 | super().__init__() 49 | self.activation = torch.nn.LeakyReLU(negative_slope=0.3) 50 | 51 | self.conv_dec1 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 52 | self.conv_dec2 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=150, kernel_size=7, stride=1, padding=3) 53 | self.conv_dec3 = torch.nn.Conv1d(in_channels=150, out_channels=150, kernel_size=7, stride=1, padding=3) 54 | self.conv_dec4 = torch.nn.Conv1d(in_channels=150, out_channels=200, kernel_size=7, stride=1, padding=3) 55 | self.conv_dec5 = torch.nn.Conv1d(in_channels=200, out_channels=200, kernel_size=7, stride=1, padding=3) 56 | self.conv_dec6 = torch.nn.Conv1d(in_channels=200, out_channels=257, kernel_size=7, stride=1, padding=3) 57 | 58 | decoder_layer = torch.nn.TransformerEncoderLayer(d_model=codebook_dim, nhead=8, dim_feedforward=codebook_dim, dropout=0.4, 59 | activation='gelu', batch_first=True) # batch, seq, feature 60 | self.transformer_decoder = torch.nn.TransformerEncoder(decoder_layer, num_layers=2) 61 | 62 | def forward(self, zq): # x.shape = torch.Size([B, T, 128]) 63 | zq = self.transformer_decoder(zq) 64 | 65 | dec1 = self.activation(self.conv_dec1(zq.transpose(2, 1))) 66 | dec2 = self.activation(self.conv_dec2(dec1)) 67 | dec3 = self.activation(self.conv_dec3(dec2)) 68 | dec4 = self.activation(self.conv_dec4(dec3+dec2)) 69 | dec5 = self.activation(self.conv_dec5(dec4)) 70 | out = F.relu(self.conv_dec6(dec5+dec4).transpose(2, 1)) 71 | return out 72 | 73 | class CNN_1D_quantizer_SE(torch.nn.Module): 74 | def __init__(self, codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 75 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code): 76 | super().__init__() 77 | 78 | self.quantizer = VectorQuantize( 79 | dim = codebook_dim, 80 | codebook_size = codebook_size, 81 | use_cosine_sim = use_cosine_sim, 82 | orthogonal_reg_weight = orthogonal_reg_weight, # in paper, they recommended a value of 10 83 | decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster 84 | commitment_weight = 1, # the weight on the commitment loss 85 | kmeans_init = kmeans_init, # set to True 86 | kmeans_iters = 10, # number of kmeans iterations to calculate the centroids for the codebook on init 87 | heads = codebook_num, 88 | separate_codebook_per_head = True, 89 | ema_update = ema_update, 90 | learnable_codebook = learnable_codebook, 91 | stochastic_sample_codes = stochastic_sample_codes, 92 | sample_codebook_temp = sample_codebook_temp, 93 | straight_through = straight_through, 94 | reinmax = reinmax, 95 | threshold_ema_dead_code = threshold_ema_dead_code, 96 | ) 97 | 98 | def forward(self, z, stochastic, update=True, indices=None): # x.shape = torch.Size([B, T, 257]) 99 | if indices == None: 100 | zq, indices, vqloss, distance = self.quantizer(z.transpose(2, 1), stochastic, update=update) 101 | return zq, indices, vqloss, distance 102 | else: 103 | zq, cross_entropy_loss = self.quantizer(z.transpose(2, 1), stochastic, update=update, indices=indices) 104 | return zq, cross_entropy_loss 105 | 106 | 107 | ### VQVAE_SE #### 108 | class VQVAE_SE(torch.nn.Module): 109 | def __init__(self, codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 110 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code): 111 | super().__init__() 112 | 113 | self.CNN_1D_encoder = CNN_1D_encoder_SE(codebook_dim) 114 | self.quantizer = CNN_1D_quantizer_SE(codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 115 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code) 116 | self.CNN_1D_decoder = CNN_1D_decoder_SE(codebook_dim) 117 | 118 | 119 | ######## Model for quality estimation 120 | class CNN_1D_encoder_QE(torch.nn.Module): 121 | def __init__(self, codebook_dim): 122 | super().__init__() 123 | self.activation = torch.nn.LeakyReLU(negative_slope=0.3) 124 | 125 | # Normailization layer 126 | self.enc_In0 = torch.nn.InstanceNorm1d(257) 127 | self.enc_In1 = torch.nn.InstanceNorm1d(128) 128 | self.enc_In2 = torch.nn.InstanceNorm1d(128) 129 | self.enc_In3 = torch.nn.InstanceNorm1d(64) 130 | self.enc_In4 = torch.nn.InstanceNorm1d(64) 131 | self.enc_In5 = torch.nn.InstanceNorm1d(codebook_dim) 132 | self.enc_In6 = torch.nn.InstanceNorm1d(codebook_dim) 133 | self.enc_In7 = torch.nn.InstanceNorm1d(codebook_dim) 134 | 135 | ## Encoder 136 | self.conv_enc1 = torch.nn.Conv1d(in_channels=257, out_channels=128, kernel_size=7, stride=1, padding=3) 137 | self.conv_enc2 = torch.nn.Conv1d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) 138 | self.conv_enc3 = torch.nn.Conv1d(in_channels=128, out_channels=64, kernel_size=7, stride=1, padding=3) 139 | self.conv_enc4 = torch.nn.Conv1d(in_channels=64, out_channels=64, kernel_size=7, stride=1, padding=3) 140 | self.conv_enc5 = torch.nn.Conv1d(in_channels=64, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 141 | self.conv_enc6 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 142 | 143 | def forward(self, x): # x.shape = torch.Size([B, T, 257]) 144 | x = self.enc_In0(x.transpose(2, 1)) # x.shape = torch.Size([B, 257, T]) 145 | 146 | enc1 = self.enc_In1(self.activation(self.conv_enc1(x))) # torch.Size([B, 128, T]) 147 | enc2 = self.enc_In2(self.activation(self.conv_enc2(enc1))) # torch.Size([B, 128, T]) 148 | enc3 = self.enc_In3(self.activation(self.conv_enc3(enc1+enc2))) # torch.Size([B, 64, T]) 149 | enc4 = self.enc_In4(self.activation(self.conv_enc4(enc3))) # torch.Size([B, 64, T]) 150 | enc5 = self.enc_In5(self.activation(self.conv_enc5(enc3+enc4))) # torch.Size([B, 32, T]) 151 | z = self.enc_In6(self.conv_enc6(enc5)) # torch.Size([B, 32, T]) 152 | return z 153 | 154 | class CNN_1D_decoder_QE(torch.nn.Module): 155 | def __init__(self, codebook_dim): 156 | super().__init__() 157 | self.activation = torch.nn.LeakyReLU(negative_slope=0.3) 158 | 159 | self.conv_dec1 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=codebook_dim, kernel_size=7, stride=1, padding=3) 160 | self.conv_dec2 = torch.nn.Conv1d(in_channels=codebook_dim, out_channels=64, kernel_size=7, stride=1, padding=3) 161 | self.conv_dec3 = torch.nn.Conv1d(in_channels=64, out_channels=64, kernel_size=7, stride=1, padding=3) 162 | self.conv_dec4 = torch.nn.Conv1d(in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3) 163 | self.conv_dec5 = torch.nn.Conv1d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) 164 | self.conv_dec6 = torch.nn.Conv1d(in_channels=128, out_channels=257, kernel_size=7, stride=1, padding=3) 165 | 166 | def forward(self, zq): # x.shape = torch.Size([B, T, 128]) 167 | dec1 = (self.activation(self.conv_dec1(zq.transpose(2, 1)))) 168 | dec2 = (self.activation(self.conv_dec2(dec1))) 169 | dec3 = (self.activation(self.conv_dec3(dec2))) 170 | dec4 = (self.activation(self.conv_dec4(dec3+dec2))) 171 | dec5 = (self.activation(self.conv_dec5(dec4))) 172 | out = F.relu(self.conv_dec6(dec5+dec4).transpose(2, 1)) # torch.Size([B, T, 257]) 173 | return out 174 | 175 | class CNN_1D_quantizer_QE(torch.nn.Module): 176 | def __init__(self, codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 177 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code, 178 | ): 179 | super().__init__() 180 | 181 | self.quantizer = VectorQuantize( 182 | dim = codebook_dim, 183 | codebook_size = codebook_size, 184 | use_cosine_sim = use_cosine_sim, 185 | orthogonal_reg_weight = orthogonal_reg_weight, # in paper, they recommended a value of 10 186 | decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster 187 | commitment_weight = 1, # the weight on the commitment loss 188 | kmeans_init = kmeans_init, # set to True 189 | kmeans_iters = 10, # number of kmeans iterations to calculate the centroids for the codebook on init 190 | heads = codebook_num, 191 | separate_codebook_per_head = True, 192 | ema_update = ema_update, 193 | learnable_codebook = learnable_codebook, 194 | stochastic_sample_codes = stochastic_sample_codes, 195 | sample_codebook_temp = sample_codebook_temp, 196 | straight_through = straight_through, 197 | reinmax = reinmax, 198 | threshold_ema_dead_code = threshold_ema_dead_code 199 | ) 200 | 201 | def forward(self, z, stochastic, update=True, indices = None): # x.shape = torch.Size([B, T, 257]) 202 | if indices == None: 203 | zq, indices, vqloss, distance = self.quantizer(z.transpose(2, 1), stochastic, update) 204 | return zq, indices, vqloss, distance 205 | else: 206 | zq, cross_entropy_loss = self.quantizer(z.transpose(2, 1), stochastic, indices) 207 | return zq, cross_entropy_loss 208 | 209 | ### VQVAE_QE #### 210 | class VQVAE_QE(torch.nn.Module): 211 | def __init__(self, codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 212 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code, 213 | ): 214 | super().__init__() 215 | 216 | self.CNN_1D_encoder = CNN_1D_encoder_QE(codebook_dim) 217 | self.quantizer = CNN_1D_quantizer_QE(codebook_size, codebook_dim, codebook_num, orthogonal_reg_weight, use_cosine_sim, ema_update, learnable_codebook, 218 | stochastic_sample_codes, sample_codebook_temp, straight_through, reinmax, kmeans_init, threshold_ema_dead_code) 219 | self.CNN_1D_decoder = CNN_1D_decoder_QE(codebook_dim) 220 | -------------------------------------------------------------------------------- /models/vector_quantize_pytorch.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import copy 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | import torch.distributed as distributed 8 | import numpy as np 9 | from torch.optim import Optimizer 10 | from torch.cuda.amp import autocast 11 | 12 | from einops import rearrange, repeat, reduce, pack, unpack 13 | 14 | from typing import Callable 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | def noop(*args, **kwargs): 23 | pass 24 | 25 | def identity(t): 26 | return t 27 | 28 | def l2norm(t): 29 | return F.normalize(t, p = 2, dim = -1) 30 | 31 | def log(t, eps = 1e-20): 32 | return torch.log(t.clamp(min = eps)) 33 | 34 | def ema_inplace(old, new, decay): 35 | is_mps = str(old.device).startswith('mps:') 36 | 37 | if not is_mps: 38 | old.lerp_(new, 1 - decay) 39 | else: 40 | old.mul_(decay).add_(new * (1 - decay)) 41 | 42 | def pack_one(t, pattern): 43 | return pack([t], pattern) 44 | 45 | def unpack_one(t, ps, pattern): 46 | return unpack(t, ps, pattern)[0] 47 | 48 | def uniform_init(*shape): 49 | t = torch.empty(shape) 50 | nn.init.kaiming_uniform_(t) 51 | return t 52 | 53 | def gumbel_noise(t): 54 | noise = torch.zeros_like(t).uniform_(0, 1) 55 | return -log(-log(noise)) 56 | 57 | def gumbel_sample( 58 | distance, 59 | temperature = 1., 60 | stochastic = False, 61 | straight_through = False, 62 | reinmax = False, 63 | dim = -1, 64 | training = True 65 | ): 66 | dtype, size = distance.dtype, distance.shape[dim] 67 | if stochastic: 68 | ''' 69 | pr = (logits/ temperature).softmax(dim = -1) # ref: https://bobondemon.github.io/2021/08/07/Gumbel-Max-Trick/ torch.Size([1, 24064, 2048]) 70 | ind = torch.zeros(1, pr.shape[1], dtype=int, device='cuda') 71 | for f in range(pr.shape[1]): 72 | ind[0,f] = np.random.choice(pr.shape[-1], size=1, p=pr[0,f,:].cpu().numpy())[0] 73 | ''' 74 | pr = (distance/ temperature).softmax(dim = -1) # ref: https://bobondemon.github.io/2021/08/07/Gumbel-Max-Trick/ torch.Size([1, 24064, 2048]) 75 | logits = torch.log(pr) 76 | sampling_logits = logits + gumbel_noise(logits) 77 | #sampling_logits = (logits / temperature) + gumbel_noise(logits) 78 | else: 79 | sampling_logits = distance 80 | ind = sampling_logits.argmax(dim = dim) # torch.Size([1, 24064]) 81 | one_hot = F.one_hot(ind, size).type(dtype) 82 | 83 | return ind, one_hot 84 | 85 | 86 | def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1): 87 | denom = x.sum(dim = dim, keepdim = True) 88 | return (x + eps) / (denom + n_categories * eps) 89 | 90 | def sample_vectors(samples, num): 91 | num_samples, device = samples.shape[0], samples.device 92 | if num_samples >= num: 93 | indices = torch.randperm(num_samples, device = device)[:num] 94 | else: 95 | indices = torch.randint(0, num_samples, (num,), device = device) 96 | 97 | return samples[indices] 98 | 99 | def batched_sample_vectors(samples, num): 100 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 101 | 102 | def pad_shape(shape, size, dim = 0): 103 | return [size if i == dim else s for i, s in enumerate(shape)] 104 | 105 | def sample_multinomial(total_count, probs): 106 | device = probs.device 107 | probs = probs.cpu() 108 | 109 | total_count = probs.new_full((), total_count) 110 | remainder = probs.new_ones(()) 111 | sample = torch.empty_like(probs, dtype = torch.long) 112 | 113 | for i, p in enumerate(probs): 114 | s = torch.binomial(total_count, p / remainder) 115 | sample[i] = s 116 | total_count -= s 117 | remainder -= p 118 | 119 | return sample.to(device) 120 | 121 | def all_gather_sizes(x, dim): 122 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 123 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 124 | distributed.all_gather(all_sizes, size) 125 | return torch.stack(all_sizes) 126 | 127 | def all_gather_variably_sized(x, sizes, dim = 0): 128 | rank = distributed.get_rank() 129 | all_x = [] 130 | 131 | for i, size in enumerate(sizes): 132 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 133 | distributed.broadcast(t, src = i, async_op = True) 134 | all_x.append(t) 135 | 136 | distributed.barrier() 137 | return all_x 138 | 139 | def sample_vectors_distributed(local_samples, num): 140 | local_samples = rearrange(local_samples, '1 ... -> ...') 141 | 142 | rank = distributed.get_rank() 143 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 144 | 145 | if rank == 0: 146 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 147 | else: 148 | samples_per_rank = torch.empty_like(all_num_samples) 149 | 150 | distributed.broadcast(samples_per_rank, src = 0) 151 | samples_per_rank = samples_per_rank.tolist() 152 | 153 | local_samples = sample_vectors(local_samples, samples_per_rank[rank]) 154 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 155 | out = torch.cat(all_samples, dim = 0) 156 | 157 | return rearrange(out, '... -> 1 ...') 158 | 159 | def batched_bincount(x, *, minlength): 160 | batch, dtype, device = x.shape[0], x.dtype, x.device 161 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 162 | values = torch.ones_like(x) 163 | target.scatter_add_(-1, x, values) 164 | return target 165 | 166 | def kmeans( 167 | samples, 168 | num_clusters, 169 | num_iters = 10, 170 | use_cosine_sim = False, 171 | sample_fn = batched_sample_vectors, 172 | all_reduce_fn = noop 173 | ): 174 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 175 | 176 | means = sample_fn(samples, num_clusters) 177 | 178 | for _ in range(num_iters): 179 | if use_cosine_sim: 180 | dists = samples @ rearrange(means, 'h n d -> h d n') 181 | else: 182 | dists = -torch.cdist(samples, means, p = 2) 183 | 184 | buckets = torch.argmax(dists, dim = -1) 185 | bins = batched_bincount(buckets, minlength = num_clusters) 186 | all_reduce_fn(bins) 187 | 188 | zero_mask = bins == 0 189 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 190 | 191 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 192 | 193 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 194 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 195 | all_reduce_fn(new_means) 196 | 197 | if use_cosine_sim: 198 | new_means = l2norm(new_means) 199 | 200 | means = torch.where( 201 | rearrange(zero_mask, '... -> ... 1'), 202 | means, 203 | new_means 204 | ) 205 | 206 | return means, bins 207 | 208 | def batched_embedding(indices, embeds): 209 | batch, dim = indices.shape[1], embeds.shape[-1] 210 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 211 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 212 | return embeds.gather(2, indices) 213 | 214 | # regularization losses 215 | 216 | def orthogonal_loss_fn(t): 217 | # eq (2) from https://arxiv.org/abs/2112.00384 218 | h, n = t.shape[:2] 219 | normed_codes = l2norm(t) 220 | cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) 221 | return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n) 222 | 223 | # distance types 224 | 225 | class EuclideanCodebook(nn.Module): 226 | def __init__( 227 | self, 228 | dim, 229 | codebook_size, 230 | num_codebooks = 1, 231 | kmeans_init = False, 232 | kmeans_iters = 10, 233 | sync_kmeans = True, 234 | decay = 0.8, 235 | eps = 1e-5, 236 | threshold_ema_dead_code = 2, 237 | reset_cluster_size = None, 238 | use_ddp = False, 239 | learnable_codebook = False, 240 | gumbel_sample = gumbel_sample, 241 | sample_codebook_temp = 1., 242 | ema_update = True, 243 | affine_param = False, 244 | sync_affine_param = False, 245 | affine_param_batch_decay = 0.99, 246 | affine_param_codebook_decay = 0.9, 247 | ): 248 | super().__init__() 249 | self.transform_input = identity 250 | self.decay = decay 251 | self.ema_update = ema_update 252 | 253 | init_fn = uniform_init if not kmeans_init else torch.zeros 254 | embed = init_fn(num_codebooks, codebook_size, dim) 255 | 256 | self.codebook_size = codebook_size 257 | self.num_codebooks = num_codebooks 258 | 259 | self.kmeans_iters = kmeans_iters 260 | self.eps = eps 261 | self.threshold_ema_dead_code = threshold_ema_dead_code 262 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) 263 | 264 | assert callable(gumbel_sample) 265 | self.gumbel_sample = gumbel_sample 266 | self.sample_codebook_temp = sample_codebook_temp 267 | 268 | assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now' 269 | 270 | self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors 271 | self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop 272 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 273 | 274 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 275 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 276 | self.register_buffer('embed_avg', embed.clone()) 277 | 278 | self.learnable_codebook = learnable_codebook 279 | if learnable_codebook: 280 | self.embed = nn.Parameter(embed) 281 | else: 282 | self.register_buffer('embed', embed) 283 | 284 | # affine related params 285 | 286 | self.affine_param = affine_param 287 | self.sync_affine_param = sync_affine_param 288 | 289 | if not affine_param: 290 | return 291 | 292 | self.affine_param_batch_decay = affine_param_batch_decay 293 | self.affine_param_codebook_decay = affine_param_codebook_decay 294 | 295 | self.register_buffer('batch_mean', None) 296 | self.register_buffer('batch_variance', None) 297 | 298 | self.register_buffer('codebook_mean_needs_init', torch.Tensor([True])) 299 | self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim)) 300 | self.register_buffer('codebook_variance_needs_init', torch.Tensor([True])) 301 | self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim)) 302 | 303 | @torch.jit.ignore 304 | def init_embed_(self, data, mask = None): 305 | if self.initted: 306 | return 307 | 308 | if exists(mask): 309 | c = data.shape[0] 310 | data = rearrange(data[mask], '(c n) d -> c n d', c = c) 311 | 312 | embed, cluster_size = kmeans( 313 | data, 314 | self.codebook_size, 315 | self.kmeans_iters, 316 | sample_fn = self.sample_fn, 317 | all_reduce_fn = self.kmeans_all_reduce_fn 318 | ) 319 | 320 | embed_sum = embed * rearrange(cluster_size, '... -> ... 1') 321 | 322 | self.embed.data.copy_(embed) 323 | self.embed_avg.data.copy_(embed_sum) 324 | self.cluster_size.data.copy_(cluster_size) 325 | self.initted.data.copy_(torch.Tensor([True])) 326 | 327 | @torch.jit.ignore 328 | def update_with_decay(self, buffer_name, new_value, decay): 329 | old_value = getattr(self, buffer_name) 330 | 331 | needs_init = getattr(self, buffer_name + "_needs_init", False) 332 | 333 | if needs_init: 334 | self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) 335 | 336 | if not exists(old_value) or needs_init: 337 | self.register_buffer(buffer_name, new_value.detach()) 338 | 339 | return 340 | 341 | value = old_value * decay + new_value.detach() * (1 - decay) 342 | self.register_buffer(buffer_name, value) 343 | 344 | @torch.jit.ignore 345 | def update_affine(self, data, embed, mask = None): 346 | assert self.affine_param 347 | 348 | var_fn = partial(torch.var, unbiased = False) 349 | 350 | # calculate codebook mean and variance 351 | 352 | embed = rearrange(embed, 'h ... d -> h (...) d') 353 | 354 | if self.training: 355 | self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay) 356 | self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay) 357 | 358 | # prepare batch data, which depends on whether it has masking 359 | 360 | data = rearrange(data, 'h ... d -> h (...) d') 361 | 362 | if exists(mask): 363 | c = data.shape[0] 364 | data = rearrange(data[mask], '(c n) d -> c n d', c = c) 365 | 366 | # calculate batch mean and variance 367 | 368 | if not self.sync_affine_param: 369 | self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay) 370 | self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay) 371 | return 372 | 373 | num_vectors, device, dtype = data.shape[-2], data.device, data.dtype 374 | 375 | # number of vectors, for denominator 376 | 377 | num_vectors = torch.tensor([num_vectors], device = device, dtype = dtype) 378 | distributed.all_reduce(num_vectors) 379 | 380 | # calculate distributed mean 381 | 382 | batch_sum = reduce(data, 'h n d -> h 1 d', 'sum') 383 | distributed.all_reduce(batch_sum) 384 | batch_mean = batch_sum / num_vectors 385 | 386 | self.update_with_decay('batch_mean', batch_mean, self.affine_param_batch_decay) 387 | 388 | # calculate distributed variance 389 | 390 | variance_numer = reduce((data - batch_mean) ** 2, 'h n d -> h 1 d', 'sum') 391 | distributed.all_reduce(variance_numer) 392 | batch_variance = variance_numer / num_vectors 393 | 394 | self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay) 395 | ''' 396 | def replace(self, batch_samples, batch_mask): 397 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 398 | if not torch.any(mask): 399 | continue 400 | 401 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 402 | sampled = rearrange(sampled, '1 ... -> ...') 403 | 404 | self.embed.data[ind][mask] = sampled 405 | 406 | self.cluster_size.data[ind][mask] = self.reset_cluster_size 407 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size 408 | 409 | def expire_codes_(self, batch_samples): 410 | if self.threshold_ema_dead_code == 0: 411 | return 412 | 413 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 414 | 415 | if not torch.any(expired_codes): 416 | return 417 | 418 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 419 | self.replace(batch_samples, batch_mask = expired_codes) 420 | 421 | ''' 422 | @autocast(enabled = False) 423 | def forward( 424 | self, 425 | x, 426 | stochastic, 427 | update, 428 | sample_codebook_temp = None, 429 | mask = None 430 | ): 431 | needs_codebook_dim = x.ndim < 4 432 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) 433 | 434 | x = x.float() 435 | 436 | if needs_codebook_dim: 437 | x = rearrange(x, '... -> 1 ...') 438 | 439 | dtype = x.dtype 440 | flatten, ps = pack_one(x, 'h * d') 441 | 442 | if exists(mask): 443 | mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1])) 444 | 445 | self.init_embed_(flatten, mask = mask) 446 | 447 | if self.affine_param: 448 | self.update_affine(flatten, self.embed, mask = mask) 449 | 450 | embed = self.embed if self.learnable_codebook else self.embed.detach() 451 | 452 | if self.affine_param: 453 | codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt() 454 | batch_std = self.batch_variance.clamp(min = 1e-5).sqrt() 455 | embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean 456 | 457 | dist = -torch.cdist(flatten, embed, p = 2) 458 | 459 | embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training, stochastic=stochastic) 460 | 461 | embed_ind = unpack_one(embed_ind, ps, 'h *') 462 | if self.training: 463 | unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c') 464 | quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed) 465 | else: 466 | quantize = batched_embedding(embed_ind, embed) 467 | 468 | if self.training and self.ema_update and update: 469 | 470 | if self.affine_param: 471 | flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean 472 | 473 | if exists(mask): 474 | embed_onehot[~mask] = 0. 475 | 476 | cluster_size = embed_onehot.sum(dim = 1) 477 | 478 | self.all_reduce_fn(cluster_size) 479 | ema_inplace(self.cluster_size.data, cluster_size, self.decay) 480 | 481 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 482 | self.all_reduce_fn(embed_sum.contiguous()) 483 | ema_inplace(self.embed_avg.data, embed_sum, self.decay) 484 | 485 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) 486 | 487 | embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') 488 | self.embed.data.copy_(embed_normalized) 489 | #self.expire_codes_(x) 490 | 491 | 492 | if needs_codebook_dim: 493 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 494 | 495 | dist = unpack_one(dist, ps, 'h * d') 496 | 497 | return quantize, embed_ind, dist 498 | 499 | class CosineSimCodebook(nn.Module): 500 | def __init__( 501 | self, 502 | dim, 503 | codebook_size, 504 | num_codebooks = 1, 505 | kmeans_init = False, 506 | kmeans_iters = 10, 507 | sync_kmeans = True, 508 | decay = 0.8, 509 | eps = 1e-5, 510 | threshold_ema_dead_code = 2, 511 | reset_cluster_size = None, 512 | use_ddp = False, 513 | learnable_codebook = False, 514 | gumbel_sample = gumbel_sample, 515 | sample_codebook_temp = 1., 516 | ema_update = True, 517 | ): 518 | super().__init__() 519 | self.transform_input = l2norm 520 | 521 | self.ema_update = ema_update 522 | self.decay = decay 523 | 524 | if not kmeans_init: 525 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) 526 | else: 527 | embed = torch.zeros(num_codebooks, codebook_size, dim) 528 | 529 | self.codebook_size = codebook_size 530 | self.num_codebooks = num_codebooks 531 | 532 | self.kmeans_iters = kmeans_iters 533 | self.eps = eps 534 | self.threshold_ema_dead_code = threshold_ema_dead_code 535 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) 536 | 537 | assert callable(gumbel_sample) 538 | self.gumbel_sample = gumbel_sample 539 | self.sample_codebook_temp = sample_codebook_temp 540 | 541 | self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors 542 | self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop 543 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 544 | 545 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 546 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 547 | self.register_buffer('embed_avg', embed.clone()) 548 | 549 | self.learnable_codebook = learnable_codebook 550 | if learnable_codebook: 551 | self.embed = nn.Parameter(embed) 552 | else: 553 | self.register_buffer('embed', embed) 554 | 555 | @torch.jit.ignore 556 | def init_embed_(self, data, mask = None): 557 | if self.initted: 558 | return 559 | 560 | if exists(mask): 561 | c = data.shape[0] 562 | data = rearrange(data[mask], '(c n) d -> c n d', c = c) 563 | 564 | embed, cluster_size = kmeans( 565 | data, 566 | self.codebook_size, 567 | self.kmeans_iters, 568 | use_cosine_sim = True, 569 | sample_fn = self.sample_fn, 570 | all_reduce_fn = self.kmeans_all_reduce_fn 571 | ) 572 | 573 | embed_sum = embed * rearrange(cluster_size, '... -> ... 1') 574 | 575 | self.embed.data.copy_(embed) 576 | self.embed_avg.data.copy_(embed_sum) 577 | self.cluster_size.data.copy_(cluster_size) 578 | self.initted.data.copy_(torch.Tensor([True])) 579 | 580 | def replace(self, batch_samples, batch_mask): 581 | batch_samples = l2norm(batch_samples) 582 | 583 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 584 | if not torch.any(mask): 585 | continue 586 | 587 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 588 | sampled = rearrange(sampled, '1 ... -> ...') 589 | 590 | self.embed.data[ind][mask] = sampled 591 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size 592 | self.cluster_size.data[ind][mask] = self.reset_cluster_size 593 | 594 | def expire_codes_(self, batch_samples): 595 | if self.threshold_ema_dead_code == 0: 596 | return 597 | 598 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 599 | 600 | if not torch.any(expired_codes): 601 | return 602 | 603 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 604 | self.replace(batch_samples, batch_mask = expired_codes) 605 | 606 | @autocast(enabled = False) 607 | def forward( 608 | self, 609 | x, 610 | stochastic, 611 | update, 612 | sample_codebook_temp = None, 613 | mask = None 614 | ): 615 | needs_codebook_dim = x.ndim < 4 616 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) 617 | 618 | x = x.float() 619 | 620 | if needs_codebook_dim: 621 | x = rearrange(x, '... -> 1 ...') 622 | 623 | dtype = x.dtype 624 | 625 | flatten, ps = pack_one(x, 'h * d') 626 | 627 | if exists(mask): 628 | mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1])) 629 | 630 | self.init_embed_(flatten, mask = mask) 631 | 632 | embed = self.embed if self.learnable_codebook else self.embed.detach() 633 | 634 | dist = einsum('h n d, h c d -> h n c', flatten, embed) 635 | 636 | embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training, stochastic=stochastic) 637 | embed_ind = unpack_one(embed_ind, ps, 'h *') 638 | 639 | if self.training: 640 | unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c') 641 | quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed) 642 | else: 643 | quantize = batched_embedding(embed_ind, embed) 644 | 645 | if self.training and self.ema_update and update: 646 | if exists(mask): 647 | embed_onehot[~mask] = 0. 648 | 649 | bins = embed_onehot.sum(dim = 1) 650 | self.all_reduce_fn(bins) 651 | 652 | ema_inplace(self.cluster_size.data, bins, self.decay) 653 | 654 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 655 | self.all_reduce_fn(embed_sum.contiguous()) 656 | ema_inplace(self.embed_avg.data, embed_sum, self.decay) 657 | 658 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) 659 | 660 | embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') 661 | embed_normalized = l2norm(embed_normalized) 662 | 663 | self.embed.data.copy_(l2norm(embed_normalized)) 664 | self.expire_codes_(x) 665 | ''' 666 | embed = self.embed if self.learnable_codebook else self.embed.detach() 667 | 668 | dist = einsum('h n d, h c d -> h n c', flatten, embed) 669 | 670 | embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training, stochastic=stochastic) 671 | embed_ind = unpack_one(embed_ind, ps, 'h *') 672 | 673 | unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c') 674 | quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed) 675 | ''' 676 | if needs_codebook_dim: 677 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 678 | 679 | dist = unpack_one(dist, ps, 'h * d') 680 | return quantize, embed_ind, dist 681 | 682 | # main class 683 | 684 | class VectorQuantize(nn.Module): 685 | def __init__( 686 | self, 687 | dim, 688 | codebook_size, 689 | codebook_dim = None, 690 | heads = 1, 691 | separate_codebook_per_head = False, 692 | decay = 0.8, 693 | eps = 1e-5, 694 | kmeans_init = False, 695 | kmeans_iters = 10, 696 | sync_kmeans = True, 697 | use_cosine_sim = False, 698 | threshold_ema_dead_code = 0, 699 | channel_last = True, 700 | accept_image_fmap = False, 701 | commitment_weight = 1., 702 | commitment_use_cross_entropy_loss = False, 703 | orthogonal_reg_weight = 0., 704 | orthogonal_reg_active_codes_only = False, 705 | orthogonal_reg_max_codes = None, 706 | stochastic_sample_codes = False, 707 | sample_codebook_temp = 1., 708 | straight_through = False, 709 | reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all 710 | sync_codebook = False, 711 | sync_affine_param = False, 712 | ema_update = True, 713 | learnable_codebook = False, 714 | in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook 715 | affine_param = False, 716 | affine_param_batch_decay = 0.99, 717 | affine_param_codebook_decay = 0.9, 718 | sync_update_v = 0., # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf 719 | ): 720 | super().__init__() 721 | self.dim = dim 722 | self.heads = heads 723 | self.separate_codebook_per_head = separate_codebook_per_head 724 | codebook_dim = default(codebook_dim, dim) 725 | codebook_input_dim = codebook_dim * heads 726 | 727 | requires_projection = codebook_input_dim != dim 728 | self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 729 | self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 730 | 731 | self.eps = eps 732 | self.commitment_weight = commitment_weight 733 | self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss 734 | 735 | self.learnable_codebook = learnable_codebook 736 | 737 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 738 | self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss 739 | self.orthogonal_reg_weight = orthogonal_reg_weight 740 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 741 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 742 | 743 | assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update' 744 | 745 | assert 0 <= sync_update_v <= 1. 746 | assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on' 747 | 748 | self.sync_update_v = sync_update_v 749 | 750 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook 751 | 752 | gumbel_sample_fn = partial( 753 | gumbel_sample, 754 | stochastic = stochastic_sample_codes, 755 | reinmax = reinmax, 756 | straight_through = straight_through 757 | ) 758 | 759 | codebook_kwargs = dict( 760 | dim = codebook_dim, 761 | num_codebooks = heads if separate_codebook_per_head else 1, 762 | codebook_size = codebook_size, 763 | kmeans_init = kmeans_init, 764 | kmeans_iters = kmeans_iters, 765 | sync_kmeans = sync_kmeans, 766 | decay = decay, 767 | eps = eps, 768 | threshold_ema_dead_code = threshold_ema_dead_code, 769 | use_ddp = sync_codebook, 770 | learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook, 771 | sample_codebook_temp = sample_codebook_temp, 772 | gumbel_sample = gumbel_sample_fn, 773 | ema_update = ema_update, 774 | ) 775 | 776 | if affine_param: 777 | assert not use_cosine_sim, 'affine param is only compatible with euclidean codebook' 778 | codebook_kwargs = dict( 779 | **codebook_kwargs, 780 | affine_param = True, 781 | sync_affine_param = sync_affine_param, 782 | affine_param_batch_decay = affine_param_batch_decay, 783 | affine_param_codebook_decay = affine_param_codebook_decay, 784 | ) 785 | 786 | self._codebook = codebook_class(**codebook_kwargs) 787 | 788 | self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None 789 | 790 | self.codebook_size = codebook_size 791 | 792 | self.accept_image_fmap = accept_image_fmap 793 | self.channel_last = channel_last 794 | 795 | @property 796 | def codebook(self): 797 | codebook = self._codebook.embed 798 | if self.separate_codebook_per_head: 799 | return codebook 800 | 801 | return rearrange(codebook, '1 ... -> ...') 802 | 803 | def get_codes_from_indices(self, indices): 804 | codebook = self.codebook 805 | is_multiheaded = codebook.ndim > 2 806 | 807 | if not is_multiheaded: 808 | codes = codebook[indices] 809 | return rearrange(codes, '... h d -> ... (h d)') 810 | 811 | indices, ps = pack_one(indices, 'b * h') 812 | indices = rearrange(indices, 'b n h -> b h n') 813 | 814 | indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1]) 815 | codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0]) 816 | 817 | codes = codebook.gather(2, indices) 818 | codes = rearrange(codes, 'b h n d -> b n (h d)') 819 | codes = unpack_one(codes, ps, 'b * d') 820 | return codes 821 | 822 | def forward( 823 | self, 824 | x, 825 | stochastic, 826 | update = True, 827 | indices = None, 828 | mask = None, 829 | sample_codebook_temp = None 830 | ): 831 | orig_input = x 832 | 833 | only_one = x.ndim == 2 834 | 835 | if only_one: 836 | assert not exists(mask) 837 | x = rearrange(x, 'b d -> b 1 d') 838 | 839 | shape, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices) 840 | 841 | need_transpose = not self.channel_last and not self.accept_image_fmap 842 | should_inplace_optimize = exists(self.in_place_codebook_optimizer) 843 | 844 | # rearrange inputs 845 | 846 | if self.accept_image_fmap: 847 | height, width = x.shape[-2:] 848 | x = rearrange(x, 'b c h w -> b (h w) c') 849 | 850 | if need_transpose: 851 | x = rearrange(x, 'b d n -> b n d') 852 | 853 | # project input 854 | 855 | x = self.project_in(x) 856 | 857 | # handle multi-headed separate codebooks 858 | 859 | if is_multiheaded: 860 | ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' 861 | x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) 862 | 863 | # l2norm for cosine sim, otherwise identity 864 | 865 | x = self._codebook.transform_input(x) 866 | 867 | # codebook forward kwargs 868 | 869 | codebook_forward_kwargs = dict( 870 | stochastic = stochastic, 871 | update = update, 872 | sample_codebook_temp = sample_codebook_temp, 873 | mask = mask 874 | ) 875 | 876 | # quantize 877 | 878 | quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) 879 | 880 | # one step in-place update 881 | 882 | if should_inplace_optimize and self.training: 883 | if exists(mask): 884 | loss = F.mse_loss(quantize, x.detach(), reduction = 'none') 885 | 886 | loss_mask = mask 887 | if is_multiheaded: 888 | loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0]) 889 | 890 | loss = loss[loss_mask].mean() 891 | 892 | else: 893 | loss = F.mse_loss(quantize, x.detach()) 894 | 895 | loss.backward() 896 | self.in_place_codebook_optimizer.step() 897 | self.in_place_codebook_optimizer.zero_grad() 898 | 899 | # quantize again 900 | 901 | quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) 902 | 903 | if self.training: 904 | # determine code to use for commitment loss 905 | maybe_detach = torch.detach if not self.learnable_codebook else identity 906 | 907 | commit_quantize = maybe_detach(quantize) 908 | 909 | # straight through 910 | 911 | quantize = x + (quantize - x).detach() 912 | 913 | if self.sync_update_v > 0.: 914 | # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf 915 | quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) 916 | 917 | # function for calculating cross entropy loss to distance matrix 918 | # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss 919 | 920 | def calculate_ce_loss(codes): 921 | if not is_multiheaded: 922 | dist_einops_eq = '1 b n l -> b l n' 923 | elif self.separate_codebook_per_head: 924 | dist_einops_eq = 'c b n l -> b l n c' 925 | else: 926 | dist_einops_eq = '1 (b h) n l -> b l n h' 927 | 928 | ce_loss = F.cross_entropy( 929 | rearrange(distances, dist_einops_eq, b = shape[0]), 930 | codes, 931 | ignore_index = -1 932 | ) 933 | 934 | return ce_loss 935 | 936 | # if returning cross entropy loss on codes that were passed in 937 | 938 | if return_loss: 939 | return quantize, calculate_ce_loss(indices) 940 | 941 | # transform embedding indices 942 | 943 | if is_multiheaded: 944 | if self.separate_codebook_per_head: 945 | embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) 946 | else: 947 | embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) 948 | 949 | if self.accept_image_fmap: 950 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width) 951 | 952 | if only_one: 953 | embed_ind = rearrange(embed_ind, 'b 1 -> b') 954 | 955 | # aggregate loss 956 | 957 | loss = torch.tensor([0.], device = device, requires_grad = self.training) 958 | 959 | if self.training: 960 | if self.commitment_weight > 0: 961 | if self.commitment_use_cross_entropy_loss: 962 | if exists(mask): 963 | ce_loss_mask = mask 964 | if is_multiheaded: 965 | ce_loss_mask = repeat(ce_loss_mask, 'b n -> b n h', h = heads) 966 | 967 | embed_ind.masked_fill_(~ce_loss_mask, -1) 968 | 969 | commit_loss = calculate_ce_loss(embed_ind) 970 | else: 971 | if exists(mask): 972 | # with variable lengthed sequences 973 | commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none') 974 | 975 | loss_mask = mask 976 | if is_multiheaded: 977 | loss_mask = repeat(loss_mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0]) 978 | 979 | commit_loss = commit_loss[loss_mask].mean() 980 | else: 981 | commit_loss = F.mse_loss(commit_quantize, x) 982 | 983 | loss = loss + commit_loss * self.commitment_weight 984 | 985 | if self.has_codebook_orthogonal_loss: 986 | codebook = self._codebook.embed 987 | 988 | # only calculate orthogonal loss for the activated codes for this batch 989 | 990 | if self.orthogonal_reg_active_codes_only: 991 | assert not (is_multiheaded and self.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet' 992 | unique_code_ids = torch.unique(embed_ind) 993 | codebook = codebook[:, unique_code_ids] 994 | 995 | num_codes = codebook.shape[-2] 996 | 997 | if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: 998 | rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] 999 | codebook = codebook[:, rand_ids] 1000 | 1001 | orthogonal_reg_loss = orthogonal_loss_fn(codebook) 1002 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight 1003 | 1004 | # handle multi-headed quantized embeddings 1005 | 1006 | if is_multiheaded: 1007 | if self.separate_codebook_per_head: 1008 | quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) 1009 | else: 1010 | quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) 1011 | 1012 | # project out 1013 | 1014 | quantize = self.project_out(quantize) 1015 | 1016 | # rearrange quantized embeddings 1017 | 1018 | if need_transpose: 1019 | quantize = rearrange(quantize, 'b n d -> b d n') 1020 | 1021 | if self.accept_image_fmap: 1022 | quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) 1023 | 1024 | if only_one: 1025 | quantize = rearrange(quantize, 'b 1 d -> b d') 1026 | 1027 | # if masking, only return quantized for where mask has True 1028 | 1029 | if exists(mask): 1030 | quantize = torch.where( 1031 | rearrange(mask, '... -> ... 1'), 1032 | quantize, 1033 | orig_input 1034 | ) 1035 | 1036 | return quantize, embed_ind, loss, distances 1037 | -------------------------------------------------------------------------------- /noisy_p232_005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/VQscore/ed7189fe6b78bd1f9e8edcfe141d604ca88b0ffe/noisy_p232_005.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | librosa==0.9.2 3 | matplotlib==3.7.0 4 | numpy==1.23.5 5 | onnxruntime_gpu==1.9.0 6 | pandas==1.5.3 7 | pesq==0.0.4 8 | PyYAML==6.0 9 | Requests==2.31.0 10 | scipy==1.12.0 11 | SoundFile==0.10.3.post1 12 | tensorboardX==2.6 13 | torch==2.0.0 14 | torchaudio==2.0.1 15 | tqdm==4.64.1 16 | -------------------------------------------------------------------------------- /trainVQVAE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import logging 6 | import argparse 7 | import soundfile as sf 8 | from torch.utils.data import DataLoader 9 | 10 | from dataloader.dataset import SingleDataset 11 | from models.VQVAE_models import VQVAE_SE, VQVAE_QE 12 | from trainer.autoencoder import Trainer as TrainerAutoEncoder 13 | from bin.train import Train 14 | 15 | 16 | class TrainMain(Train): 17 | def __init__(self, args,): 18 | super(TrainMain, self).__init__(args=args,) 19 | self.train_mode = self.config.get('train_mode', 'autoencoder') 20 | self.data_path = self.config['data']['path'] 21 | 22 | 23 | def initialize_data_loader(self): 24 | logging.info("Loading datasets...") 25 | 26 | if self.train_mode in ['autoencoder']: 27 | train_set = self._audio('clean_train') 28 | valid_set = self._audio('clean_valid') 29 | # collater = CollaterAudio(batch_length=self.config['batch_length']) 30 | collater = None 31 | self.Trainer = TrainerAutoEncoder 32 | else: 33 | raise NotImplementedError(f"Train mode: {self.train_mode} is not supported!") 34 | 35 | logging.info(f"The number of training files = {len(train_set)}.") 36 | logging.info(f"The number of validation files = {len(valid_set)}.") 37 | dataset = {'train': train_set, 'dev': valid_set} 38 | self._data_loader(dataset, collater) 39 | 40 | def define_model(self): 41 | if self.config['task'] == "Speech_Enhancement": 42 | VQVAE = VQVAE_SE( 43 | **self.config['VQVAE_params']).to(self.device) 44 | elif self.config['task'] == "Quality_Estimation": 45 | VQVAE = VQVAE_QE( 46 | **self.config['VQVAE_params']).to(self.device) 47 | 48 | self.model = {"VQVAE": VQVAE} 49 | self._define_optimizer_scheduler() 50 | 51 | def define_trainer(self): 52 | self._show_setting() 53 | trainer_parameters = {} 54 | trainer_parameters['steps'] = 0 55 | trainer_parameters['epochs'] = 0 56 | trainer_parameters['data_loader'] = self.data_loader 57 | trainer_parameters['model'] = self.model 58 | trainer_parameters['criterion'] = self.criterion 59 | trainer_parameters['optimizer'] = self.optimizer 60 | trainer_parameters['scheduler'] = self.scheduler 61 | trainer_parameters['config'] = self.config 62 | trainer_parameters['device'] = self.device 63 | self.trainer = self.Trainer(**trainer_parameters) 64 | 65 | 66 | def _data_loader(self, dataset, collater): 67 | self.data_loader = { 68 | 'train': DataLoader( 69 | dataset=dataset['train'], 70 | shuffle=True, 71 | collate_fn=collater, 72 | batch_size=self.config['batch_size'], 73 | num_workers=self.config['num_workers'], 74 | pin_memory=self.config['pin_memory'], 75 | ), 76 | 'dev': DataLoader( 77 | dataset=dataset['dev'], 78 | shuffle=False, 79 | collate_fn=collater, 80 | batch_size=self.config['batch_size'], 81 | num_workers=self.config['num_workers'], 82 | pin_memory=self.config['pin_memory'], 83 | ), 84 | } 85 | 86 | 87 | def _audio(self, subset, subset_num=-1, return_utt_id=False): 88 | audio_dir = os.path.join( 89 | self.data_path, self.config['data']['subset'][subset]) 90 | params = { 91 | 'data_path': '/', 92 | 'files': audio_dir, 93 | 'query': "*.wav", 94 | 'load_fn': sf.read, 95 | 'return_utt_id': return_utt_id, 96 | 'subset_num': subset_num, 97 | 'batch_length': self.config['batch_length'], 98 | } 99 | return SingleDataset(**params) 100 | 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('-c', '--config', type=str, required=True) 105 | parser.add_argument("--tag", type=str, required=True) 106 | parser.add_argument("--exp_root", type=str, default="exp") 107 | parser.add_argument("--resume", default="", type=str, nargs="?", 108 | help='checkpoint file path to resume training. (default="")', 109 | ) 110 | parser.add_argument('--seed', default=1337, type=int) 111 | parser.add_argument('--disable_cudnn', choices=('True','False'), default='False', help='Disable CUDNN') 112 | args = parser.parse_args() 113 | 114 | # initial train_main 115 | train_main = TrainMain(args=args) 116 | 117 | # get dataset 118 | train_main.initialize_data_loader() 119 | 120 | # define models, optimizers, and schedulers 121 | train_main.define_model() 122 | 123 | # define criterions 124 | # train_main.define_criterion() 125 | 126 | # define trainer 127 | train_main.define_trainer() 128 | 129 | # model initialization 130 | train_main.initialize_model() 131 | 132 | # run training loop 133 | train_main.run() 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /trainer/autoencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Szu-Wei Fu 5 | 6 | """Training flow of VQVAE """ 7 | 8 | import matplotlib 9 | # Force matplotlib to not use any Xwindows backend. 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | 13 | import os 14 | import sys 15 | import copy 16 | import torch 17 | import torchaudio 18 | import numpy as np 19 | import shutil 20 | import pandas as pd 21 | import torch.nn.functional as F 22 | 23 | from scipy.stats import entropy 24 | from scipy.stats import pearsonr, spearmanr 25 | 26 | from trainer.eval_dataset import load_IUB, load_Tencent, load_DNS1, load_DNS3, load_VCTK_validSet, load_VCTK_testSet 27 | from trainer.trainerAE import TrainerAE 28 | 29 | from pesq import pesq 30 | sys.path.append('./DNSMOS') 31 | from dnsmos_local import ComputeScore 32 | compute_score = ComputeScore('./DNSMOS/ONNX_models/sig_bak_ovr.onnx', './DNSMOS/ONNX_models/model_v8.onnx') 33 | 34 | class SpectralConvergenceLoss(torch.nn.Module): 35 | """Spectral convergence loss module.""" 36 | 37 | def __init__(self): 38 | """Initilize spectral convergence loss module.""" 39 | super(SpectralConvergenceLoss, self).__init__() 40 | 41 | def forward(self, x_mag, y_mag): 42 | """Calculate forward propagation. 43 | 44 | Args: 45 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 46 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 47 | 48 | Returns: 49 | Tensor: Spectral convergence loss value. 50 | 51 | """ 52 | return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) 53 | 54 | 55 | def resynthesize(enhanced_mag, noisy_inputs, hop_size): 56 | """Function for resynthesizing waveforms from enhanced mags. 57 | Arguments 58 | --------- 59 | enhanced_mag : torch.Tensor 60 | Predicted spectral magnitude, should be three dimensional. 61 | noisy_inputs : torch.Tensor 62 | The noisy waveforms before any processing, to extract phase. 63 | Returns 64 | ------- 65 | enhanced_wav : torch.Tensor 66 | The resynthesized waveforms of the enhanced magnitudes with noisy phase. 67 | """ 68 | 69 | # Extract noisy phase from inputs 70 | 71 | noisy_feats = torch.stft(noisy_inputs, n_fft=512, hop_length=hop_size, win_length=512, 72 | window=torch.hamming_window(512).to('cuda'), 73 | center=True, 74 | pad_mode="constant", 75 | onesided=True, 76 | return_complex=False).transpose(2, 1) 77 | 78 | noisy_phase = torch.atan2(noisy_feats[:, :, :, 1], noisy_feats[:, :, :, 0])[:,0:enhanced_mag.shape[1],:] 79 | 80 | # Combine with enhanced magnitude 81 | predictions = torch.mul( 82 | torch.unsqueeze(enhanced_mag, -1), 83 | torch.cat( 84 | ( 85 | torch.unsqueeze(torch.cos(noisy_phase), -1), 86 | torch.unsqueeze(torch.sin(noisy_phase), -1), 87 | ), 88 | -1, 89 | ), 90 | ).permute(0, 2, 1, 3) 91 | 92 | # isft ask complex input 93 | complex_predictions = torch.complex(predictions[..., 0], predictions[..., 1]) 94 | pred_wavs = torch.istft(input=complex_predictions, n_fft=512, hop_length=hop_size, win_length=512, 95 | window=torch.hamming_window(512).to('cuda'), 96 | center=True, 97 | onesided=True, 98 | length=noisy_inputs.shape[1]) 99 | 100 | return pred_wavs 101 | 102 | def get_filepaths(directory): 103 | """ 104 | This function will generate the file names in a directory 105 | tree by walking the tree either top-down or bottom-up. For each 106 | directory in the tree rooted at directory top (including top itself), 107 | it yields a 3-tuple (dirpath, dirnames, filenames). 108 | """ 109 | file_paths = [] # List which will store all of the full filepaths. 110 | # Walk the tree. 111 | for root, directories, files in os.walk(directory): 112 | for filename in files: 113 | # Join the two strings in order to form the full filepath. 114 | filepath = os.path.join(root, filename) 115 | file_paths.append(filepath) # Add it to the list. 116 | return file_paths # Self-explanatory. 117 | 118 | eps=1e-5 119 | class Trainer(TrainerAE): 120 | def __init__( 121 | self, 122 | steps, 123 | epochs, 124 | data_loader, 125 | model, 126 | criterion, 127 | optimizer, 128 | scheduler, 129 | config, 130 | device=torch.device("cpu"), 131 | ): 132 | super(Trainer, self).__init__( 133 | steps=steps, 134 | epochs=epochs, 135 | data_loader=data_loader, 136 | model=model, 137 | optimizer=optimizer, 138 | scheduler=scheduler, 139 | config=config, 140 | device=device, 141 | ) 142 | 143 | self.VQVAE_start = config.get('start_steps', {}).get('VQVAE', 0) 144 | self.spectral_convergence_loss = SpectralConvergenceLoss() 145 | self.exp_dir = './exp/'+self.config["name"]+'/' 146 | 147 | 148 | ## load evaluation data 149 | self.vctk_test = load_VCTK_testSet('./VCTK_noisy_testSet_with_scores.pickle') # it is the validation set for QE, and one of the test set for SE 150 | if self.config['task'] == 'Speech_Enhancement': 151 | self.hop_size = 128 152 | self.vctk_Clean_path = self.config['vctk_Clean_path'] 153 | self.dns1 = load_DNS1(self.config['DNS1_test']) 154 | self.dns3 = load_DNS3(self.config['DNS3_test']) 155 | self.vctk_valid = load_VCTK_validSet('./VCTK_noisy_validationSet.pickle') # As in MetricGAN-U, using noisy data of speakers (p226 and p287) as validation set 156 | self.highest_pesq, self.highest_dnsmos_ovr = 0, 0 157 | elif self.config['task'] == 'Quality_Estimation': 158 | self.hop_size = 256 159 | self.tencent = load_Tencent(pickle_path ='./Tencent_ind2.pickle', number_test_set = 250) 160 | self.iub = load_IUB(pickle_path = './IUB_ind2.pickle', number_test_set = 200) 161 | self.highest_dnsmos_ovr_CC = 0 162 | 163 | # Copy code to the current exp directory for tracing modification 164 | shutil.copyfile('./trainer/autoencoder.py', self.exp_dir+'autoencoder.py') 165 | shutil.copyfile('./trainer/trainerAE.py', self.exp_dir+'trainerAE.py') 166 | shutil.copyfile('./models/VQVAE_models.py', self.exp_dir+'VQVAE_models.py') 167 | shutil.copyfile('./config/'+self.config["name"]+'.yaml', self.exp_dir+self.config["name"]+'.yaml') 168 | 169 | def stft_magnitude(self, x, hop_size, fft_size=512, win_length=512): 170 | if x.is_cuda: 171 | x_stft = torch.stft( 172 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length).to('cuda'), return_complex=False 173 | ) 174 | else: 175 | x_stft = torch.stft( 176 | x, fft_size, hop_size, win_length, window=torch.hann_window(win_length), return_complex=False 177 | ) 178 | real = x_stft[..., 0] 179 | imag = x_stft[..., 1] 180 | 181 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 182 | 183 | def scatter_plot(self, Real_scores, predicted_scores, image_name): 184 | # Plotting the scatter plot 185 | plt.scatter(Real_scores, predicted_scores, s=14) 186 | plt.xlabel('Real_scores') 187 | plt.ylabel('Predicted_scores') 188 | 189 | LCC = pearsonr(Real_scores, predicted_scores)[0] 190 | SRCC = spearmanr(Real_scores, predicted_scores)[0] 191 | MSE = np.mean((np.asarray(Real_scores)-np.asarray(predicted_scores))**2) 192 | 193 | plt.title('LCC= %f, SRCC= %f, MSE= %f' % (LCC, SRCC, MSE)) 194 | plt.show() 195 | plt.savefig(self.exp_dir+image_name, dpi=150) 196 | plt.clf() 197 | 198 | def cos_loss(self, SP_noisy, SP_y_noisy): 199 | SP_noisy_norm = torch.norm(SP_noisy, p=2, dim=-1, keepdim=True)+eps 200 | SP_y_noisy_norm = torch.norm(SP_y_noisy, p=2, dim=-1, keepdim=True)+eps 201 | Cos_frame = torch.sum(SP_noisy/SP_noisy_norm * SP_y_noisy/SP_y_noisy_norm, dim=-1) # torch.Size([B, T, 1]) 202 | 203 | return -torch.mean(Cos_frame) 204 | 205 | 206 | def _train_step(self, batch): 207 | """Single step of training.""" 208 | mode = 'train' 209 | x_c = batch[0] # torch.Size([64, 1, 48000]) 210 | scalar = torch.rand((x_c.shape[0],1,1))*1.95+0.05 # to have diverse volume 211 | sample_max = torch.max(abs(x_c), dim=-1, keepdim=True)[0] 212 | scalar2 = torch.clamp(scalar, max=1/sample_max) 213 | 214 | x = (x_c*scalar2).to(self.device) 215 | 216 | # check VQVAE step 217 | if self.steps < self.VQVAE_start: 218 | self.VQVAE_train = False 219 | else: 220 | self.VQVAE_train = True 221 | 222 | if self.steps == self.config['AT_training_start_steps']: 223 | VQVAE_optimizer_class = getattr( 224 | torch.optim, 225 | self.config['VQVAE_optimizer_type']) 226 | 227 | self.optimizer = { 228 | 'VQVAE': VQVAE_optimizer_class(self.model['VQVAE'].parameters(), **self.config['VQVAE_AT_optimizer_params'])} 229 | 230 | # start with the model which has the highest pesq on validation set 231 | print('Load highest pesq model for AT training...') 232 | self.model["VQVAE"].load_state_dict(torch.load(self.exp_dir + 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'.pkl')['model']['VQVAE']) 233 | self.teacher_model = copy.deepcopy(self.model["VQVAE"]) 234 | 235 | if self.config['task'] == 'Speech_Enhancement' and self.steps >= self.config['AT_training_start_steps']: 236 | # Fix the teacher model 237 | for parameter in self.teacher_model.parameters(): 238 | parameter.requires_grad = False 239 | self.teacher_model.quantizer.eval() 240 | self.teacher_model.eval() 241 | 242 | # Fix the quantizer of the student model 243 | for parameter in self.model["VQVAE"].quantizer.parameters(): 244 | parameter.requires_grad = False 245 | self.model["VQVAE"].quantizer.eval() 246 | 247 | 248 | ####################### 249 | # VQVAE # 250 | ####################### 251 | if self.VQVAE_train: 252 | gen_loss = 0.0 # initialize VQVAE loss 253 | x = torch.squeeze(x) 254 | if self.config['task'] == 'Speech_Enhancement' and self.steps >= self.config['AT_training_start_steps']: 255 | x.requires_grad = True # Set requires_grad attribute of input tensor. Important for AT Attack. 256 | X = self.stft_magnitude(x, hop_size=self.hop_size) # shape = torch.Size([B, T, F]) 257 | if self.config['input_transform'] == 'log1p': 258 | X = torch.log1p(X) 259 | 260 | # Sec 2.5: step 2 SELF-DISTILLATION WITH ADVERSARIAL TRAINING in the paper 261 | if self.config['task'] == 'Speech_Enhancement' and self.steps >= self.config['AT_training_start_steps']: 262 | with torch.no_grad(): 263 | z_teacher = self.teacher_model.CNN_1D_encoder(X) 264 | teacher_zq, indices_teacher, _, _ = self.teacher_model.quantizer(z_teacher, stochastic=False, update=False) 265 | 266 | ###### Step 2-1 adversarial attack start ###### 267 | self.model["VQVAE"].eval() 268 | z_att = self.model["VQVAE"].CNN_1D_encoder(X) 269 | zq_att, attack_cross_entropy_loss = self.model["VQVAE"].quantizer(z_att, stochastic=False, update=False, indices=indices_teacher.detach()) 270 | 271 | # Zero all existing gradients 272 | self.model["VQVAE"].CNN_1D_encoder.zero_grad() 273 | 274 | # Calculate gradients of model in backward pass 275 | attack_cross_entropy_loss.backward() 276 | 277 | # Get input gradient 278 | adversarial_noise = x.grad.data 279 | 280 | power_ratio = torch.norm(adversarial_noise, p=2, dim=-1, keepdim=True)/torch.norm(x, p=2, dim=-1, keepdim=True) 281 | random_ratio = (self.config['adv_max_epsilon']-self.config['adv_min_epsilon'])*torch.rand((X.shape[0],1), device='cuda')+self.config['adv_min_epsilon'] 282 | 283 | perturbed_wav = x + (random_ratio/power_ratio) * adversarial_noise # gradient accent 284 | perturbed_X = self.stft_magnitude(perturbed_wav, hop_size=self.hop_size) 285 | if self.config['input_transform'] == 'log1p': 286 | perturbed_X = torch.log1p(perturbed_X) 287 | ###### adversarial attack end ###### 288 | 289 | ###### Step 2-2 Adversarial training start ###### 290 | self.model["VQVAE"].train() 291 | perturbed_z = self.model["VQVAE"].CNN_1D_encoder(perturbed_X.detach()) 292 | perturbed_zq, noisy_cross_entropy_loss = self.model["VQVAE"].quantizer(perturbed_z, stochastic=False, update=False, indices=indices_teacher.detach()) 293 | perturbed_Y_ = self.model["VQVAE"].CNN_1D_decoder(perturbed_zq.detach()) 294 | 295 | z = self.model["VQVAE"].CNN_1D_encoder(X.detach()) 296 | zq, clean_cross_entropy_loss = self.model["VQVAE"].quantizer(z, stochastic=False, update=False, indices=indices_teacher.detach()) 297 | Y_ = self.model["VQVAE"].CNN_1D_decoder(zq.detach()) 298 | 299 | # save the attacked audio for listening 300 | if self.steps % 1000 == 0: 301 | torchaudio.save(self.exp_dir + 'original.wav', x[0:1,:].cpu(), 16000) 302 | torchaudio.save(self.exp_dir + 'attacked.wav', perturbed_wav[0:1,:].cpu(), 16000) 303 | print([attack_cross_entropy_loss.item(), noisy_cross_entropy_loss.item()]) 304 | 305 | # cross_entropy_loss e.q. 7 in the paper 306 | ce_loss = noisy_cross_entropy_loss + clean_cross_entropy_loss 307 | ce_loss *= self.config["lambda_ce_loss"] 308 | gen_loss += ce_loss 309 | self.total_train_loss["train/noisy_cross_entropy_loss"] += noisy_cross_entropy_loss 310 | self.total_train_loss["train/clean_cross_entropy_loss"] += clean_cross_entropy_loss 311 | 312 | # reconstruction loss 313 | noisy_SP_loss = self.config["lambda_stft_loss"] * self.spectral_convergence_loss(perturbed_Y_, X.detach()) 314 | clean_SP_loss = self.config["lambda_stft_loss"] * self.spectral_convergence_loss(Y_, X.detach()) 315 | gen_loss += (noisy_SP_loss + clean_SP_loss) 316 | self.total_train_loss["train/noisy_SP_loss"] += noisy_SP_loss 317 | self.total_train_loss["train/clean_SP_loss"] += clean_SP_loss 318 | 319 | # Normal VQVAE training 320 | else: 321 | z = self.model["VQVAE"].CNN_1D_encoder(X) 322 | zq, indices, vqloss, distance = self.model["VQVAE"].quantizer(z, stochastic=False, update=True) 323 | Y_ = self.model["VQVAE"].CNN_1D_decoder(zq) 324 | 325 | vqloss *= self.config["lambda_vq_loss"] 326 | gen_loss += vqloss 327 | self.total_train_loss["train/vqloss"] += vqloss.item() 328 | 329 | main_loss = self.config["lambda_stft_loss"] * (self.cos_loss(X, Y_) if self.config['cos_loss'] else self.spectral_convergence_loss(Y_, X)) 330 | gen_loss += main_loss 331 | self.total_train_loss["train/main_loss"] += main_loss.item() 332 | 333 | # update VQVAE 334 | self._record_loss('VQVAE_loss', gen_loss, mode=mode) 335 | self._update_VQVAE(gen_loss) 336 | # update counts 337 | self.steps += 1 338 | self.tqdm.update(1) 339 | self._check_train_finish() 340 | 341 | 342 | @torch.no_grad() 343 | def _eval_step(self, batch): 344 | """Single step of evaluation.""" 345 | mode = 'eval' 346 | x_c = batch[0] #[:,:,0:batch[1].min().item()] # torch.Size([B, 1, 48000]) 347 | x = x_c.to(self.device) 348 | 349 | # initialize VQVAE loss 350 | gen_loss = 0.0 351 | x = torch.squeeze(x) 352 | X = self.stft_magnitude(x, hop_size=self.hop_size) 353 | if self.config['input_transform'] == 'log1p': 354 | X = torch.log1p(X) 355 | 356 | z = self.model["VQVAE"].CNN_1D_encoder(X) 357 | zq, indices, vqloss, distance = self.model["VQVAE"].quantizer(z, stochastic=False, update=False) 358 | Y_ = self.model["VQVAE"].CNN_1D_decoder(zq) 359 | 360 | # vq_loss 361 | self.total_eval_loss["eval/vqloss"] += F.mse_loss(zq, z.transpose(2, 1)) 362 | 363 | # metric loss 364 | SP_loss = self.config["lambda_stft_loss"] * (self.cos_loss(X, Y_) if self.config['cos_loss'] else self.spectral_convergence_loss(Y_, X)) 365 | gen_loss += SP_loss 366 | self.total_eval_loss["eval/SP_loss"] += SP_loss.item() 367 | 368 | if self.config['task'] == 'Speech_Enhancement': 369 | if self.config['input_transform'] == 'log1p': 370 | Y_ = torch.expm1(Y_) 371 | y_ = resynthesize(Y_, x, self.hop_size) 372 | waveform_loss = F.l1_loss(y_, x) 373 | self.total_eval_loss["eval/waveform_loss"] += waveform_loss.item() 374 | 375 | self._record_loss('VQVAE_loss', gen_loss, mode=mode) 376 | 377 | 378 | @torch.no_grad() 379 | def run_VQVAE(self, wav_input): 380 | wav_input = wav_input.to(self.device) 381 | SP_input = self.stft_magnitude(wav_input, hop_size=self.hop_size) 382 | if self.config['input_transform'] == 'log1p': 383 | SP_input = torch.log1p(SP_input) 384 | 385 | z = self.model["VQVAE"].CNN_1D_encoder(SP_input) 386 | zq, indices, vqloss, distance = self.model["VQVAE"].quantizer(z, stochastic=False, update=False) 387 | SP_output = self.model["VQVAE"].CNN_1D_decoder(zq) 388 | 389 | if self.config['input_transform'] == 'log1p': 390 | wav_output = resynthesize(torch.expm1(SP_output), wav_input, self.hop_size) 391 | else: 392 | wav_output = resynthesize(SP_output, wav_input, self.hop_size) 393 | 394 | return SP_input.cpu(), SP_output.cpu(), z.transpose(2, 1).cpu(), zq.cpu(), wav_output.cpu(), indices 395 | 396 | @torch.no_grad() 397 | def VQScore_Evaluation(self, data_dict, mos_list, dataset, dataset_sub_name): 398 | VQScore_l2_x, VQScore_cos_x = [], [] 399 | VQScore_l2_z, VQScore_cos_z = [], [] 400 | 401 | whole_name = dataset + '_' + dataset_sub_name # ex: IUB_cosine 402 | #if not os.path.exists(self.exp_dir + whole_name): 403 | # os.mkdir(self.exp_dir + whole_name) 404 | 405 | for file in data_dict: 406 | input_wav = data_dict[file] 407 | SP_input, SP_output, zT, zqT, wav_output, indices = self.run_VQVAE(input_wav) 408 | 409 | ###### Input_output error 410 | Square_diff, Square_input = torch.square(SP_input-SP_output), torch.square(SP_input) 411 | 412 | VQScore_l2_x.append(torch.mean(Square_diff / (torch.mean(Square_input, dim=-1, keepdim=True)+eps) ).numpy()) 413 | VQScore_cos_x.append(-self.cos_loss(SP_input, SP_output).numpy()) 414 | 415 | ##### Quantization error 416 | Square_z_diff, Square_z_input = torch.square(zT-zqT), torch.square(zT) 417 | 418 | VQScore_l2_z.append(torch.mean(Square_z_diff / (torch.mean(Square_z_input, dim=-1, keepdim=True)+eps) ).numpy()) 419 | VQScore_cos_z.append(-self.cos_loss(zT, zqT).numpy()) 420 | 421 | # torchaudio.save(self.exp_dir + whole_name + '/' + self.config["name"] + '_' + whole_name + '_'+ file.split('/')[-1], wav_output, 16000) 422 | 423 | ###### Record_CC: Input_output error 424 | self._record_loss(dataset_sub_name + '_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, mos_list)[0], mode=dataset) 425 | self._record_loss(dataset_sub_name + '_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, mos_list)[0], mode=dataset) 426 | self.scatter_plot(mos_list, VQScore_l2_x, whole_name + 'VQScore_l2_x.png') 427 | self.scatter_plot(mos_list, VQScore_cos_x, whole_name + 'VQScore_cos_x.png') 428 | 429 | ###### Record_CC: Quantization error 430 | self._record_loss(dataset_sub_name + '_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, mos_list)[0], mode=dataset+'_z') 431 | self._record_loss(dataset_sub_name + '_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, mos_list)[0], mode=dataset+'_z') 432 | self.scatter_plot(mos_list, VQScore_l2_z, whole_name + 'VQScore_l2_z.png') 433 | self.scatter_plot(mos_list, VQScore_cos_z, whole_name + 'VQScore_cos_z.png') 434 | 435 | df = pd.DataFrame({'file': list(data_dict.keys()),'Real mos': mos_list, 436 | 'VQScore_l2_x': VQScore_l2_x, 'VQScore_cos_x': VQScore_cos_x, 437 | 'VQScore_l2_z': VQScore_l2_z, 'VQScore_cos_z': VQScore_cos_z}) 438 | df.to_csv(self.exp_dir + '/' + whole_name + '.csv') 439 | 440 | @torch.no_grad() 441 | def _eval_IUB(self): 442 | print('_eval_IUB.........') 443 | self.VQScore_Evaluation(self.iub.IUB_cosine_data_dict, self.iub.IUB_cosine_mos, 'IUB', 'cosine') 444 | self.VQScore_Evaluation(self.iub.IUB_voices_data_dict, self.iub.IUB_voices_mos, 'IUB', 'voices') 445 | 446 | @torch.no_grad() 447 | def _eval_Tencent(self): 448 | print('_eval_Tencent.........') 449 | self.VQScore_Evaluation(self.tencent.Tencent_wR_data_dict, self.tencent.Tencent_wR_mos, 'Tencent', 'wR') 450 | self.VQScore_Evaluation(self.tencent.Tencent_woR_data_dict, self.tencent.Tencent_woR_mos, 'Tencent', 'woR') 451 | 452 | @torch.no_grad() 453 | def _eval_DNS1_test(self): 454 | print('_eval_DNS1_test.........') 455 | if not os.path.exists(self.exp_dir + '_DNS1'): 456 | os.mkdir(self.exp_dir + '_DNS1') 457 | 458 | Real_enhanced_dnsmos_p835 = [] 459 | for file in self.dns1.DNS1_Real_dict: 460 | noisy = self.dns1.DNS1_Real_dict[file] 461 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 462 | Real_enhanced_dnsmos_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 463 | torchaudio.save(self.exp_dir + '_DNS1' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 464 | self._record_loss('Real_enhanced_dnsmos_sig', np.mean([i['SIG'] for i in Real_enhanced_dnsmos_p835]), mode='dns1') 465 | self._record_loss('Real_enhanced_dnsmos_bak', np.mean([i['BAK'] for i in Real_enhanced_dnsmos_p835]), mode='dns1') 466 | self._record_loss('Real_enhanced_dnsmos_ovr', np.mean([i['OVRL'] for i in Real_enhanced_dnsmos_p835]), mode='dns1') 467 | 468 | Noreverb_enhanced_dnsmos_p835 = [] 469 | for file in self.dns1.DNS1_Noreverb_dict: 470 | noisy = self.dns1.DNS1_Noreverb_dict[file] 471 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 472 | Noreverb_enhanced_dnsmos_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 473 | torchaudio.save(self.exp_dir + '_DNS1' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 474 | self._record_loss('Noreverb_enhanced_dnsmos_sig', np.mean([i['SIG'] for i in Noreverb_enhanced_dnsmos_p835]), mode='dns1') 475 | self._record_loss('Noreverb_enhanced_dnsmos_bak', np.mean([i['BAK'] for i in Noreverb_enhanced_dnsmos_p835]), mode='dns1') 476 | self._record_loss('Noreverb_enhanced_dnsmos_ovr', np.mean([i['OVRL'] for i in Noreverb_enhanced_dnsmos_p835] ), mode='dns1') 477 | 478 | Reverb_enhanced_dnsmos_p835 = [] 479 | for file in self.dns1.DNS1_Reverb_dict: 480 | noisy = self.dns1.DNS1_Reverb_dict[file] 481 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 482 | Reverb_enhanced_dnsmos_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 483 | torchaudio.save(self.exp_dir + '_DNS1' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 484 | self._record_loss('Reverb_enhanced_dnsmos_sig', np.mean([i['SIG'] for i in Reverb_enhanced_dnsmos_p835]), mode='dns1') 485 | self._record_loss('Reverb_enhanced_dnsmos_bak', np.mean([i['BAK'] for i in Reverb_enhanced_dnsmos_p835]), mode='dns1') 486 | self._record_loss('Reverb_enhanced_dnsmos_ovr', np.mean([i['OVRL'] for i in Reverb_enhanced_dnsmos_p835] ), mode='dns1') 487 | 488 | @torch.no_grad() 489 | def _eval_DNS3_test(self): 490 | print('_eval_DNS3_test.........') 491 | if not os.path.exists(self.exp_dir + '_DNS3'): 492 | os.mkdir(self.exp_dir + '_DNS3') 493 | 494 | nonenglish_synthetic_p835 = [] 495 | for file in self.dns3.DNS3_nonenglish_synthetic_dict: 496 | noisy = self.dns3.DNS3_nonenglish_synthetic_dict[file] 497 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 498 | nonenglish_synthetic_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 499 | torchaudio.save(self.exp_dir + '_DNS3' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 500 | self._record_loss('nonenglish_synthetic_sig', np.mean([i['SIG'] for i in nonenglish_synthetic_p835]), mode='dns3') 501 | self._record_loss('nonenglish_synthetic_bak', np.mean([i['BAK'] for i in nonenglish_synthetic_p835]), mode='dns3') 502 | self._record_loss('nonenglish_synthetic_ovr', np.mean([i['OVRL'] for i in nonenglish_synthetic_p835]), mode='dns3') 503 | 504 | stationary_p835 = [] 505 | for file in self.dns3.DNS3_stationary_dict: 506 | noisy = self.dns3.DNS3_stationary_dict[file] 507 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 508 | stationary_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 509 | torchaudio.save(self.exp_dir + '_DNS3' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 510 | self._record_loss('stationary_sig', np.mean([i['SIG'] for i in stationary_p835]), mode='dns3') 511 | self._record_loss('stationary_bak', np.mean([i['BAK'] for i in stationary_p835]), mode='dns3') 512 | self._record_loss('stationary_ovr', np.mean([i['OVRL'] for i in stationary_p835]), mode='dns3') 513 | 514 | ms_realrec_nonenglish_p835 = [] 515 | for file in self.dns3.DNS3_ms_realrec_nonenglish_dict: 516 | noisy = self.dns3.DNS3_ms_realrec_nonenglish_dict[file] 517 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 518 | ms_realrec_nonenglish_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 519 | torchaudio.save(self.exp_dir + '_DNS3' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 520 | self._record_loss('ms_realrec_nonenglish_sig', np.mean([i['SIG'] for i in ms_realrec_nonenglish_p835]), mode='dns3') 521 | self._record_loss('ms_realrec_nonenglish_bak', np.mean([i['BAK'] for i in ms_realrec_nonenglish_p835]), mode='dns3') 522 | self._record_loss('ms_realrec_nonenglish_ovr', np.mean([i['OVRL'] for i in ms_realrec_nonenglish_p835]), mode='dns3') 523 | 524 | ms_realrec_p835 = [] 525 | for file in self.dns3.DNS3_ms_realrec_dict: 526 | noisy = self.dns3.DNS3_ms_realrec_dict[file] 527 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 528 | ms_realrec_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 529 | torchaudio.save(self.exp_dir + '_DNS3' + '/' + self.config["name"] + '_'+ file.split('/')[-1], y_noisy, 16000) 530 | self._record_loss('ms_realrec_sig', np.mean([i['SIG'] for i in ms_realrec_p835]), mode='dns3') 531 | self._record_loss('ms_realrec_bak', np.mean([i['BAK'] for i in ms_realrec_p835]), mode='dns3') 532 | self._record_loss('ms_realrec_ovr', np.mean([i['OVRL'] for i in ms_realrec_p835]), mode='dns3') 533 | 534 | 535 | @torch.no_grad() 536 | def _eval_vctk_ValidSet(self): 537 | print('_eval_vctk_ValidSet.........') 538 | enhanced_dnsmos_p835, enhanced_pesq = [], [] 539 | for file in self.vctk_valid.VCTK_data_dict: 540 | ########## Noisy ########## 541 | noisy = self.vctk_valid.VCTK_data_dict[file] 542 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 543 | 544 | ########## Clean ########## 545 | clean, fs = torchaudio.load(file.replace('noisy','clean')) 546 | SP_clean, SP_y_clean, zT, zqT, y_clean, clean_indices = self.run_VQVAE(clean) 547 | 548 | ## objective Metrics 549 | enhanced_dnsmos_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 550 | enhanced_pesq.append(pesq(fs=16000, ref=clean[0].numpy(), deg=y_noisy[0].numpy(), mode="wb")) 551 | 552 | ######## Objective Metrics (DNSMOS and PESQ) 553 | # noisy enhanced results 554 | enhanced_dnsmos_sig = [i['SIG'] for i in enhanced_dnsmos_p835] 555 | enhanced_dnsmos_bak = [i['BAK'] for i in enhanced_dnsmos_p835] 556 | enhanced_dnsmos_ovr = [i['OVRL'] for i in enhanced_dnsmos_p835] 557 | self._record_loss('enhanced_dnsmos_sig', np.mean(enhanced_dnsmos_sig), mode='vctk_valid') 558 | self._record_loss('enhanced_dnsmos_bak', np.mean(enhanced_dnsmos_bak), mode='vctk_valid') 559 | self._record_loss('enhanced_dnsmos_ovr', np.mean(enhanced_dnsmos_ovr), mode='vctk_valid') 560 | self._record_loss('enhanced_pesq', np.mean(enhanced_pesq), mode='vctk_valid') 561 | 562 | # Save SE model checkpoints 563 | if self.steps <= self.config['AT_training_start_steps']: 564 | if np.mean(enhanced_pesq) > self.highest_pesq: 565 | if os.path.isfile(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'.pkl')): 566 | os.remove(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'.pkl')) 567 | self.highest_pesq = np.mean(enhanced_pesq) 568 | self.save_checkpoint(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'.pkl')) 569 | 570 | if np.mean(enhanced_dnsmos_ovr) > self.highest_dnsmos_ovr: 571 | if os.path.isfile(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'.pkl')): 572 | os.remove(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'.pkl')) 573 | self.highest_dnsmos_ovr = np.mean(enhanced_dnsmos_ovr) 574 | self.save_checkpoint(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'.pkl')) 575 | else: 576 | if np.mean(enhanced_pesq) > self.highest_pesq: 577 | if os.path.isfile(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'_AT.pkl')): 578 | os.remove(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'_AT.pkl')) 579 | self.highest_pesq = np.mean(enhanced_pesq) 580 | self.save_checkpoint(os.path.join(self.config["outdir"], 'checkpoint-pesq='+ str(self.highest_pesq)[0:5]+'_AT.pkl')) 581 | 582 | if np.mean(enhanced_dnsmos_ovr) > self.highest_dnsmos_ovr: 583 | if os.path.isfile(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'_AT.pkl')): 584 | os.remove(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'_AT.pkl')) 585 | self.highest_dnsmos_ovr = np.mean(enhanced_dnsmos_ovr) 586 | self.save_checkpoint(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr='+ str(self.highest_dnsmos_ovr)[0:5]+'_AT.pkl')) 587 | 588 | @torch.no_grad() 589 | def _eval_vctk_TestSet(self): 590 | print('_eval_vctk_TestSet.........') 591 | # Quality estimation 592 | if self.config['task'] == 'Quality_Estimation': 593 | VQScore_l2_x, VQScore_cos_x = [], [] 594 | VQScore_l2_z, VQScore_cos_z = [], [] 595 | 596 | for file in self.vctk_test.VCTK_data_dict: 597 | noisy = self.vctk_test.VCTK_data_dict[file] 598 | SP_input, SP_output, zT, zqT, wav_output, indices = self.run_VQVAE(noisy) 599 | 600 | ###### Input_output error 601 | Square_diff, Square_input = torch.square(SP_input-SP_output), torch.square(SP_input) 602 | 603 | VQScore_l2_x.append(torch.mean(Square_diff / (torch.mean(Square_input, dim=-1, keepdim=True)+eps) ).numpy()) 604 | VQScore_cos_x.append(-self.cos_loss(SP_input, SP_output).numpy()) 605 | 606 | ##### Quantization error 607 | Square_z_diff, Square_z_input = torch.square(zT-zqT), torch.square(zT) 608 | 609 | VQScore_l2_z.append(torch.mean(Square_z_diff / (torch.mean(Square_z_input, dim=-1, keepdim=True)+eps) ).numpy()) 610 | VQScore_cos_z.append(-self.cos_loss(zT, zqT).numpy()) 611 | 612 | ###### Record_CC: Input_output error 613 | self._record_loss('sig_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.sig)[0], mode='vctk') 614 | self._record_loss('bak_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.bak)[0], mode='vctk') 615 | self._record_loss('ovr_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.ovr)[0], mode='vctk') 616 | self._record_loss('pesq_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.PESQ_list)[0], mode='vctk') 617 | self._record_loss('stoi_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.STOI_list)[0], mode='vctk') 618 | self._record_loss('snr_VQScore_l2_x_pearsonr', pearsonr(VQScore_l2_x, self.vctk_test.SNR_list)[0], mode='vctk') 619 | 620 | self._record_loss('sig_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.sig)[0], mode='vctk') 621 | self._record_loss('bak_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.bak)[0], mode='vctk') 622 | self._record_loss('ovr_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.ovr)[0], mode='vctk') 623 | self._record_loss('pesq_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.PESQ_list)[0], mode='vctk') 624 | self._record_loss('stoi_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.STOI_list)[0], mode='vctk') 625 | self._record_loss('snr_VQScore_cos_x_pearsonr', pearsonr(VQScore_cos_x, self.vctk_test.SNR_list)[0], mode='vctk') 626 | 627 | ###### Record_CC: Quantization error 628 | self._record_loss('sig_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.sig)[0], mode='vctk_z') 629 | self._record_loss('bak_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.bak)[0], mode='vctk_z') 630 | self._record_loss('ovr_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.ovr)[0], mode='vctk_z') 631 | self._record_loss('pesq_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.PESQ_list)[0], mode='vctk_z') 632 | self._record_loss('stoi_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.STOI_list)[0], mode='vctk_z') 633 | self._record_loss('snr_VQScore_l2_z_pearsonr', pearsonr(VQScore_l2_z, self.vctk_test.SNR_list)[0], mode='vctk_z') 634 | 635 | self._record_loss('sig_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.sig)[0], mode='vctk_z') 636 | self._record_loss('bak_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.bak)[0], mode='vctk_z') 637 | self._record_loss('ovr_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.ovr)[0], mode='vctk_z') 638 | self._record_loss('pesq_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.PESQ_list)[0], mode='vctk_z') 639 | self._record_loss('stoi_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.STOI_list)[0], mode='vctk_z') 640 | self._record_loss('snr_VQScore_cos_z_pearsonr', pearsonr(VQScore_cos_z, self.vctk_test.SNR_list)[0], mode='vctk_z') 641 | 642 | self.scatter_plot(self.vctk_test.sig, VQScore_cos_z, 'VCTK_VQScore_cos_z_sig.png') 643 | self.scatter_plot(self.vctk_test.bak, VQScore_cos_z, 'VCTK_VQScore_cos_z_bak.png') 644 | self.scatter_plot(self.vctk_test.ovr, VQScore_cos_z, 'VCTK_VQScore_cos_z_ovr.png') 645 | self.scatter_plot(self.vctk_test.PESQ_list, VQScore_cos_z, 'VCTK_VQScore_cos_z_pesq.png') 646 | self.scatter_plot(self.vctk_test.STOI_list, VQScore_cos_z, 'VCTK_VQScore_cos_z_stoi.png') 647 | self.scatter_plot(self.vctk_test.SNR_list, VQScore_cos_z, 'VCTK_VQScore_cos_z_SNR.png') 648 | 649 | if pearsonr(VQScore_cos_z, self.vctk_test.ovr)[0] > self.highest_dnsmos_ovr_CC: 650 | if os.path.isfile(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr_CC='+ str(self.highest_dnsmos_ovr_CC)[0:5]+'.pkl')): 651 | os.remove(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr_CC='+ str(self.highest_dnsmos_ovr_CC)[0:5]+'.pkl')) 652 | self.highest_dnsmos_ovr_CC = pearsonr(VQScore_cos_z, self.vctk_test.ovr)[0] 653 | self.save_checkpoint(os.path.join(self.config["outdir"], 'checkpoint-dnsmos_ovr_CC='+ str(self.highest_dnsmos_ovr_CC)[0:5]+'.pkl')) 654 | 655 | elif self.config['task'] == 'Speech_Enhancement': 656 | ### speech enhancement 657 | index_accuracy = [] 658 | 659 | Noisy_spectral_convergence_loss, Clean_spectral_convergence_loss = [], [] 660 | Noisy_to_clean_spectral_convergence_loss = [] 661 | 662 | clean_recovery_dnsmos_p835, clean_recovery_pesq = [], [] 663 | enhanced_dnsmos_p835, enhanced_pesq = [], [] 664 | 665 | Noisy_index_hist = np.zeros(self.config['VQVAE_params']['codebook_size']) 666 | Clean_index_hist = np.zeros(self.config['VQVAE_params']['codebook_size']) 667 | 668 | if not os.path.exists(self.exp_dir + '_vctk_noisy'): 669 | os.mkdir(self.exp_dir + '_vctk_noisy') 670 | if not os.path.exists(self.exp_dir + '_vctk_clean'): 671 | os.mkdir(self.exp_dir + '_vctk_clean') 672 | 673 | for file in self.vctk_test.VCTK_data_dict: 674 | ########## Noisy ########## 675 | noisy = self.vctk_test.VCTK_data_dict[file] 676 | SP_noisy, SP_y_noisy, zT, zqT, y_noisy, noisy_indices = self.run_VQVAE(noisy) 677 | 678 | ########## Clean ########## 679 | clean, fs = torchaudio.load(self.vctk_Clean_path + file.split('/')[-1]) 680 | SP_clean, SP_y_clean, zT, zqT, y_clean, clean_indices = self.run_VQVAE(clean) 681 | 682 | for ind in noisy_indices.cpu().numpy(): 683 | Noisy_index_hist[ind] += 1 684 | 685 | for ind in clean_indices.cpu().numpy(): 686 | Clean_index_hist[ind] += 1 687 | 688 | index_accuracy.append((torch.sum(noisy_indices[0,:]==clean_indices[0,:])/clean_indices.shape[1]).cpu().numpy()) 689 | 690 | Noisy_spectral_convergence_loss.append(self.spectral_convergence_loss(SP_y_noisy, SP_noisy).numpy()) 691 | Clean_spectral_convergence_loss.append(self.spectral_convergence_loss(SP_y_clean, SP_clean).numpy()) 692 | Noisy_to_clean_spectral_convergence_loss.append(self.spectral_convergence_loss(SP_y_noisy, SP_clean).numpy()) 693 | 694 | torchaudio.save(self.exp_dir + '_vctk_noisy' + '/' + self.config["name"] + '_vctk_noisy' + '_'+ file.split('/')[-1], y_noisy, 16000) 695 | torchaudio.save(self.exp_dir + '_vctk_clean' + '/' + self.config["name"] + '_vctk_clean' + '_'+ file.split('/')[-1], y_clean, 16000) 696 | 697 | ## objective Metrics 698 | clean_recovery_dnsmos_p835.append(compute_score(y_clean[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 699 | clean_recovery_pesq.append(pesq(fs=16000, ref=clean[0].numpy(), deg=y_clean[0].numpy(), mode="wb")) 700 | 701 | enhanced_dnsmos_p835.append(compute_score(y_noisy[0].numpy(), 16000, is_personalized_MOS=False, is_normalized=True, is_p808=False)) 702 | enhanced_pesq.append(pesq(fs=16000, ref=clean[0].numpy(), deg=y_noisy[0].numpy(), mode="wb")) 703 | 704 | ######## Objective Metrics (DNSMOS and PESQ) 705 | # clean recovery results 706 | clean_recovery_dnsmos_sig = [i['SIG'] for i in clean_recovery_dnsmos_p835] 707 | clean_recovery_dnsmos_bak = [i['BAK'] for i in clean_recovery_dnsmos_p835] 708 | clean_recovery_dnsmos_ovr = [i['OVRL'] for i in clean_recovery_dnsmos_p835] 709 | self._record_loss('clean_recovery_dnsmos_sig', np.mean(clean_recovery_dnsmos_sig), mode='vctk') 710 | self._record_loss('clean_recovery_dnsmos_bak', np.mean(clean_recovery_dnsmos_bak), mode='vctk') 711 | self._record_loss('clean_recovery_dnsmos_ovr', np.mean(clean_recovery_dnsmos_ovr), mode='vctk') 712 | self._record_loss('clean_recovery_pesq', np.mean(clean_recovery_pesq), mode='vctk') 713 | 714 | # noisy enhanced results 715 | enhanced_dnsmos_sig = [i['SIG'] for i in enhanced_dnsmos_p835] 716 | enhanced_dnsmos_bak = [i['BAK'] for i in enhanced_dnsmos_p835] 717 | enhanced_dnsmos_ovr = [i['OVRL'] for i in enhanced_dnsmos_p835] 718 | self._record_loss('enhanced_dnsmos_sig', np.mean(enhanced_dnsmos_sig), mode='vctk') 719 | self._record_loss('enhanced_dnsmos_bak', np.mean(enhanced_dnsmos_bak), mode='vctk') 720 | self._record_loss('enhanced_dnsmos_ovr', np.mean(enhanced_dnsmos_ovr), mode='vctk') 721 | self._record_loss('enhanced_pesq', np.mean(enhanced_pesq), mode='vctk') 722 | 723 | self._record_loss('Clean_spectral_convergence_loss', np.mean(Clean_spectral_convergence_loss), mode='vctk') 724 | self._record_loss('Noisy_spectral_convergence_loss', np.mean(Noisy_spectral_convergence_loss), mode='vctk') 725 | self._record_loss('Noisy_to_clean_spectral_convergence_loss', np.mean(Noisy_to_clean_spectral_convergence_loss), mode='vctk') 726 | self._record_loss('index_accuracy', np.mean(index_accuracy), mode='vctk') 727 | 728 | 729 | ######### Dictionary usage (optional) 730 | cluster_size = self.model["VQVAE"].quantizer.quantizer._codebook.cluster_size[0].cpu().numpy() 731 | plt.plot(range(self.config['VQVAE_params']['codebook_size']), cluster_size) 732 | plt.xlabel('Index') 733 | plt.ylabel('self_cluster_size') 734 | plt.savefig(self.exp_dir +'self_cluster_size_hist.png', dpi=150) 735 | plt.clf() 736 | 737 | plt.plot(range(self.config['VQVAE_params']['codebook_size']), Noisy_index_hist) 738 | plt.xlabel('Index') 739 | plt.ylabel('Noisy_index_hist') 740 | plt.savefig(self.exp_dir +'Noisy_VCTK_index_hist.png', dpi=150) 741 | plt.clf() 742 | 743 | plt.plot(range(self.config['VQVAE_params']['codebook_size']), Clean_index_hist) 744 | plt.xlabel('Index') 745 | plt.ylabel('Clean_index_hist') 746 | plt.savefig(self.exp_dir +'Clean_VCTK_index_hist.png', dpi=150) 747 | plt.clf() 748 | 749 | Pr_N = Noisy_index_hist/Noisy_index_hist.sum() 750 | self._record_loss('Noisy_index_entropy', entropy(Pr_N, base=2), mode='vctk') 751 | 752 | Pr_C = Clean_index_hist/Clean_index_hist.sum() 753 | self._record_loss('Clean_index_entropy', entropy(Pr_C, base=2), mode='vctk') 754 | 755 | -------------------------------------------------------------------------------- /trainer/eval_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchaudio 3 | import pandas as pd 4 | import pickle 5 | import os 6 | 7 | def get_filepaths(directory): 8 | """ 9 | This function will generate the file names in a directory 10 | tree by walking the tree either top-down or bottom-up. For each 11 | directory in the tree rooted at directory top (including top itself), 12 | it yields a 3-tuple (dirpath, dirnames, filenames). 13 | """ 14 | file_paths = [] # List which will store all of the full filepaths. 15 | # Walk the tree. 16 | for root, directories, files in os.walk(directory): 17 | for filename in files: 18 | # Join the two strings in order to form the full filepath. 19 | filepath = os.path.join(root, filename) 20 | file_paths.append(filepath) # Add it to the list. 21 | return file_paths # Self-explanatory. 22 | 23 | class load_IUB: 24 | def __init__(self, pickle_path, number_test_set): 25 | with open(pickle_path, "rb") as fp: # Unpickling 26 | [self.IUB_cosine_data_path, self.IUB_cosine_mos] = pickle.load(fp) 27 | [self.IUB_voices_data_path, self.IUB_voices_mos] = pickle.load(fp) 28 | 29 | self.IUB_cosine_data_dict = dict() 30 | for file in self.IUB_cosine_data_path[0:number_test_set]: 31 | noisy, fs = torchaudio.load(file) 32 | self.IUB_cosine_data_dict[file] = noisy 33 | self.IUB_cosine_mos = self.IUB_cosine_mos[0:number_test_set] 34 | 35 | self.IUB_voices_data_dict = dict() 36 | for file in self.IUB_voices_data_path[0:number_test_set]: 37 | noisy, fs = torchaudio.load(file) 38 | self.IUB_voices_data_dict[file] = noisy 39 | self.IUB_voices_mos = self.IUB_voices_mos[0:number_test_set] 40 | 41 | 42 | class load_Tencent: 43 | def __init__(self, pickle_path, number_test_set): 44 | with open(pickle_path, "rb") as fp: # Unpickling 45 | [Tencent_woR_data_path, Tencent_woR_mos] = pickle.load(fp) 46 | [Tencent_wR_data_path, Tencent_wR_mos] = pickle.load(fp) 47 | 48 | self.Tencent_woR_data_dict = dict() 49 | self.Tencent_woR_data_path = [] 50 | self.Tencent_woR_mos = [] 51 | n=0 52 | for i, file in enumerate(Tencent_woR_data_path): 53 | noisy, fs = torchaudio.load(file) 54 | if fs == 16000: 55 | self.Tencent_woR_data_dict[file] = noisy 56 | self.Tencent_woR_data_path.append(file) 57 | self.Tencent_woR_mos.append(Tencent_woR_mos[i]) 58 | n += 1 59 | if n == number_test_set: 60 | break 61 | 62 | self.Tencent_wR_data_dict = dict() 63 | self.Tencent_wR_data_path = [] 64 | self.Tencent_wR_mos = [] 65 | n=0 66 | for i, file in enumerate(Tencent_wR_data_path): 67 | noisy, fs = torchaudio.load(file) 68 | if fs == 16000 and Tencent_wR_mos[i]>1.1: 69 | self.Tencent_wR_data_dict[file] = noisy 70 | self.Tencent_wR_data_path.append(file) 71 | self.Tencent_wR_mos.append(Tencent_wR_mos[i]) 72 | n += 1 73 | if n == number_test_set: 74 | break 75 | 76 | class load_DNS1: 77 | def __init__(self, dir_path): 78 | DNS1_Real_list = get_filepaths(dir_path+"/real") 79 | DNS1_Noreverb_list = get_filepaths(dir_path+"/noreverb") 80 | DNS1_Reverb_list = get_filepaths(dir_path+"/reverb") 81 | 82 | self.DNS1_Real_dict = dict() 83 | for file in DNS1_Real_list: 84 | noisy, fs = torchaudio.load(file) 85 | self.DNS1_Real_dict[file] = noisy 86 | 87 | self.DNS1_Noreverb_dict = dict() 88 | for file in DNS1_Noreverb_list: 89 | noisy, fs = torchaudio.load(file) 90 | self.DNS1_Noreverb_dict[file] = noisy 91 | 92 | self.DNS1_Reverb_dict = dict() 93 | for file in DNS1_Reverb_list: 94 | noisy, fs = torchaudio.load(file) 95 | self.DNS1_Reverb_dict[file] = noisy 96 | 97 | class load_DNS3: 98 | def __init__(self, dir_path): 99 | DNS3_list = get_filepaths(dir_path) 100 | self.DNS3_nonenglish_synthetic_dict = dict() 101 | self.DNS3_stationary_dict = dict() 102 | self.DNS3_ms_realrec_nonenglish_dict = dict() 103 | self.DNS3_ms_realrec_dict = dict() 104 | 105 | for file in DNS3_list: 106 | noisy, fs = torchaudio.load(file) 107 | if 'ms_realrec_nonenglish' in file: 108 | self.DNS3_ms_realrec_nonenglish_dict[file] = noisy 109 | elif 'nonenglish' in file: 110 | self.DNS3_nonenglish_synthetic_dict[file] = noisy 111 | elif 'stationary_english' in file: 112 | self.DNS3_stationary_dict[file] = noisy 113 | else: 114 | self.DNS3_ms_realrec_dict[file] = noisy 115 | 116 | class load_VCTK_testSet: 117 | def __init__(self, pickle_path): 118 | with open(pickle_path, "rb") as fp: # Unpickling 119 | self.vctk_Noisy_list = pickle.load(fp) 120 | self.SNR_list = pickle.load(fp) 121 | self.PESQ_list = pickle.load(fp) 122 | test_DNSMOSp835 = pickle.load(fp) 123 | self.STOI_list = pickle.load(fp) 124 | 125 | self.sig = [i['SIG'] for i in test_DNSMOSp835] 126 | self.bak = [i['BAK'] for i in test_DNSMOSp835] 127 | self.ovr = [i['OVRL'] for i in test_DNSMOSp835] 128 | 129 | self.VCTK_data_dict = dict() 130 | for file in self.vctk_Noisy_list: 131 | noisy, fs = torchaudio.load(file) 132 | self.VCTK_data_dict[file] = noisy 133 | 134 | class load_VCTK_validSet: 135 | def __init__(self, pickle_path): 136 | with open(pickle_path, "rb") as fp: # Unpickling 137 | self.valid_list = pickle.load(fp) 138 | 139 | self.VCTK_data_dict = dict() 140 | for file in self.valid_list: 141 | noisy, fs = torchaudio.load(file) 142 | self.VCTK_data_dict[file] = noisy 143 | -------------------------------------------------------------------------------- /trainer/trainerAE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os 6 | import torch 7 | import time 8 | 9 | from collections import defaultdict 10 | from tensorboardX import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | class TrainerAE(object): 14 | def __init__( 15 | self, 16 | steps, 17 | epochs, 18 | data_loader, 19 | model, 20 | optimizer, 21 | scheduler, 22 | config, 23 | device=torch.device("cpu"), 24 | ): 25 | """Initialize trainer. 26 | 27 | Args: 28 | steps (int): Initial global steps. 29 | epochs (int): Initial global epochs. 30 | data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. 31 | model (dict): Dict of models. 32 | optimizer (dict): Dict of optimizers. 33 | scheduler (dict): Dict of schedulers. 34 | config (dict): Config dict loaded from yaml format configuration file. 35 | device (torch.deive): Pytorch device instance. 36 | 37 | """ 38 | self.steps = steps 39 | self.epochs = epochs 40 | self.data_loader = data_loader 41 | self.model = model 42 | self.optimizer = optimizer 43 | self.scheduler = scheduler 44 | self.config = config 45 | self.device = device 46 | self.writer = SummaryWriter(config["outdir"]) 47 | self.total_train_loss = defaultdict(float) 48 | self.total_eval_loss = defaultdict(float) 49 | self.finish_train = False 50 | 51 | 52 | def run(self): 53 | """Run training.""" 54 | self.tqdm = tqdm( 55 | initial=self.steps, total=self.config["train_max_steps"], desc="[train]" 56 | ) 57 | while True: 58 | self._train_epoch() 59 | 60 | # check whether training is finished 61 | if self.finish_train: 62 | break 63 | 64 | self.tqdm.close() 65 | logging.info("Finished training.") 66 | 67 | 68 | def save_checkpoint(self, checkpoint_path): 69 | """Save checkpoint. 70 | 71 | Args: 72 | checkpoint_path (str): Checkpoint path to be saved. 73 | 74 | """ 75 | state_dict = { 76 | "optimizer": {"VQVAE": self.optimizer["VQVAE"].state_dict()}, 77 | "scheduler": {"VQVAE": self.scheduler["VQVAE"].state_dict()}, 78 | "steps": self.steps, 79 | "epochs": self.epochs, 80 | } 81 | state_dict["model"] = { 82 | "VQVAE": self.model["VQVAE"].state_dict(), 83 | } 84 | 85 | if not os.path.exists(os.path.dirname(checkpoint_path)): 86 | os.makedirs(os.path.dirname(checkpoint_path)) 87 | torch.save(state_dict, checkpoint_path) 88 | 89 | 90 | def _train_step(self, batch): 91 | """Single step of training.""" 92 | pass 93 | 94 | 95 | def _train_epoch(self): 96 | """One epoch of training.""" 97 | for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): 98 | # train one step 99 | self._train_step(batch) 100 | 101 | # check interval 102 | self._check_log_interval() 103 | self._check_eval_interval() 104 | self._check_save_interval() 105 | 106 | # check whether training is finished 107 | if self.finish_train: 108 | return 109 | 110 | # update 111 | self.epochs += 1 112 | self.train_steps_per_epoch = train_steps_per_epoch 113 | if train_steps_per_epoch > 200: 114 | logging.info( 115 | f"(Steps: {self.steps}) Finished {self.epochs} epoch training " 116 | f"({self.train_steps_per_epoch} steps per epoch)." 117 | ) 118 | 119 | 120 | def _eval_step(self, batch): 121 | """Single step of evaluation.""" 122 | pass 123 | 124 | 125 | def _eval_epoch(self): 126 | """One epoch of evaluation.""" 127 | logging.info(f"(Steps: {self.steps}) Start evaluation.") 128 | # change mode 129 | for key in self.model.keys(): 130 | self.model[key].eval() 131 | 132 | # calculate loss for each batch 133 | for eval_steps_per_epoch, batch in enumerate( 134 | tqdm(self.data_loader["dev"], desc="[eval]"), 1 135 | ): 136 | # eval one step 137 | self._eval_step(batch) 138 | 139 | #self.remove_code(self.steps) 140 | start_time = time.time() 141 | 142 | 143 | if self.config['task'] == 'Speech_Enhancement': 144 | self._eval_vctk_ValidSet() # validation set for SE 145 | self._eval_vctk_TestSet() 146 | self._eval_DNS1_test() 147 | self._eval_DNS3_test() 148 | elif self.config['task'] == 'Quality_Estimation': 149 | self._eval_vctk_TestSet() # validation set for QE 150 | self._eval_Tencent() 151 | self._eval_IUB() 152 | else: 153 | raise NotImplementedError("Task is not supported!") 154 | 155 | end_time = time.time() 156 | print ('Evaluation takes %.2fm' % ((end_time - start_time) / 60.)) 157 | 158 | logging.info( 159 | f"(Steps: {self.steps}) Finished evaluation " 160 | f"({eval_steps_per_epoch} steps per epoch)." 161 | ) 162 | 163 | # average loss 164 | for key in self.total_eval_loss.keys(): 165 | if key.startswith('eval'): 166 | self.total_eval_loss[key] /= eval_steps_per_epoch 167 | logging.info( 168 | f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." 169 | ) 170 | 171 | # record 172 | self._write_to_tensorboard(self.total_eval_loss) 173 | 174 | # reset 175 | self.total_eval_loss = defaultdict(float) 176 | 177 | # restore mode 178 | for key in self.model.keys(): 179 | self.model[key].train() 180 | 181 | 182 | def _update_VQVAE(self, gen_loss): 183 | """Update VQVAE.""" 184 | self.optimizer["VQVAE"].zero_grad() 185 | gen_loss.backward() 186 | if self.config["VQVAE_grad_norm"] > 0: 187 | torch.nn.utils.clip_grad_norm_( 188 | self.model["VQVAE"].parameters(), 189 | self.config["VQVAE_grad_norm"], 190 | ) 191 | self.optimizer["VQVAE"].step() 192 | #self.scheduler["VQVAE"].step() 193 | 194 | 195 | def _record_loss(self, name, loss, mode='train'): 196 | """Record loss.""" 197 | if mode == 'train': 198 | self.total_train_loss[f"train/{name}"] += loss.item() 199 | elif mode == 'eval': 200 | self.total_eval_loss[f"eval/{name}"] += loss.item() 201 | elif mode == 'vctk': 202 | self.total_eval_loss[f"vctk/{name}"] += loss 203 | elif mode == 'vctk_valid': 204 | self.total_eval_loss[f"vctk_valid/{name}"] += loss 205 | elif mode == 'IUB': 206 | self.total_eval_loss[f"IUB/{name}"] += loss 207 | elif mode == 'Tencent': 208 | self.total_eval_loss[f"Tencent/{name}"] += loss 209 | elif mode == 'IUB_z': 210 | self.total_eval_loss[f"IUB_z/{name}"] += loss 211 | elif mode == 'Tencent_z': 212 | self.total_eval_loss[f"Tencent_z/{name}"] += loss 213 | elif mode == 'vctk_z': 214 | self.total_eval_loss[f"vctk_z/{name}"] += loss 215 | elif mode == 'dns1': 216 | self.total_eval_loss[f"dns1/{name}"] += loss 217 | elif mode == 'dns3': 218 | self.total_eval_loss[f"dns3/{name}"] += loss 219 | else: 220 | raise NotImplementedError(f"Mode ({mode}) is not supported!") 221 | 222 | 223 | def _write_to_tensorboard(self, loss): 224 | """Write to tensorboard.""" 225 | for key, value in loss.items(): 226 | self.writer.add_scalar(key, value, self.steps) 227 | 228 | 229 | def _check_save_interval(self): 230 | if self.steps and (self.steps % self.config["save_interval_steps"] == 0): 231 | self.save_checkpoint( 232 | os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl") 233 | ) 234 | logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") 235 | 236 | 237 | def _check_eval_interval(self): 238 | if self.steps % self.config["eval_interval_steps"] == 0: 239 | self._eval_epoch() 240 | 241 | 242 | def _check_log_interval(self): 243 | if self.steps % self.config["log_interval_steps"] == 0: 244 | for key in self.total_train_loss.keys(): 245 | self.total_train_loss[key] /= self.config["log_interval_steps"] 246 | logging.info( 247 | f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." 248 | ) 249 | self._write_to_tensorboard(self.total_train_loss) 250 | 251 | # reset 252 | self.total_train_loss = defaultdict(float) 253 | 254 | 255 | def _check_train_finish(self): 256 | if self.steps >= self.config["train_max_steps"]: 257 | self.finish_train = True 258 | 259 | --------------------------------------------------------------------------------