├── Figures ├── model.png ├── table_16k.png └── table_48k.png ├── LICENSE ├── README.md ├── VCTK-Corpus-0.92 ├── flac2wav.py ├── test.txt ├── training.txt ├── vctk-silence-labels │ ├── README.md │ ├── assets │ │ ├── original.png │ │ ├── trim.png │ │ └── trimpad.png │ └── vctk-silences.0.92.txt ├── wav16 │ ├── test │ │ └── readme.txt │ └── train │ │ └── readme.txt └── wav48 │ ├── test │ └── readme.txt │ └── train │ └── readme.txt ├── cal_metrics.py ├── cal_visqol_48k.py ├── checkpoints └── README.md ├── configs ├── __init__.py ├── config_12kto48k.json ├── config_16kto48k.json ├── config_24kto48k.json ├── config_2kto16k.json ├── config_4kto16k.json ├── config_8kto16k.json └── config_8kto48k.json ├── datasets ├── __init__.py └── dataset.py ├── docs ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css ├── index.html ├── js │ ├── ViewSDKInterface.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ ├── index.js │ ├── jquery.min.js │ └── main.js └── samples │ ├── Ablation │ ├── ap-bwe.png │ ├── ap-bwe.wav │ ├── ap-bwe_mpd_only.png │ ├── ap-bwe_mpd_only.wav │ ├── ap-bwe_mrad_only.png │ ├── ap-bwe_mrad_only.wav │ ├── ap-bwe_mrpd_only.png │ ├── ap-bwe_mrpd_only.wav │ ├── ap-bwe_wo_AtoP.png │ ├── ap-bwe_wo_AtoP.wav │ ├── ap-bwe_wo_PtoA.png │ ├── ap-bwe_wo_PtoA.wav │ ├── ap-bwe_wo_connect.png │ ├── ap-bwe_wo_connect.wav │ ├── ap-bwe_wo_disc.png │ ├── ap-bwe_wo_disc.wav │ ├── ap-bwe_wo_mpd.png │ ├── ap-bwe_wo_mpd.wav │ ├── ap-bwe_wo_mrad.png │ ├── ap-bwe_wo_mrad.wav │ ├── ap-bwe_wo_mrpd.png │ ├── ap-bwe_wo_mrpd.wav │ ├── wideband.png │ └── wideband.wav │ ├── BWE_16k │ ├── 2kto16k │ │ ├── afilm.png │ │ ├── afilm.wav │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── nvsr.png │ │ ├── nvsr.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── tfilm.png │ │ ├── tfilm.wav │ │ ├── wideband.png │ │ └── wideband.wav │ ├── 4kto16k │ │ ├── afilm.png │ │ ├── afilm.wav │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── nvsr.png │ │ ├── nvsr.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── tfilm.png │ │ ├── tfilm.wav │ │ ├── wideband.png │ │ └── wideband.wav │ └── 8kto16k │ │ ├── afilm.png │ │ ├── afilm.wav │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── nvsr.png │ │ ├── nvsr.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── tfilm.png │ │ ├── tfilm.wav │ │ ├── wideband.png │ │ └── wideband.wav │ ├── BWE_48k │ ├── 12kto48k │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── mdctgan.png │ │ ├── mdctgan.wav │ │ ├── nuwave2.png │ │ ├── nuwave2.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── udm+.png │ │ ├── udm+.wav │ │ ├── wideband.png │ │ └── wideband.wav │ ├── 16kto48k │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── mdctgan.png │ │ ├── mdctgan.wav │ │ ├── nuwave2.png │ │ ├── nuwave2.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── udm+.png │ │ ├── udm+.wav │ │ ├── wideband.png │ │ └── wideband.wav │ ├── 24kto48k │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── mdctgan.png │ │ ├── mdctgan.wav │ │ ├── nuwave2.png │ │ ├── nuwave2.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── udm+.png │ │ ├── udm+.wav │ │ ├── wideband.png │ │ └── wideband.wav │ └── 8kto48k │ │ ├── ap-bwe.png │ │ ├── ap-bwe.wav │ │ ├── mdctgan.png │ │ ├── mdctgan.wav │ │ ├── nuwave2.png │ │ ├── nuwave2.wav │ │ ├── sinc.png │ │ ├── sinc.wav │ │ ├── udm+.png │ │ ├── udm+.wav │ │ ├── wideband.png │ │ └── wideband.wav │ └── CrossDataset │ ├── HiFi-TTS │ ├── ap-bwe.png │ ├── ap-bwe.wav │ ├── mdctgan.png │ ├── mdctgan.wav │ ├── nuwave2.png │ ├── nuwave2.wav │ ├── udm+.png │ ├── udm+.wav │ ├── wideband.png │ └── wideband.wav │ └── Libri-TTS │ ├── ap-bwe.png │ ├── ap-bwe.wav │ ├── mdctgan.png │ ├── mdctgan.wav │ ├── nuwave2.png │ ├── nuwave2.wav │ ├── udm+.png │ ├── udm+.wav │ ├── wideband.png │ └── wideband.wav ├── env.py ├── inference ├── __init__.py ├── inference_16k.py └── inference_48k.py ├── models ├── __init__.py └── model.py ├── requirements.txt ├── train ├── __init__.py ├── train_16k.py └── train_48k.py ├── utils.py └── weights_LICENSE.txt /Figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/model.png -------------------------------------------------------------------------------- /Figures/table_16k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/table_16k.png -------------------------------------------------------------------------------- /Figures/table_48k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/table_48k.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ye-Xin Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards High-Quality and Efficient Speech Bandwidth Extension with Parallel Amplitude and Phase Prediction 2 | ### Ye-Xin Lu, Yang Ai, Hui-Peng Du, Zhen-Hua Ling 3 | 4 | **Abstract:** 5 | Speech bandwidth extension (BWE) refers to widening the frequency bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller. 6 | This paper proposes a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both high-quality and efficient wideband speech waveform generation. 7 | The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs). 8 | It features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the input narrowband amplitude and phase spectra. 9 | To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively. 10 | Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in terms of speech quality for BWE tasks targeting sampling rates of both 16 kHz and 48 kHz. 11 | In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on a single CPU. 12 | Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods. 13 | 14 | **We provide our implementation as open source in this repository. Audio samples can be found at the [demo website](http://yxlu-0102.github.io/AP-BWE).** 15 | 16 | ## License 17 | 18 | The code in this repository is licensed under the MIT License.
19 | The weights for this model are stored in [Google Drive](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link). You can access and use the weights under the same MIT License. Please refer to the [weights_LICENSE.txt](weights_LICENSE.txt) for more details on the licensing. 20 | 21 | ## Pre-requisites 22 | 0. Python >= 3.9. 23 | 0. Clone this repository. 24 | 0. Install python requirements. Please refer [requirements.txt](requirements.txt). 25 | 0. Download datasets 26 | 1. Download and extract the [VCTK-0.92 dataset](https://datashare.ed.ac.uk/handle/10283/3443), and move its `wav48` directory into [VCTK-Corpus-0.92](VCTK-Corpus-0.92) and rename it as `wav48_origin`. 27 | 1. Trim the silence of the dataset, and the trimmed files will be saved to `wav48_silence_trimmed`. 28 | ``` 29 | cd VCTK-Corpus-0.92 30 | python flac2wav.py 31 | ``` 32 | 1. Move all the trimmed training files from `wav48_silence_trimmed` to [wav48/train](wav48/train) following the indexes in [training.txt](VCTK-Corpus-0.92/training.txt), and move all the untrimmed test files from `wav48_origin` to [wav48/test](wav48/test) following the indexes in [test.txt](VCTK-Corpus-0.92/test.txt). 33 | 34 | ## Training 35 | ``` 36 | cd train 37 | CUDA_VISIBLE_DEVICES=0 python train_16k.py --config [config file path] 38 | CUDA_VISIBLE_DEVICES=0 python train_48k.py --config [config file path] 39 | ``` 40 | Checkpoints and copies of the configuration file are saved in the `cp_model` directory by default.
41 | You can change the path by using the `--checkpoint_path` option. 42 | Here is an example: 43 | ``` 44 | CUDA_VISIBLE_DEVICES=0 python train_16k.py --config ../configs/config_2kto16k.json --checkpoint_path ../checkpoints/AP-BWE_2kto16k 45 | ``` 46 | 47 | ## Inference 48 | ``` 49 | cd inference 50 | python inference_16k.py --checkpoint_file [generator checkpoint file path] 51 | python inference_48k.py --checkpoint_file [generator checkpoint file path] 52 | ``` 53 | You can download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) we provide and move all the files to the `checkpoints` directory. 54 |
55 | Generated wav files are saved in `generated_files` by default. 56 | You can change the path by adding `--output_dir` option. 57 | Here is an example: 58 | ``` 59 | python inference_16k.py --checkpoint_file ../checkpoints/2kto16k/g_2kto16k --output_dir ../generated_files/2kto16k 60 | ``` 61 | 62 | ## Model Structure 63 | ![model](Figures/model.png) 64 | 65 | ## Comparison with other speech BWE methods 66 | ### 2k/4k/8kHz to 16kHz 67 |

68 | comparison 69 |

70 | 71 | ### 8k/12k/16/24kHz to 48kHz 72 |

73 | comparison 74 |

75 | 76 | ## Acknowledgements 77 | We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSPP](https://github.com/YangAi520/NSPP) to implement this. 78 | 79 | ## Citation 80 | ``` 81 | @article{lu2024towards, 82 | title={Towards high-quality and efficient speech bandwidth extension with parallel amplitude and phase prediction}, 83 | author={Lu, Ye-Xin and Ai, Yang and Du, Hui-Peng and Ling, Zhen-Hua}, 84 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 85 | volume={33}, 86 | pages={236--250}, 87 | year={2024} 88 | } 89 | 90 | @inproceedings{lu2024multi, 91 | title={Multi-Stage Speech Bandwidth Extension with Flexible Sampling Rate Control}, 92 | author={Lu, Ye-Xin and Ai, Yang and Sheng, Zheng-Yan and Ling, Zhen-Hua}, 93 | booktitle={Proc. Interspeech}, 94 | pages={2270--2274}, 95 | year={2024} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/flac2wav.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import soundfile as sf 3 | import os 4 | from glob import glob 5 | from tqdm import tqdm 6 | import multiprocessing as mp 7 | import torch 8 | import torchaudio.functional as aF 9 | 10 | def flac2wav(wav): 11 | y, sr = librosa.load(wav, sr=48000, mono=True) 12 | file_id = os.path.split(wav)[-1].split('_mic')[0] 13 | if file_id in timestamps: 14 | start, end = timestamps[file_id] 15 | start = start - min(start, int(0.1 * sr)) 16 | end = end + min(len(y) - end, int(0.1 * sr)) 17 | y = y[start: end] 18 | 19 | # y = torch.FloatTensor(y).unsqueeze(0) 20 | # y = aF.resample(y, orig_freq=sr, new_freq=22050).squeeze().numpy() 21 | 22 | os.makedirs(os.path.join('wav48_silence_trimmed', wav.split(os.sep)[-2]), exist_ok=True) 23 | 24 | wav_path = os.path.join('wav48_silence_trimmed', wav.split(os.sep)[-2], file_id +'.wav') 25 | 26 | sf.write(wav_path, y, 48000, 'PCM_16') 27 | del y 28 | return 29 | 30 | 31 | if __name__=='__main__': 32 | 33 | base_dir = 'wav48_origin' 34 | 35 | wavs = glob(os.path.join(base_dir, '*/*mic1.flac')) 36 | sampling_rate = 48000 37 | 38 | timestamps = {} 39 | path_timestamps = 'vctk-silence-labels/vctk-silences.0.92.txt' 40 | with open(path_timestamps, 'r') as f: 41 | timestamps_list = f.readlines() 42 | for line in timestamps_list: 43 | timestamp_data = line.strip().split(' ') 44 | if len(timestamp_data) == 3: 45 | file_id, t_start, t_end = timestamp_data 46 | t_start = int(float(t_start) * sampling_rate) 47 | t_end = int(float(t_end) * sampling_rate) 48 | timestamps[file_id] = (t_start, t_end) 49 | 50 | pool = mp.Pool(processes = 8) 51 | with tqdm(total = len(wavs)) as pbar: 52 | for _ in tqdm(pool.imap_unordered(flac2wav, wavs)): 53 | pbar.update() 54 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/vctk-silence-labels/README.md: -------------------------------------------------------------------------------- 1 | # Leading and trailing silence label for VCTK Corpus (version 0.92) 2 | 3 | This repository provides the information about the leading and trailing silences of VCTK corpus. The infomation is obtained by training an ASR model with the corpus and use it to extract alignment, which is more robust approach to non-uniform noise than using signal processing. 4 | 5 | This information is useful for preprocessing the VCTK Corpus to remove exessive leading and trailing silence frames with noise. The silence information is only compatiable with version 0.92 of the corpus. 6 | 7 | ## Original waveform and silence timestamp 8 | 9 | ![Original](assets/original.png) 10 | 11 | *p226_036_mic1.wav* 12 | 13 | The leading and trailing silence timestamps can be found in `vctk-silences.0.92.txt`. It includes utterance id, start and end of speech segment in second. 14 | ``` 15 | p226_036 0.930 3.450 16 | ``` 17 | 18 | ## Trim the leading and trailing silences 19 | ```bash 20 | sox p226_036_mic1.wav p226_036_mic1.trim.wav trim 0.930 =3.450 21 | ``` 22 | 23 | We could use `sox` to trim the silences as shown by the above example. As it is just timestamp, we can manipulate the clipping region by adjusting it. 24 | 25 | ![Trim](assets/trim.png) 26 | 27 | *p226_036_mic1.trim.wav* 28 | 29 | 30 | ## Trim then pad with artificial silences 31 | ```bash 32 | sox p226_036_mic1.wav p226_036_mic1.trimpad.wav trim 0.930 =3.450 pad 0.25 0.25 33 | ``` 34 | 35 | We can also pad the utterance with a small leading and trailing segment silence as shown by the above example. It might be helpful for training end-to-end TTS model. 36 | 37 | ![Trim and Pad](assets/trimpad.png) 38 | 39 | *p226_036_mic1.trimpad.wav* 40 | 41 | ## References 42 | [VCTK Corpus (version 0.92)](https://datashare.is.ed.ac.uk/handle/10283/3443) 43 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/vctk-silence-labels/assets/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/original.png -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/vctk-silence-labels/assets/trim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/trim.png -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/vctk-silence-labels/assets/trimpad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/trimpad.png -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/wav16/test/readme.txt: -------------------------------------------------------------------------------- 1 | Move all the test files here. 2 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/wav16/train/readme.txt: -------------------------------------------------------------------------------- 1 | Move all the training files here. 2 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/wav48/test/readme.txt: -------------------------------------------------------------------------------- 1 | Move all the test files here. 2 | -------------------------------------------------------------------------------- /VCTK-Corpus-0.92/wav48/train/readme.txt: -------------------------------------------------------------------------------- 1 | Move all the training files here. 2 | -------------------------------------------------------------------------------- /cal_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torchaudio 5 | import numpy as np 6 | from rich.progress import track 7 | 8 | def stft(audio, n_fft=2048, hop_length=512): 9 | hann_window = torch.hann_window(n_fft).to(audio.device) 10 | stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True) 11 | stft_mag = torch.abs(stft_spec) 12 | stft_pha = torch.angle(stft_spec) 13 | 14 | return stft_mag, stft_pha 15 | 16 | 17 | def cal_snr(pred, target): 18 | snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean() 19 | return snr 20 | 21 | 22 | def cal_lsd(pred, target): 23 | sp = torch.log10(stft(pred)[0].square().clamp(1e-8)) 24 | st = torch.log10(stft(target)[0].square().clamp(1e-8)) 25 | return (sp - st).square().mean(dim=1).sqrt().mean() 26 | 27 | 28 | def anti_wrapping_function(x): 29 | return x - torch.round(x / (2 * np.pi)) * 2 * np.pi 30 | 31 | 32 | def cal_apd(pred, target): 33 | pha_pred = stft(pred)[1] 34 | pha_target = stft(target)[1] 35 | dim_freq = 1025 36 | dim_time = pha_pred.size(-1) 37 | 38 | gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - torch.eye(dim_freq)).to(device) 39 | gd_r = torch.matmul(pha_target.permute(0, 2, 1), gd_matrix) 40 | gd_g = torch.matmul(pha_pred.permute(0, 2, 1), gd_matrix) 41 | 42 | iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - torch.eye(dim_time)).to(device) 43 | iaf_r = torch.matmul(pha_target, iaf_matrix) 44 | iaf_g = torch.matmul(pha_pred, iaf_matrix) 45 | 46 | apd_ip = anti_wrapping_function(pha_pred - pha_target).square().mean(dim=1).sqrt().mean() 47 | apd_gd = anti_wrapping_function(gd_r - gd_g).square().mean(dim=1).sqrt().mean() 48 | apd_iaf = anti_wrapping_function(iaf_r - iaf_g).square().mean(dim=1).sqrt().mean() 49 | 50 | return apd_ip, apd_gd, apd_iaf 51 | 52 | 53 | def main(h): 54 | 55 | wav_indexes = os.listdir(h.reference_wav_dir) 56 | 57 | metrics = {'lsd':[], 'apd_ip': [], 'apd_gd': [], 'apd_iaf': [], 'snr':[]} 58 | 59 | for wav_index in track(wav_indexes): 60 | 61 | ref_wav, ref_sr = torchaudio.load(os.path.join(h.reference_wav_dir, wav_index)) 62 | syn_wav, syn_sr = torchaudio.load(os.path.join(h.synthesis_wav_dir, wav_index)) 63 | 64 | length = min(ref_wav.size(1), syn_wav.size(1)) 65 | ref_wav = ref_wav[:, : length].to(device) 66 | syn_wav = syn_wav[:, : length].to(device) 67 | ref_wav = ref_wav.to(device) 68 | syn_wav = syn_wav[:, : ref_wav.size(1)].to(device) 69 | 70 | lsd_score = cal_lsd(syn_wav, ref_wav) 71 | apd_score = cal_apd(syn_wav, ref_wav) 72 | snr_score = cal_snr(syn_wav, ref_wav) 73 | 74 | 75 | metrics['lsd'].append(lsd_score) 76 | metrics['apd_ip'].append(apd_score[0]) 77 | metrics['apd_gd'].append(apd_score[1]) 78 | metrics['apd_iaf'].append(apd_score[2]) 79 | metrics['snr'].append(snr_score) 80 | 81 | 82 | lsd_mean = torch.stack(metrics['lsd'], dim=0).mean() 83 | apd_ip_mean = torch.stack(metrics['apd_ip'], dim=0).mean() 84 | apd_gd_mean = torch.stack(metrics['apd_gd'], dim=0).mean() 85 | apd_iaf_mean = torch.stack(metrics['apd_iaf'], dim=0).mean() 86 | snr_mean = torch.stack(metrics['snr'], dim=0).mean() 87 | 88 | print('LSD: {:.3f}'.format(lsd_mean)) 89 | print('SNR: {:.3f}'.format(snr_mean)) 90 | print('APD_IP: {:.3f}'.format(apd_ip_mean)) 91 | print('APD_GD: {:.3f}'.format(apd_gd_mean)) 92 | print('APD_IAF: {:.3f}'.format(apd_iaf_mean)) 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser() 96 | 97 | parser.add_argument('--reference_wav_dir', default='VCTK-Corpus-0.92/wav16/test') 98 | parser.add_argument('--synthesis_wav_dir', default='generated_files/AP-BWE') 99 | 100 | h = parser.parse_args() 101 | 102 | global device 103 | if torch.cuda.is_available(): 104 | device = torch.device('cuda') 105 | else: 106 | device = torch.device('cpu') 107 | 108 | main(h) -------------------------------------------------------------------------------- /cal_visqol_48k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import librosa 5 | from visqol import visqol_lib_py 6 | from visqol.pb2 import visqol_config_pb2 7 | from visqol.pb2 import similarity_result_pb2 8 | from rich.progress import track 9 | from joblib import Parallel, delayed 10 | 11 | config = visqol_config_pb2.VisqolConfig() 12 | 13 | def cal_vq(reference, degraded, mode='audio'): 14 | if mode == "audio": 15 | config.audio.sample_rate = 48000 16 | config.options.use_speech_scoring = False 17 | svr_model_path = "libsvm_nu_svr_model.txt" 18 | elif mode == "speech": 19 | config.audio.sample_rate = 16000 20 | config.options.use_speech_scoring = True 21 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" 22 | else: 23 | raise ValueError(f"Unrecognized mode: {mode}") 24 | 25 | config.options.svr_model_path = os.path.join( 26 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) 27 | 28 | api = visqol_lib_py.VisqolApi() 29 | 30 | api.Create(config) 31 | 32 | similarity_result = api.Measure(reference, degraded) 33 | 34 | return similarity_result.moslqo 35 | 36 | 37 | def main(h): 38 | # with open(h.test_file, 'r', encoding='utf-8') as fi: 39 | # wav_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 40 | wav_indexes = os.listdir(h.ref_wav_dir) 41 | 42 | metrics = {'vq':[]} 43 | 44 | for wav_index in track(wav_indexes): 45 | 46 | ref_wav, ref_sr = librosa.load(os.path.join(h.ref_wav_dir, wav_index), sr=float(h.sampling_rate), dtype=np.float64) 47 | syn_wav, syn_sr = librosa.load(os.path.join(h.syn_wav_dir, wav_index), sr=float(h.sampling_rate), dtype=np.float64) 48 | 49 | if float(h.sampling_rate) != 48000: 50 | ref_wav = librosa.resample(ref_wav, orig_sr=float(h.sampling_rate), target_sr=48000) 51 | syn_wav = librosa.resample(syn_wav, orig_sr=float(h.sampling_rate), target_sr=48000) 52 | 53 | length = min(len(ref_wav), len(syn_wav)) 54 | ref_wav = ref_wav[: length] 55 | syn_wav = syn_wav[: length] 56 | try: 57 | vq_score = cal_vq(ref_wav, syn_wav) 58 | metrics['vq'].append(vq_score) 59 | except: 60 | vq_score = 0 61 | 62 | vq_mean = np.mean(metrics['vq']) 63 | 64 | print('VISQOL: {:.3f}'.format(vq_mean)) 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--sampling_rate', required=True) 69 | parser.add_argument('--ref_wav_dir', required=True) 70 | parser.add_argument('--syn_wav_dir', required=True) 71 | 72 | h = parser.parse_args() 73 | 74 | main(h) -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) and move the checkpoints files here. 2 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configs/config_12kto48k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 48000, 19 | "lr_sampling_rate": 12000, 20 | "subsampling_rate": 4, 21 | 22 | "num_workers": 4, 23 | 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321", 27 | "world_size": 1 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /configs/config_16kto48k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 48000, 19 | "lr_sampling_rate": 16000, 20 | 21 | "num_workers": 4, 22 | 23 | "dist_config": { 24 | "dist_backend": "nccl", 25 | "dist_url": "tcp://localhost:54321", 26 | "world_size": 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/config_24kto48k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 48000, 19 | "lr_sampling_rate": 24000, 20 | "subsampling_rate": 2, 21 | 22 | "num_workers": 4, 23 | 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321", 27 | "world_size": 1 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /configs/config_2kto16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 16000, 19 | "lr_sampling_rate": 2000, 20 | 21 | "num_workers": 4, 22 | 23 | "dist_config": { 24 | "dist_backend": "nccl", 25 | "dist_url": "tcp://localhost:54321", 26 | "world_size": 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/config_4kto16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 16000, 19 | "lr_sampling_rate": 4000, 20 | 21 | "num_workers": 4, 22 | 23 | "dist_config": { 24 | "dist_backend": "nccl", 25 | "dist_url": "tcp://localhost:54321", 26 | "world_size": 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/config_8kto16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 16000, 19 | "lr_sampling_rate": 8000, 20 | 21 | "num_workers": 4, 22 | 23 | "dist_config": { 24 | "dist_backend": "nccl", 25 | "dist_url": "tcp://localhost:54321", 26 | "world_size": 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/config_8kto48k.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_gpus": 0, 3 | "batch_size": 16, 4 | "learning_rate": 0.0002, 5 | "adam_b1": 0.8, 6 | "adam_b2": 0.99, 7 | "lr_decay": 0.999, 8 | "seed": 1234, 9 | 10 | "ConvNeXt_channels": 512, 11 | "ConvNeXt_layers": 8, 12 | 13 | "segment_size": 8000, 14 | "n_fft": 1024, 15 | "hop_size": 80, 16 | "win_size": 320, 17 | 18 | "hr_sampling_rate": 48000, 19 | "lr_sampling_rate": 8000, 20 | "subsampling_rate": 6, 21 | 22 | "num_workers": 4, 23 | 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321", 27 | "world_size": 1 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchaudio 5 | import torch.utils.data 6 | import torchaudio.functional as aF 7 | 8 | def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True): 9 | 10 | hann_window = torch.hann_window(win_size).to(audio.device) 11 | stft_spec = torch.stft(audio, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, 12 | center=center, pad_mode='reflect', normalized=False, return_complex=True) 13 | log_amp = torch.log(torch.abs(stft_spec)+1e-4) 14 | pha = torch.angle(stft_spec) 15 | 16 | com = torch.stack((torch.exp(log_amp)*torch.cos(pha), 17 | torch.exp(log_amp)*torch.sin(pha)), dim=-1) 18 | 19 | return log_amp, pha, com 20 | 21 | 22 | def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True): 23 | 24 | amp = torch.exp(log_amp) 25 | com = torch.complex(amp*torch.cos(pha), amp*torch.sin(pha)) 26 | hann_window = torch.hann_window(win_size).to(com.device) 27 | audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) 28 | 29 | return audio 30 | 31 | 32 | def get_dataset_filelist(a): 33 | with open(a.input_training_file, 'r', encoding='utf-8') as fi: 34 | training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 35 | 36 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 37 | validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 38 | 39 | return training_indexes, validation_indexes 40 | 41 | 42 | class Dataset(torch.utils.data.Dataset): 43 | def __init__(self, training_indexes, wavs_dir, segment_size, hr_sampling_rate, lr_sampling_rate, 44 | split=True, shuffle=True, n_cache_reuse=1, device=None): 45 | self.audio_indexes = training_indexes 46 | random.seed(1234) 47 | if shuffle: 48 | random.shuffle(self.audio_indexes) 49 | self.wavs_dir = wavs_dir 50 | self.segment_size = segment_size 51 | self.hr_sampling_rate = hr_sampling_rate 52 | self.lr_sampling_rate = lr_sampling_rate 53 | self.split = split 54 | self.cached_wav = None 55 | self.n_cache_reuse = n_cache_reuse 56 | self._cache_ref_count = 0 57 | self.device = device 58 | 59 | def __getitem__(self, index): 60 | filename = self.audio_indexes[index] 61 | if self._cache_ref_count == 0: 62 | audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + '.wav')) 63 | self.cached_wav = audio 64 | self._cache_ref_count = self.n_cache_reuse 65 | else: 66 | audio = self.cached_wav 67 | self._cache_ref_count -= 1 68 | 69 | if orig_sampling_rate == self.hr_sampling_rate: 70 | audio_hr = audio 71 | else: 72 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.hr_sampling_rate) 73 | 74 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.lr_sampling_rate) 75 | audio_lr = aF.resample(audio_lr, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate) 76 | audio_lr = audio_lr[:, : audio_hr.size(1)] 77 | 78 | if self.split: 79 | if audio_hr.size(1) >= self.segment_size: 80 | max_audio_start = audio_hr.size(1) - self.segment_size 81 | audio_start = random.randint(0, max_audio_start) 82 | audio_hr = audio_hr[:, audio_start: audio_start+self.segment_size] 83 | audio_lr = audio_lr[:, audio_start: audio_start+self.segment_size] 84 | else: 85 | audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), 'constant') 86 | audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), 'constant') 87 | 88 | return (audio_hr.squeeze(), audio_lr.squeeze()) 89 | 90 | def __len__(self): 91 | 92 | return len(self.audio_indexes) 93 | -------------------------------------------------------------------------------- /docs/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /docs/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | AP-BWE 5 | 6 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 |
29 |
30 |
31 |
32 |

Towards Efficient and High-Quality Bandwidth Extension with Parallel Amplitude-Phase Prediction

33 |
34 | 35 | 36 | Ye-Xin Lu, 37 | 38 | Yang Ai, 39 | 40 | Hui-Peng Du, 41 | 42 | Zhen-Hua Ling, 43 |
44 | 45 |
46 | National Engineering Research Center of Speech and Language Information Processing
University of Science and Technology of China
47 |
48 | 49 |
50 | 72 |
73 | 74 |
75 |
76 |
77 |
78 |
79 | 80 | 81 |
82 |
83 |
84 |
85 |

Abstract

86 |
87 |

88 | Speech bandwidth extension (BWE) refers to increasing the bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller. 89 | This paper proposed a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both efficient and high-quality wideband waveform generation. 90 | Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods. 91 | The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs), it features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the narrowband amplitude and phase spectra. 92 | To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively. 93 | Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in speech quality for both BWE tasks targeting sampling rates of 16 kHz and 48 kHz. 94 | In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on CPU. 95 |

96 |
97 |
98 |
99 |
100 |
101 | 102 | 103 |
104 |
105 |
106 |
107 |
108 |

I. Audio Samples with Target sampling Rate of 16kHz

109 |
110 |
111 |
112 |
113 | 114 | 115 | 116 | 117 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 |
Input sr 118 | Wideband 119 | Sinc 120 | TFiLM 121 | AFiLM 122 | NVSR 123 | AP-BWE (Ours) 124 |
2 kHz
4 kHz
8 kHz
186 |
187 | 188 |
189 |
190 |
191 |
192 |
193 |

II. Audio Samples with Target Sampling Rate of 48kHz

194 |
195 |
196 |
197 |
198 | 199 | 200 | 201 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 |
Input sr 202 | Wideband 203 | Sinc 204 | NU-Wave 2 205 | UDM+ 206 | mdctGAN 207 | AP-BWE (Ours) 208 |
8 kHz
12 kHz
16 kHz
24 kHz
290 |
291 | 292 |
293 |
294 |
295 |
296 |
297 |

III. Ablation Study (8 kHz to 48 kHz)

298 |
299 |
300 |
301 |
302 | 303 | 304 | 305 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 |
Wideband 306 | AP-BWE (Ours) 307 | w/o MPD 308 | w/o MRAD 309 | w/o MRPD 310 | MPD Only 311 |
MRAD Only 333 | MRPD Only 334 | w/o Disc. 335 | w/o A to P 336 | w/o P to A 337 | w/o Connections 338 |
358 |
359 | 360 |
361 |
362 |
363 |
364 |
365 |

IV. Cross-Dataset Evaluation

366 |
367 |
368 |
369 |
370 | 371 | 372 |
373 |
374 |
375 |

Libri-TTS (8 kHz to 24 kHz)

376 |
377 |
378 |
379 |
380 | 381 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 |
Wideband 382 | NU-Wave2 383 | UDM+ 384 | mdctGAN 385 | AP-BWE (Ours) 386 |
405 |
406 | 407 | 408 |
409 |
410 |
411 |

HiFi-TTS (8 kHz to 44.1 kHz)

412 |
413 |
414 |
415 |
416 | 417 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 |
Wideband 418 | NU-Wave2 419 | UDM+ 420 | mdctGAN 421 | AP-BWE (Ours) 422 |
440 | 441 |
442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | -------------------------------------------------------------------------------- /docs/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: false, 13 | autoplaySpeed: 3000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | // Loop on each carousel initialized 20 | for(var i = 0; i < carousels.length; i++) { 21 | // Add listener to event 22 | carousels[i].on('before:show', state => { 23 | console.log(state); 24 | }); 25 | } 26 | 27 | // Access to bulmaCarousel instance of an element 28 | var element = document.querySelector('#my-element'); 29 | if (element && element.bulmaCarousel) { 30 | // bulmaCarousel instance is available as element.bulmaCarousel 31 | element.bulmaCarousel.on('before-show', function(state) { 32 | console.log(state); 33 | }); 34 | } 35 | 36 | /*var player = document.getElementById('interpolation-video'); 37 | player.addEventListener('loadedmetadata', function() { 38 | $('#interpolation-slider').on('input', function(event) { 39 | console.log(this.value, player.duration); 40 | player.currentTime = player.duration / 100 * this.value; 41 | }) 42 | }, false);*/ 43 | 44 | bulmaSlider.attach(); 45 | 46 | }) 47 | -------------------------------------------------------------------------------- /docs/js/main.js: -------------------------------------------------------------------------------- 1 | (()=>{"use strict";var e={r:e=>{"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})}},t={};e.r(t);const c=Object.freeze({latest:"3.12.1_3.2.3-63070bed",public:"3.12.1_3.2.2-0a1b32f6",legacy:"3.12.1_3.2.2-0a1b32f6",edit:"2.24.4_2.12.0-7d0a8c15"});!function loadSDKScript(){var e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:"latest.js",t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:c.latest,n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:"View",r=Array.from(document.scripts),a="view-sdk/".concat(e),i=r.find((function(e){return e.src&&e.src.endsWith(a)})).src,o=i.substring(0,i.indexOf(e)),d="".concat("".concat(o+t,"/").concat(n),"SDKInterface.js"),s=document.createElement("script");s.async=!1,s.setAttribute("src",d),document.head.appendChild(s)}("main.js",c.legacy),window.adobe_dc_view_sdk=t})(); 2 | //# sourceMappingURL=3.12.1_3.2.3-63070bed/private/main.js.map -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mpd_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mpd_only.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mpd_only.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mpd_only.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mrad_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrad_only.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mrad_only.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrad_only.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mrpd_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrpd_only.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_mrpd_only.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrpd_only.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_AtoP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_AtoP.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_AtoP.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_AtoP.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_PtoA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_PtoA.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_PtoA.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_PtoA.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_connect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_connect.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_connect.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_connect.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_disc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_disc.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_disc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_disc.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mpd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mpd.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mpd.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mpd.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrad.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mrad.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrad.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mrpd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrpd.png -------------------------------------------------------------------------------- /docs/samples/Ablation/ap-bwe_wo_mrpd.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrpd.wav -------------------------------------------------------------------------------- /docs/samples/Ablation/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/wideband.png -------------------------------------------------------------------------------- /docs/samples/Ablation/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/afilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/afilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/afilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/afilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/nvsr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/nvsr.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/nvsr.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/nvsr.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/tfilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/tfilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/tfilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/tfilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/2kto16k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/afilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/afilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/afilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/afilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/nvsr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/nvsr.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/nvsr.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/nvsr.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/tfilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/tfilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/tfilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/tfilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/4kto16k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/afilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/afilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/afilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/afilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/nvsr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/nvsr.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/nvsr.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/nvsr.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/tfilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/tfilm.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/tfilm.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/tfilm.wav -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_16k/8kto16k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/udm+.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/udm+.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/12kto48k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/udm+.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/udm+.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/16kto48k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/udm+.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/udm+.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/24kto48k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/sinc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/sinc.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/sinc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/sinc.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/udm+.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/udm+.wav -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/wideband.png -------------------------------------------------------------------------------- /docs/samples/BWE_48k/8kto48k/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/wideband.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/udm+.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/udm+.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/wideband.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/HiFi-TTS/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/wideband.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/ap-bwe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/ap-bwe.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/ap-bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/ap-bwe.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/mdctgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/mdctgan.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/mdctgan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/mdctgan.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/nuwave2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/nuwave2.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/nuwave2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/nuwave2.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/udm+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/udm+.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/udm+.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/udm+.wav -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/wideband.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/wideband.png -------------------------------------------------------------------------------- /docs/samples/CrossDataset/Libri-TTS/wideband.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/wideband.wav -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /inference/inference_16k.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import sys 3 | sys.path.append("..") 4 | import glob 5 | import os 6 | import argparse 7 | import json 8 | from re import S 9 | import torch 10 | import time 11 | import numpy as np 12 | import torchaudio 13 | import torchaudio.functional as aF 14 | from env import AttrDict 15 | from datasets.dataset import amp_pha_stft, amp_pha_istft 16 | from models.model import APNet_BWE_Model 17 | import soundfile as sf 18 | import matplotlib.pyplot as plt 19 | from rich.progress import track 20 | 21 | h = None 22 | device = None 23 | 24 | def load_checkpoint(filepath, device): 25 | assert os.path.isfile(filepath) 26 | print("Loading '{}'".format(filepath)) 27 | checkpoint_dict = torch.load(filepath, map_location=device) 28 | print("Complete.") 29 | return checkpoint_dict 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + '*') 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return '' 36 | return sorted(cp_list)[-1] 37 | 38 | def inference(a): 39 | model = APNet_BWE_Model(h).to(device) 40 | 41 | state_dict = load_checkpoint(a.checkpoint_file, device) 42 | model.load_state_dict(state_dict['generator']) 43 | 44 | test_indexes = os.listdir(a.input_wavs_dir) 45 | 46 | os.makedirs(a.output_dir, exist_ok=True) 47 | 48 | model.eval() 49 | duration_tot = 0 50 | with torch.no_grad(): 51 | for i, index in enumerate(track(test_indexes)): 52 | # print(index) 53 | audio, orig_sampling_rate = torchaudio.load(os.path.join(a.input_wavs_dir, index)) 54 | audio = audio.to(device) 55 | 56 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.hr_sampling_rate) 57 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.lr_sampling_rate) 58 | audio_lr = aF.resample(audio_lr, orig_freq=h.lr_sampling_rate, new_freq=h.hr_sampling_rate) 59 | audio_lr = audio_lr[:, : audio_hr.size(1)] 60 | 61 | amp_wb, pha_wb, com_wb = amp_pha_stft(audio_hr, h.n_fft, h.hop_size, h.win_size) 62 | 63 | pred_start = time.time() 64 | amp_nb, pha_nb, com_nb = amp_pha_stft(audio_lr, h.n_fft, h.hop_size, h.win_size) 65 | 66 | amp_wb_g, pha_wb_g, com_wb_g = model(amp_nb, pha_nb) 67 | 68 | audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 69 | duration_tot += time.time() - pred_start 70 | 71 | output_file = os.path.join(a.output_dir, index) 72 | 73 | sf.write(output_file, audio_hr_g.squeeze().cpu().numpy(), h.hr_sampling_rate, 'PCM_16') 74 | 75 | print(duration_tot) 76 | 77 | 78 | def main(): 79 | print('Initializing Inference Process..') 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--input_wavs_dir', default='VCTK-Corpus-0.92/wav16/test') 83 | parser.add_argument('--output_dir', default='../generated_files') 84 | parser.add_argument('--checkpoint_file', required=True) 85 | a = parser.parse_args() 86 | 87 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 88 | with open(config_file) as f: 89 | data = f.read() 90 | 91 | global h 92 | json_config = json.loads(data) 93 | h = AttrDict(json_config) 94 | 95 | torch.manual_seed(h.seed) 96 | global device 97 | if torch.cuda.is_available(): 98 | torch.cuda.manual_seed(h.seed) 99 | device = torch.device('cuda') 100 | else: 101 | device = torch.device('cpu') 102 | 103 | inference(a) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | 109 | -------------------------------------------------------------------------------- /inference/inference_48k.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import sys 3 | sys.path.append("..") 4 | import glob 5 | import os 6 | import argparse 7 | import json 8 | from re import S 9 | import torch 10 | import numpy as np 11 | import torchaudio 12 | import time 13 | import torchaudio.functional as aF 14 | from env import AttrDict 15 | from datasets.dataset import amp_pha_stft, amp_pha_istft 16 | from models.model import APNet_BWE_Model 17 | import soundfile as sf 18 | import matplotlib.pyplot as plt 19 | from rich.progress import track 20 | 21 | h = None 22 | device = None 23 | 24 | def load_checkpoint(filepath, device): 25 | assert os.path.isfile(filepath) 26 | print("Loading '{}'".format(filepath)) 27 | checkpoint_dict = torch.load(filepath, map_location=device) 28 | print("Complete.") 29 | return checkpoint_dict 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + '*') 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return '' 36 | return sorted(cp_list)[-1] 37 | 38 | def inference(a): 39 | model = APNet_BWE_Model(h).to(device) 40 | 41 | state_dict = load_checkpoint(a.checkpoint_file, device) 42 | model.load_state_dict(state_dict['generator']) 43 | 44 | test_indexes = os.listdir(a.input_wavs_dir) 45 | 46 | os.makedirs(a.output_dir, exist_ok=True) 47 | 48 | model.eval() 49 | duration_tot = 0 50 | with torch.no_grad(): 51 | for i, index in enumerate(track(test_indexes)): 52 | # print(index) 53 | audio, orig_sampling_rate = torchaudio.load(os.path.join(a.input_wavs_dir, index)) 54 | audio = audio.to(device) 55 | 56 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.hr_sampling_rate) 57 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.lr_sampling_rate) 58 | audio_lr = aF.resample(audio_lr, orig_freq=h.lr_sampling_rate, new_freq=h.hr_sampling_rate) 59 | audio_lr = audio_lr[:, : audio_hr.size(1)] 60 | 61 | pred_start = time.time() 62 | amp_nb, pha_nb, com_nb = amp_pha_stft(audio_lr, h.n_fft, h.hop_size, h.win_size) 63 | 64 | amp_wb_g, pha_wb_g, com_wb_g = model(amp_nb, pha_nb) 65 | 66 | audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 67 | duration_tot += time.time() - pred_start 68 | 69 | output_file = os.path.join(a.output_dir, index) 70 | 71 | sf.write(output_file, audio_hr_g.squeeze().cpu().numpy(), h.hr_sampling_rate, 'PCM_16') 72 | 73 | print(duration_tot) 74 | 75 | def main(): 76 | print('Initializing Inference Process..') 77 | 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--input_wavs_dir', default='VCTK-Corpus-0.92/wav48/test') 80 | parser.add_argument('--output_dir', default='../generated_files') 81 | parser.add_argument('--checkpoint_file', required=True) 82 | a = parser.parse_args() 83 | 84 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 85 | with open(config_file) as f: 86 | data = f.read() 87 | 88 | global h 89 | json_config = json.loads(data) 90 | h = AttrDict(json_config) 91 | 92 | torch.manual_seed(h.seed) 93 | global device 94 | if torch.cuda.is_available(): 95 | torch.cuda.manual_seed(h.seed) 96 | device = torch.device('cuda') 97 | else: 98 | device = torch.device('cpu') 99 | 100 | inference(a) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | 106 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 5 | from utils import init_weights, get_padding 6 | import numpy as np 7 | from typing import Tuple, List 8 | 9 | LRELU_SLOPE = 0.1 10 | 11 | class ConvNeXtBlock(nn.Module): 12 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 13 | 14 | Args: 15 | dim (int): Number of input channels. 16 | intermediate_dim (int): Dimensionality of the intermediate layer. 17 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 18 | Defaults to None. 19 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 20 | None means non-conditional LayerNorm. Defaults to None. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dim: int, 26 | layer_scale_init_value= None, 27 | adanorm_num_embeddings = None, 28 | ): 29 | super().__init__() 30 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 31 | self.adanorm = adanorm_num_embeddings is not None 32 | 33 | self.norm = nn.LayerNorm(dim, eps=1e-6) 34 | self.pwconv1 = nn.Linear(dim, dim*3) # pointwise/1x1 convs, implemented with linear layers 35 | self.act = nn.GELU() 36 | self.pwconv2 = nn.Linear(dim*3, dim) 37 | self.gamma = ( 38 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 39 | if layer_scale_init_value > 0 40 | else None 41 | ) 42 | 43 | def forward(self, x, cond_embedding_id = None) : 44 | residual = x 45 | x = self.dwconv(x) 46 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 47 | if self.adanorm: 48 | assert cond_embedding_id is not None 49 | x = self.norm(x, cond_embedding_id) 50 | else: 51 | x = self.norm(x) 52 | x = self.pwconv1(x) 53 | x = self.act(x) 54 | x = self.pwconv2(x) 55 | if self.gamma is not None: 56 | x = self.gamma * x 57 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 58 | 59 | x = residual + x 60 | return x 61 | 62 | 63 | class APNet_BWE_Model(torch.nn.Module): 64 | def __init__(self, h): 65 | super(APNet_BWE_Model, self).__init__() 66 | self.h = h 67 | self.adanorm_num_embeddings = None 68 | layer_scale_init_value = 1 / h.ConvNeXt_layers 69 | 70 | self.conv_pre_mag = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1)) 71 | self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6) 72 | self.conv_pre_pha = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1)) 73 | self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6) 74 | 75 | self.convnext_mag = nn.ModuleList( 76 | [ 77 | ConvNeXtBlock( 78 | dim=h.ConvNeXt_channels, 79 | layer_scale_init_value=layer_scale_init_value, 80 | adanorm_num_embeddings=self.adanorm_num_embeddings, 81 | ) 82 | for _ in range(h.ConvNeXt_layers) 83 | ] 84 | ) 85 | 86 | self.convnext_pha = nn.ModuleList( 87 | [ 88 | ConvNeXtBlock( 89 | dim=h.ConvNeXt_channels, 90 | layer_scale_init_value=layer_scale_init_value, 91 | adanorm_num_embeddings=self.adanorm_num_embeddings, 92 | ) 93 | for _ in range(h.ConvNeXt_layers) 94 | ] 95 | ) 96 | 97 | self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6) 98 | self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6) 99 | self.apply(self._init_weights) 100 | self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1) 101 | self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1) 102 | self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1) 103 | 104 | def _init_weights(self, m): 105 | if isinstance(m, (nn.Conv1d, nn.Linear)): 106 | nn.init.trunc_normal_(m.weight, std=0.02) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | def forward(self, mag_nb, pha_nb): 110 | 111 | x_mag = self.conv_pre_mag(mag_nb) 112 | x_pha = self.conv_pre_pha(pha_nb) 113 | x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2) 114 | x_pha = self.norm_pre_pha(x_pha.transpose(1, 2)).transpose(1, 2) 115 | 116 | for conv_block_mag, conv_block_pha in zip(self.convnext_mag, self.convnext_pha): 117 | x_mag = x_mag + x_pha 118 | x_pha = x_pha + x_mag 119 | x_mag = conv_block_mag(x_mag, cond_embedding_id=None) 120 | x_pha = conv_block_pha(x_pha, cond_embedding_id=None) 121 | 122 | x_mag = self.norm_post_mag(x_mag.transpose(1, 2)) 123 | mag_wb = mag_nb + self.linear_post_mag(x_mag).transpose(1, 2) 124 | 125 | x_pha = self.norm_post_pha(x_pha.transpose(1, 2)) 126 | x_pha_r = self.linear_post_pha_r(x_pha) 127 | x_pha_i = self.linear_post_pha_i(x_pha) 128 | pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2) 129 | 130 | com_wb = torch.stack((torch.exp(mag_wb)*torch.cos(pha_wb), 131 | torch.exp(mag_wb)*torch.sin(pha_wb)), dim=-1) 132 | 133 | return mag_wb, pha_wb, com_wb 134 | 135 | 136 | 137 | class DiscriminatorP(torch.nn.Module): 138 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 139 | super(DiscriminatorP, self).__init__() 140 | self.period = period 141 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 142 | self.convs = nn.ModuleList([ 143 | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 144 | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 145 | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 146 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 147 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 148 | ]) 149 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 150 | 151 | def forward(self, x): 152 | fmap = [] 153 | 154 | # 1d to 2d 155 | b, c, t = x.shape 156 | if t % self.period != 0: # pad first 157 | n_pad = self.period - (t % self.period) 158 | x = F.pad(x, (0, n_pad), "reflect") 159 | t = t + n_pad 160 | x = x.view(b, c, t // self.period, self.period) 161 | 162 | for i,l in enumerate(self.convs): 163 | x = l(x) 164 | x = F.leaky_relu(x, LRELU_SLOPE) 165 | if i > 0: 166 | fmap.append(x) 167 | x = self.conv_post(x) 168 | fmap.append(x) 169 | x = torch.flatten(x, 1, -1) 170 | 171 | return x, fmap 172 | 173 | 174 | class MultiPeriodDiscriminator(torch.nn.Module): 175 | def __init__(self): 176 | super(MultiPeriodDiscriminator, self).__init__() 177 | self.discriminators = nn.ModuleList([ 178 | DiscriminatorP(2), 179 | DiscriminatorP(3), 180 | DiscriminatorP(5), 181 | DiscriminatorP(7), 182 | DiscriminatorP(11), 183 | ]) 184 | 185 | def forward(self, y, y_hat): 186 | y_d_rs = [] 187 | y_d_gs = [] 188 | fmap_rs = [] 189 | fmap_gs = [] 190 | for i, d in enumerate(self.discriminators): 191 | y_d_r, fmap_r = d(y) 192 | y_d_g, fmap_g = d(y_hat) 193 | y_d_rs.append(y_d_r) 194 | fmap_rs.append(fmap_r) 195 | y_d_gs.append(y_d_g) 196 | fmap_gs.append(fmap_g) 197 | 198 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 199 | 200 | 201 | class MultiResolutionAmplitudeDiscriminator(nn.Module): 202 | def __init__( 203 | self, 204 | resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)), 205 | num_embeddings: int = None, 206 | ): 207 | super().__init__() 208 | self.discriminators = nn.ModuleList( 209 | [DiscriminatorAR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] 210 | ) 211 | 212 | def forward( 213 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 214 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 215 | y_d_rs = [] 216 | y_d_gs = [] 217 | fmap_rs = [] 218 | fmap_gs = [] 219 | 220 | for d in self.discriminators: 221 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 222 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 223 | y_d_rs.append(y_d_r) 224 | fmap_rs.append(fmap_r) 225 | y_d_gs.append(y_d_g) 226 | fmap_gs.append(fmap_g) 227 | 228 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 229 | 230 | 231 | class DiscriminatorAR(nn.Module): 232 | def __init__( 233 | self, 234 | resolution: Tuple[int, int, int], 235 | channels: int = 64, 236 | in_channels: int = 1, 237 | num_embeddings: int = None, 238 | ): 239 | super().__init__() 240 | self.resolution = resolution 241 | self.in_channels = in_channels 242 | self.convs = nn.ModuleList( 243 | [ 244 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), 245 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), 246 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), 247 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), 248 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), 249 | ] 250 | ) 251 | if num_embeddings is not None: 252 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 253 | torch.nn.init.zeros_(self.emb.weight) 254 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 255 | 256 | def forward( 257 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None 258 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 259 | fmap = [] 260 | x=x.squeeze(1) 261 | 262 | x = self.spectrogram(x) 263 | x = x.unsqueeze(1) 264 | for l in self.convs: 265 | x = l(x) 266 | x = F.leaky_relu(x, LRELU_SLOPE) 267 | fmap.append(x) 268 | if cond_embedding_id is not None: 269 | emb = self.emb(cond_embedding_id) 270 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 271 | else: 272 | h = 0 273 | x = self.conv_post(x) 274 | fmap.append(x) 275 | x += h 276 | x = torch.flatten(x, 1, -1) 277 | 278 | return x, fmap 279 | 280 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 281 | n_fft, hop_length, win_length = self.resolution 282 | amplitude_spectrogram = torch.stft( 283 | x, 284 | n_fft=n_fft, 285 | hop_length=hop_length, 286 | win_length=win_length, 287 | window=None, # interestingly rectangular window kind of works here 288 | center=True, 289 | return_complex=True, 290 | ).abs() 291 | 292 | return amplitude_spectrogram 293 | 294 | 295 | class MultiResolutionPhaseDiscriminator(nn.Module): 296 | def __init__( 297 | self, 298 | resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)), 299 | num_embeddings: int = None, 300 | ): 301 | super().__init__() 302 | self.discriminators = nn.ModuleList( 303 | [DiscriminatorPR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] 304 | ) 305 | 306 | def forward( 307 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 308 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 309 | y_d_rs = [] 310 | y_d_gs = [] 311 | fmap_rs = [] 312 | fmap_gs = [] 313 | 314 | for d in self.discriminators: 315 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 316 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 317 | y_d_rs.append(y_d_r) 318 | fmap_rs.append(fmap_r) 319 | y_d_gs.append(y_d_g) 320 | fmap_gs.append(fmap_g) 321 | 322 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 323 | 324 | 325 | class DiscriminatorPR(nn.Module): 326 | def __init__( 327 | self, 328 | resolution: Tuple[int, int, int], 329 | channels: int = 64, 330 | in_channels: int = 1, 331 | num_embeddings: int = None, 332 | ): 333 | super().__init__() 334 | self.resolution = resolution 335 | self.in_channels = in_channels 336 | self.convs = nn.ModuleList( 337 | [ 338 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), 339 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), 340 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), 341 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), 342 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), 343 | ] 344 | ) 345 | if num_embeddings is not None: 346 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 347 | torch.nn.init.zeros_(self.emb.weight) 348 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) 349 | 350 | def forward( 351 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None 352 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 353 | fmap = [] 354 | x=x.squeeze(1) 355 | 356 | x = self.spectrogram(x) 357 | x = x.unsqueeze(1) 358 | for l in self.convs: 359 | x = l(x) 360 | x = F.leaky_relu(x, LRELU_SLOPE) 361 | fmap.append(x) 362 | if cond_embedding_id is not None: 363 | emb = self.emb(cond_embedding_id) 364 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 365 | else: 366 | h = 0 367 | x = self.conv_post(x) 368 | fmap.append(x) 369 | x += h 370 | x = torch.flatten(x, 1, -1) 371 | 372 | return x, fmap 373 | 374 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 375 | n_fft, hop_length, win_length = self.resolution 376 | phase_spectrogram = torch.stft( 377 | x, 378 | n_fft=n_fft, 379 | hop_length=hop_length, 380 | win_length=win_length, 381 | window=None, # interestingly rectangular window kind of works here 382 | center=True, 383 | return_complex=True, 384 | ).angle() 385 | 386 | return phase_spectrogram 387 | 388 | 389 | def feature_loss(fmap_r, fmap_g): 390 | loss = 0 391 | for dr, dg in zip(fmap_r, fmap_g): 392 | for rl, gl in zip(dr, dg): 393 | loss += torch.mean(torch.abs(rl - gl)) 394 | 395 | return loss 396 | 397 | 398 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 399 | loss = 0 400 | r_losses = [] 401 | g_losses = [] 402 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 403 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 404 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 405 | loss += r_loss + g_loss 406 | r_losses.append(r_loss.item()) 407 | g_losses.append(g_loss.item()) 408 | 409 | return loss, r_losses, g_losses 410 | 411 | 412 | def generator_loss(disc_outputs): 413 | loss = 0 414 | gen_losses = [] 415 | for dg in disc_outputs: 416 | l = torch.mean(torch.clamp(1 - dg, min=0)) 417 | gen_losses.append(l) 418 | loss += l 419 | 420 | return loss, gen_losses 421 | 422 | 423 | def phase_losses(phase_r, phase_g): 424 | 425 | ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 426 | gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) 427 | iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) 428 | 429 | return ip_loss, gd_loss, iaf_loss 430 | 431 | def anti_wrapping_function(x): 432 | 433 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 434 | 435 | def stft_mag(audio, n_fft=2048, hop_length=512): 436 | hann_window = torch.hann_window(n_fft).to(audio.device) 437 | stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True) 438 | stft_mag = torch.abs(stft_spec) 439 | return(stft_mag) 440 | 441 | def cal_snr(pred, target): 442 | snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean() 443 | return snr 444 | 445 | def cal_lsd(pred, target): 446 | sp = torch.log10(stft_mag(pred).square().clamp(1e-8)) 447 | st = torch.log10(stft_mag(target).square().clamp(1e-8)) 448 | return (sp - st).square().mean(dim=1).sqrt().mean() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchaudio==2.0.1 3 | matplotlib==3.8.0 4 | numpy==1.26.0 5 | librosa==0.7.2 6 | scipy==1.11.2 7 | tensorboard==2.11.2 8 | einops==0.8.0 9 | joblib==1.3.2 10 | natsort==8.3.0 11 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train/train_16k.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import sys 4 | sys.path.append("..") 5 | import itertools 6 | import os 7 | import time 8 | import argparse 9 | import json 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.utils import clip_grad_norm 14 | from torch.utils.data import DistributedSampler, DataLoader 15 | import torch.multiprocessing as mp 16 | from torch.distributed import init_process_group 17 | from torch.nn.parallel import DistributedDataParallel 18 | from env import AttrDict, build_env 19 | from datasets.dataset import Dataset, amp_pha_stft, amp_pha_istft, get_dataset_filelist 20 | from models.model import APNet_BWE_Model, MultiPeriodDiscriminator, MultiResolutionAmplitudeDiscriminator, MultiResolutionPhaseDiscriminator, \ 21 | feature_loss, generator_loss, discriminator_loss, phase_losses, cal_snr, cal_lsd 22 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | def train(rank, a, h): 27 | if h.num_gpus > 1: 28 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 29 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 30 | 31 | torch.cuda.manual_seed(h.seed) 32 | device = torch.device('cuda:{:d}'.format(rank)) 33 | 34 | generator = APNet_BWE_Model(h).to(device) 35 | mpd = MultiPeriodDiscriminator().to(device) 36 | mrad = MultiResolutionAmplitudeDiscriminator().to(device) 37 | mrpd = MultiResolutionPhaseDiscriminator().to(device) 38 | 39 | if rank == 0: 40 | print(generator) 41 | num_params = 0 42 | for p in generator.parameters(): 43 | num_params += p.numel() 44 | print(num_params) 45 | os.makedirs(a.checkpoint_path, exist_ok=True) 46 | os.makedirs(os.path.join(a.checkpoint_path, 'logs'), exist_ok=True) 47 | print("checkpoints directory : ", a.checkpoint_path) 48 | 49 | if os.path.isdir(a.checkpoint_path): 50 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 51 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 52 | 53 | steps = 0 54 | if cp_g is None or cp_do is None: 55 | state_dict_do = None 56 | last_epoch = -1 57 | else: 58 | state_dict_g = load_checkpoint(cp_g, device) 59 | state_dict_do = load_checkpoint(cp_do, device) 60 | generator.load_state_dict(state_dict_g['generator']) 61 | mpd.load_state_dict(state_dict_do['mpd']) 62 | mrad.load_state_dict(state_dict_do['mrad']) 63 | mrpd.load_state_dict(state_dict_do['mrpd']) 64 | steps = state_dict_do['steps'] + 1 65 | last_epoch = state_dict_do['epoch'] 66 | 67 | if h.num_gpus > 1: 68 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 69 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 70 | mrad = DistributedDataParallel(mrad, device_ids=[rank]).to(device) 71 | mrpd = DistributedDataParallel(mrpd, device_ids=[rank]).to(device) 72 | 73 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 74 | optim_d = torch.optim.AdamW(itertools.chain(mrad.parameters(), mrpd.parameters(), mpd.parameters()), 75 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 76 | 77 | if state_dict_do is not None: 78 | optim_g.load_state_dict(state_dict_do['optim_g']) 79 | optim_d.load_state_dict(state_dict_do['optim_d']) 80 | 81 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 82 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 83 | 84 | training_indexes, validation_indexes = get_dataset_filelist(a) 85 | 86 | trainset = Dataset(training_indexes, a.input_training_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate, 87 | split=True, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, device=device) 88 | 89 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 90 | 91 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 92 | sampler=train_sampler, 93 | batch_size=h.batch_size, 94 | pin_memory=True, 95 | drop_last=True) 96 | if rank == 0: 97 | validset = Dataset(validation_indexes, a.input_validation_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate, 98 | split=False, shuffle=False, n_cache_reuse=0, device=device) 99 | 100 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 101 | sampler=None, 102 | batch_size=1, 103 | pin_memory=True, 104 | drop_last=True) 105 | 106 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 107 | 108 | generator.train() 109 | mpd.train() 110 | mrad.train() 111 | mrpd.train() 112 | 113 | for epoch in range(max(0, last_epoch), a.training_epochs): 114 | if rank == 0: 115 | start = time.time() 116 | print("Epoch: {}".format(epoch+1)) 117 | 118 | if h.num_gpus > 1: 119 | train_sampler.set_epoch(epoch) 120 | 121 | for i, batch in enumerate(train_loader): 122 | 123 | if rank == 0: 124 | start_b = time.time() 125 | audio_wb, audio_nb = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes 126 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True)) 127 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True)) 128 | 129 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size) 130 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size) 131 | 132 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb, pha_nb) 133 | 134 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 135 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size) 136 | audio_wb, audio_wb_g = audio_wb.unsqueeze(1), audio_wb_g.unsqueeze(1) 137 | 138 | optim_d.zero_grad() 139 | 140 | # MPD 141 | audio_df_r, audio_df_g, _, _ = mpd(audio_wb, audio_wb_g.detach()) 142 | loss_disc_f, losses_disc_p_r, losses_disc_p_g = discriminator_loss(audio_df_r, audio_df_g) 143 | 144 | # MRAD 145 | spec_da_r, spec_da_g, _, _ = mrad(audio_wb, audio_wb_g.detach()) 146 | loss_disc_a, losses_disc_a_r, losses_disc_a_g = discriminator_loss(spec_da_r, spec_da_g) 147 | 148 | # MRPD 149 | spec_dp_r, spec_dp_g, _, _ = mrpd(audio_wb, audio_wb_g.detach()) 150 | loss_disc_p, losses_disc_p_r, losses_disc_p_g = discriminator_loss(spec_dp_r, spec_dp_g) 151 | 152 | loss_disc_all = (loss_disc_a + loss_disc_p) * 0.1 + loss_disc_f 153 | 154 | loss_disc_all.backward() 155 | torch.nn.utils.clip_grad_norm_(parameters=mpd.parameters(), max_norm=10, norm_type=2) 156 | torch.nn.utils.clip_grad_norm_(parameters=mrad.parameters(), max_norm=10, norm_type=2) 157 | torch.nn.utils.clip_grad_norm_(parameters=mrpd.parameters(), max_norm=10, norm_type=2) 158 | optim_d.step() 159 | 160 | # Generator 161 | optim_g.zero_grad() 162 | 163 | # L2 Magnitude Loss 164 | loss_mag = F.mse_loss(mag_wb, mag_wb_g) * 45 165 | # Anti-wrapping Phase Loss 166 | loss_ip, loss_gd, loss_iaf = phase_losses(pha_wb, pha_wb_g) 167 | loss_pha = (loss_ip + loss_gd + loss_iaf) * 100 168 | # L2 Complex Loss 169 | loss_com = F.mse_loss(com_wb, com_wb_g) * 90 170 | # L2 Consistency Loss 171 | loss_stft = F.mse_loss(com_wb_g, com_wb_g_hat) * 90 172 | 173 | audio_df_r, audio_df_g, fmap_f_r, fmap_f_g = mpd(audio_wb, audio_wb_g) 174 | spec_da_r, spec_da_g, fmap_a_r, fmap_a_g = mrad(audio_wb, audio_wb_g) 175 | spec_dp_r, spec_dp_g, fmap_p_r, fmap_p_g = mrpd(audio_wb, audio_wb_g) 176 | 177 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 178 | loss_fm_a = feature_loss(fmap_a_r, fmap_a_g) 179 | loss_fm_p = feature_loss(fmap_p_r, fmap_p_g) 180 | 181 | loss_gen_f, losses_gen_f = generator_loss(audio_df_g) 182 | loss_gen_a, losses_gen_a = generator_loss(spec_da_g) 183 | loss_gen_p, losses_gen_p = generator_loss(spec_dp_g) 184 | 185 | loss_gen = (loss_gen_a + loss_gen_p) * 0.1 + loss_gen_f 186 | loss_fm = (loss_fm_a + loss_fm_p) * 0.1 + loss_fm_f 187 | 188 | loss_gen_all = loss_mag + loss_pha + loss_com + loss_stft + loss_gen + loss_fm 189 | 190 | loss_gen_all.backward() 191 | torch.nn.utils.clip_grad_norm_(parameters=generator.parameters(), max_norm=10, norm_type=2) 192 | optim_g.step() 193 | 194 | if rank == 0: 195 | # STDOUT logging 196 | if steps % a.stdout_interval == 0: 197 | with torch.no_grad(): 198 | mag_error = F.mse_loss(mag_wb, mag_wb_g).item() 199 | ip_error, gd_error, iaf_error = phase_losses(pha_wb, pha_wb_g) 200 | pha_error = (ip_error + gd_error + iaf_error).item() 201 | com_error = F.mse_loss(com_wb, com_wb_g).item() 202 | stft_error = F.mse_loss(com_wb_g, com_wb_g_hat).item() 203 | print('Steps : {:d}, Gen Loss: {:4.3f}, Magnitude Loss : {:4.3f}, Phase Loss : {:4.3f}, Complex Loss : {:4.3f}, STFT Loss : {:4.3f}, s/b : {:4.3f}'. 204 | format(steps, loss_gen_all, mag_error, pha_error, com_error, stft_error, time.time() - start_b)) 205 | 206 | # checkpointing 207 | if steps % a.checkpoint_interval == 0 and steps != 0: 208 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 209 | save_checkpoint(checkpoint_path, 210 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) 211 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 212 | save_checkpoint(checkpoint_path, 213 | {'mpd': (mpd.module if h.num_gpus > 1 214 | else mpd).state_dict(), 215 | 'mrad': (mrad.module if h.num_gpus > 1 216 | else mrad).state_dict(), 217 | 'mrpd': (mrpd.module if h.num_gpus > 1 218 | else mrpd).state_dict(), 219 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 220 | 'epoch': epoch}) 221 | 222 | # Tensorboard summary logging 223 | if steps % a.summary_interval == 0: 224 | sw.add_scalar("Training/Generator Loss", loss_gen_all, steps) 225 | sw.add_scalar("Training/Magnitude Loss", mag_error, steps) 226 | sw.add_scalar("Training/Phase Loss", pha_error, steps) 227 | sw.add_scalar("Training/Complex Loss", com_error, steps) 228 | sw.add_scalar("Training/Consistency Loss", stft_error, steps) 229 | 230 | # Validation 231 | if steps % a.validation_interval == 0: 232 | start_v = time.time() 233 | generator.eval() 234 | torch.cuda.empty_cache() 235 | val_mag_err_tot = 0 236 | val_pha_err_tot = 0 237 | val_com_err_tot = 0 238 | val_stft_err_tot = 0 239 | val_snr_score_tot = 0 240 | val_lsd_score_tot = 0 241 | with torch.no_grad(): 242 | for j, batch in enumerate(validation_loader): 243 | audio_wb, audio_nb = batch 244 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True)) 245 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True)) 246 | 247 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size) 248 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size) 249 | 250 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb.to(device), pha_nb.to(device)) 251 | 252 | audio_wb = amp_pha_istft(mag_wb, pha_wb, h.n_fft, h.hop_size, h.win_size) 253 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 254 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size) 255 | 256 | val_mag_err_tot += F.mse_loss(mag_wb, mag_wb_g_hat).item() 257 | val_ip_err, val_gd_err, val_iaf_err = phase_losses(pha_wb, pha_wb_g_hat) 258 | val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item() 259 | val_com_err_tot += F.mse_loss(com_wb, com_wb_g_hat).item() 260 | val_stft_err_tot += F.mse_loss(com_wb_g, com_wb_g_hat).item() 261 | val_snr_score_tot += cal_snr(audio_wb_g, audio_wb).item() 262 | val_lsd_score_tot += cal_lsd(audio_wb_g, audio_wb).item() 263 | 264 | if j <= 4: 265 | if steps == 0: 266 | sw.add_audio('gt/audio_nb_{}'.format(j), audio_nb[0], steps, h.hr_sampling_rate) 267 | sw.add_audio('gt/audio_wb_{}'.format(j), audio_wb[0], steps, h.hr_sampling_rate) 268 | sw.add_figure('gt/spec_nb_{}'.format(j), plot_spectrogram(mag_nb.squeeze().cpu().numpy()), steps) 269 | sw.add_figure('gt/spec_wb_{}'.format(j), plot_spectrogram(mag_wb.squeeze().cpu().numpy()), steps) 270 | 271 | sw.add_audio('generated/audio_g_{}'.format(j), audio_wb_g[0], steps, h.hr_sampling_rate) 272 | sw.add_figure('generated/spec_g_{}'.format(j), plot_spectrogram(mag_wb_g.squeeze().cpu().numpy()), steps) 273 | 274 | val_mag_err = val_mag_err_tot / (j+1) 275 | val_pha_err = val_pha_err_tot / (j+1) 276 | val_com_err = val_com_err_tot / (j+1) 277 | val_stft_err = val_stft_err_tot / (j+1) 278 | val_snr_score = val_snr_score_tot / (j+1) 279 | val_lsd_score = val_lsd_score_tot / (j+1) 280 | 281 | print('Steps : {:d}, SNR Score: {:4.3f}, LSD Score: {:4.3f}, s/b : {:4.3f}'. 282 | format(steps, val_snr_score, val_lsd_score, time.time() - start_v)) 283 | sw.add_scalar("Validation/LSD Score", val_lsd_score, steps) 284 | sw.add_scalar("Validation/SNR Score", val_snr_score, steps) 285 | sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps) 286 | sw.add_scalar("Validation/Phase Loss", val_pha_err, steps) 287 | sw.add_scalar("Validation/Complex Loss", val_com_err, steps) 288 | sw.add_scalar("Validation/Consistency Loss", val_stft_err, steps) 289 | 290 | generator.train() 291 | 292 | steps += 1 293 | 294 | scheduler_g.step() 295 | scheduler_d.step() 296 | 297 | if rank == 0: 298 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 299 | 300 | 301 | def main(): 302 | print('Initializing Training Process..') 303 | 304 | parser = argparse.ArgumentParser() 305 | 306 | parser.add_argument('--group_name', default=None) 307 | parser.add_argument('--input_training_wavs_dir', default='VCTK-Corpus-0.92/wav16/train') 308 | parser.add_argument('--input_validation_wavs_dir', default='VCTK-Corpus-0.92/wav16/test') 309 | parser.add_argument('--input_training_file', default='VCTK-Corpus-0.92/training.txt') 310 | parser.add_argument('--input_validation_file', default='VCTK-Corpus-0.92/test.txt') 311 | parser.add_argument('--checkpoint_path', default='cp_model') 312 | parser.add_argument('--config', default='') 313 | parser.add_argument('--training_epochs', default=3100, type=int) 314 | parser.add_argument('--stdout_interval', default=5, type=int) 315 | parser.add_argument('--checkpoint_interval', default=5000, type=int) 316 | parser.add_argument('--summary_interval', default=100, type=int) 317 | parser.add_argument('--validation_interval', default=5000, type=int) 318 | 319 | a = parser.parse_args() 320 | 321 | with open(a.config) as f: 322 | data = f.read() 323 | 324 | json_config = json.loads(data) 325 | h = AttrDict(json_config) 326 | build_env(a.config, 'config.json', a.checkpoint_path) 327 | 328 | torch.manual_seed(h.seed) 329 | if torch.cuda.is_available(): 330 | torch.cuda.manual_seed(h.seed) 331 | h.num_gpus = torch.cuda.device_count() 332 | h.batch_size = int(h.batch_size / h.num_gpus) 333 | print('Batch size per GPU :', h.batch_size) 334 | else: 335 | pass 336 | 337 | if h.num_gpus > 1: 338 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) 339 | else: 340 | train(0, a, h) 341 | 342 | 343 | if __name__ == '__main__': 344 | main() -------------------------------------------------------------------------------- /train/train_48k.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import sys 4 | sys.path.append("..") 5 | import itertools 6 | import os 7 | import time 8 | import argparse 9 | import json 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.utils import clip_grad_norm 14 | from torch.utils.data import DistributedSampler, DataLoader 15 | import torch.multiprocessing as mp 16 | from torch.distributed import init_process_group 17 | from torch.nn.parallel import DistributedDataParallel 18 | from env import AttrDict, build_env 19 | from datasets.dataset import Dataset, amp_pha_stft, amp_pha_istft, get_dataset_filelist 20 | from models.model import APNet_BWE_Model, MultiPeriodDiscriminator, MultiResolutionAmplitudeDiscriminator, MultiResolutionPhaseDiscriminator, \ 21 | feature_loss, generator_loss, discriminator_loss, phase_losses, cal_snr, cal_lsd 22 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | def train(rank, a, h): 27 | if h.num_gpus > 1: 28 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 29 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 30 | 31 | torch.cuda.manual_seed(h.seed) 32 | device = torch.device('cuda:{:d}'.format(rank)) 33 | 34 | generator = APNet_BWE_Model(h).to(device) 35 | mpd = MultiPeriodDiscriminator().to(device) 36 | mrad = MultiResolutionAmplitudeDiscriminator().to(device) 37 | mrpd = MultiResolutionPhaseDiscriminator().to(device) 38 | 39 | if rank == 0: 40 | print(generator) 41 | num_params = 0 42 | for p in generator.parameters(): 43 | num_params += p.numel() 44 | print(num_params) 45 | os.makedirs(a.checkpoint_path, exist_ok=True) 46 | os.makedirs(os.path.join(a.checkpoint_path, 'logs'), exist_ok=True) 47 | print("checkpoints directory : ", a.checkpoint_path) 48 | 49 | if os.path.isdir(a.checkpoint_path): 50 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 51 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 52 | 53 | steps = 0 54 | if cp_g is None or cp_do is None: 55 | state_dict_do = None 56 | last_epoch = -1 57 | else: 58 | state_dict_g = load_checkpoint(cp_g, device) 59 | state_dict_do = load_checkpoint(cp_do, device) 60 | generator.load_state_dict(state_dict_g['generator']) 61 | mpd.load_state_dict(state_dict_do['mpd']) 62 | mrad.load_state_dict(state_dict_do['mrad']) 63 | mrpd.load_state_dict(state_dict_do['mrpd']) 64 | steps = state_dict_do['steps'] + 1 65 | last_epoch = state_dict_do['epoch'] 66 | 67 | if h.num_gpus > 1: 68 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 69 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 70 | mrad = DistributedDataParallel(mrad, device_ids=[rank]).to(device) 71 | mrpd = DistributedDataParallel(mrpd, device_ids=[rank]).to(device) 72 | 73 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 74 | optim_d = torch.optim.AdamW(itertools.chain(mrad.parameters(), mrpd.parameters(), mpd.parameters()), 75 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 76 | 77 | if state_dict_do is not None: 78 | optim_g.load_state_dict(state_dict_do['optim_g']) 79 | optim_d.load_state_dict(state_dict_do['optim_d']) 80 | 81 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 82 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 83 | 84 | training_indexes, validation_indexes = get_dataset_filelist(a) 85 | 86 | trainset = Dataset(training_indexes, a.input_training_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate, 87 | split=True, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, device=device) 88 | 89 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 90 | 91 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 92 | sampler=train_sampler, 93 | batch_size=h.batch_size, 94 | pin_memory=True, 95 | drop_last=True) 96 | if rank == 0: 97 | validset = Dataset(validation_indexes, a.input_validation_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate, 98 | split=False, shuffle=False, n_cache_reuse=0, device=device) 99 | 100 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 101 | sampler=None, 102 | batch_size=1, 103 | pin_memory=True, 104 | drop_last=True) 105 | 106 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 107 | 108 | generator.train() 109 | mpd.train() 110 | mrad.train() 111 | mrpd.train() 112 | 113 | for epoch in range(max(0, last_epoch), a.training_epochs): 114 | if rank == 0: 115 | start = time.time() 116 | print("Epoch: {}".format(epoch+1)) 117 | 118 | if h.num_gpus > 1: 119 | train_sampler.set_epoch(epoch) 120 | 121 | for i, batch in enumerate(train_loader): 122 | 123 | if rank == 0: 124 | start_b = time.time() 125 | audio_wb, audio_nb = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes 126 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True)) 127 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True)) 128 | 129 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size) 130 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size) 131 | 132 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb, pha_nb) 133 | 134 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 135 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size) 136 | audio_wb, audio_wb_g = audio_wb.unsqueeze(1), audio_wb_g.unsqueeze(1) 137 | 138 | optim_d.zero_grad() 139 | 140 | # MPD 141 | audio_df_r, audio_df_g, _, _ = mpd(audio_wb, audio_wb_g.detach()) 142 | loss_disc_f, losses_disc_p_r, losses_disc_p_g = discriminator_loss(audio_df_r, audio_df_g) 143 | 144 | # MRAD 145 | spec_da_r, spec_da_g, _, _ = mrad(audio_wb, audio_wb_g.detach()) 146 | loss_disc_a, losses_disc_a_r, losses_disc_a_g = discriminator_loss(spec_da_r, spec_da_g) 147 | 148 | # MRPD 149 | spec_dp_r, spec_dp_g, _, _ = mrpd(audio_wb, audio_wb_g.detach()) 150 | loss_disc_p, losses_disc_p_r, losses_disc_p_g = discriminator_loss(spec_dp_r, spec_dp_g) 151 | 152 | loss_disc_all = (loss_disc_a + loss_disc_p) * 0.1 + loss_disc_f 153 | 154 | loss_disc_all.backward() 155 | torch.nn.utils.clip_grad_norm_(parameters=mpd.parameters(), max_norm=10, norm_type=2) 156 | torch.nn.utils.clip_grad_norm_(parameters=mrad.parameters(), max_norm=10, norm_type=2) 157 | torch.nn.utils.clip_grad_norm_(parameters=mrpd.parameters(), max_norm=10, norm_type=2) 158 | optim_d.step() 159 | 160 | # Generator 161 | optim_g.zero_grad() 162 | 163 | # L2 Magnitude Loss 164 | loss_mag = F.mse_loss(mag_wb, mag_wb_g) * 45 165 | # Anti-wrapping Phase Loss 166 | loss_ip, loss_gd, loss_iaf = phase_losses(pha_wb, pha_wb_g) 167 | loss_pha = (loss_ip + loss_gd + loss_iaf) * 100 168 | # L2 Complex Loss 169 | loss_com = F.mse_loss(com_wb, com_wb_g) * 90 170 | # L2 Consistency Loss 171 | loss_stft = F.mse_loss(com_wb_g, com_wb_g_hat) * 90 172 | 173 | audio_df_r, audio_df_g, fmap_f_r, fmap_f_g = mpd(audio_wb, audio_wb_g) 174 | spec_da_r, spec_da_g, fmap_a_r, fmap_a_g = mrad(audio_wb, audio_wb_g) 175 | spec_dp_r, spec_dp_g, fmap_p_r, fmap_p_g = mrpd(audio_wb, audio_wb_g) 176 | 177 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 178 | loss_fm_a = feature_loss(fmap_a_r, fmap_a_g) 179 | loss_fm_p = feature_loss(fmap_p_r, fmap_p_g) 180 | 181 | loss_gen_f, losses_gen_f = generator_loss(audio_df_g) 182 | loss_gen_a, losses_gen_a = generator_loss(spec_da_g) 183 | loss_gen_p, losses_gen_p = generator_loss(spec_dp_g) 184 | 185 | loss_gen = (loss_gen_a + loss_gen_p) * 0.1 + loss_gen_f 186 | loss_fm = (loss_fm_a + loss_fm_p) * 0.1 + loss_fm_f 187 | 188 | loss_gen_all = loss_mag + loss_pha + loss_com + loss_stft + loss_gen + loss_fm 189 | 190 | loss_gen_all.backward() 191 | torch.nn.utils.clip_grad_norm_(parameters=generator.parameters(), max_norm=10, norm_type=2) 192 | optim_g.step() 193 | 194 | if rank == 0: 195 | # STDOUT logging 196 | if steps % a.stdout_interval == 0: 197 | with torch.no_grad(): 198 | mag_error = F.mse_loss(mag_wb, mag_wb_g).item() 199 | ip_error, gd_error, iaf_error = phase_losses(pha_wb, pha_wb_g) 200 | pha_error = (ip_error + gd_error + iaf_error).item() 201 | com_error = F.mse_loss(com_wb, com_wb_g).item() 202 | stft_error = F.mse_loss(com_wb_g, com_wb_g_hat).item() 203 | print('Steps : {:d}, Gen Loss: {:4.3f}, Magnitude Loss : {:4.3f}, Phase Loss : {:4.3f}, Complex Loss : {:4.3f}, STFT Loss : {:4.3f}, s/b : {:4.3f}'. 204 | format(steps, loss_gen_all, mag_error, pha_error, com_error, stft_error, time.time() - start_b)) 205 | 206 | # checkpointing 207 | if steps % a.checkpoint_interval == 0 and steps != 0: 208 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 209 | save_checkpoint(checkpoint_path, 210 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) 211 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 212 | save_checkpoint(checkpoint_path, 213 | {'mpd': (mpd.module if h.num_gpus > 1 214 | else mpd).state_dict(), 215 | 'mrad': (mrad.module if h.num_gpus > 1 216 | else mrad).state_dict(), 217 | 'mrpd': (mrpd.module if h.num_gpus > 1 218 | else mrpd).state_dict(), 219 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 220 | 'epoch': epoch}) 221 | 222 | # Tensorboard summary logging 223 | if steps % a.summary_interval == 0: 224 | sw.add_scalar("Training/Generator Loss", loss_gen_all, steps) 225 | sw.add_scalar("Training/Magnitude Loss", mag_error, steps) 226 | sw.add_scalar("Training/Phase Loss", pha_error, steps) 227 | sw.add_scalar("Training/Complex Loss", com_error, steps) 228 | sw.add_scalar("Training/Consistency Loss", stft_error, steps) 229 | 230 | # Validation 231 | if steps % a.validation_interval == 0: 232 | start_v = time.time() 233 | generator.eval() 234 | torch.cuda.empty_cache() 235 | val_mag_err_tot = 0 236 | val_pha_err_tot = 0 237 | val_com_err_tot = 0 238 | val_stft_err_tot = 0 239 | val_snr_score_tot = 0 240 | val_lsd_score_tot = 0 241 | with torch.no_grad(): 242 | for j, batch in enumerate(validation_loader): 243 | audio_wb, audio_nb = batch 244 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True)) 245 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True)) 246 | 247 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size) 248 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size) 249 | 250 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb.to(device), pha_nb.to(device)) 251 | 252 | audio_wb = amp_pha_istft(mag_wb, pha_wb, h.n_fft, h.hop_size, h.win_size) 253 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size) 254 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size) 255 | 256 | val_mag_err_tot += F.mse_loss(mag_wb, mag_wb_g_hat).item() 257 | val_ip_err, val_gd_err, val_iaf_err = phase_losses(pha_wb, pha_wb_g_hat) 258 | val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item() 259 | val_com_err_tot += F.mse_loss(com_wb, com_wb_g_hat).item() 260 | val_stft_err_tot += F.mse_loss(com_wb_g, com_wb_g_hat).item() 261 | val_snr_score_tot += cal_snr(audio_wb_g, audio_wb).item() 262 | val_lsd_score_tot += cal_lsd(audio_wb_g, audio_wb).item() 263 | 264 | if j <= 4: 265 | if steps == 0: 266 | sw.add_audio('gt/audio_nb_{}'.format(j), audio_nb[0], steps, h.hr_sampling_rate) 267 | sw.add_audio('gt/audio_wb_{}'.format(j), audio_wb[0], steps, h.hr_sampling_rate) 268 | sw.add_figure('gt/spec_nb_{}'.format(j), plot_spectrogram(mag_nb.squeeze().cpu().numpy()), steps) 269 | sw.add_figure('gt/spec_wb_{}'.format(j), plot_spectrogram(mag_wb.squeeze().cpu().numpy()), steps) 270 | 271 | sw.add_audio('generated/audio_g_{}'.format(j), audio_wb_g[0], steps, h.hr_sampling_rate) 272 | sw.add_figure('generated/spec_g_{}'.format(j), plot_spectrogram(mag_wb_g.squeeze().cpu().numpy()), steps) 273 | 274 | val_mag_err = val_mag_err_tot / (j+1) 275 | val_pha_err = val_pha_err_tot / (j+1) 276 | val_com_err = val_com_err_tot / (j+1) 277 | val_stft_err = val_stft_err_tot / (j+1) 278 | val_snr_score = val_snr_score_tot / (j+1) 279 | val_lsd_score = val_lsd_score_tot / (j+1) 280 | 281 | print('Steps : {:d}, SNR Score: {:4.3f}, LSD Score: {:4.3f}, s/b : {:4.3f}'. 282 | format(steps, val_snr_score, val_lsd_score, time.time() - start_v)) 283 | sw.add_scalar("Validation/LSD Score", val_lsd_score, steps) 284 | sw.add_scalar("Validation/SNR Score", val_snr_score, steps) 285 | sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps) 286 | sw.add_scalar("Validation/Phase Loss", val_pha_err, steps) 287 | sw.add_scalar("Validation/Complex Loss", val_com_err, steps) 288 | sw.add_scalar("Validation/Consistency Loss", val_stft_err, steps) 289 | 290 | generator.train() 291 | 292 | steps += 1 293 | 294 | scheduler_g.step() 295 | scheduler_d.step() 296 | 297 | if rank == 0: 298 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 299 | 300 | 301 | def main(): 302 | print('Initializing Training Process..') 303 | 304 | parser = argparse.ArgumentParser() 305 | 306 | parser.add_argument('--group_name', default=None) 307 | parser.add_argument('--input_training_wavs_dir', default='VCTK-Corpus-0.92/wav48/train') 308 | parser.add_argument('--input_validation_wavs_dir', default='VCTK-Corpus-0.92/wav48/test') 309 | parser.add_argument('--input_training_file', default='VCTK-Corpus-0.92/training.txt') 310 | parser.add_argument('--input_validation_file', default='VCTK-Corpus-0.92/test.txt') 311 | parser.add_argument('--checkpoint_path', default='cp_model') 312 | parser.add_argument('--config', default='') 313 | parser.add_argument('--training_epochs', default=3100, type=int) 314 | parser.add_argument('--stdout_interval', default=5, type=int) 315 | parser.add_argument('--checkpoint_interval', default=5000, type=int) 316 | parser.add_argument('--summary_interval', default=100, type=int) 317 | parser.add_argument('--validation_interval', default=5000, type=int) 318 | 319 | a = parser.parse_args() 320 | 321 | with open(a.config) as f: 322 | data = f.read() 323 | 324 | json_config = json.loads(data) 325 | h = AttrDict(json_config) 326 | build_env(a.config, 'config.json', a.checkpoint_path) 327 | 328 | torch.manual_seed(h.seed) 329 | if torch.cuda.is_available(): 330 | torch.cuda.manual_seed(h.seed) 331 | h.num_gpus = torch.cuda.device_count() 332 | h.batch_size = int(h.batch_size / h.num_gpus) 333 | print('Batch size per GPU :', h.batch_size) 334 | else: 335 | pass 336 | 337 | if h.num_gpus > 1: 338 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) 339 | else: 340 | train(0, a, h) 341 | 342 | 343 | if __name__ == '__main__': 344 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils import weight_norm 6 | import matplotlib 7 | matplotlib.use("Agg") 8 | import matplotlib.pylab as plt 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(4, 3)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | fig.canvas.draw() 16 | plt.close() 17 | 18 | return fig 19 | 20 | def init_weights(m, mean=0.0, std=0.01): 21 | classname = m.__class__.__name__ 22 | if classname.find("Conv") != -1: 23 | m.weight.data.normal_(mean, std) 24 | 25 | def apply_weight_norm(m): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | weight_norm(m) 29 | 30 | def get_padding(kernel_size, dilation=1): 31 | return int((kernel_size*dilation - dilation)/2) 32 | 33 | 34 | def get_padding_2d(kernel_size, dilation=(1, 1)): 35 | return (int((kernel_size[0]*dilation[0] - dilation[0])/2), int((kernel_size[1]*dilation[1] - dilation[1])/2)) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] -------------------------------------------------------------------------------- /weights_LICENSE.txt: -------------------------------------------------------------------------------- 1 | The weights for this model are licensed under the MIT License. 2 | You can use, modify, and distribute the weights freely as long as you comply with the terms of the MIT License. 3 | --------------------------------------------------------------------------------