├── .gitignore
├── LICENSE
├── README.md
├── config
├── default_c16.yaml
└── default_c32.yaml
├── datasets
├── dataloader.py
└── metadata
│ ├── libritts_train_clean_360_audiopath_text_sid_train.txt
│ └── libritts_train_clean_360_audiopath_text_sid_val.txt
├── docs
├── index.html
├── loss.png
├── model_architecture.png
└── samples
│ ├── seen
│ ├── c16
│ │ ├── 2004_147967_000029_000002.wav
│ │ ├── 337_126286_000008_000000.wav
│ │ ├── 3537_5704_000008_000005.wav
│ │ ├── 5319_84357_000005_000004.wav
│ │ ├── 6294_86679_000035_000004.wav
│ │ └── 949_134657_000002_000005.wav
│ ├── c32
│ │ ├── 2004_147967_000029_000002.wav
│ │ ├── 337_126286_000008_000000.wav
│ │ ├── 3537_5704_000008_000005.wav
│ │ ├── 5319_84357_000005_000004.wav
│ │ ├── 6294_86679_000035_000004.wav
│ │ └── 949_134657_000002_000005.wav
│ ├── ground_truth
│ │ ├── 2004_147967_000029_000002.wav
│ │ ├── 337_126286_000008_000000.wav
│ │ ├── 3537_5704_000008_000005.wav
│ │ ├── 5319_84357_000005_000004.wav
│ │ ├── 6294_86679_000035_000004.wav
│ │ └── 949_134657_000002_000005.wav
│ ├── official_c16
│ │ ├── 2004_147967_000029_000002.wav
│ │ ├── 337_126286_000008_000000.wav
│ │ ├── 3537_5704_000008_000005.wav
│ │ ├── 5319_84357_000005_000004.wav
│ │ └── 6294_86679_000035_000004.wav
│ └── official_c32
│ │ ├── 2004_147967_000029_000002.wav
│ │ ├── 337_126286_000008_000000.wav
│ │ ├── 3537_5704_000008_000005.wav
│ │ ├── 5319_84357_000005_000004.wav
│ │ └── 6294_86679_000035_000004.wav
│ └── unseen
│ ├── c16
│ ├── 1089_134686_000007_000005.wav
│ ├── 3575_170457_000037_000002.wav
│ ├── 4507_16021_000029_000005.wav
│ ├── 7021_85628_000037_000000.wav
│ ├── 7176_92135_000006_000005.wav
│ └── 8224_274384_000016_000000.wav
│ ├── c32
│ ├── 1089_134686_000007_000005.wav
│ ├── 3575_170457_000037_000002.wav
│ ├── 4507_16021_000029_000005.wav
│ ├── 7021_85628_000037_000000.wav
│ ├── 7176_92135_000006_000005.wav
│ └── 8224_274384_000016_000000.wav
│ ├── ground_truth
│ ├── 1089_134686_000007_000005.wav
│ ├── 3575_170457_000037_000002.wav
│ ├── 4507_16021_000029_000005.wav
│ ├── 7021_85628_000037_000000.wav
│ ├── 7176_92135_000006_000005.wav
│ └── 8224_274384_000016_000000.wav
│ ├── official_c16
│ ├── 1089_134686_000007_000005.wav
│ ├── 3575_170457_000037_000002.wav
│ ├── 7021_85628_000037_000000.wav
│ ├── 7176_92135_000006_000005.wav
│ └── 8224_274384_000016_000000.wav
│ └── official_c32
│ ├── 1089_134686_000007_000005.wav
│ ├── 3575_170457_000037_000002.wav
│ ├── 7021_85628_000037_000000.wav
│ ├── 7176_92135_000006_000005.wav
│ └── 8224_274384_000016_000000.wav
├── inference.py
├── model
├── discriminator.py
├── generator.py
├── lvcnet.py
├── mpd.py
└── mrd.py
├── requirements.txt
├── trainer.py
└── utils
├── plotting.py
├── stft.py
├── stft_loss.py
├── train.py
├── utils.py
├── validation.py
└── writer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDE configuration
2 | .idea/
3 |
4 | # configuration
5 | config/*
6 | !config/default.yaml
7 | temp-restore.yaml
8 |
9 | # logs, checkpoints
10 | chkpt/
11 | logs/
12 |
13 | # just a temporary folder
14 | temp/
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Seungwon Park 박승원
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UnivNet
2 | **UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation**
3 |
4 | This is an unofficial PyTorch implementation of ***Jang et al.* (Kakao), [UnivNet](https://arxiv.org/abs/2106.07889)**.
5 |
6 | Audio samples are uploaded!
7 |
8 | [](https://arxiv.org/abs/2106.07889) [](https://mindslab-ai.github.io/univnet/) [](./LICENSE)
9 |
10 | ## Notes
11 |
12 | **Both UnivNet-c16 and c32 results and the pre-trained weights have been uploaded.**
13 |
14 | **For both models, our implementation matches the objective scores (PESQ and RMSE) of the original paper.**
15 |
16 | ## Key Features
17 |
18 |
19 |
20 | - According to the authors of the paper, UnivNet obtained the best objective results among the recent GAN-based neural vocoders (including HiFi-GAN) as well as outperforming HiFi-GAN in a subjective evaluation. Also its inference speed is 1.5 times faster than HiFi-GAN.
21 |
22 | - This repository uses the same mel-spectrogram function as the [Official HiFi-GAN](https://github.com/jik876/hifi-gan), which is compatible with [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2).
23 |
24 | - Our default mel calculation hyperparameters are as below, following the original paper.
25 |
26 | ```yaml
27 | audio:
28 | n_mel_channels: 100
29 | filter_length: 1024
30 | hop_length: 256 # WARNING: this can't be changed.
31 | win_length: 1024
32 | sampling_rate: 24000
33 | mel_fmin: 0.0
34 | mel_fmax: 12000.0
35 | ```
36 |
37 | You can modify the hyperparameters to be compatible with your acoustic model.
38 |
39 | ## Prerequisites
40 |
41 | The implementation needs following dependencies.
42 |
43 | 0. Python 3.6
44 | 1. [PyTorch](https://pytorch.org/) 1.6.0
45 | 2. [NumPy](https://numpy.org/) 1.17.4 and [SciPy](https://www.scipy.org/) 1.5.4
46 | 3. Install other dependencies in [requirements.txt](./requirements.txt).
47 | ```bash
48 | pip install -r requirements.txt
49 | ```
50 |
51 | ## Datasets
52 |
53 | **Preparing Data**
54 |
55 | - Download the training dataset. This can be any wav file with sampling rate 24,000Hz. The original paper used LibriTTS.
56 | - LibriTTS train-clean-360 split [tar.gz link](https://www.openslr.org/resources/60/train-clean-360.tar.gz)
57 | - Unzip and place its contents under `datasets/LibriTTS/train-clean-360`.
58 | - If you want to use wav files with a different sampling rate, please edit the configuration file (see below).
59 |
60 | Note: The mel-spectrograms calculated from audio file will be saved as `**.mel` at first, and then loaded from disk afterwards.
61 |
62 | **Preparing Metadata**
63 |
64 | Following the format from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), the metadata should be formatted as:
65 | ```
66 | path_to_wav|transcript|speaker_id
67 | path_to_wav|transcript|speaker_id
68 | ...
69 | ```
70 |
71 | Train/validation metadata for LibriTTS train-clean-360 split and are already prepared in `datasets/metadata`.
72 | 5% of the train-clean-360 utterances were randomly sampled for validation.
73 |
74 | Since this model is a vocoder, the transcripts are **NOT** used during training.
75 |
76 | ## Train
77 |
78 | **Preparing Configuration Files**
79 |
80 | - Run `cp config/default_c32.yaml config/config.yaml` and then edit `config.yaml`
81 |
82 | - Write down the root path of train/validation in the `data` section. The data loader parses list of files within the path recursively.
83 |
84 | ```yaml
85 | data:
86 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok)
87 | train_meta: 'metadata/libritts_train_clean_360_train.txt' # relative path of metadata file from train_dir
88 | val_dir: 'datasets/' # root path of validation data
89 | val_meta: 'metadata/libritts_train_clean_360_val.txt' # relative path of metadata file from val_dir
90 | ```
91 |
92 | We provide the default metadata for LibriTTS train-clean-360 split.
93 |
94 | - Modify `channel_size` in `gen` to switch between UnivNet-c16 and c32.
95 |
96 | ```yaml
97 | gen:
98 | noise_dim: 64
99 | channel_size: 32 # 32 or 16
100 | dilations: [1, 3, 9, 27]
101 | strides: [8, 8, 4]
102 | lReLU_slope: 0.2
103 | ```
104 |
105 | **Training**
106 |
107 | ```bash
108 | python trainer.py -c CONFIG_YAML_FILE -n NAME_OF_THE_RUN
109 | ```
110 |
111 | **Tensorboard**
112 |
113 | ```bash
114 | tensorboard --logdir logs/
115 | ```
116 |
117 | If you are running tensorboard on a remote machine, you can open the tensorboard page by adding `--bind_all` option.
118 |
119 | ## Inference
120 |
121 | ```bash
122 | python inference.py -p CHECKPOINT_PATH -i INPUT_MEL_PATH -o OUTPUT_WAV_PATH
123 | ```
124 |
125 | ## Pre-trained Model
126 |
127 | You can download the pre-trained models from the Google Drive link below. The models were trained on LibriTTS train-clean-360 split.
128 | - **UnivNet-c16: [Google Drive](https://drive.google.com/file/d/1Iqw9T0rRklLsg-6aayNk6NlsLVHfuftv/view?usp=sharing)**
129 | - **UnivNet-c32: [Google Drive](https://drive.google.com/file/d/1QZFprpvYEhLWCDF90gSl6Dpn0gonS_Rv/view?usp=sharing)**
130 |
131 | ## Results
132 |
133 | See audio samples at https://mindslab-ai.github.io/univnet/
134 |
135 | We evaluated our model with validation set.
136 |
137 | | Model | PESQ(↑) | RMSE(↓) | Model Size |
138 | | -------------------- | --------- | --------- | ---------- |
139 | | HiFi-GAN v1 | 3.54 | 0.423 | 14.01M |
140 | | Official UnivNet-c16 | 3.59 | 0.337 | 4.00M |
141 | | **Our UnivNet-c16** | **3.60** | **0.317** | **4.00M** |
142 | | Official UnivNet-c32 | 3.70 | 0.316 | 14.86M |
143 | | **Our UnivNet-c32** | **3.68** | **0.304** | **14.87M** |
144 |
145 | The loss graphs of UnivNet are listed below.
146 |
147 | The orange and blue graphs indicate c16 and c32, respectively.
148 |
149 |
150 |
151 | ## Implementation Authors
152 |
153 | Implementation authors are:
154 |
155 | - [Kang-wook Kim](http://github.com/wookladin) @ [MINDsLab Inc.](https://maum.ai/) (full324@snu.ac.kr, kwkim@mindslab.ai)
156 | - [Wonbin Jung](https://github.com/Wonbin-Jung) @ [MINDsLab Inc.](https://maum.ai/) (santabin@kaist.ac.kr, wbjung@mindslab.ai)
157 |
158 | Contributors are:
159 |
160 | - [Kuan Chen](https://github.com/azraelkuan)
161 |
162 | Special thanks to
163 |
164 | - [Seungu Han](https://github.com/Seungwoo0326) @ [MINDsLab Inc.](https://maum.ai/)
165 | - [Junhyeok Lee](https://github.com/junjun3518) @ [MINDsLab Inc.](https://maum.ai/)
166 | - [Sang Hoon Woo](https://github.com/tonyswoo) @ [MINDsLab Inc.](https://maum.ai/)
167 |
168 | ## License
169 |
170 | This code is licensed under BSD 3-Clause License.
171 |
172 | We referred following codes and repositories.
173 |
174 | - The overall structure of the repository is based on [https://github.com/seungwonpark/melgan](https://github.com/seungwonpark/melgan).
175 | - [datasets/dataloader.py](./datasets/dataloader.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License)
176 | - [model/mpd.py](./model/mpd.py) from https://github.com/jik876/hifi-gan (MIT License)
177 | - [model/lvcnet.py](./model/lvcnet.py) from https://github.com/zceng/LVCNet (Apache License 2.0)
178 | - [utils/stft_loss.py](./utils/stft_loss.py) # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT)
179 |
180 | ## References
181 |
182 | Papers
183 |
184 | - *Jang et al.*, [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889)
185 | - *Zeng et al.*, [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform Generation](https://arxiv.org/abs/2102.10815)
186 | - *Kong et al.*, [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646)
187 |
188 | Datasets
189 |
190 | - [LibriTTS](https://openslr.org/60/)
191 |
--------------------------------------------------------------------------------
/config/default_c16.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok)
3 | train_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_train.txt' # relative path of metadata file from train_dir
4 | val_dir: 'datasets/' # root path of validation data
5 | val_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_val.txt' # relative path of metadata file from val_dir
6 | #############################
7 | train:
8 | num_workers: 8
9 | batch_size: 32
10 | optimizer: 'adam'
11 | seed: 1234
12 | adam:
13 | lr: 0.0001
14 | beta1: 0.5
15 | beta2: 0.9
16 | stft_lamb: 2.5
17 | spk_balanced: False # Using balanced sampling for each speaker
18 | #############################
19 | audio:
20 | n_mel_channels: 100
21 | segment_length: 16384 # Should be multiple of 256
22 | pad_short: 2000
23 | filter_length: 1024
24 | hop_length: 256 # WARNING: this can't be changed.
25 | win_length: 1024
26 | sampling_rate: 24000
27 | mel_fmin: 0.0
28 | mel_fmax: 12000.0
29 | #############################
30 | gen:
31 | noise_dim: 64
32 | channel_size: 16
33 | dilations: [1, 3, 9, 27]
34 | strides: [8, 8, 4]
35 | lReLU_slope: 0.2
36 | kpnet_conv_size: 3
37 | #############################
38 | mpd:
39 | periods: [2,3,5,7,11]
40 | kernel_size: 5
41 | stride: 3
42 | use_spectral_norm: False
43 | lReLU_slope: 0.2
44 | #############################
45 | mrd:
46 | resolutions: "[(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]" # (filter_length, hop_length, win_length)
47 | use_spectral_norm: False
48 | lReLU_slope: 0.2
49 | #############################
50 | dist_config:
51 | dist_backend: "nccl"
52 | dist_url: "tcp://localhost:54321"
53 | world_size: 1
54 | #############################
55 | log:
56 | summary_interval: 1
57 | validation_interval: 1
58 | save_interval: 1
59 | num_audio: 5
60 | chkpt_dir: 'chkpt'
61 | log_dir: 'logs'
62 |
--------------------------------------------------------------------------------
/config/default_c32.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok)
3 | train_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_train.txt' # relative path of metadata file from train_dir
4 | val_dir: 'datasets/' # root path of validation data
5 | val_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_val.txt' # relative path of metadata file from val_dir
6 | #############################
7 | train:
8 | num_workers: 8
9 | batch_size: 32
10 | optimizer: 'adam'
11 | seed: 1234
12 | adam:
13 | lr: 0.0001
14 | beta1: 0.5
15 | beta2: 0.9
16 | stft_lamb: 2.5
17 | spk_balanced: False # Using balanced sampling for each speaker
18 | #############################
19 | audio:
20 | n_mel_channels: 100
21 | segment_length: 16384 # Should be multiple of 256
22 | pad_short: 2000
23 | filter_length: 1024
24 | hop_length: 256 # WARNING: this can't be changed.
25 | win_length: 1024
26 | sampling_rate: 24000
27 | mel_fmin: 0.0
28 | mel_fmax: 12000.0
29 | #############################
30 | gen:
31 | noise_dim: 64
32 | channel_size: 32
33 | dilations: [1, 3, 9, 27]
34 | strides: [8, 8, 4]
35 | lReLU_slope: 0.2
36 | kpnet_conv_size: 3
37 | #############################
38 | mpd:
39 | periods: [2,3,5,7,11]
40 | kernel_size: 5
41 | stride: 3
42 | use_spectral_norm: False
43 | lReLU_slope: 0.2
44 | #############################
45 | mrd:
46 | resolutions: "[(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]" # (filter_length, hop_length, win_length)
47 | use_spectral_norm: False
48 | lReLU_slope: 0.2
49 | #############################
50 | dist_config:
51 | dist_backend: "nccl"
52 | dist_url: "tcp://localhost:54321"
53 | world_size: 1
54 | #############################
55 | log:
56 | summary_interval: 1
57 | validation_interval: 1
58 | save_interval: 1
59 | num_audio: 5
60 | chkpt_dir: 'chkpt'
61 | log_dir: 'logs'
62 |
--------------------------------------------------------------------------------
/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import random
5 | import numpy as np
6 | from torch.utils.data import DistributedSampler, DataLoader, Dataset
7 | from collections import Counter
8 |
9 | from utils.utils import read_wav_np
10 | from utils.stft import TacotronSTFT
11 |
12 |
13 | def create_dataloader(hp, args, train, device):
14 | if train:
15 | dataset = MelFromDisk(hp, hp.data.train_dir, hp.data.train_meta, args, train, device)
16 | return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=False,
17 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
18 |
19 | else:
20 | dataset = MelFromDisk(hp, hp.data.val_dir, hp.data.val_meta, args, train, device)
21 | return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
22 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=False)
23 |
24 |
25 | class MelFromDisk(Dataset):
26 | def __init__(self, hp, data_dir, metadata_path, args, train, device):
27 | random.seed(hp.train.seed)
28 | self.hp = hp
29 | self.args = args
30 | self.train = train
31 | self.data_dir = data_dir
32 | metadata_path = os.path.join(data_dir, metadata_path)
33 | self.meta = self.load_metadata(metadata_path)
34 | self.stft = TacotronSTFT(hp.audio.filter_length, hp.audio.hop_length, hp.audio.win_length,
35 | hp.audio.n_mel_channels, hp.audio.sampling_rate,
36 | hp.audio.mel_fmin, hp.audio.mel_fmax, center=False, device=device)
37 |
38 | self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length
39 | self.shuffle = hp.train.spk_balanced
40 |
41 | if train and hp.train.spk_balanced:
42 | # balanced sampling for each speaker
43 | speaker_counter = Counter((spk_id \
44 | for audiopath, text, spk_id in self.meta))
45 | weights = [1.0 / speaker_counter[spk_id] \
46 | for audiopath, text, spk_id in self.meta]
47 |
48 | self.mapping_weights = torch.DoubleTensor(weights)
49 |
50 | elif train:
51 | weights = [1.0 / len(self.meta) for _, _, _ in self.meta]
52 | self.mapping_weights = torch.DoubleTensor(weights)
53 |
54 |
55 | def __len__(self):
56 | return len(self.meta)
57 |
58 | def __getitem__(self, idx):
59 | if self.train:
60 | idx = torch.multinomial(self.mapping_weights, 1).item()
61 | return self.my_getitem(idx)
62 | else:
63 | return self.my_getitem(idx)
64 |
65 | def shuffle_mapping(self):
66 | random.shuffle(self.mapping_weights)
67 |
68 | def my_getitem(self, idx):
69 | wavpath, _, _ = self.meta[idx]
70 | wavpath = os.path.join(self.data_dir, wavpath)
71 | sr, audio = read_wav_np(wavpath)
72 |
73 | if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
74 | audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
75 | mode='constant', constant_values=0.0)
76 |
77 | audio = torch.from_numpy(audio).unsqueeze(0)
78 | mel = self.get_mel(wavpath)
79 |
80 | if self.train:
81 | max_mel_start = mel.size(1) - self.mel_segment_length -1
82 | mel_start = random.randint(0, max_mel_start)
83 | mel_end = mel_start + self.mel_segment_length
84 | mel = mel[:, mel_start:mel_end]
85 |
86 | audio_start = mel_start * self.hp.audio.hop_length
87 | audio_len = self.hp.audio.segment_length
88 | audio = audio[:, audio_start:audio_start + audio_len]
89 |
90 | return mel, audio
91 |
92 | def get_mel(self, wavpath):
93 | melpath = wavpath.replace('.wav', '.mel')
94 | try:
95 | mel = torch.load(melpath, map_location='cpu')
96 | assert mel.size(0) == self.hp.audio.n_mel_channels, \
97 | 'Mel dimension mismatch: expected %d, got %d' % \
98 | (self.hp.audio.n_mel_channels, mel.size(0))
99 |
100 | except (FileNotFoundError, RuntimeError, TypeError, AssertionError):
101 | sr, wav = read_wav_np(wavpath)
102 | assert sr == self.hp.audio.sampling_rate, \
103 | 'sample mismatch: expected %d, got %d at %s' % (self.hp.audio.sampling_rate, sr, wavpath)
104 |
105 | if len(wav) < self.hp.audio.segment_length + self.hp.audio.pad_short:
106 | wav = np.pad(wav, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(wav)), \
107 | mode='constant', constant_values=0.0)
108 |
109 | wav = torch.from_numpy(wav).unsqueeze(0)
110 | mel = self.stft.mel_spectrogram(wav)
111 |
112 | mel = mel.squeeze(0)
113 |
114 | torch.save(mel, melpath)
115 |
116 | return mel
117 |
118 | def load_metadata(self, path, split="|"):
119 | metadata = []
120 | with open(path, 'r', encoding='utf-8') as f:
121 | for line in f:
122 | stripped = line.strip().split(split)
123 | metadata.append(stripped)
124 |
125 | return metadata
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
13 |
14 |
15 | Audio Samples from Unofficial Implementation of UnivNet vocoder
16 |
17 |
18 | Audio Samples from Unofficial Implementation of UnivNet vocoder
19 |
20 | Implementation authors: Kang-wook Kim, Wonbin Jung @MINDsLab Inc.
21 |
24 | Please check the Github repository for the implementation details and the pre-trained model checkpoint.
25 | The ground truth audio samples from the official demo page have a different volume from the original dataset. We believe the official ground truth audio samples have been post-processed. Therefore, the audio volume of our samples might have a different volume from the official version.
26 | Notes: Both UnivNet-c16 and c32 results and the pre-trained weights have been uploaded. Our implementation matches the objective scores (PESQ and RMSE) of the original paper.
27 |
28 | Synthesis for seen speakers
29 | LibriTTS/train-clean-360 dataset was used to both train and evaluate the model.
30 | The first five texts are taken directly from the official demo page for comparison.
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | Text |
45 | Nobody could touch the body until the coroner came. |
46 | Then they came to another grove of trees, where all the leaves were of gold; and afterwards to a third, where the leaves were all glittering diamonds. |
47 | The shaven face of the priest is a further item to the same effect. |
48 | The lady who had spoken English approached the table as if looking for something, and when Beaton looked again, the portrait had been turned on its face. |
49 | He shouted, 'Where art thou, ring?' And the ring said, 'I am here,' though it was on the bed of the ocean. |
50 | The monarch, after making some inquiry into the rank and character of his rival, despatched the informer with a present of a pair of purple slippers, to complete the magnificence of his Imperial habit. |
51 |
52 |
53 |
54 |
55 | GT |
56 | |
57 | |
58 | |
59 | |
60 | |
61 | |
62 |
63 |
64 | Official UnivNet-c16 |
65 | |
66 | |
67 | |
68 | |
69 | |
70 | - |
71 |
72 |
73 | Official UnivNet-c32 |
74 | |
75 | |
76 | |
77 | |
78 | |
79 | - |
80 |
81 |
82 | Our UnivNet-c16 |
83 | |
84 | |
85 | |
86 | |
87 | |
88 | |
89 |
90 |
91 | Our UnivNet-c32 |
92 | |
93 | |
94 | |
95 | |
96 | |
97 | |
98 |
99 |
100 |
101 |
102 |
103 | Synthesis for unseen speakers
104 | LibriTTS/test-clean dataset was used to evaluate the model.
105 | The first five texts are taken directly from the official demo page for comparison.
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 | Text |
120 | The stars began to crumble and a cloud of fine stardust fell through space. |
121 | On August 27, 1837, she writes:-- |
122 | But when his big brother heard that he had refused to give his cap for a King's golden crown, he said that Anders was a stupid. |
123 | The only thing is, I don't know anything about technique and stagecraft and the three unities and that sort of rot. |
124 | The marquis of Worcester, a man past eighty-four, was the last in England that submitted to the authority of the parliament. |
125 | True history being a mixture of all things, the true historian mingles in everything. |
126 |
127 |
128 |
129 |
130 | GT |
131 | |
132 | |
133 | |
134 | |
135 | |
136 | |
137 |
138 |
139 | Official UnivNet-c16 |
140 | |
141 | |
142 | |
143 | |
144 | |
145 | - |
146 |
147 |
148 | Official UnivNet-c32 |
149 | |
150 | |
151 | |
152 | |
153 | |
154 | - |
155 |
156 |
157 | Our UnivNet-c16 |
158 | |
159 | |
160 | |
161 | |
162 | |
163 | |
164 |
165 |
166 | Our UnivNet-c32 |
167 | |
168 | |
169 | |
170 | |
171 | |
172 | |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/docs/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/loss.png
--------------------------------------------------------------------------------
/docs/model_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/model_architecture.png
--------------------------------------------------------------------------------
/docs/samples/seen/c16/2004_147967_000029_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/2004_147967_000029_000002.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c16/337_126286_000008_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/337_126286_000008_000000.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c16/3537_5704_000008_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/3537_5704_000008_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c16/5319_84357_000005_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/5319_84357_000005_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c16/6294_86679_000035_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/6294_86679_000035_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c16/949_134657_000002_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/949_134657_000002_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/2004_147967_000029_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/2004_147967_000029_000002.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/337_126286_000008_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/337_126286_000008_000000.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/3537_5704_000008_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/3537_5704_000008_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/5319_84357_000005_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/5319_84357_000005_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/6294_86679_000035_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/6294_86679_000035_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/c32/949_134657_000002_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/949_134657_000002_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/2004_147967_000029_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/2004_147967_000029_000002.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/337_126286_000008_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/337_126286_000008_000000.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/3537_5704_000008_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/3537_5704_000008_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/5319_84357_000005_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/5319_84357_000005_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/6294_86679_000035_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/6294_86679_000035_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/ground_truth/949_134657_000002_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/949_134657_000002_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c16/2004_147967_000029_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/2004_147967_000029_000002.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c16/337_126286_000008_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/337_126286_000008_000000.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c16/3537_5704_000008_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/3537_5704_000008_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c16/5319_84357_000005_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/5319_84357_000005_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c16/6294_86679_000035_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/6294_86679_000035_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c32/2004_147967_000029_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/2004_147967_000029_000002.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c32/337_126286_000008_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/337_126286_000008_000000.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c32/3537_5704_000008_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/3537_5704_000008_000005.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c32/5319_84357_000005_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/5319_84357_000005_000004.wav
--------------------------------------------------------------------------------
/docs/samples/seen/official_c32/6294_86679_000035_000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/6294_86679_000035_000004.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/1089_134686_000007_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/1089_134686_000007_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/3575_170457_000037_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/3575_170457_000037_000002.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/4507_16021_000029_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/4507_16021_000029_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/7021_85628_000037_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/7021_85628_000037_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/7176_92135_000006_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/7176_92135_000006_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c16/8224_274384_000016_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/8224_274384_000016_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/1089_134686_000007_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/1089_134686_000007_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/3575_170457_000037_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/3575_170457_000037_000002.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/4507_16021_000029_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/4507_16021_000029_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/7021_85628_000037_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/7021_85628_000037_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/7176_92135_000006_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/7176_92135_000006_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/c32/8224_274384_000016_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/8224_274384_000016_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/1089_134686_000007_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/1089_134686_000007_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/3575_170457_000037_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/3575_170457_000037_000002.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/4507_16021_000029_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/4507_16021_000029_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/7021_85628_000037_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/7021_85628_000037_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/7176_92135_000006_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/7176_92135_000006_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/ground_truth/8224_274384_000016_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/8224_274384_000016_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c16/1089_134686_000007_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/1089_134686_000007_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c16/3575_170457_000037_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/3575_170457_000037_000002.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c16/7021_85628_000037_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/7021_85628_000037_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c16/7176_92135_000006_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/7176_92135_000006_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c16/8224_274384_000016_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/8224_274384_000016_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c32/1089_134686_000007_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/1089_134686_000007_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c32/3575_170457_000037_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/3575_170457_000037_000002.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c32/7021_85628_000037_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/7021_85628_000037_000000.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c32/7176_92135_000006_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/7176_92135_000006_000005.wav
--------------------------------------------------------------------------------
/docs/samples/unseen/official_c32/8224_274384_000016_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/8224_274384_000016_000000.wav
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import tqdm
4 | import torch
5 | import argparse
6 | from scipy.io.wavfile import write
7 | from omegaconf import OmegaConf
8 |
9 | from model.generator import Generator
10 |
11 |
12 | def main(args):
13 | checkpoint = torch.load(args.checkpoint_path)
14 | if args.config is not None:
15 | hp = OmegaConf.load(args.config)
16 | else:
17 | hp = OmegaConf.create(checkpoint['hp_str'])
18 |
19 | model = Generator(hp).cuda()
20 | saved_state_dict = checkpoint['model_g']
21 | new_state_dict = {}
22 |
23 | for k, v in saved_state_dict.items():
24 | try:
25 | new_state_dict[k] = saved_state_dict['module.' + k]
26 | except:
27 | new_state_dict[k] = v
28 | model.load_state_dict(new_state_dict)
29 | model.eval(inference=True)
30 |
31 | with torch.no_grad():
32 | for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
33 | mel = torch.load(melpath)
34 | if len(mel.shape) == 2:
35 | mel = mel.unsqueeze(0)
36 | mel = mel.cuda()
37 |
38 | audio = model.inference(mel)
39 | audio = audio.cpu().detach().numpy()
40 |
41 | if args.output_folder is None: # if output folder is not defined, audio samples are saved in input folder
42 | out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
43 | else:
44 | basename = os.path.basename(melpath)
45 | basename = basename.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
46 | out_path = os.path.join(args.output_folder, basename)
47 | write(out_path, hp.audio.sampling_rate, audio)
48 |
49 |
50 | if __name__ == '__main__':
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('-c', '--config', type=str, default=None,
53 | help="yaml file for config. will use hp_str from checkpoint if not given.")
54 | parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
55 | help="path of checkpoint pt file for evaluation")
56 | parser.add_argument('-i', '--input_folder', type=str, required=True,
57 | help="directory of mel-spectrograms to invert into raw audio.")
58 | parser.add_argument('-o', '--output_folder', type=str, default=None,
59 | help="directory which generated raw audio is saved.")
60 | args = parser.parse_args()
61 |
62 | main(args)
63 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .mpd import MultiPeriodDiscriminator
5 | from .mrd import MultiResolutionDiscriminator
6 | from omegaconf import OmegaConf
7 |
8 | class Discriminator(nn.Module):
9 | def __init__(self, hp):
10 | super(Discriminator, self).__init__()
11 | self.MRD = MultiResolutionDiscriminator(hp)
12 | self.MPD = MultiPeriodDiscriminator(hp)
13 |
14 | def forward(self, x):
15 | return self.MRD(x), self.MPD(x)
16 |
17 | if __name__ == '__main__':
18 | hp = OmegaConf.load('../config/default.yaml')
19 | model = Discriminator(hp)
20 |
21 | x = torch.randn(3, 1, 16384)
22 | print(x.shape)
23 |
24 | mrd_output, mpd_output = model(x)
25 | for features, score in mpd_output:
26 | for feat in features:
27 | print(feat.shape)
28 | print(score.shape)
29 |
30 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
31 | print(pytorch_total_params)
32 |
33 |
--------------------------------------------------------------------------------
/model/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from omegaconf import OmegaConf
4 |
5 | from .lvcnet import LVCBlock
6 |
7 | MAX_WAV_VALUE = 32768.0
8 |
9 | class Generator(nn.Module):
10 | """UnivNet Generator"""
11 | def __init__(self, hp):
12 | super(Generator, self).__init__()
13 | self.mel_channel = hp.audio.n_mel_channels
14 | self.noise_dim = hp.gen.noise_dim
15 | self.hop_length = hp.audio.hop_length
16 | channel_size = hp.gen.channel_size
17 | kpnet_conv_size = hp.gen.kpnet_conv_size
18 |
19 | self.res_stack = nn.ModuleList()
20 | hop_length = 1
21 | for stride in hp.gen.strides:
22 | hop_length = stride * hop_length
23 | self.res_stack.append(
24 | LVCBlock(
25 | channel_size,
26 | hp.audio.n_mel_channels,
27 | stride=stride,
28 | dilations=hp.gen.dilations,
29 | lReLU_slope=hp.gen.lReLU_slope,
30 | cond_hop_length=hop_length,
31 | kpnet_conv_size=kpnet_conv_size
32 | )
33 | )
34 |
35 | self.conv_pre = \
36 | nn.utils.weight_norm(nn.Conv1d(hp.gen.noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
37 |
38 | self.conv_post = nn.Sequential(
39 | nn.LeakyReLU(hp.gen.lReLU_slope),
40 | nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
41 | nn.Tanh(),
42 | )
43 |
44 | def forward(self, c, z):
45 | '''
46 | Args:
47 | c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
48 | z (Tensor): the noise sequence (batch, noise_dim, in_length)
49 |
50 | '''
51 | z = self.conv_pre(z) # (B, c_g, L)
52 |
53 | for res_block in self.res_stack:
54 | res_block.to(z.device)
55 | z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
56 |
57 | z = self.conv_post(z) # (B, 1, L * 256)
58 |
59 | return z
60 |
61 | def eval(self, inference=False):
62 | super(Generator, self).eval()
63 | # don't remove weight norm while validation in training loop
64 | if inference:
65 | self.remove_weight_norm()
66 |
67 | def remove_weight_norm(self):
68 | print('Removing weight norm...')
69 |
70 | nn.utils.remove_weight_norm(self.conv_pre)
71 |
72 | for layer in self.conv_post:
73 | if len(layer.state_dict()) != 0:
74 | nn.utils.remove_weight_norm(layer)
75 |
76 | for res_block in self.res_stack:
77 | res_block.remove_weight_norm()
78 |
79 | def inference(self, c, z=None):
80 | # pad input mel with zeros to cut artifact
81 | # see https://github.com/seungwonpark/melgan/issues/8
82 | zero = torch.full((1, self.mel_channel, 10), -11.5129).to(c.device)
83 | mel = torch.cat((c, zero), dim=2)
84 |
85 | if z is None:
86 | z = torch.randn(1, self.noise_dim, mel.size(2)).to(mel.device)
87 |
88 | audio = self.forward(mel, z)
89 | audio = audio.squeeze() # collapse all dimension except time axis
90 | audio = audio[:-(self.hop_length*10)]
91 | audio = MAX_WAV_VALUE * audio
92 | audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
93 | audio = audio.short()
94 |
95 | return audio
96 |
97 | if __name__ == '__main__':
98 | hp = OmegaConf.load('../config/default.yaml')
99 | model = Generator(hp)
100 |
101 | c = torch.randn(3, 100, 10)
102 | z = torch.randn(3, 64, 10)
103 | print(c.shape)
104 |
105 | y = model(c, z)
106 | print(y.shape)
107 | assert y.shape == torch.Size([3, 1, 2560])
108 |
109 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
110 | print(pytorch_total_params)
111 |
--------------------------------------------------------------------------------
/model/lvcnet.py:
--------------------------------------------------------------------------------
1 | """ refer from https://github.com/zceng/LVCNet """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class KernelPredictor(torch.nn.Module):
8 | ''' Kernel predictor for the location-variable convolutions'''
9 | def __init__(
10 | self,
11 | cond_channels,
12 | conv_in_channels,
13 | conv_out_channels,
14 | conv_layers,
15 | conv_kernel_size=3,
16 | kpnet_hidden_channels=64,
17 | kpnet_conv_size=3,
18 | kpnet_dropout=0.0,
19 | kpnet_nonlinear_activation="LeakyReLU",
20 | kpnet_nonlinear_activation_params={"negative_slope":0.1},
21 | ):
22 | '''
23 | Args:
24 | cond_channels (int): number of channel for the conditioning sequence,
25 | conv_in_channels (int): number of channel for the input sequence,
26 | conv_out_channels (int): number of channel for the output sequence,
27 | conv_layers (int): number of layers
28 | '''
29 | super().__init__()
30 |
31 | self.conv_in_channels = conv_in_channels
32 | self.conv_out_channels = conv_out_channels
33 | self.conv_kernel_size = conv_kernel_size
34 | self.conv_layers = conv_layers
35 |
36 | kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
37 | kpnet_bias_channels = conv_out_channels * conv_layers # l_b
38 |
39 | self.input_conv = nn.Sequential(
40 | nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
41 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
42 | )
43 |
44 | self.residual_convs = nn.ModuleList()
45 | padding = (kpnet_conv_size - 1) // 2
46 | for _ in range(3):
47 | self.residual_convs.append(
48 | nn.Sequential(
49 | nn.Dropout(kpnet_dropout),
50 | nn.utils.weight_norm(nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)),
51 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
52 | nn.utils.weight_norm(nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)),
53 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
54 | )
55 | )
56 | self.kernel_conv = nn.utils.weight_norm(
57 | nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
58 | self.bias_conv = nn.utils.weight_norm(
59 | nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
60 |
61 | def forward(self, c):
62 | '''
63 | Args:
64 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
65 | '''
66 | batch, _, cond_length = c.shape
67 | c = self.input_conv(c)
68 | for residual_conv in self.residual_convs:
69 | residual_conv.to(c.device)
70 | c = c + residual_conv(c)
71 | k = self.kernel_conv(c)
72 | b = self.bias_conv(c)
73 | kernels = k.contiguous().view(
74 | batch,
75 | self.conv_layers,
76 | self.conv_in_channels,
77 | self.conv_out_channels,
78 | self.conv_kernel_size,
79 | cond_length,
80 | )
81 | bias = b.contiguous().view(
82 | batch,
83 | self.conv_layers,
84 | self.conv_out_channels,
85 | cond_length,
86 | )
87 |
88 | return kernels, bias
89 |
90 | def remove_weight_norm(self):
91 | nn.utils.remove_weight_norm(self.input_conv[0])
92 | nn.utils.remove_weight_norm(self.kernel_conv)
93 | nn.utils.remove_weight_norm(self.bias_conv)
94 | for block in self.residual_convs:
95 | nn.utils.remove_weight_norm(block[1])
96 | nn.utils.remove_weight_norm(block[3])
97 |
98 | class LVCBlock(torch.nn.Module):
99 | '''the location-variable convolutions'''
100 | def __init__(
101 | self,
102 | in_channels,
103 | cond_channels,
104 | stride,
105 | dilations=[1, 3, 9, 27],
106 | lReLU_slope=0.2,
107 | conv_kernel_size=3,
108 | cond_hop_length=256,
109 | kpnet_hidden_channels=64,
110 | kpnet_conv_size=3,
111 | kpnet_dropout=0.0,
112 | ):
113 | super().__init__()
114 |
115 | self.cond_hop_length = cond_hop_length
116 | self.conv_layers = len(dilations)
117 | self.conv_kernel_size = conv_kernel_size
118 |
119 | self.kernel_predictor = KernelPredictor(
120 | cond_channels=cond_channels,
121 | conv_in_channels=in_channels,
122 | conv_out_channels=2 * in_channels,
123 | conv_layers=len(dilations),
124 | conv_kernel_size=conv_kernel_size,
125 | kpnet_hidden_channels=kpnet_hidden_channels,
126 | kpnet_conv_size=kpnet_conv_size,
127 | kpnet_dropout=kpnet_dropout,
128 | kpnet_nonlinear_activation_params={"negative_slope":lReLU_slope}
129 | )
130 |
131 | self.convt_pre = nn.Sequential(
132 | nn.LeakyReLU(lReLU_slope),
133 | nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride, padding=stride // 2 + stride % 2, output_padding=stride % 2)),
134 | )
135 |
136 | self.conv_blocks = nn.ModuleList()
137 | for dilation in dilations:
138 | self.conv_blocks.append(
139 | nn.Sequential(
140 | nn.LeakyReLU(lReLU_slope),
141 | nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
142 | nn.LeakyReLU(lReLU_slope),
143 | )
144 | )
145 |
146 | def forward(self, x, c):
147 | ''' forward propagation of the location-variable convolutions.
148 | Args:
149 | x (Tensor): the input sequence (batch, in_channels, in_length)
150 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
151 |
152 | Returns:
153 | Tensor: the output sequence (batch, in_channels, in_length)
154 | '''
155 | _, in_channels, _ = x.shape # (B, c_g, L')
156 |
157 | x = self.convt_pre(x) # (B, c_g, stride * L')
158 | kernels, bias = self.kernel_predictor(c)
159 |
160 | for i, conv in enumerate(self.conv_blocks):
161 | output = conv(x) # (B, c_g, stride * L')
162 |
163 | k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
164 | b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
165 |
166 | output = self.location_variable_convolution(output, k, b, hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
167 | x = x + torch.sigmoid(output[ :, :in_channels, :]) * torch.tanh(output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
168 |
169 | return x
170 |
171 | def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
172 | ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
173 | Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
174 | Args:
175 | x (Tensor): the input sequence (batch, in_channels, in_length).
176 | kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
177 | bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
178 | dilation (int): the dilation of convolution.
179 | hop_size (int): the hop_size of the conditioning sequence.
180 | Returns:
181 | (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
182 | '''
183 | batch, _, in_length = x.shape
184 | batch, _, out_channels, kernel_size, kernel_length = kernel.shape
185 | assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
186 |
187 | padding = dilation * int((kernel_size - 1) / 2)
188 | x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
189 | x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
190 |
191 | if hop_size < dilation:
192 | x = F.pad(x, (0, dilation), 'constant', 0)
193 | x = x.unfold(3, dilation, dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
194 | x = x[:, :, :, :, :hop_size]
195 | x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
196 | x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
197 |
198 | o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
199 | o = o.to(memory_format=torch.channels_last_3d)
200 | bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
201 | o = o + bias
202 | o = o.contiguous().view(batch, out_channels, -1)
203 |
204 | return o
205 |
206 | def remove_weight_norm(self):
207 | self.kernel_predictor.remove_weight_norm()
208 | nn.utils.remove_weight_norm(self.convt_pre[1])
209 | for block in self.conv_blocks:
210 | nn.utils.remove_weight_norm(block[1])
--------------------------------------------------------------------------------
/model/mpd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm, spectral_norm
5 |
6 | class DiscriminatorP(nn.Module):
7 | def __init__(self, hp, period):
8 | super(DiscriminatorP, self).__init__()
9 |
10 | self.LRELU_SLOPE = hp.mpd.lReLU_slope
11 | self.period = period
12 |
13 | kernel_size = hp.mpd.kernel_size
14 | stride = hp.mpd.stride
15 | norm_f = weight_norm if hp.mpd.use_spectral_norm == False else spectral_norm
16 |
17 | self.convs = nn.ModuleList([
18 | norm_f(nn.Conv2d(1, 64, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
19 | norm_f(nn.Conv2d(64, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
20 | norm_f(nn.Conv2d(128, 256, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
21 | norm_f(nn.Conv2d(256, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
22 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), 1, padding=(kernel_size // 2, 0))),
23 | ])
24 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
25 |
26 | def forward(self, x):
27 | fmap = []
28 |
29 | # 1d to 2d
30 | b, c, t = x.shape
31 | if t % self.period != 0: # pad first
32 | n_pad = self.period - (t % self.period)
33 | x = F.pad(x, (0, n_pad), "reflect")
34 | t = t + n_pad
35 | x = x.view(b, c, t // self.period, self.period)
36 |
37 | for l in self.convs:
38 | x = l(x)
39 | x = F.leaky_relu(x, self.LRELU_SLOPE)
40 | fmap.append(x)
41 | x = self.conv_post(x)
42 | fmap.append(x)
43 | x = torch.flatten(x, 1, -1)
44 |
45 | return fmap, x
46 |
47 |
48 | class MultiPeriodDiscriminator(nn.Module):
49 | def __init__(self, hp):
50 | super(MultiPeriodDiscriminator, self).__init__()
51 |
52 | self.discriminators = nn.ModuleList(
53 | [DiscriminatorP(hp, period) for period in hp.mpd.periods]
54 | )
55 |
56 | def forward(self, x):
57 | ret = list()
58 | for disc in self.discriminators:
59 | ret.append(disc(x))
60 |
61 | return ret # [(feat, score), (feat, score), (feat, score), (feat, score), (feat, score)]
62 |
--------------------------------------------------------------------------------
/model/mrd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm, spectral_norm
5 |
6 | class DiscriminatorR(torch.nn.Module):
7 | def __init__(self, hp, resolution):
8 | super(DiscriminatorR, self).__init__()
9 |
10 | self.resolution = resolution
11 | self.LRELU_SLOPE = hp.mpd.lReLU_slope
12 |
13 | norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm
14 |
15 | self.convs = nn.ModuleList([
16 | norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
17 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
18 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
19 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
20 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
21 | ])
22 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
23 |
24 | def forward(self, x):
25 | fmap = []
26 |
27 | x = self.spectrogram(x)
28 | x = x.unsqueeze(1)
29 | for l in self.convs:
30 | x = l(x)
31 | x = F.leaky_relu(x, self.LRELU_SLOPE)
32 | fmap.append(x)
33 | x = self.conv_post(x)
34 | fmap.append(x)
35 | x = torch.flatten(x, 1, -1)
36 |
37 | return fmap, x
38 |
39 | def spectrogram(self, x):
40 | n_fft, hop_length, win_length = self.resolution
41 | x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
42 | x = x.squeeze(1)
43 | x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False) #[B, F, TT, 2]
44 | mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
45 |
46 | return mag
47 |
48 |
49 | class MultiResolutionDiscriminator(torch.nn.Module):
50 | def __init__(self, hp):
51 | super(MultiResolutionDiscriminator, self).__init__()
52 | self.resolutions = eval(hp.mrd.resolutions)
53 | self.discriminators = nn.ModuleList(
54 | [DiscriminatorR(hp, resolution) for resolution in self.resolutions]
55 | )
56 |
57 | def forward(self, x):
58 | ret = list()
59 | for disc in self.discriminators:
60 | ret.append(disc(x))
61 |
62 | return ret # [(feat, score), (feat, score), (feat, score)]
63 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.8.1
2 | matplotlib==3.1.3
3 | numpy==1.17.4
4 | scipy==1.5.4
5 | torch==1.6.0
6 | tensorboard==2.3.0
7 | tqdm==4.61.2
8 | omegaconf==2.0.6
9 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import argparse
5 | import torch
6 | import torch.multiprocessing as mp
7 | from omegaconf import OmegaConf
8 |
9 | from utils.train import train
10 |
11 | torch.backends.cudnn.benchmark = True
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-c', '--config', type=str, required=True,
17 | help="yaml file for configuration")
18 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
19 | help="path of checkpoint pt file to resume training")
20 | parser.add_argument('-n', '--name', type=str, required=True,
21 | help="name of the model for logging, saving checkpoint")
22 | args = parser.parse_args()
23 |
24 | hp = OmegaConf.load(args.config)
25 | with open(args.config, 'r') as f:
26 | hp_str = ''.join(f.readlines())
27 |
28 | assert hp.audio.hop_length == 256, \
29 | 'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length
30 |
31 | args.num_gpus = 0
32 | torch.manual_seed(hp.train.seed)
33 | if torch.cuda.is_available():
34 | torch.cuda.manual_seed(hp.train.seed)
35 | args.num_gpus = torch.cuda.device_count()
36 | print('Batch size per GPU :', hp.train.batch_size)
37 | else:
38 | pass
39 |
40 | if args.num_gpus > 1:
41 | mp.spawn(train, nprocs=args.num_gpus,
42 | args=(args, args.checkpoint_path, hp, hp_str,))
43 | else:
44 | train(0, args, args.checkpoint_path, hp, hp_str)
45 |
--------------------------------------------------------------------------------
/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use("Agg")
3 | import matplotlib.pylab as plt
4 | import numpy as np
5 |
6 |
7 | def save_figure_to_numpy(fig):
8 | # save it to a numpy array.
9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
11 | data = np.transpose(data, (2, 0, 1))
12 | return data
13 |
14 |
15 | def plot_waveform_to_numpy(waveform):
16 | fig, ax = plt.subplots(figsize=(12, 4))
17 | ax.plot()
18 | ax.plot(range(len(waveform)), waveform,
19 | linewidth=0.1, alpha=0.7, color='blue')
20 |
21 | plt.xlabel("Samples")
22 | plt.ylabel("Amplitude")
23 | plt.ylim(-1, 1)
24 | plt.tight_layout()
25 |
26 | fig.canvas.draw()
27 | data = save_figure_to_numpy(fig)
28 | plt.close()
29 |
30 | return data
31 |
32 |
33 | def plot_spectrogram_to_numpy(spectrogram):
34 | fig, ax = plt.subplots(figsize=(12, 4))
35 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
36 | interpolation='none')
37 | plt.colorbar(im, ax=ax)
38 | plt.xlabel("Frames")
39 | plt.ylabel("Channels")
40 | plt.tight_layout()
41 |
42 | fig.canvas.draw()
43 | data = save_figure_to_numpy(fig)
44 | plt.close()
45 | return data
46 |
--------------------------------------------------------------------------------
/utils/stft.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 Jungil Kong
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 | import os
25 | import random
26 | import torch
27 | import torch.utils.data
28 | import numpy as np
29 | from librosa.util import normalize
30 | from scipy.io.wavfile import read
31 | from librosa.filters import mel as librosa_mel_fn
32 |
33 |
34 | class TacotronSTFT(torch.nn.Module):
35 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
36 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
37 | mel_fmax=None, center=False, device='cpu'):
38 | super(TacotronSTFT, self).__init__()
39 | self.n_mel_channels = n_mel_channels
40 | self.sampling_rate = sampling_rate
41 | self.n_fft = filter_length
42 | self.hop_size = hop_length
43 | self.win_size = win_length
44 | self.fmin = mel_fmin
45 | self.fmax = mel_fmax
46 | self.center = center
47 |
48 | mel = librosa_mel_fn(
49 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
50 |
51 | mel_basis = torch.from_numpy(mel).float().to(device)
52 | hann_window = torch.hann_window(win_length).to(device)
53 |
54 | self.register_buffer('mel_basis', mel_basis)
55 | self.register_buffer('hann_window', hann_window)
56 |
57 | def linear_spectrogram(self, y):
58 | assert (torch.min(y.data) >= -1)
59 | assert (torch.max(y.data) <= 1)
60 |
61 | y = torch.nn.functional.pad(y.unsqueeze(1),
62 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
63 | mode='reflect')
64 | y = y.squeeze(1)
65 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
66 | center=self.center, pad_mode='reflect', normalized=False, onesided=True)
67 | spec = torch.norm(spec, p=2, dim=-1)
68 |
69 | return spec
70 |
71 | def mel_spectrogram(self, y):
72 | """Computes mel-spectrograms from a batch of waves
73 | PARAMS
74 | ------
75 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
76 |
77 | RETURNS
78 | -------
79 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
80 | """
81 | assert(torch.min(y.data) >= -1)
82 | assert(torch.max(y.data) <= 1)
83 |
84 | y = torch.nn.functional.pad(y.unsqueeze(1),
85 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
86 | mode='reflect')
87 | y = y.squeeze(1)
88 |
89 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
90 | center=self.center, pad_mode='reflect', normalized=False, onesided=True)
91 |
92 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
93 |
94 | spec = torch.matmul(self.mel_basis, spec)
95 | spec = self.spectral_normalize_torch(spec)
96 |
97 | return spec
98 |
99 | def spectral_normalize_torch(self, magnitudes):
100 | output = self.dynamic_range_compression_torch(magnitudes)
101 | return output
102 |
103 | def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5):
104 | return torch.log(torch.clamp(x, min=clip_val) * C)
105 |
--------------------------------------------------------------------------------
/utils/stft_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2019 Tomoki Hayashi
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """STFT-based Loss modules."""
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | def stft(x, fft_size, hop_size, win_length, window):
13 | """Perform STFT and convert to magnitude spectrogram.
14 | Args:
15 | x (Tensor): Input signal tensor (B, T).
16 | fft_size (int): FFT size.
17 | hop_size (int): Hop size.
18 | win_length (int): Window length.
19 | window (str): Window function type.
20 | Returns:
21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22 | """
23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
24 | real = x_stft[..., 0]
25 | imag = x_stft[..., 1]
26 |
27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
29 |
30 |
31 | class SpectralConvergengeLoss(torch.nn.Module):
32 | """Spectral convergence loss module."""
33 |
34 | def __init__(self):
35 | """Initilize spectral convergence loss module."""
36 | super(SpectralConvergengeLoss, self).__init__()
37 |
38 | def forward(self, x_mag, y_mag):
39 | """Calculate forward propagation.
40 | Args:
41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
43 | Returns:
44 | Tensor: Spectral convergence loss value.
45 | """
46 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
47 |
48 |
49 | class LogSTFTMagnitudeLoss(torch.nn.Module):
50 | """Log STFT magnitude loss module."""
51 |
52 | def __init__(self):
53 | """Initilize los STFT magnitude loss module."""
54 | super(LogSTFTMagnitudeLoss, self).__init__()
55 |
56 | def forward(self, x_mag, y_mag):
57 | """Calculate forward propagation.
58 | Args:
59 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
60 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
61 | Returns:
62 | Tensor: Log STFT magnitude loss value.
63 | """
64 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
65 |
66 |
67 | class STFTLoss(torch.nn.Module):
68 | """STFT loss module."""
69 |
70 | def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
71 | """Initialize STFT loss module."""
72 | super(STFTLoss, self).__init__()
73 | self.fft_size = fft_size
74 | self.shift_size = shift_size
75 | self.win_length = win_length
76 | self.window = getattr(torch, window)(win_length).to(device)
77 | self.spectral_convergenge_loss = SpectralConvergengeLoss()
78 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
79 |
80 | def forward(self, x, y):
81 | """Calculate forward propagation.
82 | Args:
83 | x (Tensor): Predicted signal (B, T).
84 | y (Tensor): Groundtruth signal (B, T).
85 | Returns:
86 | Tensor: Spectral convergence loss value.
87 | Tensor: Log STFT magnitude loss value.
88 | """
89 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
90 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
91 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
92 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
93 |
94 | return sc_loss, mag_loss
95 |
96 |
97 | class MultiResolutionSTFTLoss(torch.nn.Module):
98 | """Multi resolution STFT loss module."""
99 |
100 | def __init__(self,
101 | device,
102 | resolutions,
103 | window="hann_window"):
104 | """Initialize Multi resolution STFT loss module.
105 | Args:
106 | resolutions (list): List of (FFT size, hop size, window length).
107 | window (str): Window function type.
108 | """
109 | super(MultiResolutionSTFTLoss, self).__init__()
110 | self.stft_losses = torch.nn.ModuleList()
111 | for fs, ss, wl in resolutions:
112 | self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
113 |
114 | def forward(self, x, y):
115 | """Calculate forward propagation.
116 | Args:
117 | x (Tensor): Predicted signal (B, T).
118 | y (Tensor): Groundtruth signal (B, T).
119 | Returns:
120 | Tensor: Multi resolution spectral convergence loss value.
121 | Tensor: Multi resolution log STFT magnitude loss value.
122 | """
123 | sc_loss = 0.0
124 | mag_loss = 0.0
125 | for f in self.stft_losses:
126 | sc_l, mag_l = f(x, y)
127 | sc_loss += sc_l
128 | mag_loss += mag_l
129 |
130 | sc_loss /= len(self.stft_losses)
131 | mag_loss /= len(self.stft_losses)
132 |
133 | return sc_loss, mag_loss
134 |
--------------------------------------------------------------------------------
/utils/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import math
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.distributed import init_process_group
10 | from torch.nn.parallel import DistributedDataParallel
11 | import itertools
12 | import traceback
13 |
14 | from datasets.dataloader import create_dataloader
15 | from utils.writer import MyWriter
16 | from utils.stft import TacotronSTFT
17 | from utils.stft_loss import MultiResolutionSTFTLoss
18 | from model.generator import Generator
19 | from model.discriminator import Discriminator
20 | from .utils import get_commit_hash
21 | from .validation import validate
22 |
23 |
24 | def train(rank, args, chkpt_path, hp, hp_str):
25 |
26 | if args.num_gpus > 1:
27 | init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url,
28 | world_size=hp.dist_config.world_size * args.num_gpus, rank=rank)
29 |
30 | torch.cuda.manual_seed(hp.train.seed)
31 | device = torch.device('cuda:{:d}'.format(rank))
32 |
33 | model_g = Generator(hp).to(device)
34 | model_d = Discriminator(hp).to(device)
35 |
36 | optim_g = torch.optim.AdamW(model_g.parameters(),
37 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
38 | optim_d = torch.optim.AdamW(model_d.parameters(),
39 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
40 |
41 | githash = get_commit_hash()
42 |
43 | init_epoch = -1
44 | step = 0
45 |
46 | # define logger, writer, valloader, stft at rank_zero
47 | if rank == 0:
48 | pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
49 | log_dir = os.path.join(hp.log.log_dir, args.name)
50 | os.makedirs(pt_dir, exist_ok=True)
51 | os.makedirs(log_dir, exist_ok=True)
52 |
53 | logging.basicConfig(
54 | level=logging.INFO,
55 | format='%(asctime)s - %(levelname)s - %(message)s',
56 | handlers=[
57 | logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))),
58 | logging.StreamHandler()
59 | ]
60 | )
61 | logger = logging.getLogger()
62 | writer = MyWriter(hp, log_dir)
63 | valloader = create_dataloader(hp, args, False, device='cpu')
64 | stft = TacotronSTFT(filter_length=hp.audio.filter_length,
65 | hop_length=hp.audio.hop_length,
66 | win_length=hp.audio.win_length,
67 | n_mel_channels=hp.audio.n_mel_channels,
68 | sampling_rate=hp.audio.sampling_rate,
69 | mel_fmin=hp.audio.mel_fmin,
70 | mel_fmax=hp.audio.mel_fmax,
71 | center=False,
72 | device=device)
73 |
74 | if chkpt_path is not None:
75 | if rank == 0:
76 | logger.info("Resuming from checkpoint: %s" % chkpt_path)
77 | checkpoint = torch.load(chkpt_path)
78 | model_g.load_state_dict(checkpoint['model_g'])
79 | model_d.load_state_dict(checkpoint['model_d'])
80 | optim_g.load_state_dict(checkpoint['optim_g'])
81 | optim_d.load_state_dict(checkpoint['optim_d'])
82 | step = checkpoint['step']
83 | init_epoch = checkpoint['epoch']
84 |
85 | if rank == 0:
86 | if hp_str != checkpoint['hp_str']:
87 | logger.warning("New hparams is different from checkpoint. Will use new.")
88 |
89 | if githash != checkpoint['githash']:
90 | logger.warning("Code might be different: git hash is different.")
91 | logger.warning("%s -> %s" % (checkpoint['githash'], githash))
92 |
93 | else:
94 | if rank == 0:
95 | logger.info("Starting new training run.")
96 |
97 | if args.num_gpus > 1:
98 | model_g = DistributedDataParallel(model_g, device_ids=[rank]).to(device)
99 | model_d = DistributedDataParallel(model_d, device_ids=[rank]).to(device)
100 |
101 | # this accelerates training when the size of minibatch is always consistent.
102 | # if not consistent, it'll horribly slow down.
103 | torch.backends.cudnn.benchmark = True
104 |
105 | trainloader = create_dataloader(hp, args, True, device='cpu')
106 |
107 | model_g.train()
108 | model_d.train()
109 |
110 | resolutions = eval(hp.mrd.resolutions)
111 | stft_criterion = MultiResolutionSTFTLoss(device, resolutions)
112 |
113 | for epoch in itertools.count(init_epoch+1):
114 |
115 | if rank == 0 and epoch % hp.log.validation_interval == 0:
116 | with torch.no_grad():
117 | validate(hp, args, model_g, model_d, valloader, stft, writer, step, device)
118 |
119 | trainloader.dataset.shuffle_mapping()
120 | if rank == 0:
121 | loader = tqdm.tqdm(trainloader, desc='Loading train data')
122 | else:
123 | loader = trainloader
124 |
125 | for mel, audio in loader:
126 |
127 | mel = mel.to(device)
128 | audio = audio.to(device)
129 | noise = torch.randn(hp.train.batch_size, hp.gen.noise_dim, mel.size(2)).to(device)
130 |
131 | # generator
132 | optim_g.zero_grad()
133 | fake_audio = model_g(mel, noise)
134 |
135 | # Multi-Resolution STFT Loss
136 | sc_loss, mag_loss = stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
137 | stft_loss = (sc_loss + mag_loss) * hp.train.stft_lamb
138 |
139 | res_fake, period_fake = model_d(fake_audio)
140 |
141 | score_loss = 0.0
142 |
143 | for (_, score_fake) in res_fake + period_fake:
144 | score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))
145 |
146 | score_loss = score_loss / len(res_fake + period_fake)
147 |
148 | loss_g = score_loss + stft_loss
149 |
150 | loss_g.backward()
151 | optim_g.step()
152 |
153 | # discriminator
154 |
155 | optim_d.zero_grad()
156 | res_fake, period_fake = model_d(fake_audio.detach())
157 | res_real, period_real = model_d(audio)
158 |
159 | loss_d = 0.0
160 | for (_, score_fake), (_, score_real) in zip(res_fake + period_fake, res_real + period_real):
161 | loss_d += torch.mean(torch.pow(score_real - 1.0, 2))
162 | loss_d += torch.mean(torch.pow(score_fake, 2))
163 |
164 | loss_d = loss_d / len(res_fake + period_fake)
165 |
166 | loss_d.backward()
167 | optim_d.step()
168 |
169 | step += 1
170 | # logging
171 | loss_g = loss_g.item()
172 | loss_d = loss_d.item()
173 |
174 | if rank == 0 and step % hp.log.summary_interval == 0:
175 | writer.log_training(loss_g, loss_d, stft_loss.item(), score_loss.item(), step)
176 | loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d, step))
177 |
178 | if rank == 0 and epoch % hp.log.save_interval == 0:
179 | save_path = os.path.join(pt_dir, '%s_%04d.pt'
180 | % (args.name, epoch))
181 | torch.save({
182 | 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
183 | 'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(),
184 | 'optim_g': optim_g.state_dict(),
185 | 'optim_d': optim_d.state_dict(),
186 | 'step': step,
187 | 'epoch': epoch,
188 | 'hp_str': hp_str,
189 | 'githash': githash,
190 | }, save_path)
191 | logger.info("Saved checkpoint to: %s" % save_path)
192 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import subprocess
3 | import numpy as np
4 | from scipy.io.wavfile import read
5 |
6 |
7 | def get_commit_hash():
8 | message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
9 | return message.strip().decode('utf-8')
10 |
11 | def read_wav_np(path):
12 | sr, wav = read(path)
13 |
14 | if len(wav.shape) == 2:
15 | wav = wav[:, 0]
16 |
17 | if wav.dtype == np.int16:
18 | wav = wav / 32768.0
19 | elif wav.dtype == np.int32:
20 | wav = wav / 2147483648.0
21 | elif wav.dtype == np.uint8:
22 | wav = (wav - 128) / 128.0
23 |
24 | wav = wav.astype(np.float32)
25 |
26 | return sr, wav
27 |
--------------------------------------------------------------------------------
/utils/validation.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def validate(hp, args, generator, discriminator, valloader, stft, writer, step, device):
7 | generator.eval()
8 | discriminator.eval()
9 | torch.backends.cudnn.benchmark = False
10 |
11 | loader = tqdm.tqdm(valloader, desc='Validation loop')
12 | mel_loss = 0.0
13 | for idx, (mel, audio) in enumerate(loader):
14 | mel = mel.to(device)
15 | audio = audio.to(device)
16 | noise = torch.randn(1, hp.gen.noise_dim, mel.size(2)).to(device)
17 |
18 | fake_audio = generator(mel, noise)[:,:,:audio.size(2)]
19 |
20 | mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
21 | mel_real = stft.mel_spectrogram(audio.squeeze(1))
22 |
23 | mel_loss += F.l1_loss(mel_fake, mel_real).item()
24 |
25 | if idx < hp.log.num_audio:
26 | spec_fake = stft.linear_spectrogram(fake_audio.squeeze(1))
27 | spec_real = stft.linear_spectrogram(audio.squeeze(1))
28 |
29 | audio = audio[0][0].cpu().detach().numpy()
30 | fake_audio = fake_audio[0][0].cpu().detach().numpy()
31 | spec_fake = spec_fake[0].cpu().detach().numpy()
32 | spec_real = spec_real[0].cpu().detach().numpy()
33 | writer.log_fig_audio(audio, fake_audio, spec_fake, spec_real, idx, step)
34 |
35 | mel_loss = mel_loss / len(valloader.dataset)
36 |
37 | writer.log_validation(mel_loss, generator, discriminator, step)
38 |
39 | torch.backends.cudnn.benchmark = True
40 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | import numpy as np
3 | import librosa
4 |
5 | from .plotting import plot_waveform_to_numpy, plot_spectrogram_to_numpy
6 |
7 | class MyWriter(SummaryWriter):
8 | def __init__(self, hp, logdir):
9 | super(MyWriter, self).__init__(logdir)
10 | self.sample_rate = hp.audio.sampling_rate
11 | self.is_first = True
12 |
13 | def log_training(self, g_loss, d_loss, stft_loss, score_loss, step):
14 | self.add_scalar('train/g_loss', g_loss, step)
15 | self.add_scalar('train/d_loss', d_loss, step)
16 |
17 | self.add_scalar('train/score_loss', score_loss, step)
18 | self.add_scalar('train/stft_loss', stft_loss, step)
19 |
20 | def log_validation(self, mel_loss, generator, discriminator, step):
21 | self.add_scalar('validation/mel_loss', mel_loss, step)
22 |
23 | self.log_histogram(generator, step)
24 | self.log_histogram(discriminator, step)
25 | if self.is_first:
26 | self.is_first = False
27 |
28 | def log_fig_audio(self, target, prediction, spec_fake, spec_real, idx, step):
29 | if idx == 0:
30 | spec_fake = librosa.amplitude_to_db(spec_fake, ref=np.max,top_db=80.)
31 | spec_real = librosa.amplitude_to_db(spec_real, ref=np.max,top_db=80.)
32 | self.add_image('spec/predicted', plot_spectrogram_to_numpy(spec_fake), step)
33 | self.add_image('spec/error', plot_spectrogram_to_numpy(np.power(spec_real - spec_fake, 2)), step)
34 | self.add_image('waveform/predicted', plot_waveform_to_numpy(prediction), step)
35 | if self.is_first:
36 | self.add_image('spec/target', plot_spectrogram_to_numpy(spec_real), step)
37 | self.add_image('waveform/target', plot_waveform_to_numpy(target), step)
38 |
39 | self.add_audio('predicted/raw_audio_%d' % idx, prediction, step, self.sample_rate)
40 | if self.is_first:
41 | self.add_audio('target/raw_audio_%d' % idx, target, step, self.sample_rate)
42 |
43 |
44 | def log_histogram(self, model, step):
45 | for tag, value in model.named_parameters():
46 | self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)
47 |
--------------------------------------------------------------------------------