├── Figures
├── model.png
├── table_16k.png
└── table_48k.png
├── LICENSE
├── README.md
├── VCTK-Corpus-0.92
├── flac2wav.py
├── test.txt
├── training.txt
├── vctk-silence-labels
│ ├── README.md
│ ├── assets
│ │ ├── original.png
│ │ ├── trim.png
│ │ └── trimpad.png
│ └── vctk-silences.0.92.txt
├── wav16
│ ├── test
│ │ └── readme.txt
│ └── train
│ │ └── readme.txt
└── wav48
│ ├── test
│ └── readme.txt
│ └── train
│ └── readme.txt
├── cal_metrics.py
├── cal_visqol_48k.py
├── checkpoints
└── README.md
├── configs
├── __init__.py
├── config_12kto48k.json
├── config_16kto48k.json
├── config_24kto48k.json
├── config_2kto16k.json
├── config_4kto16k.json
├── config_8kto16k.json
└── config_8kto48k.json
├── datasets
├── __init__.py
└── dataset.py
├── docs
├── css
│ ├── bulma-carousel.min.css
│ ├── bulma-slider.min.css
│ ├── bulma.min.css
│ ├── fontawesome.all.min.css
│ └── index.css
├── index.html
├── js
│ ├── ViewSDKInterface.js
│ ├── bulma-carousel.min.js
│ ├── bulma-slider.min.js
│ ├── fontawesome.all.min.js
│ ├── index.js
│ ├── jquery.min.js
│ └── main.js
└── samples
│ ├── Ablation
│ ├── ap-bwe.png
│ ├── ap-bwe.wav
│ ├── ap-bwe_mpd_only.png
│ ├── ap-bwe_mpd_only.wav
│ ├── ap-bwe_mrad_only.png
│ ├── ap-bwe_mrad_only.wav
│ ├── ap-bwe_mrpd_only.png
│ ├── ap-bwe_mrpd_only.wav
│ ├── ap-bwe_wo_AtoP.png
│ ├── ap-bwe_wo_AtoP.wav
│ ├── ap-bwe_wo_PtoA.png
│ ├── ap-bwe_wo_PtoA.wav
│ ├── ap-bwe_wo_connect.png
│ ├── ap-bwe_wo_connect.wav
│ ├── ap-bwe_wo_disc.png
│ ├── ap-bwe_wo_disc.wav
│ ├── ap-bwe_wo_mpd.png
│ ├── ap-bwe_wo_mpd.wav
│ ├── ap-bwe_wo_mrad.png
│ ├── ap-bwe_wo_mrad.wav
│ ├── ap-bwe_wo_mrpd.png
│ ├── ap-bwe_wo_mrpd.wav
│ ├── wideband.png
│ └── wideband.wav
│ ├── BWE_16k
│ ├── 2kto16k
│ │ ├── afilm.png
│ │ ├── afilm.wav
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── nvsr.png
│ │ ├── nvsr.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── tfilm.png
│ │ ├── tfilm.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ ├── 4kto16k
│ │ ├── afilm.png
│ │ ├── afilm.wav
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── nvsr.png
│ │ ├── nvsr.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── tfilm.png
│ │ ├── tfilm.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ └── 8kto16k
│ │ ├── afilm.png
│ │ ├── afilm.wav
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── nvsr.png
│ │ ├── nvsr.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── tfilm.png
│ │ ├── tfilm.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ ├── BWE_48k
│ ├── 12kto48k
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── mdctgan.png
│ │ ├── mdctgan.wav
│ │ ├── nuwave2.png
│ │ ├── nuwave2.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── udm+.png
│ │ ├── udm+.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ ├── 16kto48k
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── mdctgan.png
│ │ ├── mdctgan.wav
│ │ ├── nuwave2.png
│ │ ├── nuwave2.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── udm+.png
│ │ ├── udm+.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ ├── 24kto48k
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── mdctgan.png
│ │ ├── mdctgan.wav
│ │ ├── nuwave2.png
│ │ ├── nuwave2.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── udm+.png
│ │ ├── udm+.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ └── 8kto48k
│ │ ├── ap-bwe.png
│ │ ├── ap-bwe.wav
│ │ ├── mdctgan.png
│ │ ├── mdctgan.wav
│ │ ├── nuwave2.png
│ │ ├── nuwave2.wav
│ │ ├── sinc.png
│ │ ├── sinc.wav
│ │ ├── udm+.png
│ │ ├── udm+.wav
│ │ ├── wideband.png
│ │ └── wideband.wav
│ └── CrossDataset
│ ├── HiFi-TTS
│ ├── ap-bwe.png
│ ├── ap-bwe.wav
│ ├── mdctgan.png
│ ├── mdctgan.wav
│ ├── nuwave2.png
│ ├── nuwave2.wav
│ ├── udm+.png
│ ├── udm+.wav
│ ├── wideband.png
│ └── wideband.wav
│ └── Libri-TTS
│ ├── ap-bwe.png
│ ├── ap-bwe.wav
│ ├── mdctgan.png
│ ├── mdctgan.wav
│ ├── nuwave2.png
│ ├── nuwave2.wav
│ ├── udm+.png
│ ├── udm+.wav
│ ├── wideband.png
│ └── wideband.wav
├── env.py
├── inference
├── __init__.py
├── inference_16k.py
└── inference_48k.py
├── models
├── __init__.py
└── model.py
├── requirements.txt
├── train
├── __init__.py
├── train_16k.py
└── train_48k.py
├── utils.py
└── weights_LICENSE.txt
/Figures/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/model.png
--------------------------------------------------------------------------------
/Figures/table_16k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/table_16k.png
--------------------------------------------------------------------------------
/Figures/table_48k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/Figures/table_48k.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ye-Xin Lu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards High-Quality and Efficient Speech Bandwidth Extension with Parallel Amplitude and Phase Prediction
2 | ### Ye-Xin Lu, Yang Ai, Hui-Peng Du, Zhen-Hua Ling
3 |
4 | **Abstract:**
5 | Speech bandwidth extension (BWE) refers to widening the frequency bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller.
6 | This paper proposes a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both high-quality and efficient wideband speech waveform generation.
7 | The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs).
8 | It features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the input narrowband amplitude and phase spectra.
9 | To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively.
10 | Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in terms of speech quality for BWE tasks targeting sampling rates of both 16 kHz and 48 kHz.
11 | In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on a single CPU.
12 | Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods.
13 |
14 | **We provide our implementation as open source in this repository. Audio samples can be found at the [demo website](http://yxlu-0102.github.io/AP-BWE).**
15 |
16 | ## License
17 |
18 | The code in this repository is licensed under the MIT License.
19 | The weights for this model are stored in [Google Drive](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link). You can access and use the weights under the same MIT License. Please refer to the [weights_LICENSE.txt](weights_LICENSE.txt) for more details on the licensing.
20 |
21 | ## Pre-requisites
22 | 0. Python >= 3.9.
23 | 0. Clone this repository.
24 | 0. Install python requirements. Please refer [requirements.txt](requirements.txt).
25 | 0. Download datasets
26 | 1. Download and extract the [VCTK-0.92 dataset](https://datashare.ed.ac.uk/handle/10283/3443), and move its `wav48` directory into [VCTK-Corpus-0.92](VCTK-Corpus-0.92) and rename it as `wav48_origin`.
27 | 1. Trim the silence of the dataset, and the trimmed files will be saved to `wav48_silence_trimmed`.
28 | ```
29 | cd VCTK-Corpus-0.92
30 | python flac2wav.py
31 | ```
32 | 1. Move all the trimmed training files from `wav48_silence_trimmed` to [wav48/train](wav48/train) following the indexes in [training.txt](VCTK-Corpus-0.92/training.txt), and move all the untrimmed test files from `wav48_origin` to [wav48/test](wav48/test) following the indexes in [test.txt](VCTK-Corpus-0.92/test.txt).
33 |
34 | ## Training
35 | ```
36 | cd train
37 | CUDA_VISIBLE_DEVICES=0 python train_16k.py --config [config file path]
38 | CUDA_VISIBLE_DEVICES=0 python train_48k.py --config [config file path]
39 | ```
40 | Checkpoints and copies of the configuration file are saved in the `cp_model` directory by default.
41 | You can change the path by using the `--checkpoint_path` option.
42 | Here is an example:
43 | ```
44 | CUDA_VISIBLE_DEVICES=0 python train_16k.py --config ../configs/config_2kto16k.json --checkpoint_path ../checkpoints/AP-BWE_2kto16k
45 | ```
46 |
47 | ## Inference
48 | ```
49 | cd inference
50 | python inference_16k.py --checkpoint_file [generator checkpoint file path]
51 | python inference_48k.py --checkpoint_file [generator checkpoint file path]
52 | ```
53 | You can download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) we provide and move all the files to the `checkpoints` directory.
54 |
55 | Generated wav files are saved in `generated_files` by default.
56 | You can change the path by adding `--output_dir` option.
57 | Here is an example:
58 | ```
59 | python inference_16k.py --checkpoint_file ../checkpoints/2kto16k/g_2kto16k --output_dir ../generated_files/2kto16k
60 | ```
61 |
62 | ## Model Structure
63 | 
64 |
65 | ## Comparison with other speech BWE methods
66 | ### 2k/4k/8kHz to 16kHz
67 |
68 |
69 |
70 |
71 | ### 8k/12k/16/24kHz to 48kHz
72 |
73 |
74 |
75 |
76 | ## Acknowledgements
77 | We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSPP](https://github.com/YangAi520/NSPP) to implement this.
78 |
79 | ## Citation
80 | ```
81 | @article{lu2024towards,
82 | title={Towards high-quality and efficient speech bandwidth extension with parallel amplitude and phase prediction},
83 | author={Lu, Ye-Xin and Ai, Yang and Du, Hui-Peng and Ling, Zhen-Hua},
84 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
85 | volume={33},
86 | pages={236--250},
87 | year={2024}
88 | }
89 |
90 | @inproceedings{lu2024multi,
91 | title={Multi-Stage Speech Bandwidth Extension with Flexible Sampling Rate Control},
92 | author={Lu, Ye-Xin and Ai, Yang and Sheng, Zheng-Yan and Ling, Zhen-Hua},
93 | booktitle={Proc. Interspeech},
94 | pages={2270--2274},
95 | year={2024}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/flac2wav.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import soundfile as sf
3 | import os
4 | from glob import glob
5 | from tqdm import tqdm
6 | import multiprocessing as mp
7 | import torch
8 | import torchaudio.functional as aF
9 |
10 | def flac2wav(wav):
11 | y, sr = librosa.load(wav, sr=48000, mono=True)
12 | file_id = os.path.split(wav)[-1].split('_mic')[0]
13 | if file_id in timestamps:
14 | start, end = timestamps[file_id]
15 | start = start - min(start, int(0.1 * sr))
16 | end = end + min(len(y) - end, int(0.1 * sr))
17 | y = y[start: end]
18 |
19 | # y = torch.FloatTensor(y).unsqueeze(0)
20 | # y = aF.resample(y, orig_freq=sr, new_freq=22050).squeeze().numpy()
21 |
22 | os.makedirs(os.path.join('wav48_silence_trimmed', wav.split(os.sep)[-2]), exist_ok=True)
23 |
24 | wav_path = os.path.join('wav48_silence_trimmed', wav.split(os.sep)[-2], file_id +'.wav')
25 |
26 | sf.write(wav_path, y, 48000, 'PCM_16')
27 | del y
28 | return
29 |
30 |
31 | if __name__=='__main__':
32 |
33 | base_dir = 'wav48_origin'
34 |
35 | wavs = glob(os.path.join(base_dir, '*/*mic1.flac'))
36 | sampling_rate = 48000
37 |
38 | timestamps = {}
39 | path_timestamps = 'vctk-silence-labels/vctk-silences.0.92.txt'
40 | with open(path_timestamps, 'r') as f:
41 | timestamps_list = f.readlines()
42 | for line in timestamps_list:
43 | timestamp_data = line.strip().split(' ')
44 | if len(timestamp_data) == 3:
45 | file_id, t_start, t_end = timestamp_data
46 | t_start = int(float(t_start) * sampling_rate)
47 | t_end = int(float(t_end) * sampling_rate)
48 | timestamps[file_id] = (t_start, t_end)
49 |
50 | pool = mp.Pool(processes = 8)
51 | with tqdm(total = len(wavs)) as pbar:
52 | for _ in tqdm(pool.imap_unordered(flac2wav, wavs)):
53 | pbar.update()
54 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/vctk-silence-labels/README.md:
--------------------------------------------------------------------------------
1 | # Leading and trailing silence label for VCTK Corpus (version 0.92)
2 |
3 | This repository provides the information about the leading and trailing silences of VCTK corpus. The infomation is obtained by training an ASR model with the corpus and use it to extract alignment, which is more robust approach to non-uniform noise than using signal processing.
4 |
5 | This information is useful for preprocessing the VCTK Corpus to remove exessive leading and trailing silence frames with noise. The silence information is only compatiable with version 0.92 of the corpus.
6 |
7 | ## Original waveform and silence timestamp
8 |
9 | 
10 |
11 | *p226_036_mic1.wav*
12 |
13 | The leading and trailing silence timestamps can be found in `vctk-silences.0.92.txt`. It includes utterance id, start and end of speech segment in second.
14 | ```
15 | p226_036 0.930 3.450
16 | ```
17 |
18 | ## Trim the leading and trailing silences
19 | ```bash
20 | sox p226_036_mic1.wav p226_036_mic1.trim.wav trim 0.930 =3.450
21 | ```
22 |
23 | We could use `sox` to trim the silences as shown by the above example. As it is just timestamp, we can manipulate the clipping region by adjusting it.
24 |
25 | 
26 |
27 | *p226_036_mic1.trim.wav*
28 |
29 |
30 | ## Trim then pad with artificial silences
31 | ```bash
32 | sox p226_036_mic1.wav p226_036_mic1.trimpad.wav trim 0.930 =3.450 pad 0.25 0.25
33 | ```
34 |
35 | We can also pad the utterance with a small leading and trailing segment silence as shown by the above example. It might be helpful for training end-to-end TTS model.
36 |
37 | 
38 |
39 | *p226_036_mic1.trimpad.wav*
40 |
41 | ## References
42 | [VCTK Corpus (version 0.92)](https://datashare.is.ed.ac.uk/handle/10283/3443)
43 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/vctk-silence-labels/assets/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/original.png
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/vctk-silence-labels/assets/trim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/trim.png
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/vctk-silence-labels/assets/trimpad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/VCTK-Corpus-0.92/vctk-silence-labels/assets/trimpad.png
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/wav16/test/readme.txt:
--------------------------------------------------------------------------------
1 | Move all the test files here.
2 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/wav16/train/readme.txt:
--------------------------------------------------------------------------------
1 | Move all the training files here.
2 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/wav48/test/readme.txt:
--------------------------------------------------------------------------------
1 | Move all the test files here.
2 |
--------------------------------------------------------------------------------
/VCTK-Corpus-0.92/wav48/train/readme.txt:
--------------------------------------------------------------------------------
1 | Move all the training files here.
2 |
--------------------------------------------------------------------------------
/cal_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torchaudio
5 | import numpy as np
6 | from rich.progress import track
7 |
8 | def stft(audio, n_fft=2048, hop_length=512):
9 | hann_window = torch.hann_window(n_fft).to(audio.device)
10 | stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
11 | stft_mag = torch.abs(stft_spec)
12 | stft_pha = torch.angle(stft_spec)
13 |
14 | return stft_mag, stft_pha
15 |
16 |
17 | def cal_snr(pred, target):
18 | snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
19 | return snr
20 |
21 |
22 | def cal_lsd(pred, target):
23 | sp = torch.log10(stft(pred)[0].square().clamp(1e-8))
24 | st = torch.log10(stft(target)[0].square().clamp(1e-8))
25 | return (sp - st).square().mean(dim=1).sqrt().mean()
26 |
27 |
28 | def anti_wrapping_function(x):
29 | return x - torch.round(x / (2 * np.pi)) * 2 * np.pi
30 |
31 |
32 | def cal_apd(pred, target):
33 | pha_pred = stft(pred)[1]
34 | pha_target = stft(target)[1]
35 | dim_freq = 1025
36 | dim_time = pha_pred.size(-1)
37 |
38 | gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - torch.eye(dim_freq)).to(device)
39 | gd_r = torch.matmul(pha_target.permute(0, 2, 1), gd_matrix)
40 | gd_g = torch.matmul(pha_pred.permute(0, 2, 1), gd_matrix)
41 |
42 | iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - torch.eye(dim_time)).to(device)
43 | iaf_r = torch.matmul(pha_target, iaf_matrix)
44 | iaf_g = torch.matmul(pha_pred, iaf_matrix)
45 |
46 | apd_ip = anti_wrapping_function(pha_pred - pha_target).square().mean(dim=1).sqrt().mean()
47 | apd_gd = anti_wrapping_function(gd_r - gd_g).square().mean(dim=1).sqrt().mean()
48 | apd_iaf = anti_wrapping_function(iaf_r - iaf_g).square().mean(dim=1).sqrt().mean()
49 |
50 | return apd_ip, apd_gd, apd_iaf
51 |
52 |
53 | def main(h):
54 |
55 | wav_indexes = os.listdir(h.reference_wav_dir)
56 |
57 | metrics = {'lsd':[], 'apd_ip': [], 'apd_gd': [], 'apd_iaf': [], 'snr':[]}
58 |
59 | for wav_index in track(wav_indexes):
60 |
61 | ref_wav, ref_sr = torchaudio.load(os.path.join(h.reference_wav_dir, wav_index))
62 | syn_wav, syn_sr = torchaudio.load(os.path.join(h.synthesis_wav_dir, wav_index))
63 |
64 | length = min(ref_wav.size(1), syn_wav.size(1))
65 | ref_wav = ref_wav[:, : length].to(device)
66 | syn_wav = syn_wav[:, : length].to(device)
67 | ref_wav = ref_wav.to(device)
68 | syn_wav = syn_wav[:, : ref_wav.size(1)].to(device)
69 |
70 | lsd_score = cal_lsd(syn_wav, ref_wav)
71 | apd_score = cal_apd(syn_wav, ref_wav)
72 | snr_score = cal_snr(syn_wav, ref_wav)
73 |
74 |
75 | metrics['lsd'].append(lsd_score)
76 | metrics['apd_ip'].append(apd_score[0])
77 | metrics['apd_gd'].append(apd_score[1])
78 | metrics['apd_iaf'].append(apd_score[2])
79 | metrics['snr'].append(snr_score)
80 |
81 |
82 | lsd_mean = torch.stack(metrics['lsd'], dim=0).mean()
83 | apd_ip_mean = torch.stack(metrics['apd_ip'], dim=0).mean()
84 | apd_gd_mean = torch.stack(metrics['apd_gd'], dim=0).mean()
85 | apd_iaf_mean = torch.stack(metrics['apd_iaf'], dim=0).mean()
86 | snr_mean = torch.stack(metrics['snr'], dim=0).mean()
87 |
88 | print('LSD: {:.3f}'.format(lsd_mean))
89 | print('SNR: {:.3f}'.format(snr_mean))
90 | print('APD_IP: {:.3f}'.format(apd_ip_mean))
91 | print('APD_GD: {:.3f}'.format(apd_gd_mean))
92 | print('APD_IAF: {:.3f}'.format(apd_iaf_mean))
93 |
94 | if __name__ == '__main__':
95 | parser = argparse.ArgumentParser()
96 |
97 | parser.add_argument('--reference_wav_dir', default='VCTK-Corpus-0.92/wav16/test')
98 | parser.add_argument('--synthesis_wav_dir', default='generated_files/AP-BWE')
99 |
100 | h = parser.parse_args()
101 |
102 | global device
103 | if torch.cuda.is_available():
104 | device = torch.device('cuda')
105 | else:
106 | device = torch.device('cpu')
107 |
108 | main(h)
--------------------------------------------------------------------------------
/cal_visqol_48k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | import librosa
5 | from visqol import visqol_lib_py
6 | from visqol.pb2 import visqol_config_pb2
7 | from visqol.pb2 import similarity_result_pb2
8 | from rich.progress import track
9 | from joblib import Parallel, delayed
10 |
11 | config = visqol_config_pb2.VisqolConfig()
12 |
13 | def cal_vq(reference, degraded, mode='audio'):
14 | if mode == "audio":
15 | config.audio.sample_rate = 48000
16 | config.options.use_speech_scoring = False
17 | svr_model_path = "libsvm_nu_svr_model.txt"
18 | elif mode == "speech":
19 | config.audio.sample_rate = 16000
20 | config.options.use_speech_scoring = True
21 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
22 | else:
23 | raise ValueError(f"Unrecognized mode: {mode}")
24 |
25 | config.options.svr_model_path = os.path.join(
26 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
27 |
28 | api = visqol_lib_py.VisqolApi()
29 |
30 | api.Create(config)
31 |
32 | similarity_result = api.Measure(reference, degraded)
33 |
34 | return similarity_result.moslqo
35 |
36 |
37 | def main(h):
38 | # with open(h.test_file, 'r', encoding='utf-8') as fi:
39 | # wav_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0]
40 | wav_indexes = os.listdir(h.ref_wav_dir)
41 |
42 | metrics = {'vq':[]}
43 |
44 | for wav_index in track(wav_indexes):
45 |
46 | ref_wav, ref_sr = librosa.load(os.path.join(h.ref_wav_dir, wav_index), sr=float(h.sampling_rate), dtype=np.float64)
47 | syn_wav, syn_sr = librosa.load(os.path.join(h.syn_wav_dir, wav_index), sr=float(h.sampling_rate), dtype=np.float64)
48 |
49 | if float(h.sampling_rate) != 48000:
50 | ref_wav = librosa.resample(ref_wav, orig_sr=float(h.sampling_rate), target_sr=48000)
51 | syn_wav = librosa.resample(syn_wav, orig_sr=float(h.sampling_rate), target_sr=48000)
52 |
53 | length = min(len(ref_wav), len(syn_wav))
54 | ref_wav = ref_wav[: length]
55 | syn_wav = syn_wav[: length]
56 | try:
57 | vq_score = cal_vq(ref_wav, syn_wav)
58 | metrics['vq'].append(vq_score)
59 | except:
60 | vq_score = 0
61 |
62 | vq_mean = np.mean(metrics['vq'])
63 |
64 | print('VISQOL: {:.3f}'.format(vq_mean))
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--sampling_rate', required=True)
69 | parser.add_argument('--ref_wav_dir', required=True)
70 | parser.add_argument('--syn_wav_dir', required=True)
71 |
72 | h = parser.parse_args()
73 |
74 | main(h)
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | Download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) and move the checkpoints files here.
2 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/configs/config_12kto48k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 48000,
19 | "lr_sampling_rate": 12000,
20 | "subsampling_rate": 4,
21 |
22 | "num_workers": 4,
23 |
24 | "dist_config": {
25 | "dist_backend": "nccl",
26 | "dist_url": "tcp://localhost:54321",
27 | "world_size": 1
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/configs/config_16kto48k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 48000,
19 | "lr_sampling_rate": 16000,
20 |
21 | "num_workers": 4,
22 |
23 | "dist_config": {
24 | "dist_backend": "nccl",
25 | "dist_url": "tcp://localhost:54321",
26 | "world_size": 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/configs/config_24kto48k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 48000,
19 | "lr_sampling_rate": 24000,
20 | "subsampling_rate": 2,
21 |
22 | "num_workers": 4,
23 |
24 | "dist_config": {
25 | "dist_backend": "nccl",
26 | "dist_url": "tcp://localhost:54321",
27 | "world_size": 1
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/configs/config_2kto16k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 16000,
19 | "lr_sampling_rate": 2000,
20 |
21 | "num_workers": 4,
22 |
23 | "dist_config": {
24 | "dist_backend": "nccl",
25 | "dist_url": "tcp://localhost:54321",
26 | "world_size": 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/configs/config_4kto16k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 16000,
19 | "lr_sampling_rate": 4000,
20 |
21 | "num_workers": 4,
22 |
23 | "dist_config": {
24 | "dist_backend": "nccl",
25 | "dist_url": "tcp://localhost:54321",
26 | "world_size": 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/configs/config_8kto16k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 16000,
19 | "lr_sampling_rate": 8000,
20 |
21 | "num_workers": 4,
22 |
23 | "dist_config": {
24 | "dist_backend": "nccl",
25 | "dist_url": "tcp://localhost:54321",
26 | "world_size": 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/configs/config_8kto48k.json:
--------------------------------------------------------------------------------
1 | {
2 | "num_gpus": 0,
3 | "batch_size": 16,
4 | "learning_rate": 0.0002,
5 | "adam_b1": 0.8,
6 | "adam_b2": 0.99,
7 | "lr_decay": 0.999,
8 | "seed": 1234,
9 |
10 | "ConvNeXt_channels": 512,
11 | "ConvNeXt_layers": 8,
12 |
13 | "segment_size": 8000,
14 | "n_fft": 1024,
15 | "hop_size": 80,
16 | "win_size": 320,
17 |
18 | "hr_sampling_rate": 48000,
19 | "lr_sampling_rate": 8000,
20 | "subsampling_rate": 6,
21 |
22 | "num_workers": 4,
23 |
24 | "dist_config": {
25 | "dist_backend": "nccl",
26 | "dist_url": "tcp://localhost:54321",
27 | "world_size": 1
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import torchaudio
5 | import torch.utils.data
6 | import torchaudio.functional as aF
7 |
8 | def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
9 |
10 | hann_window = torch.hann_window(win_size).to(audio.device)
11 | stft_spec = torch.stft(audio, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
12 | center=center, pad_mode='reflect', normalized=False, return_complex=True)
13 | log_amp = torch.log(torch.abs(stft_spec)+1e-4)
14 | pha = torch.angle(stft_spec)
15 |
16 | com = torch.stack((torch.exp(log_amp)*torch.cos(pha),
17 | torch.exp(log_amp)*torch.sin(pha)), dim=-1)
18 |
19 | return log_amp, pha, com
20 |
21 |
22 | def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
23 |
24 | amp = torch.exp(log_amp)
25 | com = torch.complex(amp*torch.cos(pha), amp*torch.sin(pha))
26 | hann_window = torch.hann_window(win_size).to(com.device)
27 | audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
28 |
29 | return audio
30 |
31 |
32 | def get_dataset_filelist(a):
33 | with open(a.input_training_file, 'r', encoding='utf-8') as fi:
34 | training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0]
35 |
36 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
37 | validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0]
38 |
39 | return training_indexes, validation_indexes
40 |
41 |
42 | class Dataset(torch.utils.data.Dataset):
43 | def __init__(self, training_indexes, wavs_dir, segment_size, hr_sampling_rate, lr_sampling_rate,
44 | split=True, shuffle=True, n_cache_reuse=1, device=None):
45 | self.audio_indexes = training_indexes
46 | random.seed(1234)
47 | if shuffle:
48 | random.shuffle(self.audio_indexes)
49 | self.wavs_dir = wavs_dir
50 | self.segment_size = segment_size
51 | self.hr_sampling_rate = hr_sampling_rate
52 | self.lr_sampling_rate = lr_sampling_rate
53 | self.split = split
54 | self.cached_wav = None
55 | self.n_cache_reuse = n_cache_reuse
56 | self._cache_ref_count = 0
57 | self.device = device
58 |
59 | def __getitem__(self, index):
60 | filename = self.audio_indexes[index]
61 | if self._cache_ref_count == 0:
62 | audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + '.wav'))
63 | self.cached_wav = audio
64 | self._cache_ref_count = self.n_cache_reuse
65 | else:
66 | audio = self.cached_wav
67 | self._cache_ref_count -= 1
68 |
69 | if orig_sampling_rate == self.hr_sampling_rate:
70 | audio_hr = audio
71 | else:
72 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.hr_sampling_rate)
73 |
74 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.lr_sampling_rate)
75 | audio_lr = aF.resample(audio_lr, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
76 | audio_lr = audio_lr[:, : audio_hr.size(1)]
77 |
78 | if self.split:
79 | if audio_hr.size(1) >= self.segment_size:
80 | max_audio_start = audio_hr.size(1) - self.segment_size
81 | audio_start = random.randint(0, max_audio_start)
82 | audio_hr = audio_hr[:, audio_start: audio_start+self.segment_size]
83 | audio_lr = audio_lr[:, audio_start: audio_start+self.segment_size]
84 | else:
85 | audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), 'constant')
86 | audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), 'constant')
87 |
88 | return (audio_hr.squeeze(), audio_lr.squeeze())
89 |
90 | def __len__(self):
91 |
92 | return len(self.audio_indexes)
93 |
--------------------------------------------------------------------------------
/docs/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/css/bulma-slider.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff}
--------------------------------------------------------------------------------
/docs/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 |
5 |
6 | .footer .icon-link {
7 | font-size: 25px;
8 | color: #000;
9 | }
10 |
11 | .link-block a {
12 | margin-top: 5px;
13 | margin-bottom: 5px;
14 | }
15 |
16 | .dnerf {
17 | font-variant: small-caps;
18 | }
19 |
20 |
21 | .teaser .hero-body {
22 | padding-top: 0;
23 | padding-bottom: 3rem;
24 | }
25 |
26 | .teaser {
27 | font-family: 'Google Sans', sans-serif;
28 | }
29 |
30 |
31 | .publication-title {
32 | }
33 |
34 | .publication-banner {
35 | max-height: parent;
36 |
37 | }
38 |
39 | .publication-banner video {
40 | position: relative;
41 | left: auto;
42 | top: auto;
43 | transform: none;
44 | object-fit: fit;
45 | }
46 |
47 | .publication-header .hero-body {
48 | }
49 |
50 | .publication-title {
51 | font-family: 'Google Sans', sans-serif;
52 | }
53 |
54 | .publication-authors {
55 | font-family: 'Google Sans', sans-serif;
56 | }
57 |
58 | .publication-venue {
59 | color: #555;
60 | width: fit-content;
61 | font-weight: bold;
62 | }
63 |
64 | .publication-awards {
65 | color: #ff3860;
66 | width: fit-content;
67 | font-weight: bolder;
68 | }
69 |
70 | .publication-authors {
71 | }
72 |
73 | .publication-authors a {
74 | color: hsl(204, 86%, 53%) !important;
75 | }
76 |
77 | .publication-authors a:hover {
78 | text-decoration: underline;
79 | }
80 |
81 | .author-block {
82 | display: inline-block;
83 | }
84 |
85 | .publication-banner img {
86 | }
87 |
88 | .publication-authors {
89 | /*color: #4286f4;*/
90 | }
91 |
92 | .publication-video {
93 | position: relative;
94 | width: 100%;
95 | height: 0;
96 | padding-bottom: 56.25%;
97 |
98 | overflow: hidden;
99 | border-radius: 10px !important;
100 | }
101 |
102 | .publication-video iframe {
103 | position: absolute;
104 | top: 0;
105 | left: 0;
106 | width: 100%;
107 | height: 100%;
108 | }
109 |
110 | .publication-body img {
111 | }
112 |
113 | .results-carousel {
114 | overflow: hidden;
115 | }
116 |
117 | .results-carousel .item {
118 | margin: 5px;
119 | overflow: hidden;
120 | padding: 20px;
121 | font-size: 0;
122 | }
123 |
124 | .results-carousel video {
125 | margin: 0;
126 | }
127 |
128 | .slider-pagination .slider-page {
129 | background: #000000;
130 | }
131 |
132 | .eql-cntrb {
133 | font-size: smaller;
134 | }
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | AP-BWE
5 |
6 |
8 |
9 |
10 |
11 |
12 |
13 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
Towards Efficient and High-Quality Bandwidth Extension with Parallel Amplitude-Phase Prediction
33 |
44 |
45 |
46 | National Engineering Research Center of Speech and Language Information Processing University of Science and Technology of China
47 |
48 |
49 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
Abstract
86 |
87 |
88 | Speech bandwidth extension (BWE) refers to increasing the bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller.
89 | This paper proposed a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both efficient and high-quality wideband waveform generation.
90 | Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods.
91 | The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs), it features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the narrowband amplitude and phase spectra.
92 | To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively.
93 | Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in speech quality for both BWE tasks targeting sampling rates of 16 kHz and 48 kHz.
94 | In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on CPU.
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
187 |
188 |
189 |
291 |
292 |
293 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
IV. Cross-Dataset Evaluation
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
Libri-TTS (8 kHz to 24 kHz)
376 |
377 |
378 |
379 |
380 |
381 | Wideband
382 | NU-Wave2
383 | UDM+
384 | mdctGAN
385 | AP-BWE (Ours)
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
HiFi-TTS (8 kHz to 44.1 kHz)
412 |
413 |
414 |
415 |
416 |
417 | Wideband
418 | NU-Wave2
419 | UDM+
420 | mdctGAN
421 | AP-BWE (Ours)
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
--------------------------------------------------------------------------------
/docs/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 |
4 | $(document).ready(function() {
5 | // Check for click events on the navbar burger icon
6 |
7 | var options = {
8 | slidesToScroll: 1,
9 | slidesToShow: 1,
10 | loop: true,
11 | infinite: true,
12 | autoplay: false,
13 | autoplaySpeed: 3000,
14 | }
15 |
16 | // Initialize all div with carousel class
17 | var carousels = bulmaCarousel.attach('.carousel', options);
18 |
19 | // Loop on each carousel initialized
20 | for(var i = 0; i < carousels.length; i++) {
21 | // Add listener to event
22 | carousels[i].on('before:show', state => {
23 | console.log(state);
24 | });
25 | }
26 |
27 | // Access to bulmaCarousel instance of an element
28 | var element = document.querySelector('#my-element');
29 | if (element && element.bulmaCarousel) {
30 | // bulmaCarousel instance is available as element.bulmaCarousel
31 | element.bulmaCarousel.on('before-show', function(state) {
32 | console.log(state);
33 | });
34 | }
35 |
36 | /*var player = document.getElementById('interpolation-video');
37 | player.addEventListener('loadedmetadata', function() {
38 | $('#interpolation-slider').on('input', function(event) {
39 | console.log(this.value, player.duration);
40 | player.currentTime = player.duration / 100 * this.value;
41 | })
42 | }, false);*/
43 |
44 | bulmaSlider.attach();
45 |
46 | })
47 |
--------------------------------------------------------------------------------
/docs/js/main.js:
--------------------------------------------------------------------------------
1 | (()=>{"use strict";var e={r:e=>{"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})}},t={};e.r(t);const c=Object.freeze({latest:"3.12.1_3.2.3-63070bed",public:"3.12.1_3.2.2-0a1b32f6",legacy:"3.12.1_3.2.2-0a1b32f6",edit:"2.24.4_2.12.0-7d0a8c15"});!function loadSDKScript(){var e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:"latest.js",t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:c.latest,n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:"View",r=Array.from(document.scripts),a="view-sdk/".concat(e),i=r.find((function(e){return e.src&&e.src.endsWith(a)})).src,o=i.substring(0,i.indexOf(e)),d="".concat("".concat(o+t,"/").concat(n),"SDKInterface.js"),s=document.createElement("script");s.async=!1,s.setAttribute("src",d),document.head.appendChild(s)}("main.js",c.legacy),window.adobe_dc_view_sdk=t})();
2 | //# sourceMappingURL=3.12.1_3.2.3-63070bed/private/main.js.map
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mpd_only.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mpd_only.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mpd_only.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mpd_only.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mrad_only.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrad_only.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mrad_only.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrad_only.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mrpd_only.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrpd_only.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_mrpd_only.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_mrpd_only.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_AtoP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_AtoP.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_AtoP.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_AtoP.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_PtoA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_PtoA.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_PtoA.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_PtoA.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_connect.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_connect.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_connect.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_connect.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_disc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_disc.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_disc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_disc.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mpd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mpd.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mpd.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mpd.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mrad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrad.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mrad.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrad.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mrpd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrpd.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/ap-bwe_wo_mrpd.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/ap-bwe_wo_mrpd.wav
--------------------------------------------------------------------------------
/docs/samples/Ablation/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/wideband.png
--------------------------------------------------------------------------------
/docs/samples/Ablation/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/Ablation/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/afilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/afilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/afilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/afilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/nvsr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/nvsr.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/nvsr.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/nvsr.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/tfilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/tfilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/tfilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/tfilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/2kto16k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/2kto16k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/afilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/afilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/afilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/afilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/nvsr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/nvsr.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/nvsr.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/nvsr.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/tfilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/tfilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/tfilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/tfilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/4kto16k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/4kto16k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/afilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/afilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/afilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/afilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/nvsr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/nvsr.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/nvsr.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/nvsr.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/tfilm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/tfilm.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/tfilm.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/tfilm.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_16k/8kto16k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_16k/8kto16k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/udm+.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/12kto48k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/12kto48k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/udm+.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/16kto48k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/16kto48k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/udm+.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/24kto48k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/24kto48k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/sinc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/sinc.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/sinc.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/sinc.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/udm+.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/wideband.png
--------------------------------------------------------------------------------
/docs/samples/BWE_48k/8kto48k/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/BWE_48k/8kto48k/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/udm+.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/wideband.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/HiFi-TTS/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/HiFi-TTS/wideband.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/ap-bwe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/ap-bwe.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/ap-bwe.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/ap-bwe.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/mdctgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/mdctgan.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/mdctgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/mdctgan.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/nuwave2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/nuwave2.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/nuwave2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/nuwave2.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/udm+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/udm+.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/udm+.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/udm+.wav
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/wideband.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/wideband.png
--------------------------------------------------------------------------------
/docs/samples/CrossDataset/Libri-TTS/wideband.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxlu-0102/AP-BWE/751710f22404c27e5bcc983248f8b856a04b8422/docs/samples/CrossDataset/Libri-TTS/wideband.wav
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/inference/inference_16k.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import sys
3 | sys.path.append("..")
4 | import glob
5 | import os
6 | import argparse
7 | import json
8 | from re import S
9 | import torch
10 | import time
11 | import numpy as np
12 | import torchaudio
13 | import torchaudio.functional as aF
14 | from env import AttrDict
15 | from datasets.dataset import amp_pha_stft, amp_pha_istft
16 | from models.model import APNet_BWE_Model
17 | import soundfile as sf
18 | import matplotlib.pyplot as plt
19 | from rich.progress import track
20 |
21 | h = None
22 | device = None
23 |
24 | def load_checkpoint(filepath, device):
25 | assert os.path.isfile(filepath)
26 | print("Loading '{}'".format(filepath))
27 | checkpoint_dict = torch.load(filepath, map_location=device)
28 | print("Complete.")
29 | return checkpoint_dict
30 |
31 | def scan_checkpoint(cp_dir, prefix):
32 | pattern = os.path.join(cp_dir, prefix + '*')
33 | cp_list = glob.glob(pattern)
34 | if len(cp_list) == 0:
35 | return ''
36 | return sorted(cp_list)[-1]
37 |
38 | def inference(a):
39 | model = APNet_BWE_Model(h).to(device)
40 |
41 | state_dict = load_checkpoint(a.checkpoint_file, device)
42 | model.load_state_dict(state_dict['generator'])
43 |
44 | test_indexes = os.listdir(a.input_wavs_dir)
45 |
46 | os.makedirs(a.output_dir, exist_ok=True)
47 |
48 | model.eval()
49 | duration_tot = 0
50 | with torch.no_grad():
51 | for i, index in enumerate(track(test_indexes)):
52 | # print(index)
53 | audio, orig_sampling_rate = torchaudio.load(os.path.join(a.input_wavs_dir, index))
54 | audio = audio.to(device)
55 |
56 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.hr_sampling_rate)
57 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.lr_sampling_rate)
58 | audio_lr = aF.resample(audio_lr, orig_freq=h.lr_sampling_rate, new_freq=h.hr_sampling_rate)
59 | audio_lr = audio_lr[:, : audio_hr.size(1)]
60 |
61 | amp_wb, pha_wb, com_wb = amp_pha_stft(audio_hr, h.n_fft, h.hop_size, h.win_size)
62 |
63 | pred_start = time.time()
64 | amp_nb, pha_nb, com_nb = amp_pha_stft(audio_lr, h.n_fft, h.hop_size, h.win_size)
65 |
66 | amp_wb_g, pha_wb_g, com_wb_g = model(amp_nb, pha_nb)
67 |
68 | audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
69 | duration_tot += time.time() - pred_start
70 |
71 | output_file = os.path.join(a.output_dir, index)
72 |
73 | sf.write(output_file, audio_hr_g.squeeze().cpu().numpy(), h.hr_sampling_rate, 'PCM_16')
74 |
75 | print(duration_tot)
76 |
77 |
78 | def main():
79 | print('Initializing Inference Process..')
80 |
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument('--input_wavs_dir', default='VCTK-Corpus-0.92/wav16/test')
83 | parser.add_argument('--output_dir', default='../generated_files')
84 | parser.add_argument('--checkpoint_file', required=True)
85 | a = parser.parse_args()
86 |
87 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
88 | with open(config_file) as f:
89 | data = f.read()
90 |
91 | global h
92 | json_config = json.loads(data)
93 | h = AttrDict(json_config)
94 |
95 | torch.manual_seed(h.seed)
96 | global device
97 | if torch.cuda.is_available():
98 | torch.cuda.manual_seed(h.seed)
99 | device = torch.device('cuda')
100 | else:
101 | device = torch.device('cpu')
102 |
103 | inference(a)
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
109 |
--------------------------------------------------------------------------------
/inference/inference_48k.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import sys
3 | sys.path.append("..")
4 | import glob
5 | import os
6 | import argparse
7 | import json
8 | from re import S
9 | import torch
10 | import numpy as np
11 | import torchaudio
12 | import time
13 | import torchaudio.functional as aF
14 | from env import AttrDict
15 | from datasets.dataset import amp_pha_stft, amp_pha_istft
16 | from models.model import APNet_BWE_Model
17 | import soundfile as sf
18 | import matplotlib.pyplot as plt
19 | from rich.progress import track
20 |
21 | h = None
22 | device = None
23 |
24 | def load_checkpoint(filepath, device):
25 | assert os.path.isfile(filepath)
26 | print("Loading '{}'".format(filepath))
27 | checkpoint_dict = torch.load(filepath, map_location=device)
28 | print("Complete.")
29 | return checkpoint_dict
30 |
31 | def scan_checkpoint(cp_dir, prefix):
32 | pattern = os.path.join(cp_dir, prefix + '*')
33 | cp_list = glob.glob(pattern)
34 | if len(cp_list) == 0:
35 | return ''
36 | return sorted(cp_list)[-1]
37 |
38 | def inference(a):
39 | model = APNet_BWE_Model(h).to(device)
40 |
41 | state_dict = load_checkpoint(a.checkpoint_file, device)
42 | model.load_state_dict(state_dict['generator'])
43 |
44 | test_indexes = os.listdir(a.input_wavs_dir)
45 |
46 | os.makedirs(a.output_dir, exist_ok=True)
47 |
48 | model.eval()
49 | duration_tot = 0
50 | with torch.no_grad():
51 | for i, index in enumerate(track(test_indexes)):
52 | # print(index)
53 | audio, orig_sampling_rate = torchaudio.load(os.path.join(a.input_wavs_dir, index))
54 | audio = audio.to(device)
55 |
56 | audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.hr_sampling_rate)
57 | audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=h.lr_sampling_rate)
58 | audio_lr = aF.resample(audio_lr, orig_freq=h.lr_sampling_rate, new_freq=h.hr_sampling_rate)
59 | audio_lr = audio_lr[:, : audio_hr.size(1)]
60 |
61 | pred_start = time.time()
62 | amp_nb, pha_nb, com_nb = amp_pha_stft(audio_lr, h.n_fft, h.hop_size, h.win_size)
63 |
64 | amp_wb_g, pha_wb_g, com_wb_g = model(amp_nb, pha_nb)
65 |
66 | audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
67 | duration_tot += time.time() - pred_start
68 |
69 | output_file = os.path.join(a.output_dir, index)
70 |
71 | sf.write(output_file, audio_hr_g.squeeze().cpu().numpy(), h.hr_sampling_rate, 'PCM_16')
72 |
73 | print(duration_tot)
74 |
75 | def main():
76 | print('Initializing Inference Process..')
77 |
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument('--input_wavs_dir', default='VCTK-Corpus-0.92/wav48/test')
80 | parser.add_argument('--output_dir', default='../generated_files')
81 | parser.add_argument('--checkpoint_file', required=True)
82 | a = parser.parse_args()
83 |
84 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
85 | with open(config_file) as f:
86 | data = f.read()
87 |
88 | global h
89 | json_config = json.loads(data)
90 | h = AttrDict(json_config)
91 |
92 | torch.manual_seed(h.seed)
93 | global device
94 | if torch.cuda.is_available():
95 | torch.cuda.manual_seed(h.seed)
96 | device = torch.device('cuda')
97 | else:
98 | device = torch.device('cpu')
99 |
100 | inference(a)
101 |
102 |
103 | if __name__ == '__main__':
104 | main()
105 |
106 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
5 | from utils import init_weights, get_padding
6 | import numpy as np
7 | from typing import Tuple, List
8 |
9 | LRELU_SLOPE = 0.1
10 |
11 | class ConvNeXtBlock(nn.Module):
12 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
13 |
14 | Args:
15 | dim (int): Number of input channels.
16 | intermediate_dim (int): Dimensionality of the intermediate layer.
17 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
18 | Defaults to None.
19 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
20 | None means non-conditional LayerNorm. Defaults to None.
21 | """
22 |
23 | def __init__(
24 | self,
25 | dim: int,
26 | layer_scale_init_value= None,
27 | adanorm_num_embeddings = None,
28 | ):
29 | super().__init__()
30 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
31 | self.adanorm = adanorm_num_embeddings is not None
32 |
33 | self.norm = nn.LayerNorm(dim, eps=1e-6)
34 | self.pwconv1 = nn.Linear(dim, dim*3) # pointwise/1x1 convs, implemented with linear layers
35 | self.act = nn.GELU()
36 | self.pwconv2 = nn.Linear(dim*3, dim)
37 | self.gamma = (
38 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
39 | if layer_scale_init_value > 0
40 | else None
41 | )
42 |
43 | def forward(self, x, cond_embedding_id = None) :
44 | residual = x
45 | x = self.dwconv(x)
46 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
47 | if self.adanorm:
48 | assert cond_embedding_id is not None
49 | x = self.norm(x, cond_embedding_id)
50 | else:
51 | x = self.norm(x)
52 | x = self.pwconv1(x)
53 | x = self.act(x)
54 | x = self.pwconv2(x)
55 | if self.gamma is not None:
56 | x = self.gamma * x
57 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
58 |
59 | x = residual + x
60 | return x
61 |
62 |
63 | class APNet_BWE_Model(torch.nn.Module):
64 | def __init__(self, h):
65 | super(APNet_BWE_Model, self).__init__()
66 | self.h = h
67 | self.adanorm_num_embeddings = None
68 | layer_scale_init_value = 1 / h.ConvNeXt_layers
69 |
70 | self.conv_pre_mag = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
71 | self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
72 | self.conv_pre_pha = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
73 | self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
74 |
75 | self.convnext_mag = nn.ModuleList(
76 | [
77 | ConvNeXtBlock(
78 | dim=h.ConvNeXt_channels,
79 | layer_scale_init_value=layer_scale_init_value,
80 | adanorm_num_embeddings=self.adanorm_num_embeddings,
81 | )
82 | for _ in range(h.ConvNeXt_layers)
83 | ]
84 | )
85 |
86 | self.convnext_pha = nn.ModuleList(
87 | [
88 | ConvNeXtBlock(
89 | dim=h.ConvNeXt_channels,
90 | layer_scale_init_value=layer_scale_init_value,
91 | adanorm_num_embeddings=self.adanorm_num_embeddings,
92 | )
93 | for _ in range(h.ConvNeXt_layers)
94 | ]
95 | )
96 |
97 | self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
98 | self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
99 | self.apply(self._init_weights)
100 | self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
101 | self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
102 | self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
103 |
104 | def _init_weights(self, m):
105 | if isinstance(m, (nn.Conv1d, nn.Linear)):
106 | nn.init.trunc_normal_(m.weight, std=0.02)
107 | nn.init.constant_(m.bias, 0)
108 |
109 | def forward(self, mag_nb, pha_nb):
110 |
111 | x_mag = self.conv_pre_mag(mag_nb)
112 | x_pha = self.conv_pre_pha(pha_nb)
113 | x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
114 | x_pha = self.norm_pre_pha(x_pha.transpose(1, 2)).transpose(1, 2)
115 |
116 | for conv_block_mag, conv_block_pha in zip(self.convnext_mag, self.convnext_pha):
117 | x_mag = x_mag + x_pha
118 | x_pha = x_pha + x_mag
119 | x_mag = conv_block_mag(x_mag, cond_embedding_id=None)
120 | x_pha = conv_block_pha(x_pha, cond_embedding_id=None)
121 |
122 | x_mag = self.norm_post_mag(x_mag.transpose(1, 2))
123 | mag_wb = mag_nb + self.linear_post_mag(x_mag).transpose(1, 2)
124 |
125 | x_pha = self.norm_post_pha(x_pha.transpose(1, 2))
126 | x_pha_r = self.linear_post_pha_r(x_pha)
127 | x_pha_i = self.linear_post_pha_i(x_pha)
128 | pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
129 |
130 | com_wb = torch.stack((torch.exp(mag_wb)*torch.cos(pha_wb),
131 | torch.exp(mag_wb)*torch.sin(pha_wb)), dim=-1)
132 |
133 | return mag_wb, pha_wb, com_wb
134 |
135 |
136 |
137 | class DiscriminatorP(torch.nn.Module):
138 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
139 | super(DiscriminatorP, self).__init__()
140 | self.period = period
141 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
142 | self.convs = nn.ModuleList([
143 | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
144 | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
145 | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
146 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
147 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
148 | ])
149 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
150 |
151 | def forward(self, x):
152 | fmap = []
153 |
154 | # 1d to 2d
155 | b, c, t = x.shape
156 | if t % self.period != 0: # pad first
157 | n_pad = self.period - (t % self.period)
158 | x = F.pad(x, (0, n_pad), "reflect")
159 | t = t + n_pad
160 | x = x.view(b, c, t // self.period, self.period)
161 |
162 | for i,l in enumerate(self.convs):
163 | x = l(x)
164 | x = F.leaky_relu(x, LRELU_SLOPE)
165 | if i > 0:
166 | fmap.append(x)
167 | x = self.conv_post(x)
168 | fmap.append(x)
169 | x = torch.flatten(x, 1, -1)
170 |
171 | return x, fmap
172 |
173 |
174 | class MultiPeriodDiscriminator(torch.nn.Module):
175 | def __init__(self):
176 | super(MultiPeriodDiscriminator, self).__init__()
177 | self.discriminators = nn.ModuleList([
178 | DiscriminatorP(2),
179 | DiscriminatorP(3),
180 | DiscriminatorP(5),
181 | DiscriminatorP(7),
182 | DiscriminatorP(11),
183 | ])
184 |
185 | def forward(self, y, y_hat):
186 | y_d_rs = []
187 | y_d_gs = []
188 | fmap_rs = []
189 | fmap_gs = []
190 | for i, d in enumerate(self.discriminators):
191 | y_d_r, fmap_r = d(y)
192 | y_d_g, fmap_g = d(y_hat)
193 | y_d_rs.append(y_d_r)
194 | fmap_rs.append(fmap_r)
195 | y_d_gs.append(y_d_g)
196 | fmap_gs.append(fmap_g)
197 |
198 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
199 |
200 |
201 | class MultiResolutionAmplitudeDiscriminator(nn.Module):
202 | def __init__(
203 | self,
204 | resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
205 | num_embeddings: int = None,
206 | ):
207 | super().__init__()
208 | self.discriminators = nn.ModuleList(
209 | [DiscriminatorAR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
210 | )
211 |
212 | def forward(
213 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
214 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
215 | y_d_rs = []
216 | y_d_gs = []
217 | fmap_rs = []
218 | fmap_gs = []
219 |
220 | for d in self.discriminators:
221 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
222 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
223 | y_d_rs.append(y_d_r)
224 | fmap_rs.append(fmap_r)
225 | y_d_gs.append(y_d_g)
226 | fmap_gs.append(fmap_g)
227 |
228 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
229 |
230 |
231 | class DiscriminatorAR(nn.Module):
232 | def __init__(
233 | self,
234 | resolution: Tuple[int, int, int],
235 | channels: int = 64,
236 | in_channels: int = 1,
237 | num_embeddings: int = None,
238 | ):
239 | super().__init__()
240 | self.resolution = resolution
241 | self.in_channels = in_channels
242 | self.convs = nn.ModuleList(
243 | [
244 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
245 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
246 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
247 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
248 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
249 | ]
250 | )
251 | if num_embeddings is not None:
252 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
253 | torch.nn.init.zeros_(self.emb.weight)
254 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
255 |
256 | def forward(
257 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
258 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
259 | fmap = []
260 | x=x.squeeze(1)
261 |
262 | x = self.spectrogram(x)
263 | x = x.unsqueeze(1)
264 | for l in self.convs:
265 | x = l(x)
266 | x = F.leaky_relu(x, LRELU_SLOPE)
267 | fmap.append(x)
268 | if cond_embedding_id is not None:
269 | emb = self.emb(cond_embedding_id)
270 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
271 | else:
272 | h = 0
273 | x = self.conv_post(x)
274 | fmap.append(x)
275 | x += h
276 | x = torch.flatten(x, 1, -1)
277 |
278 | return x, fmap
279 |
280 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
281 | n_fft, hop_length, win_length = self.resolution
282 | amplitude_spectrogram = torch.stft(
283 | x,
284 | n_fft=n_fft,
285 | hop_length=hop_length,
286 | win_length=win_length,
287 | window=None, # interestingly rectangular window kind of works here
288 | center=True,
289 | return_complex=True,
290 | ).abs()
291 |
292 | return amplitude_spectrogram
293 |
294 |
295 | class MultiResolutionPhaseDiscriminator(nn.Module):
296 | def __init__(
297 | self,
298 | resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
299 | num_embeddings: int = None,
300 | ):
301 | super().__init__()
302 | self.discriminators = nn.ModuleList(
303 | [DiscriminatorPR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
304 | )
305 |
306 | def forward(
307 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
308 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
309 | y_d_rs = []
310 | y_d_gs = []
311 | fmap_rs = []
312 | fmap_gs = []
313 |
314 | for d in self.discriminators:
315 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
316 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
317 | y_d_rs.append(y_d_r)
318 | fmap_rs.append(fmap_r)
319 | y_d_gs.append(y_d_g)
320 | fmap_gs.append(fmap_g)
321 |
322 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
323 |
324 |
325 | class DiscriminatorPR(nn.Module):
326 | def __init__(
327 | self,
328 | resolution: Tuple[int, int, int],
329 | channels: int = 64,
330 | in_channels: int = 1,
331 | num_embeddings: int = None,
332 | ):
333 | super().__init__()
334 | self.resolution = resolution
335 | self.in_channels = in_channels
336 | self.convs = nn.ModuleList(
337 | [
338 | weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
339 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
340 | weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
341 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
342 | weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
343 | ]
344 | )
345 | if num_embeddings is not None:
346 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
347 | torch.nn.init.zeros_(self.emb.weight)
348 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
349 |
350 | def forward(
351 | self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
352 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
353 | fmap = []
354 | x=x.squeeze(1)
355 |
356 | x = self.spectrogram(x)
357 | x = x.unsqueeze(1)
358 | for l in self.convs:
359 | x = l(x)
360 | x = F.leaky_relu(x, LRELU_SLOPE)
361 | fmap.append(x)
362 | if cond_embedding_id is not None:
363 | emb = self.emb(cond_embedding_id)
364 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
365 | else:
366 | h = 0
367 | x = self.conv_post(x)
368 | fmap.append(x)
369 | x += h
370 | x = torch.flatten(x, 1, -1)
371 |
372 | return x, fmap
373 |
374 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
375 | n_fft, hop_length, win_length = self.resolution
376 | phase_spectrogram = torch.stft(
377 | x,
378 | n_fft=n_fft,
379 | hop_length=hop_length,
380 | win_length=win_length,
381 | window=None, # interestingly rectangular window kind of works here
382 | center=True,
383 | return_complex=True,
384 | ).angle()
385 |
386 | return phase_spectrogram
387 |
388 |
389 | def feature_loss(fmap_r, fmap_g):
390 | loss = 0
391 | for dr, dg in zip(fmap_r, fmap_g):
392 | for rl, gl in zip(dr, dg):
393 | loss += torch.mean(torch.abs(rl - gl))
394 |
395 | return loss
396 |
397 |
398 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
399 | loss = 0
400 | r_losses = []
401 | g_losses = []
402 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
403 | r_loss = torch.mean(torch.clamp(1 - dr, min=0))
404 | g_loss = torch.mean(torch.clamp(1 + dg, min=0))
405 | loss += r_loss + g_loss
406 | r_losses.append(r_loss.item())
407 | g_losses.append(g_loss.item())
408 |
409 | return loss, r_losses, g_losses
410 |
411 |
412 | def generator_loss(disc_outputs):
413 | loss = 0
414 | gen_losses = []
415 | for dg in disc_outputs:
416 | l = torch.mean(torch.clamp(1 - dg, min=0))
417 | gen_losses.append(l)
418 | loss += l
419 |
420 | return loss, gen_losses
421 |
422 |
423 | def phase_losses(phase_r, phase_g):
424 |
425 | ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
426 | gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
427 | iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
428 |
429 | return ip_loss, gd_loss, iaf_loss
430 |
431 | def anti_wrapping_function(x):
432 |
433 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
434 |
435 | def stft_mag(audio, n_fft=2048, hop_length=512):
436 | hann_window = torch.hann_window(n_fft).to(audio.device)
437 | stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
438 | stft_mag = torch.abs(stft_spec)
439 | return(stft_mag)
440 |
441 | def cal_snr(pred, target):
442 | snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
443 | return snr
444 |
445 | def cal_lsd(pred, target):
446 | sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
447 | st = torch.log10(stft_mag(target).square().clamp(1e-8))
448 | return (sp - st).square().mean(dim=1).sqrt().mean()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | torchaudio==2.0.1
3 | matplotlib==3.8.0
4 | numpy==1.26.0
5 | librosa==0.7.2
6 | scipy==1.11.2
7 | tensorboard==2.11.2
8 | einops==0.8.0
9 | joblib==1.3.2
10 | natsort==8.3.0
11 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/train/train_16k.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import sys
4 | sys.path.append("..")
5 | import itertools
6 | import os
7 | import time
8 | import argparse
9 | import json
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.nn.utils import clip_grad_norm
14 | from torch.utils.data import DistributedSampler, DataLoader
15 | import torch.multiprocessing as mp
16 | from torch.distributed import init_process_group
17 | from torch.nn.parallel import DistributedDataParallel
18 | from env import AttrDict, build_env
19 | from datasets.dataset import Dataset, amp_pha_stft, amp_pha_istft, get_dataset_filelist
20 | from models.model import APNet_BWE_Model, MultiPeriodDiscriminator, MultiResolutionAmplitudeDiscriminator, MultiResolutionPhaseDiscriminator, \
21 | feature_loss, generator_loss, discriminator_loss, phase_losses, cal_snr, cal_lsd
22 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
23 |
24 | torch.backends.cudnn.benchmark = True
25 |
26 | def train(rank, a, h):
27 | if h.num_gpus > 1:
28 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
29 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
30 |
31 | torch.cuda.manual_seed(h.seed)
32 | device = torch.device('cuda:{:d}'.format(rank))
33 |
34 | generator = APNet_BWE_Model(h).to(device)
35 | mpd = MultiPeriodDiscriminator().to(device)
36 | mrad = MultiResolutionAmplitudeDiscriminator().to(device)
37 | mrpd = MultiResolutionPhaseDiscriminator().to(device)
38 |
39 | if rank == 0:
40 | print(generator)
41 | num_params = 0
42 | for p in generator.parameters():
43 | num_params += p.numel()
44 | print(num_params)
45 | os.makedirs(a.checkpoint_path, exist_ok=True)
46 | os.makedirs(os.path.join(a.checkpoint_path, 'logs'), exist_ok=True)
47 | print("checkpoints directory : ", a.checkpoint_path)
48 |
49 | if os.path.isdir(a.checkpoint_path):
50 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
51 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
52 |
53 | steps = 0
54 | if cp_g is None or cp_do is None:
55 | state_dict_do = None
56 | last_epoch = -1
57 | else:
58 | state_dict_g = load_checkpoint(cp_g, device)
59 | state_dict_do = load_checkpoint(cp_do, device)
60 | generator.load_state_dict(state_dict_g['generator'])
61 | mpd.load_state_dict(state_dict_do['mpd'])
62 | mrad.load_state_dict(state_dict_do['mrad'])
63 | mrpd.load_state_dict(state_dict_do['mrpd'])
64 | steps = state_dict_do['steps'] + 1
65 | last_epoch = state_dict_do['epoch']
66 |
67 | if h.num_gpus > 1:
68 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
69 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
70 | mrad = DistributedDataParallel(mrad, device_ids=[rank]).to(device)
71 | mrpd = DistributedDataParallel(mrpd, device_ids=[rank]).to(device)
72 |
73 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
74 | optim_d = torch.optim.AdamW(itertools.chain(mrad.parameters(), mrpd.parameters(), mpd.parameters()),
75 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
76 |
77 | if state_dict_do is not None:
78 | optim_g.load_state_dict(state_dict_do['optim_g'])
79 | optim_d.load_state_dict(state_dict_do['optim_d'])
80 |
81 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
82 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
83 |
84 | training_indexes, validation_indexes = get_dataset_filelist(a)
85 |
86 | trainset = Dataset(training_indexes, a.input_training_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate,
87 | split=True, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, device=device)
88 |
89 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
90 |
91 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
92 | sampler=train_sampler,
93 | batch_size=h.batch_size,
94 | pin_memory=True,
95 | drop_last=True)
96 | if rank == 0:
97 | validset = Dataset(validation_indexes, a.input_validation_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate,
98 | split=False, shuffle=False, n_cache_reuse=0, device=device)
99 |
100 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
101 | sampler=None,
102 | batch_size=1,
103 | pin_memory=True,
104 | drop_last=True)
105 |
106 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
107 |
108 | generator.train()
109 | mpd.train()
110 | mrad.train()
111 | mrpd.train()
112 |
113 | for epoch in range(max(0, last_epoch), a.training_epochs):
114 | if rank == 0:
115 | start = time.time()
116 | print("Epoch: {}".format(epoch+1))
117 |
118 | if h.num_gpus > 1:
119 | train_sampler.set_epoch(epoch)
120 |
121 | for i, batch in enumerate(train_loader):
122 |
123 | if rank == 0:
124 | start_b = time.time()
125 | audio_wb, audio_nb = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes
126 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True))
127 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True))
128 |
129 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size)
130 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size)
131 |
132 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb, pha_nb)
133 |
134 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
135 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size)
136 | audio_wb, audio_wb_g = audio_wb.unsqueeze(1), audio_wb_g.unsqueeze(1)
137 |
138 | optim_d.zero_grad()
139 |
140 | # MPD
141 | audio_df_r, audio_df_g, _, _ = mpd(audio_wb, audio_wb_g.detach())
142 | loss_disc_f, losses_disc_p_r, losses_disc_p_g = discriminator_loss(audio_df_r, audio_df_g)
143 |
144 | # MRAD
145 | spec_da_r, spec_da_g, _, _ = mrad(audio_wb, audio_wb_g.detach())
146 | loss_disc_a, losses_disc_a_r, losses_disc_a_g = discriminator_loss(spec_da_r, spec_da_g)
147 |
148 | # MRPD
149 | spec_dp_r, spec_dp_g, _, _ = mrpd(audio_wb, audio_wb_g.detach())
150 | loss_disc_p, losses_disc_p_r, losses_disc_p_g = discriminator_loss(spec_dp_r, spec_dp_g)
151 |
152 | loss_disc_all = (loss_disc_a + loss_disc_p) * 0.1 + loss_disc_f
153 |
154 | loss_disc_all.backward()
155 | torch.nn.utils.clip_grad_norm_(parameters=mpd.parameters(), max_norm=10, norm_type=2)
156 | torch.nn.utils.clip_grad_norm_(parameters=mrad.parameters(), max_norm=10, norm_type=2)
157 | torch.nn.utils.clip_grad_norm_(parameters=mrpd.parameters(), max_norm=10, norm_type=2)
158 | optim_d.step()
159 |
160 | # Generator
161 | optim_g.zero_grad()
162 |
163 | # L2 Magnitude Loss
164 | loss_mag = F.mse_loss(mag_wb, mag_wb_g) * 45
165 | # Anti-wrapping Phase Loss
166 | loss_ip, loss_gd, loss_iaf = phase_losses(pha_wb, pha_wb_g)
167 | loss_pha = (loss_ip + loss_gd + loss_iaf) * 100
168 | # L2 Complex Loss
169 | loss_com = F.mse_loss(com_wb, com_wb_g) * 90
170 | # L2 Consistency Loss
171 | loss_stft = F.mse_loss(com_wb_g, com_wb_g_hat) * 90
172 |
173 | audio_df_r, audio_df_g, fmap_f_r, fmap_f_g = mpd(audio_wb, audio_wb_g)
174 | spec_da_r, spec_da_g, fmap_a_r, fmap_a_g = mrad(audio_wb, audio_wb_g)
175 | spec_dp_r, spec_dp_g, fmap_p_r, fmap_p_g = mrpd(audio_wb, audio_wb_g)
176 |
177 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
178 | loss_fm_a = feature_loss(fmap_a_r, fmap_a_g)
179 | loss_fm_p = feature_loss(fmap_p_r, fmap_p_g)
180 |
181 | loss_gen_f, losses_gen_f = generator_loss(audio_df_g)
182 | loss_gen_a, losses_gen_a = generator_loss(spec_da_g)
183 | loss_gen_p, losses_gen_p = generator_loss(spec_dp_g)
184 |
185 | loss_gen = (loss_gen_a + loss_gen_p) * 0.1 + loss_gen_f
186 | loss_fm = (loss_fm_a + loss_fm_p) * 0.1 + loss_fm_f
187 |
188 | loss_gen_all = loss_mag + loss_pha + loss_com + loss_stft + loss_gen + loss_fm
189 |
190 | loss_gen_all.backward()
191 | torch.nn.utils.clip_grad_norm_(parameters=generator.parameters(), max_norm=10, norm_type=2)
192 | optim_g.step()
193 |
194 | if rank == 0:
195 | # STDOUT logging
196 | if steps % a.stdout_interval == 0:
197 | with torch.no_grad():
198 | mag_error = F.mse_loss(mag_wb, mag_wb_g).item()
199 | ip_error, gd_error, iaf_error = phase_losses(pha_wb, pha_wb_g)
200 | pha_error = (ip_error + gd_error + iaf_error).item()
201 | com_error = F.mse_loss(com_wb, com_wb_g).item()
202 | stft_error = F.mse_loss(com_wb_g, com_wb_g_hat).item()
203 | print('Steps : {:d}, Gen Loss: {:4.3f}, Magnitude Loss : {:4.3f}, Phase Loss : {:4.3f}, Complex Loss : {:4.3f}, STFT Loss : {:4.3f}, s/b : {:4.3f}'.
204 | format(steps, loss_gen_all, mag_error, pha_error, com_error, stft_error, time.time() - start_b))
205 |
206 | # checkpointing
207 | if steps % a.checkpoint_interval == 0 and steps != 0:
208 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
209 | save_checkpoint(checkpoint_path,
210 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
211 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
212 | save_checkpoint(checkpoint_path,
213 | {'mpd': (mpd.module if h.num_gpus > 1
214 | else mpd).state_dict(),
215 | 'mrad': (mrad.module if h.num_gpus > 1
216 | else mrad).state_dict(),
217 | 'mrpd': (mrpd.module if h.num_gpus > 1
218 | else mrpd).state_dict(),
219 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
220 | 'epoch': epoch})
221 |
222 | # Tensorboard summary logging
223 | if steps % a.summary_interval == 0:
224 | sw.add_scalar("Training/Generator Loss", loss_gen_all, steps)
225 | sw.add_scalar("Training/Magnitude Loss", mag_error, steps)
226 | sw.add_scalar("Training/Phase Loss", pha_error, steps)
227 | sw.add_scalar("Training/Complex Loss", com_error, steps)
228 | sw.add_scalar("Training/Consistency Loss", stft_error, steps)
229 |
230 | # Validation
231 | if steps % a.validation_interval == 0:
232 | start_v = time.time()
233 | generator.eval()
234 | torch.cuda.empty_cache()
235 | val_mag_err_tot = 0
236 | val_pha_err_tot = 0
237 | val_com_err_tot = 0
238 | val_stft_err_tot = 0
239 | val_snr_score_tot = 0
240 | val_lsd_score_tot = 0
241 | with torch.no_grad():
242 | for j, batch in enumerate(validation_loader):
243 | audio_wb, audio_nb = batch
244 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True))
245 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True))
246 |
247 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size)
248 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size)
249 |
250 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb.to(device), pha_nb.to(device))
251 |
252 | audio_wb = amp_pha_istft(mag_wb, pha_wb, h.n_fft, h.hop_size, h.win_size)
253 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
254 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size)
255 |
256 | val_mag_err_tot += F.mse_loss(mag_wb, mag_wb_g_hat).item()
257 | val_ip_err, val_gd_err, val_iaf_err = phase_losses(pha_wb, pha_wb_g_hat)
258 | val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item()
259 | val_com_err_tot += F.mse_loss(com_wb, com_wb_g_hat).item()
260 | val_stft_err_tot += F.mse_loss(com_wb_g, com_wb_g_hat).item()
261 | val_snr_score_tot += cal_snr(audio_wb_g, audio_wb).item()
262 | val_lsd_score_tot += cal_lsd(audio_wb_g, audio_wb).item()
263 |
264 | if j <= 4:
265 | if steps == 0:
266 | sw.add_audio('gt/audio_nb_{}'.format(j), audio_nb[0], steps, h.hr_sampling_rate)
267 | sw.add_audio('gt/audio_wb_{}'.format(j), audio_wb[0], steps, h.hr_sampling_rate)
268 | sw.add_figure('gt/spec_nb_{}'.format(j), plot_spectrogram(mag_nb.squeeze().cpu().numpy()), steps)
269 | sw.add_figure('gt/spec_wb_{}'.format(j), plot_spectrogram(mag_wb.squeeze().cpu().numpy()), steps)
270 |
271 | sw.add_audio('generated/audio_g_{}'.format(j), audio_wb_g[0], steps, h.hr_sampling_rate)
272 | sw.add_figure('generated/spec_g_{}'.format(j), plot_spectrogram(mag_wb_g.squeeze().cpu().numpy()), steps)
273 |
274 | val_mag_err = val_mag_err_tot / (j+1)
275 | val_pha_err = val_pha_err_tot / (j+1)
276 | val_com_err = val_com_err_tot / (j+1)
277 | val_stft_err = val_stft_err_tot / (j+1)
278 | val_snr_score = val_snr_score_tot / (j+1)
279 | val_lsd_score = val_lsd_score_tot / (j+1)
280 |
281 | print('Steps : {:d}, SNR Score: {:4.3f}, LSD Score: {:4.3f}, s/b : {:4.3f}'.
282 | format(steps, val_snr_score, val_lsd_score, time.time() - start_v))
283 | sw.add_scalar("Validation/LSD Score", val_lsd_score, steps)
284 | sw.add_scalar("Validation/SNR Score", val_snr_score, steps)
285 | sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps)
286 | sw.add_scalar("Validation/Phase Loss", val_pha_err, steps)
287 | sw.add_scalar("Validation/Complex Loss", val_com_err, steps)
288 | sw.add_scalar("Validation/Consistency Loss", val_stft_err, steps)
289 |
290 | generator.train()
291 |
292 | steps += 1
293 |
294 | scheduler_g.step()
295 | scheduler_d.step()
296 |
297 | if rank == 0:
298 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
299 |
300 |
301 | def main():
302 | print('Initializing Training Process..')
303 |
304 | parser = argparse.ArgumentParser()
305 |
306 | parser.add_argument('--group_name', default=None)
307 | parser.add_argument('--input_training_wavs_dir', default='VCTK-Corpus-0.92/wav16/train')
308 | parser.add_argument('--input_validation_wavs_dir', default='VCTK-Corpus-0.92/wav16/test')
309 | parser.add_argument('--input_training_file', default='VCTK-Corpus-0.92/training.txt')
310 | parser.add_argument('--input_validation_file', default='VCTK-Corpus-0.92/test.txt')
311 | parser.add_argument('--checkpoint_path', default='cp_model')
312 | parser.add_argument('--config', default='')
313 | parser.add_argument('--training_epochs', default=3100, type=int)
314 | parser.add_argument('--stdout_interval', default=5, type=int)
315 | parser.add_argument('--checkpoint_interval', default=5000, type=int)
316 | parser.add_argument('--summary_interval', default=100, type=int)
317 | parser.add_argument('--validation_interval', default=5000, type=int)
318 |
319 | a = parser.parse_args()
320 |
321 | with open(a.config) as f:
322 | data = f.read()
323 |
324 | json_config = json.loads(data)
325 | h = AttrDict(json_config)
326 | build_env(a.config, 'config.json', a.checkpoint_path)
327 |
328 | torch.manual_seed(h.seed)
329 | if torch.cuda.is_available():
330 | torch.cuda.manual_seed(h.seed)
331 | h.num_gpus = torch.cuda.device_count()
332 | h.batch_size = int(h.batch_size / h.num_gpus)
333 | print('Batch size per GPU :', h.batch_size)
334 | else:
335 | pass
336 |
337 | if h.num_gpus > 1:
338 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
339 | else:
340 | train(0, a, h)
341 |
342 |
343 | if __name__ == '__main__':
344 | main()
--------------------------------------------------------------------------------
/train/train_48k.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import sys
4 | sys.path.append("..")
5 | import itertools
6 | import os
7 | import time
8 | import argparse
9 | import json
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.nn.utils import clip_grad_norm
14 | from torch.utils.data import DistributedSampler, DataLoader
15 | import torch.multiprocessing as mp
16 | from torch.distributed import init_process_group
17 | from torch.nn.parallel import DistributedDataParallel
18 | from env import AttrDict, build_env
19 | from datasets.dataset import Dataset, amp_pha_stft, amp_pha_istft, get_dataset_filelist
20 | from models.model import APNet_BWE_Model, MultiPeriodDiscriminator, MultiResolutionAmplitudeDiscriminator, MultiResolutionPhaseDiscriminator, \
21 | feature_loss, generator_loss, discriminator_loss, phase_losses, cal_snr, cal_lsd
22 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
23 |
24 | torch.backends.cudnn.benchmark = True
25 |
26 | def train(rank, a, h):
27 | if h.num_gpus > 1:
28 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
29 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
30 |
31 | torch.cuda.manual_seed(h.seed)
32 | device = torch.device('cuda:{:d}'.format(rank))
33 |
34 | generator = APNet_BWE_Model(h).to(device)
35 | mpd = MultiPeriodDiscriminator().to(device)
36 | mrad = MultiResolutionAmplitudeDiscriminator().to(device)
37 | mrpd = MultiResolutionPhaseDiscriminator().to(device)
38 |
39 | if rank == 0:
40 | print(generator)
41 | num_params = 0
42 | for p in generator.parameters():
43 | num_params += p.numel()
44 | print(num_params)
45 | os.makedirs(a.checkpoint_path, exist_ok=True)
46 | os.makedirs(os.path.join(a.checkpoint_path, 'logs'), exist_ok=True)
47 | print("checkpoints directory : ", a.checkpoint_path)
48 |
49 | if os.path.isdir(a.checkpoint_path):
50 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
51 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
52 |
53 | steps = 0
54 | if cp_g is None or cp_do is None:
55 | state_dict_do = None
56 | last_epoch = -1
57 | else:
58 | state_dict_g = load_checkpoint(cp_g, device)
59 | state_dict_do = load_checkpoint(cp_do, device)
60 | generator.load_state_dict(state_dict_g['generator'])
61 | mpd.load_state_dict(state_dict_do['mpd'])
62 | mrad.load_state_dict(state_dict_do['mrad'])
63 | mrpd.load_state_dict(state_dict_do['mrpd'])
64 | steps = state_dict_do['steps'] + 1
65 | last_epoch = state_dict_do['epoch']
66 |
67 | if h.num_gpus > 1:
68 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
69 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
70 | mrad = DistributedDataParallel(mrad, device_ids=[rank]).to(device)
71 | mrpd = DistributedDataParallel(mrpd, device_ids=[rank]).to(device)
72 |
73 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
74 | optim_d = torch.optim.AdamW(itertools.chain(mrad.parameters(), mrpd.parameters(), mpd.parameters()),
75 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
76 |
77 | if state_dict_do is not None:
78 | optim_g.load_state_dict(state_dict_do['optim_g'])
79 | optim_d.load_state_dict(state_dict_do['optim_d'])
80 |
81 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
82 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
83 |
84 | training_indexes, validation_indexes = get_dataset_filelist(a)
85 |
86 | trainset = Dataset(training_indexes, a.input_training_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate,
87 | split=True, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, device=device)
88 |
89 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
90 |
91 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
92 | sampler=train_sampler,
93 | batch_size=h.batch_size,
94 | pin_memory=True,
95 | drop_last=True)
96 | if rank == 0:
97 | validset = Dataset(validation_indexes, a.input_validation_wavs_dir, h.segment_size, h.hr_sampling_rate, h.lr_sampling_rate,
98 | split=False, shuffle=False, n_cache_reuse=0, device=device)
99 |
100 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
101 | sampler=None,
102 | batch_size=1,
103 | pin_memory=True,
104 | drop_last=True)
105 |
106 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
107 |
108 | generator.train()
109 | mpd.train()
110 | mrad.train()
111 | mrpd.train()
112 |
113 | for epoch in range(max(0, last_epoch), a.training_epochs):
114 | if rank == 0:
115 | start = time.time()
116 | print("Epoch: {}".format(epoch+1))
117 |
118 | if h.num_gpus > 1:
119 | train_sampler.set_epoch(epoch)
120 |
121 | for i, batch in enumerate(train_loader):
122 |
123 | if rank == 0:
124 | start_b = time.time()
125 | audio_wb, audio_nb = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes
126 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True))
127 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True))
128 |
129 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size)
130 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size)
131 |
132 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb, pha_nb)
133 |
134 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
135 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size)
136 | audio_wb, audio_wb_g = audio_wb.unsqueeze(1), audio_wb_g.unsqueeze(1)
137 |
138 | optim_d.zero_grad()
139 |
140 | # MPD
141 | audio_df_r, audio_df_g, _, _ = mpd(audio_wb, audio_wb_g.detach())
142 | loss_disc_f, losses_disc_p_r, losses_disc_p_g = discriminator_loss(audio_df_r, audio_df_g)
143 |
144 | # MRAD
145 | spec_da_r, spec_da_g, _, _ = mrad(audio_wb, audio_wb_g.detach())
146 | loss_disc_a, losses_disc_a_r, losses_disc_a_g = discriminator_loss(spec_da_r, spec_da_g)
147 |
148 | # MRPD
149 | spec_dp_r, spec_dp_g, _, _ = mrpd(audio_wb, audio_wb_g.detach())
150 | loss_disc_p, losses_disc_p_r, losses_disc_p_g = discriminator_loss(spec_dp_r, spec_dp_g)
151 |
152 | loss_disc_all = (loss_disc_a + loss_disc_p) * 0.1 + loss_disc_f
153 |
154 | loss_disc_all.backward()
155 | torch.nn.utils.clip_grad_norm_(parameters=mpd.parameters(), max_norm=10, norm_type=2)
156 | torch.nn.utils.clip_grad_norm_(parameters=mrad.parameters(), max_norm=10, norm_type=2)
157 | torch.nn.utils.clip_grad_norm_(parameters=mrpd.parameters(), max_norm=10, norm_type=2)
158 | optim_d.step()
159 |
160 | # Generator
161 | optim_g.zero_grad()
162 |
163 | # L2 Magnitude Loss
164 | loss_mag = F.mse_loss(mag_wb, mag_wb_g) * 45
165 | # Anti-wrapping Phase Loss
166 | loss_ip, loss_gd, loss_iaf = phase_losses(pha_wb, pha_wb_g)
167 | loss_pha = (loss_ip + loss_gd + loss_iaf) * 100
168 | # L2 Complex Loss
169 | loss_com = F.mse_loss(com_wb, com_wb_g) * 90
170 | # L2 Consistency Loss
171 | loss_stft = F.mse_loss(com_wb_g, com_wb_g_hat) * 90
172 |
173 | audio_df_r, audio_df_g, fmap_f_r, fmap_f_g = mpd(audio_wb, audio_wb_g)
174 | spec_da_r, spec_da_g, fmap_a_r, fmap_a_g = mrad(audio_wb, audio_wb_g)
175 | spec_dp_r, spec_dp_g, fmap_p_r, fmap_p_g = mrpd(audio_wb, audio_wb_g)
176 |
177 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
178 | loss_fm_a = feature_loss(fmap_a_r, fmap_a_g)
179 | loss_fm_p = feature_loss(fmap_p_r, fmap_p_g)
180 |
181 | loss_gen_f, losses_gen_f = generator_loss(audio_df_g)
182 | loss_gen_a, losses_gen_a = generator_loss(spec_da_g)
183 | loss_gen_p, losses_gen_p = generator_loss(spec_dp_g)
184 |
185 | loss_gen = (loss_gen_a + loss_gen_p) * 0.1 + loss_gen_f
186 | loss_fm = (loss_fm_a + loss_fm_p) * 0.1 + loss_fm_f
187 |
188 | loss_gen_all = loss_mag + loss_pha + loss_com + loss_stft + loss_gen + loss_fm
189 |
190 | loss_gen_all.backward()
191 | torch.nn.utils.clip_grad_norm_(parameters=generator.parameters(), max_norm=10, norm_type=2)
192 | optim_g.step()
193 |
194 | if rank == 0:
195 | # STDOUT logging
196 | if steps % a.stdout_interval == 0:
197 | with torch.no_grad():
198 | mag_error = F.mse_loss(mag_wb, mag_wb_g).item()
199 | ip_error, gd_error, iaf_error = phase_losses(pha_wb, pha_wb_g)
200 | pha_error = (ip_error + gd_error + iaf_error).item()
201 | com_error = F.mse_loss(com_wb, com_wb_g).item()
202 | stft_error = F.mse_loss(com_wb_g, com_wb_g_hat).item()
203 | print('Steps : {:d}, Gen Loss: {:4.3f}, Magnitude Loss : {:4.3f}, Phase Loss : {:4.3f}, Complex Loss : {:4.3f}, STFT Loss : {:4.3f}, s/b : {:4.3f}'.
204 | format(steps, loss_gen_all, mag_error, pha_error, com_error, stft_error, time.time() - start_b))
205 |
206 | # checkpointing
207 | if steps % a.checkpoint_interval == 0 and steps != 0:
208 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
209 | save_checkpoint(checkpoint_path,
210 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
211 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
212 | save_checkpoint(checkpoint_path,
213 | {'mpd': (mpd.module if h.num_gpus > 1
214 | else mpd).state_dict(),
215 | 'mrad': (mrad.module if h.num_gpus > 1
216 | else mrad).state_dict(),
217 | 'mrpd': (mrpd.module if h.num_gpus > 1
218 | else mrpd).state_dict(),
219 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
220 | 'epoch': epoch})
221 |
222 | # Tensorboard summary logging
223 | if steps % a.summary_interval == 0:
224 | sw.add_scalar("Training/Generator Loss", loss_gen_all, steps)
225 | sw.add_scalar("Training/Magnitude Loss", mag_error, steps)
226 | sw.add_scalar("Training/Phase Loss", pha_error, steps)
227 | sw.add_scalar("Training/Complex Loss", com_error, steps)
228 | sw.add_scalar("Training/Consistency Loss", stft_error, steps)
229 |
230 | # Validation
231 | if steps % a.validation_interval == 0:
232 | start_v = time.time()
233 | generator.eval()
234 | torch.cuda.empty_cache()
235 | val_mag_err_tot = 0
236 | val_pha_err_tot = 0
237 | val_com_err_tot = 0
238 | val_stft_err_tot = 0
239 | val_snr_score_tot = 0
240 | val_lsd_score_tot = 0
241 | with torch.no_grad():
242 | for j, batch in enumerate(validation_loader):
243 | audio_wb, audio_nb = batch
244 | audio_wb = torch.autograd.Variable(audio_wb.to(device, non_blocking=True))
245 | audio_nb = torch.autograd.Variable(audio_nb.to(device, non_blocking=True))
246 |
247 | mag_wb, pha_wb, com_wb = amp_pha_stft(audio_wb, h.n_fft, h.hop_size, h.win_size)
248 | mag_nb, pha_nb, com_nb = amp_pha_stft(audio_nb, h.n_fft, h.hop_size, h.win_size)
249 |
250 | mag_wb_g, pha_wb_g, com_wb_g = generator(mag_nb.to(device), pha_nb.to(device))
251 |
252 | audio_wb = amp_pha_istft(mag_wb, pha_wb, h.n_fft, h.hop_size, h.win_size)
253 | audio_wb_g = amp_pha_istft(mag_wb_g, pha_wb_g, h.n_fft, h.hop_size, h.win_size)
254 | mag_wb_g_hat, pha_wb_g_hat, com_wb_g_hat = amp_pha_stft(audio_wb_g, h.n_fft, h.hop_size, h.win_size)
255 |
256 | val_mag_err_tot += F.mse_loss(mag_wb, mag_wb_g_hat).item()
257 | val_ip_err, val_gd_err, val_iaf_err = phase_losses(pha_wb, pha_wb_g_hat)
258 | val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item()
259 | val_com_err_tot += F.mse_loss(com_wb, com_wb_g_hat).item()
260 | val_stft_err_tot += F.mse_loss(com_wb_g, com_wb_g_hat).item()
261 | val_snr_score_tot += cal_snr(audio_wb_g, audio_wb).item()
262 | val_lsd_score_tot += cal_lsd(audio_wb_g, audio_wb).item()
263 |
264 | if j <= 4:
265 | if steps == 0:
266 | sw.add_audio('gt/audio_nb_{}'.format(j), audio_nb[0], steps, h.hr_sampling_rate)
267 | sw.add_audio('gt/audio_wb_{}'.format(j), audio_wb[0], steps, h.hr_sampling_rate)
268 | sw.add_figure('gt/spec_nb_{}'.format(j), plot_spectrogram(mag_nb.squeeze().cpu().numpy()), steps)
269 | sw.add_figure('gt/spec_wb_{}'.format(j), plot_spectrogram(mag_wb.squeeze().cpu().numpy()), steps)
270 |
271 | sw.add_audio('generated/audio_g_{}'.format(j), audio_wb_g[0], steps, h.hr_sampling_rate)
272 | sw.add_figure('generated/spec_g_{}'.format(j), plot_spectrogram(mag_wb_g.squeeze().cpu().numpy()), steps)
273 |
274 | val_mag_err = val_mag_err_tot / (j+1)
275 | val_pha_err = val_pha_err_tot / (j+1)
276 | val_com_err = val_com_err_tot / (j+1)
277 | val_stft_err = val_stft_err_tot / (j+1)
278 | val_snr_score = val_snr_score_tot / (j+1)
279 | val_lsd_score = val_lsd_score_tot / (j+1)
280 |
281 | print('Steps : {:d}, SNR Score: {:4.3f}, LSD Score: {:4.3f}, s/b : {:4.3f}'.
282 | format(steps, val_snr_score, val_lsd_score, time.time() - start_v))
283 | sw.add_scalar("Validation/LSD Score", val_lsd_score, steps)
284 | sw.add_scalar("Validation/SNR Score", val_snr_score, steps)
285 | sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps)
286 | sw.add_scalar("Validation/Phase Loss", val_pha_err, steps)
287 | sw.add_scalar("Validation/Complex Loss", val_com_err, steps)
288 | sw.add_scalar("Validation/Consistency Loss", val_stft_err, steps)
289 |
290 | generator.train()
291 |
292 | steps += 1
293 |
294 | scheduler_g.step()
295 | scheduler_d.step()
296 |
297 | if rank == 0:
298 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
299 |
300 |
301 | def main():
302 | print('Initializing Training Process..')
303 |
304 | parser = argparse.ArgumentParser()
305 |
306 | parser.add_argument('--group_name', default=None)
307 | parser.add_argument('--input_training_wavs_dir', default='VCTK-Corpus-0.92/wav48/train')
308 | parser.add_argument('--input_validation_wavs_dir', default='VCTK-Corpus-0.92/wav48/test')
309 | parser.add_argument('--input_training_file', default='VCTK-Corpus-0.92/training.txt')
310 | parser.add_argument('--input_validation_file', default='VCTK-Corpus-0.92/test.txt')
311 | parser.add_argument('--checkpoint_path', default='cp_model')
312 | parser.add_argument('--config', default='')
313 | parser.add_argument('--training_epochs', default=3100, type=int)
314 | parser.add_argument('--stdout_interval', default=5, type=int)
315 | parser.add_argument('--checkpoint_interval', default=5000, type=int)
316 | parser.add_argument('--summary_interval', default=100, type=int)
317 | parser.add_argument('--validation_interval', default=5000, type=int)
318 |
319 | a = parser.parse_args()
320 |
321 | with open(a.config) as f:
322 | data = f.read()
323 |
324 | json_config = json.loads(data)
325 | h = AttrDict(json_config)
326 | build_env(a.config, 'config.json', a.checkpoint_path)
327 |
328 | torch.manual_seed(h.seed)
329 | if torch.cuda.is_available():
330 | torch.cuda.manual_seed(h.seed)
331 | h.num_gpus = torch.cuda.device_count()
332 | h.batch_size = int(h.batch_size / h.num_gpus)
333 | print('Batch size per GPU :', h.batch_size)
334 | else:
335 | pass
336 |
337 | if h.num_gpus > 1:
338 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
339 | else:
340 | train(0, a, h)
341 |
342 |
343 | if __name__ == '__main__':
344 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.utils import weight_norm
6 | import matplotlib
7 | matplotlib.use("Agg")
8 | import matplotlib.pylab as plt
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(4, 3))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 | fig.canvas.draw()
16 | plt.close()
17 |
18 | return fig
19 |
20 | def init_weights(m, mean=0.0, std=0.01):
21 | classname = m.__class__.__name__
22 | if classname.find("Conv") != -1:
23 | m.weight.data.normal_(mean, std)
24 |
25 | def apply_weight_norm(m):
26 | classname = m.__class__.__name__
27 | if classname.find("Conv") != -1:
28 | weight_norm(m)
29 |
30 | def get_padding(kernel_size, dilation=1):
31 | return int((kernel_size*dilation - dilation)/2)
32 |
33 |
34 | def get_padding_2d(kernel_size, dilation=(1, 1)):
35 | return (int((kernel_size[0]*dilation[0] - dilation[0])/2), int((kernel_size[1]*dilation[1] - dilation[1])/2))
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
--------------------------------------------------------------------------------
/weights_LICENSE.txt:
--------------------------------------------------------------------------------
1 | The weights for this model are licensed under the MIT License.
2 | You can use, modify, and distribute the weights freely as long as you comply with the terms of the MIT License.
3 |
--------------------------------------------------------------------------------