├── .gitignore ├── LICENSE ├── README.md ├── config ├── default_c16.yaml └── default_c32.yaml ├── datasets ├── dataloader.py └── metadata │ ├── libritts_train_clean_360_audiopath_text_sid_train.txt │ └── libritts_train_clean_360_audiopath_text_sid_val.txt ├── docs ├── index.html ├── loss.png ├── model_architecture.png └── samples │ ├── seen │ ├── c16 │ │ ├── 2004_147967_000029_000002.wav │ │ ├── 337_126286_000008_000000.wav │ │ ├── 3537_5704_000008_000005.wav │ │ ├── 5319_84357_000005_000004.wav │ │ ├── 6294_86679_000035_000004.wav │ │ └── 949_134657_000002_000005.wav │ ├── c32 │ │ ├── 2004_147967_000029_000002.wav │ │ ├── 337_126286_000008_000000.wav │ │ ├── 3537_5704_000008_000005.wav │ │ ├── 5319_84357_000005_000004.wav │ │ ├── 6294_86679_000035_000004.wav │ │ └── 949_134657_000002_000005.wav │ ├── ground_truth │ │ ├── 2004_147967_000029_000002.wav │ │ ├── 337_126286_000008_000000.wav │ │ ├── 3537_5704_000008_000005.wav │ │ ├── 5319_84357_000005_000004.wav │ │ ├── 6294_86679_000035_000004.wav │ │ └── 949_134657_000002_000005.wav │ ├── official_c16 │ │ ├── 2004_147967_000029_000002.wav │ │ ├── 337_126286_000008_000000.wav │ │ ├── 3537_5704_000008_000005.wav │ │ ├── 5319_84357_000005_000004.wav │ │ └── 6294_86679_000035_000004.wav │ └── official_c32 │ │ ├── 2004_147967_000029_000002.wav │ │ ├── 337_126286_000008_000000.wav │ │ ├── 3537_5704_000008_000005.wav │ │ ├── 5319_84357_000005_000004.wav │ │ └── 6294_86679_000035_000004.wav │ └── unseen │ ├── c16 │ ├── 1089_134686_000007_000005.wav │ ├── 3575_170457_000037_000002.wav │ ├── 4507_16021_000029_000005.wav │ ├── 7021_85628_000037_000000.wav │ ├── 7176_92135_000006_000005.wav │ └── 8224_274384_000016_000000.wav │ ├── c32 │ ├── 1089_134686_000007_000005.wav │ ├── 3575_170457_000037_000002.wav │ ├── 4507_16021_000029_000005.wav │ ├── 7021_85628_000037_000000.wav │ ├── 7176_92135_000006_000005.wav │ └── 8224_274384_000016_000000.wav │ ├── ground_truth │ ├── 1089_134686_000007_000005.wav │ ├── 3575_170457_000037_000002.wav │ ├── 4507_16021_000029_000005.wav │ ├── 7021_85628_000037_000000.wav │ ├── 7176_92135_000006_000005.wav │ └── 8224_274384_000016_000000.wav │ ├── official_c16 │ ├── 1089_134686_000007_000005.wav │ ├── 3575_170457_000037_000002.wav │ ├── 7021_85628_000037_000000.wav │ ├── 7176_92135_000006_000005.wav │ └── 8224_274384_000016_000000.wav │ └── official_c32 │ ├── 1089_134686_000007_000005.wav │ ├── 3575_170457_000037_000002.wav │ ├── 7021_85628_000037_000000.wav │ ├── 7176_92135_000006_000005.wav │ └── 8224_274384_000016_000000.wav ├── inference.py ├── model ├── discriminator.py ├── generator.py ├── lvcnet.py ├── mpd.py └── mrd.py ├── requirements.txt ├── trainer.py └── utils ├── plotting.py ├── stft.py ├── stft_loss.py ├── train.py ├── utils.py ├── validation.py └── writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE configuration 2 | .idea/ 3 | 4 | # configuration 5 | config/* 6 | !config/default.yaml 7 | temp-restore.yaml 8 | 9 | # logs, checkpoints 10 | chkpt/ 11 | logs/ 12 | 13 | # just a temporary folder 14 | temp/ 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Seungwon Park 박승원 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnivNet 2 | **UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation** 3 | 4 | This is an unofficial PyTorch implementation of ***Jang et al.* (Kakao), [UnivNet](https://arxiv.org/abs/2106.07889)**. 5 | 6 | Audio samples are uploaded! 7 | 8 | [![arXiv](https://img.shields.io/badge/arXiv-2106.07889-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2106.07889) [![githubio](https://img.shields.io/static/v1?message=Audio%20Samples&logo=Github&labelColor=grey&color=blue&logoColor=white&label=%20&style=flat-square)](https://mindslab-ai.github.io/univnet/) [![GitHub](https://img.shields.io/github/license/mindslab-ai/univnet?style=flat-square)](./LICENSE) 9 | 10 | ## Notes 11 | 12 | **Both UnivNet-c16 and c32 results and the pre-trained weights have been uploaded.** 13 | 14 | **For both models, our implementation matches the objective scores (PESQ and RMSE) of the original paper.** 15 | 16 | ## Key Features 17 | 18 | 19 | 20 | - According to the authors of the paper, UnivNet obtained the best objective results among the recent GAN-based neural vocoders (including HiFi-GAN) as well as outperforming HiFi-GAN in a subjective evaluation. Also its inference speed is 1.5 times faster than HiFi-GAN. 21 | 22 | - This repository uses the same mel-spectrogram function as the [Official HiFi-GAN](https://github.com/jik876/hifi-gan), which is compatible with [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2). 23 | 24 | - Our default mel calculation hyperparameters are as below, following the original paper. 25 | 26 | ```yaml 27 | audio: 28 | n_mel_channels: 100 29 | filter_length: 1024 30 | hop_length: 256 # WARNING: this can't be changed. 31 | win_length: 1024 32 | sampling_rate: 24000 33 | mel_fmin: 0.0 34 | mel_fmax: 12000.0 35 | ``` 36 | 37 | You can modify the hyperparameters to be compatible with your acoustic model. 38 | 39 | ## Prerequisites 40 | 41 | The implementation needs following dependencies. 42 | 43 | 0. Python 3.6 44 | 1. [PyTorch](https://pytorch.org/) 1.6.0 45 | 2. [NumPy](https://numpy.org/) 1.17.4 and [SciPy](https://www.scipy.org/) 1.5.4 46 | 3. Install other dependencies in [requirements.txt](./requirements.txt). 47 | ```bash 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | ## Datasets 52 | 53 | **Preparing Data** 54 | 55 | - Download the training dataset. This can be any wav file with sampling rate 24,000Hz. The original paper used LibriTTS. 56 | - LibriTTS train-clean-360 split [tar.gz link](https://www.openslr.org/resources/60/train-clean-360.tar.gz) 57 | - Unzip and place its contents under `datasets/LibriTTS/train-clean-360`. 58 | - If you want to use wav files with a different sampling rate, please edit the configuration file (see below). 59 | 60 | Note: The mel-spectrograms calculated from audio file will be saved as `**.mel` at first, and then loaded from disk afterwards. 61 | 62 | **Preparing Metadata** 63 | 64 | Following the format from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), the metadata should be formatted as: 65 | ``` 66 | path_to_wav|transcript|speaker_id 67 | path_to_wav|transcript|speaker_id 68 | ... 69 | ``` 70 | 71 | Train/validation metadata for LibriTTS train-clean-360 split and are already prepared in `datasets/metadata`. 72 | 5% of the train-clean-360 utterances were randomly sampled for validation. 73 | 74 | Since this model is a vocoder, the transcripts are **NOT** used during training. 75 | 76 | ## Train 77 | 78 | **Preparing Configuration Files** 79 | 80 | - Run `cp config/default_c32.yaml config/config.yaml` and then edit `config.yaml` 81 | 82 | - Write down the root path of train/validation in the `data` section. The data loader parses list of files within the path recursively. 83 | 84 | ```yaml 85 | data: 86 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok) 87 | train_meta: 'metadata/libritts_train_clean_360_train.txt' # relative path of metadata file from train_dir 88 | val_dir: 'datasets/' # root path of validation data 89 | val_meta: 'metadata/libritts_train_clean_360_val.txt' # relative path of metadata file from val_dir 90 | ``` 91 | 92 | We provide the default metadata for LibriTTS train-clean-360 split. 93 | 94 | - Modify `channel_size` in `gen` to switch between UnivNet-c16 and c32. 95 | 96 | ```yaml 97 | gen: 98 | noise_dim: 64 99 | channel_size: 32 # 32 or 16 100 | dilations: [1, 3, 9, 27] 101 | strides: [8, 8, 4] 102 | lReLU_slope: 0.2 103 | ``` 104 | 105 | **Training** 106 | 107 | ```bash 108 | python trainer.py -c CONFIG_YAML_FILE -n NAME_OF_THE_RUN 109 | ``` 110 | 111 | **Tensorboard** 112 | 113 | ```bash 114 | tensorboard --logdir logs/ 115 | ``` 116 | 117 | If you are running tensorboard on a remote machine, you can open the tensorboard page by adding `--bind_all` option. 118 | 119 | ## Inference 120 | 121 | ```bash 122 | python inference.py -p CHECKPOINT_PATH -i INPUT_MEL_PATH -o OUTPUT_WAV_PATH 123 | ``` 124 | 125 | ## Pre-trained Model 126 | 127 | You can download the pre-trained models from the Google Drive link below. The models were trained on LibriTTS train-clean-360 split. 128 | - **UnivNet-c16: [Google Drive](https://drive.google.com/file/d/1Iqw9T0rRklLsg-6aayNk6NlsLVHfuftv/view?usp=sharing)** 129 | - **UnivNet-c32: [Google Drive](https://drive.google.com/file/d/1QZFprpvYEhLWCDF90gSl6Dpn0gonS_Rv/view?usp=sharing)** 130 | 131 | ## Results 132 | 133 | See audio samples at https://mindslab-ai.github.io/univnet/ 134 | 135 | We evaluated our model with validation set. 136 | 137 | | Model | PESQ(↑) | RMSE(↓) | Model Size | 138 | | -------------------- | --------- | --------- | ---------- | 139 | | HiFi-GAN v1 | 3.54 | 0.423 | 14.01M | 140 | | Official UnivNet-c16 | 3.59 | 0.337 | 4.00M | 141 | | **Our UnivNet-c16** | **3.60** | **0.317** | **4.00M** | 142 | | Official UnivNet-c32 | 3.70 | 0.316 | 14.86M | 143 | | **Our UnivNet-c32** | **3.68** | **0.304** | **14.87M** | 144 | 145 | The loss graphs of UnivNet are listed below. 146 | 147 | The orange and blue graphs indicate c16 and c32, respectively. 148 | 149 | 150 | 151 | ## Implementation Authors 152 | 153 | Implementation authors are: 154 | 155 | - [Kang-wook Kim](http://github.com/wookladin) @ [MINDsLab Inc.](https://maum.ai/) (full324@snu.ac.kr, kwkim@mindslab.ai) 156 | - [Wonbin Jung](https://github.com/Wonbin-Jung) @ [MINDsLab Inc.](https://maum.ai/) (santabin@kaist.ac.kr, wbjung@mindslab.ai) 157 | 158 | Contributors are: 159 | 160 | - [Kuan Chen](https://github.com/azraelkuan) 161 | 162 | Special thanks to 163 | 164 | - [Seungu Han](https://github.com/Seungwoo0326) @ [MINDsLab Inc.](https://maum.ai/) 165 | - [Junhyeok Lee](https://github.com/junjun3518) @ [MINDsLab Inc.](https://maum.ai/) 166 | - [Sang Hoon Woo](https://github.com/tonyswoo) @ [MINDsLab Inc.](https://maum.ai/) 167 | 168 | ## License 169 | 170 | This code is licensed under BSD 3-Clause License. 171 | 172 | We referred following codes and repositories. 173 | 174 | - The overall structure of the repository is based on [https://github.com/seungwonpark/melgan](https://github.com/seungwonpark/melgan). 175 | - [datasets/dataloader.py](./datasets/dataloader.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License) 176 | - [model/mpd.py](./model/mpd.py) from https://github.com/jik876/hifi-gan (MIT License) 177 | - [model/lvcnet.py](./model/lvcnet.py) from https://github.com/zceng/LVCNet (Apache License 2.0) 178 | - [utils/stft_loss.py](./utils/stft_loss.py) # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) 179 | 180 | ## References 181 | 182 | Papers 183 | 184 | - *Jang et al.*, [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) 185 | - *Zeng et al.*, [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform Generation](https://arxiv.org/abs/2102.10815) 186 | - *Kong et al.*, [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) 187 | 188 | Datasets 189 | 190 | - [LibriTTS](https://openslr.org/60/) 191 | -------------------------------------------------------------------------------- /config/default_c16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok) 3 | train_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_train.txt' # relative path of metadata file from train_dir 4 | val_dir: 'datasets/' # root path of validation data 5 | val_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_val.txt' # relative path of metadata file from val_dir 6 | ############################# 7 | train: 8 | num_workers: 8 9 | batch_size: 32 10 | optimizer: 'adam' 11 | seed: 1234 12 | adam: 13 | lr: 0.0001 14 | beta1: 0.5 15 | beta2: 0.9 16 | stft_lamb: 2.5 17 | spk_balanced: False # Using balanced sampling for each speaker 18 | ############################# 19 | audio: 20 | n_mel_channels: 100 21 | segment_length: 16384 # Should be multiple of 256 22 | pad_short: 2000 23 | filter_length: 1024 24 | hop_length: 256 # WARNING: this can't be changed. 25 | win_length: 1024 26 | sampling_rate: 24000 27 | mel_fmin: 0.0 28 | mel_fmax: 12000.0 29 | ############################# 30 | gen: 31 | noise_dim: 64 32 | channel_size: 16 33 | dilations: [1, 3, 9, 27] 34 | strides: [8, 8, 4] 35 | lReLU_slope: 0.2 36 | kpnet_conv_size: 3 37 | ############################# 38 | mpd: 39 | periods: [2,3,5,7,11] 40 | kernel_size: 5 41 | stride: 3 42 | use_spectral_norm: False 43 | lReLU_slope: 0.2 44 | ############################# 45 | mrd: 46 | resolutions: "[(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]" # (filter_length, hop_length, win_length) 47 | use_spectral_norm: False 48 | lReLU_slope: 0.2 49 | ############################# 50 | dist_config: 51 | dist_backend: "nccl" 52 | dist_url: "tcp://localhost:54321" 53 | world_size: 1 54 | ############################# 55 | log: 56 | summary_interval: 1 57 | validation_interval: 1 58 | save_interval: 1 59 | num_audio: 5 60 | chkpt_dir: 'chkpt' 61 | log_dir: 'logs' 62 | -------------------------------------------------------------------------------- /config/default_c32.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dir: 'datasets/' # root path of train data (either relative/absoulte path is ok) 3 | train_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_train.txt' # relative path of metadata file from train_dir 4 | val_dir: 'datasets/' # root path of validation data 5 | val_meta: 'metadata/libritts_train_clean_360_audiopath_text_sid_val.txt' # relative path of metadata file from val_dir 6 | ############################# 7 | train: 8 | num_workers: 8 9 | batch_size: 32 10 | optimizer: 'adam' 11 | seed: 1234 12 | adam: 13 | lr: 0.0001 14 | beta1: 0.5 15 | beta2: 0.9 16 | stft_lamb: 2.5 17 | spk_balanced: False # Using balanced sampling for each speaker 18 | ############################# 19 | audio: 20 | n_mel_channels: 100 21 | segment_length: 16384 # Should be multiple of 256 22 | pad_short: 2000 23 | filter_length: 1024 24 | hop_length: 256 # WARNING: this can't be changed. 25 | win_length: 1024 26 | sampling_rate: 24000 27 | mel_fmin: 0.0 28 | mel_fmax: 12000.0 29 | ############################# 30 | gen: 31 | noise_dim: 64 32 | channel_size: 32 33 | dilations: [1, 3, 9, 27] 34 | strides: [8, 8, 4] 35 | lReLU_slope: 0.2 36 | kpnet_conv_size: 3 37 | ############################# 38 | mpd: 39 | periods: [2,3,5,7,11] 40 | kernel_size: 5 41 | stride: 3 42 | use_spectral_norm: False 43 | lReLU_slope: 0.2 44 | ############################# 45 | mrd: 46 | resolutions: "[(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]" # (filter_length, hop_length, win_length) 47 | use_spectral_norm: False 48 | lReLU_slope: 0.2 49 | ############################# 50 | dist_config: 51 | dist_backend: "nccl" 52 | dist_url: "tcp://localhost:54321" 53 | world_size: 1 54 | ############################# 55 | log: 56 | summary_interval: 1 57 | validation_interval: 1 58 | save_interval: 1 59 | num_audio: 5 60 | chkpt_dir: 'chkpt' 61 | log_dir: 'logs' 62 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import DistributedSampler, DataLoader, Dataset 7 | from collections import Counter 8 | 9 | from utils.utils import read_wav_np 10 | from utils.stft import TacotronSTFT 11 | 12 | 13 | def create_dataloader(hp, args, train, device): 14 | if train: 15 | dataset = MelFromDisk(hp, hp.data.train_dir, hp.data.train_meta, args, train, device) 16 | return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=False, 17 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=True) 18 | 19 | else: 20 | dataset = MelFromDisk(hp, hp.data.val_dir, hp.data.val_meta, args, train, device) 21 | return DataLoader(dataset=dataset, batch_size=1, shuffle=False, 22 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=False) 23 | 24 | 25 | class MelFromDisk(Dataset): 26 | def __init__(self, hp, data_dir, metadata_path, args, train, device): 27 | random.seed(hp.train.seed) 28 | self.hp = hp 29 | self.args = args 30 | self.train = train 31 | self.data_dir = data_dir 32 | metadata_path = os.path.join(data_dir, metadata_path) 33 | self.meta = self.load_metadata(metadata_path) 34 | self.stft = TacotronSTFT(hp.audio.filter_length, hp.audio.hop_length, hp.audio.win_length, 35 | hp.audio.n_mel_channels, hp.audio.sampling_rate, 36 | hp.audio.mel_fmin, hp.audio.mel_fmax, center=False, device=device) 37 | 38 | self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length 39 | self.shuffle = hp.train.spk_balanced 40 | 41 | if train and hp.train.spk_balanced: 42 | # balanced sampling for each speaker 43 | speaker_counter = Counter((spk_id \ 44 | for audiopath, text, spk_id in self.meta)) 45 | weights = [1.0 / speaker_counter[spk_id] \ 46 | for audiopath, text, spk_id in self.meta] 47 | 48 | self.mapping_weights = torch.DoubleTensor(weights) 49 | 50 | elif train: 51 | weights = [1.0 / len(self.meta) for _, _, _ in self.meta] 52 | self.mapping_weights = torch.DoubleTensor(weights) 53 | 54 | 55 | def __len__(self): 56 | return len(self.meta) 57 | 58 | def __getitem__(self, idx): 59 | if self.train: 60 | idx = torch.multinomial(self.mapping_weights, 1).item() 61 | return self.my_getitem(idx) 62 | else: 63 | return self.my_getitem(idx) 64 | 65 | def shuffle_mapping(self): 66 | random.shuffle(self.mapping_weights) 67 | 68 | def my_getitem(self, idx): 69 | wavpath, _, _ = self.meta[idx] 70 | wavpath = os.path.join(self.data_dir, wavpath) 71 | sr, audio = read_wav_np(wavpath) 72 | 73 | if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short: 74 | audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \ 75 | mode='constant', constant_values=0.0) 76 | 77 | audio = torch.from_numpy(audio).unsqueeze(0) 78 | mel = self.get_mel(wavpath) 79 | 80 | if self.train: 81 | max_mel_start = mel.size(1) - self.mel_segment_length -1 82 | mel_start = random.randint(0, max_mel_start) 83 | mel_end = mel_start + self.mel_segment_length 84 | mel = mel[:, mel_start:mel_end] 85 | 86 | audio_start = mel_start * self.hp.audio.hop_length 87 | audio_len = self.hp.audio.segment_length 88 | audio = audio[:, audio_start:audio_start + audio_len] 89 | 90 | return mel, audio 91 | 92 | def get_mel(self, wavpath): 93 | melpath = wavpath.replace('.wav', '.mel') 94 | try: 95 | mel = torch.load(melpath, map_location='cpu') 96 | assert mel.size(0) == self.hp.audio.n_mel_channels, \ 97 | 'Mel dimension mismatch: expected %d, got %d' % \ 98 | (self.hp.audio.n_mel_channels, mel.size(0)) 99 | 100 | except (FileNotFoundError, RuntimeError, TypeError, AssertionError): 101 | sr, wav = read_wav_np(wavpath) 102 | assert sr == self.hp.audio.sampling_rate, \ 103 | 'sample mismatch: expected %d, got %d at %s' % (self.hp.audio.sampling_rate, sr, wavpath) 104 | 105 | if len(wav) < self.hp.audio.segment_length + self.hp.audio.pad_short: 106 | wav = np.pad(wav, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(wav)), \ 107 | mode='constant', constant_values=0.0) 108 | 109 | wav = torch.from_numpy(wav).unsqueeze(0) 110 | mel = self.stft.mel_spectrogram(wav) 111 | 112 | mel = mel.squeeze(0) 113 | 114 | torch.save(mel, melpath) 115 | 116 | return mel 117 | 118 | def load_metadata(self, path, split="|"): 119 | metadata = [] 120 | with open(path, 'r', encoding='utf-8') as f: 121 | for line in f: 122 | stripped = line.strip().split(split) 123 | metadata.append(stripped) 124 | 125 | return metadata -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | 14 | 15 | Audio Samples from Unofficial Implementation of UnivNet vocoder 16 | 17 | 18 |

Audio Samples from Unofficial Implementation of UnivNet vocoder

19 |

Paper: Jang et al., UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

20 |

Implementation authors: Kang-wook Kim, Wonbin Jung @MINDsLab Inc.

21 |

GitHub repository: mindslab-ai/univnet 22 | 23 |

24 |

Please check the Github repository for the implementation details and the pre-trained model checkpoint.

25 |

The ground truth audio samples from the official demo page have a different volume from the original dataset. We believe the official ground truth audio samples have been post-processed. Therefore, the audio volume of our samples might have a different volume from the official version.

26 |

Notes: Both UnivNet-c16 and c32 results and the pre-trained weights have been uploaded. Our implementation matches the objective scores (PESQ and RMSE) of the original paper.

27 |
28 |

Synthesis for seen speakers

29 |
  • LibriTTS/train-clean-360 dataset was used to both train and evaluate the model.
  • 30 |
  • The first five texts are taken directly from the official demo page for comparison.
  • 31 |
    32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 |
    Text

    Nobody could touch the body until the coroner came.

    Then they came to another grove of trees, where all the leaves were of gold; and afterwards to a third, where the leaves were all glittering diamonds.

    The shaven face of the priest is a further item to the same effect.

    The lady who had spoken English approached the table as if looking for something, and when Beaton looked again, the portrait had been turned on its face.

    He shouted, 'Where art thou, ring?' And the ring said, 'I am here,' though it was on the bed of the ocean.

    The monarch, after making some inquiry into the rank and character of his rival, despatched the informer with a present of a pair of purple slippers, to complete the magnificence of his Imperial habit.

    GT
    Official UnivNet-c16
    -
    Official UnivNet-c32
    -
    Our UnivNet-c16
    Our UnivNet-c32
    101 |
    102 | 103 |

    Synthesis for unseen speakers

    104 |
  • LibriTTS/test-clean dataset was used to evaluate the model.
  • 105 |
  • The first five texts are taken directly from the official demo page for comparison.
  • 106 |
    107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 |
    Text

    The stars began to crumble and a cloud of fine stardust fell through space.

    On August 27, 1837, she writes:--

    But when his big brother heard that he had refused to give his cap for a King's golden crown, he said that Anders was a stupid.

    The only thing is, I don't know anything about technique and stagecraft and the three unities and that sort of rot.

    The marquis of Worcester, a man past eighty-four, was the last in England that submitted to the authority of the parliament.

    True history being a mixture of all things, the true historian mingles in everything.

    GT
    Official UnivNet-c16
    -
    Official UnivNet-c32
    -
    Our UnivNet-c16
    Our UnivNet-c32
    176 | 177 | -------------------------------------------------------------------------------- /docs/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/loss.png -------------------------------------------------------------------------------- /docs/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/model_architecture.png -------------------------------------------------------------------------------- /docs/samples/seen/c16/2004_147967_000029_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/2004_147967_000029_000002.wav -------------------------------------------------------------------------------- /docs/samples/seen/c16/337_126286_000008_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/337_126286_000008_000000.wav -------------------------------------------------------------------------------- /docs/samples/seen/c16/3537_5704_000008_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/3537_5704_000008_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/c16/5319_84357_000005_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/5319_84357_000005_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/c16/6294_86679_000035_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/6294_86679_000035_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/c16/949_134657_000002_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c16/949_134657_000002_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/2004_147967_000029_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/2004_147967_000029_000002.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/337_126286_000008_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/337_126286_000008_000000.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/3537_5704_000008_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/3537_5704_000008_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/5319_84357_000005_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/5319_84357_000005_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/6294_86679_000035_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/6294_86679_000035_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/c32/949_134657_000002_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/c32/949_134657_000002_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/2004_147967_000029_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/2004_147967_000029_000002.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/337_126286_000008_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/337_126286_000008_000000.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/3537_5704_000008_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/3537_5704_000008_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/5319_84357_000005_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/5319_84357_000005_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/6294_86679_000035_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/6294_86679_000035_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/ground_truth/949_134657_000002_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/ground_truth/949_134657_000002_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c16/2004_147967_000029_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/2004_147967_000029_000002.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c16/337_126286_000008_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/337_126286_000008_000000.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c16/3537_5704_000008_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/3537_5704_000008_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c16/5319_84357_000005_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/5319_84357_000005_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c16/6294_86679_000035_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c16/6294_86679_000035_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c32/2004_147967_000029_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/2004_147967_000029_000002.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c32/337_126286_000008_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/337_126286_000008_000000.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c32/3537_5704_000008_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/3537_5704_000008_000005.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c32/5319_84357_000005_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/5319_84357_000005_000004.wav -------------------------------------------------------------------------------- /docs/samples/seen/official_c32/6294_86679_000035_000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/seen/official_c32/6294_86679_000035_000004.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/1089_134686_000007_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/1089_134686_000007_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/3575_170457_000037_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/3575_170457_000037_000002.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/4507_16021_000029_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/4507_16021_000029_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/7021_85628_000037_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/7021_85628_000037_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/7176_92135_000006_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/7176_92135_000006_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c16/8224_274384_000016_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c16/8224_274384_000016_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/1089_134686_000007_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/1089_134686_000007_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/3575_170457_000037_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/3575_170457_000037_000002.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/4507_16021_000029_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/4507_16021_000029_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/7021_85628_000037_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/7021_85628_000037_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/7176_92135_000006_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/7176_92135_000006_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/c32/8224_274384_000016_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/c32/8224_274384_000016_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/1089_134686_000007_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/1089_134686_000007_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/3575_170457_000037_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/3575_170457_000037_000002.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/4507_16021_000029_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/4507_16021_000029_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/7021_85628_000037_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/7021_85628_000037_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/7176_92135_000006_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/7176_92135_000006_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/ground_truth/8224_274384_000016_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/ground_truth/8224_274384_000016_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c16/1089_134686_000007_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/1089_134686_000007_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c16/3575_170457_000037_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/3575_170457_000037_000002.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c16/7021_85628_000037_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/7021_85628_000037_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c16/7176_92135_000006_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/7176_92135_000006_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c16/8224_274384_000016_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c16/8224_274384_000016_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c32/1089_134686_000007_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/1089_134686_000007_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c32/3575_170457_000037_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/3575_170457_000037_000002.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c32/7021_85628_000037_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/7021_85628_000037_000000.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c32/7176_92135_000006_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/7176_92135_000006_000005.wav -------------------------------------------------------------------------------- /docs/samples/unseen/official_c32/8224_274384_000016_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maum-ai/univnet/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/docs/samples/unseen/official_c32/8224_274384_000016_000000.wav -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import torch 5 | import argparse 6 | from scipy.io.wavfile import write 7 | from omegaconf import OmegaConf 8 | 9 | from model.generator import Generator 10 | 11 | 12 | def main(args): 13 | checkpoint = torch.load(args.checkpoint_path) 14 | if args.config is not None: 15 | hp = OmegaConf.load(args.config) 16 | else: 17 | hp = OmegaConf.create(checkpoint['hp_str']) 18 | 19 | model = Generator(hp).cuda() 20 | saved_state_dict = checkpoint['model_g'] 21 | new_state_dict = {} 22 | 23 | for k, v in saved_state_dict.items(): 24 | try: 25 | new_state_dict[k] = saved_state_dict['module.' + k] 26 | except: 27 | new_state_dict[k] = v 28 | model.load_state_dict(new_state_dict) 29 | model.eval(inference=True) 30 | 31 | with torch.no_grad(): 32 | for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))): 33 | mel = torch.load(melpath) 34 | if len(mel.shape) == 2: 35 | mel = mel.unsqueeze(0) 36 | mel = mel.cuda() 37 | 38 | audio = model.inference(mel) 39 | audio = audio.cpu().detach().numpy() 40 | 41 | if args.output_folder is None: # if output folder is not defined, audio samples are saved in input folder 42 | out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) 43 | else: 44 | basename = os.path.basename(melpath) 45 | basename = basename.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) 46 | out_path = os.path.join(args.output_folder, basename) 47 | write(out_path, hp.audio.sampling_rate, audio) 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('-c', '--config', type=str, default=None, 53 | help="yaml file for config. will use hp_str from checkpoint if not given.") 54 | parser.add_argument('-p', '--checkpoint_path', type=str, required=True, 55 | help="path of checkpoint pt file for evaluation") 56 | parser.add_argument('-i', '--input_folder', type=str, required=True, 57 | help="directory of mel-spectrograms to invert into raw audio.") 58 | parser.add_argument('-o', '--output_folder', type=str, default=None, 59 | help="directory which generated raw audio is saved.") 60 | args = parser.parse_args() 61 | 62 | main(args) 63 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .mpd import MultiPeriodDiscriminator 5 | from .mrd import MultiResolutionDiscriminator 6 | from omegaconf import OmegaConf 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, hp): 10 | super(Discriminator, self).__init__() 11 | self.MRD = MultiResolutionDiscriminator(hp) 12 | self.MPD = MultiPeriodDiscriminator(hp) 13 | 14 | def forward(self, x): 15 | return self.MRD(x), self.MPD(x) 16 | 17 | if __name__ == '__main__': 18 | hp = OmegaConf.load('../config/default.yaml') 19 | model = Discriminator(hp) 20 | 21 | x = torch.randn(3, 1, 16384) 22 | print(x.shape) 23 | 24 | mrd_output, mpd_output = model(x) 25 | for features, score in mpd_output: 26 | for feat in features: 27 | print(feat.shape) 28 | print(score.shape) 29 | 30 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 31 | print(pytorch_total_params) 32 | 33 | -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from omegaconf import OmegaConf 4 | 5 | from .lvcnet import LVCBlock 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | class Generator(nn.Module): 10 | """UnivNet Generator""" 11 | def __init__(self, hp): 12 | super(Generator, self).__init__() 13 | self.mel_channel = hp.audio.n_mel_channels 14 | self.noise_dim = hp.gen.noise_dim 15 | self.hop_length = hp.audio.hop_length 16 | channel_size = hp.gen.channel_size 17 | kpnet_conv_size = hp.gen.kpnet_conv_size 18 | 19 | self.res_stack = nn.ModuleList() 20 | hop_length = 1 21 | for stride in hp.gen.strides: 22 | hop_length = stride * hop_length 23 | self.res_stack.append( 24 | LVCBlock( 25 | channel_size, 26 | hp.audio.n_mel_channels, 27 | stride=stride, 28 | dilations=hp.gen.dilations, 29 | lReLU_slope=hp.gen.lReLU_slope, 30 | cond_hop_length=hop_length, 31 | kpnet_conv_size=kpnet_conv_size 32 | ) 33 | ) 34 | 35 | self.conv_pre = \ 36 | nn.utils.weight_norm(nn.Conv1d(hp.gen.noise_dim, channel_size, 7, padding=3, padding_mode='reflect')) 37 | 38 | self.conv_post = nn.Sequential( 39 | nn.LeakyReLU(hp.gen.lReLU_slope), 40 | nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')), 41 | nn.Tanh(), 42 | ) 43 | 44 | def forward(self, c, z): 45 | ''' 46 | Args: 47 | c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) 48 | z (Tensor): the noise sequence (batch, noise_dim, in_length) 49 | 50 | ''' 51 | z = self.conv_pre(z) # (B, c_g, L) 52 | 53 | for res_block in self.res_stack: 54 | res_block.to(z.device) 55 | z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) 56 | 57 | z = self.conv_post(z) # (B, 1, L * 256) 58 | 59 | return z 60 | 61 | def eval(self, inference=False): 62 | super(Generator, self).eval() 63 | # don't remove weight norm while validation in training loop 64 | if inference: 65 | self.remove_weight_norm() 66 | 67 | def remove_weight_norm(self): 68 | print('Removing weight norm...') 69 | 70 | nn.utils.remove_weight_norm(self.conv_pre) 71 | 72 | for layer in self.conv_post: 73 | if len(layer.state_dict()) != 0: 74 | nn.utils.remove_weight_norm(layer) 75 | 76 | for res_block in self.res_stack: 77 | res_block.remove_weight_norm() 78 | 79 | def inference(self, c, z=None): 80 | # pad input mel with zeros to cut artifact 81 | # see https://github.com/seungwonpark/melgan/issues/8 82 | zero = torch.full((1, self.mel_channel, 10), -11.5129).to(c.device) 83 | mel = torch.cat((c, zero), dim=2) 84 | 85 | if z is None: 86 | z = torch.randn(1, self.noise_dim, mel.size(2)).to(mel.device) 87 | 88 | audio = self.forward(mel, z) 89 | audio = audio.squeeze() # collapse all dimension except time axis 90 | audio = audio[:-(self.hop_length*10)] 91 | audio = MAX_WAV_VALUE * audio 92 | audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) 93 | audio = audio.short() 94 | 95 | return audio 96 | 97 | if __name__ == '__main__': 98 | hp = OmegaConf.load('../config/default.yaml') 99 | model = Generator(hp) 100 | 101 | c = torch.randn(3, 100, 10) 102 | z = torch.randn(3, 64, 10) 103 | print(c.shape) 104 | 105 | y = model(c, z) 106 | print(y.shape) 107 | assert y.shape == torch.Size([3, 1, 2560]) 108 | 109 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 110 | print(pytorch_total_params) 111 | -------------------------------------------------------------------------------- /model/lvcnet.py: -------------------------------------------------------------------------------- 1 | """ refer from https://github.com/zceng/LVCNet """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class KernelPredictor(torch.nn.Module): 8 | ''' Kernel predictor for the location-variable convolutions''' 9 | def __init__( 10 | self, 11 | cond_channels, 12 | conv_in_channels, 13 | conv_out_channels, 14 | conv_layers, 15 | conv_kernel_size=3, 16 | kpnet_hidden_channels=64, 17 | kpnet_conv_size=3, 18 | kpnet_dropout=0.0, 19 | kpnet_nonlinear_activation="LeakyReLU", 20 | kpnet_nonlinear_activation_params={"negative_slope":0.1}, 21 | ): 22 | ''' 23 | Args: 24 | cond_channels (int): number of channel for the conditioning sequence, 25 | conv_in_channels (int): number of channel for the input sequence, 26 | conv_out_channels (int): number of channel for the output sequence, 27 | conv_layers (int): number of layers 28 | ''' 29 | super().__init__() 30 | 31 | self.conv_in_channels = conv_in_channels 32 | self.conv_out_channels = conv_out_channels 33 | self.conv_kernel_size = conv_kernel_size 34 | self.conv_layers = conv_layers 35 | 36 | kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w 37 | kpnet_bias_channels = conv_out_channels * conv_layers # l_b 38 | 39 | self.input_conv = nn.Sequential( 40 | nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), 41 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 42 | ) 43 | 44 | self.residual_convs = nn.ModuleList() 45 | padding = (kpnet_conv_size - 1) // 2 46 | for _ in range(3): 47 | self.residual_convs.append( 48 | nn.Sequential( 49 | nn.Dropout(kpnet_dropout), 50 | nn.utils.weight_norm(nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)), 51 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 52 | nn.utils.weight_norm(nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)), 53 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 54 | ) 55 | ) 56 | self.kernel_conv = nn.utils.weight_norm( 57 | nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True)) 58 | self.bias_conv = nn.utils.weight_norm( 59 | nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True)) 60 | 61 | def forward(self, c): 62 | ''' 63 | Args: 64 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 65 | ''' 66 | batch, _, cond_length = c.shape 67 | c = self.input_conv(c) 68 | for residual_conv in self.residual_convs: 69 | residual_conv.to(c.device) 70 | c = c + residual_conv(c) 71 | k = self.kernel_conv(c) 72 | b = self.bias_conv(c) 73 | kernels = k.contiguous().view( 74 | batch, 75 | self.conv_layers, 76 | self.conv_in_channels, 77 | self.conv_out_channels, 78 | self.conv_kernel_size, 79 | cond_length, 80 | ) 81 | bias = b.contiguous().view( 82 | batch, 83 | self.conv_layers, 84 | self.conv_out_channels, 85 | cond_length, 86 | ) 87 | 88 | return kernels, bias 89 | 90 | def remove_weight_norm(self): 91 | nn.utils.remove_weight_norm(self.input_conv[0]) 92 | nn.utils.remove_weight_norm(self.kernel_conv) 93 | nn.utils.remove_weight_norm(self.bias_conv) 94 | for block in self.residual_convs: 95 | nn.utils.remove_weight_norm(block[1]) 96 | nn.utils.remove_weight_norm(block[3]) 97 | 98 | class LVCBlock(torch.nn.Module): 99 | '''the location-variable convolutions''' 100 | def __init__( 101 | self, 102 | in_channels, 103 | cond_channels, 104 | stride, 105 | dilations=[1, 3, 9, 27], 106 | lReLU_slope=0.2, 107 | conv_kernel_size=3, 108 | cond_hop_length=256, 109 | kpnet_hidden_channels=64, 110 | kpnet_conv_size=3, 111 | kpnet_dropout=0.0, 112 | ): 113 | super().__init__() 114 | 115 | self.cond_hop_length = cond_hop_length 116 | self.conv_layers = len(dilations) 117 | self.conv_kernel_size = conv_kernel_size 118 | 119 | self.kernel_predictor = KernelPredictor( 120 | cond_channels=cond_channels, 121 | conv_in_channels=in_channels, 122 | conv_out_channels=2 * in_channels, 123 | conv_layers=len(dilations), 124 | conv_kernel_size=conv_kernel_size, 125 | kpnet_hidden_channels=kpnet_hidden_channels, 126 | kpnet_conv_size=kpnet_conv_size, 127 | kpnet_dropout=kpnet_dropout, 128 | kpnet_nonlinear_activation_params={"negative_slope":lReLU_slope} 129 | ) 130 | 131 | self.convt_pre = nn.Sequential( 132 | nn.LeakyReLU(lReLU_slope), 133 | nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride, padding=stride // 2 + stride % 2, output_padding=stride % 2)), 134 | ) 135 | 136 | self.conv_blocks = nn.ModuleList() 137 | for dilation in dilations: 138 | self.conv_blocks.append( 139 | nn.Sequential( 140 | nn.LeakyReLU(lReLU_slope), 141 | nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)), 142 | nn.LeakyReLU(lReLU_slope), 143 | ) 144 | ) 145 | 146 | def forward(self, x, c): 147 | ''' forward propagation of the location-variable convolutions. 148 | Args: 149 | x (Tensor): the input sequence (batch, in_channels, in_length) 150 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 151 | 152 | Returns: 153 | Tensor: the output sequence (batch, in_channels, in_length) 154 | ''' 155 | _, in_channels, _ = x.shape # (B, c_g, L') 156 | 157 | x = self.convt_pre(x) # (B, c_g, stride * L') 158 | kernels, bias = self.kernel_predictor(c) 159 | 160 | for i, conv in enumerate(self.conv_blocks): 161 | output = conv(x) # (B, c_g, stride * L') 162 | 163 | k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) 164 | b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) 165 | 166 | output = self.location_variable_convolution(output, k, b, hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC 167 | x = x + torch.sigmoid(output[ :, :in_channels, :]) * torch.tanh(output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU 168 | 169 | return x 170 | 171 | def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): 172 | ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. 173 | Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. 174 | Args: 175 | x (Tensor): the input sequence (batch, in_channels, in_length). 176 | kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) 177 | bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) 178 | dilation (int): the dilation of convolution. 179 | hop_size (int): the hop_size of the conditioning sequence. 180 | Returns: 181 | (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). 182 | ''' 183 | batch, _, in_length = x.shape 184 | batch, _, out_channels, kernel_size, kernel_length = kernel.shape 185 | assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" 186 | 187 | padding = dilation * int((kernel_size - 1) / 2) 188 | x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) 189 | x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) 190 | 191 | if hop_size < dilation: 192 | x = F.pad(x, (0, dilation), 'constant', 0) 193 | x = x.unfold(3, dilation, dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) 194 | x = x[:, :, :, :, :hop_size] 195 | x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) 196 | x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) 197 | 198 | o = torch.einsum('bildsk,biokl->bolsd', x, kernel) 199 | o = o.to(memory_format=torch.channels_last_3d) 200 | bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) 201 | o = o + bias 202 | o = o.contiguous().view(batch, out_channels, -1) 203 | 204 | return o 205 | 206 | def remove_weight_norm(self): 207 | self.kernel_predictor.remove_weight_norm() 208 | nn.utils.remove_weight_norm(self.convt_pre[1]) 209 | for block in self.conv_blocks: 210 | nn.utils.remove_weight_norm(block[1]) -------------------------------------------------------------------------------- /model/mpd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | 6 | class DiscriminatorP(nn.Module): 7 | def __init__(self, hp, period): 8 | super(DiscriminatorP, self).__init__() 9 | 10 | self.LRELU_SLOPE = hp.mpd.lReLU_slope 11 | self.period = period 12 | 13 | kernel_size = hp.mpd.kernel_size 14 | stride = hp.mpd.stride 15 | norm_f = weight_norm if hp.mpd.use_spectral_norm == False else spectral_norm 16 | 17 | self.convs = nn.ModuleList([ 18 | norm_f(nn.Conv2d(1, 64, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 19 | norm_f(nn.Conv2d(64, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 20 | norm_f(nn.Conv2d(128, 256, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 21 | norm_f(nn.Conv2d(256, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 22 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), 1, padding=(kernel_size // 2, 0))), 23 | ]) 24 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 25 | 26 | def forward(self, x): 27 | fmap = [] 28 | 29 | # 1d to 2d 30 | b, c, t = x.shape 31 | if t % self.period != 0: # pad first 32 | n_pad = self.period - (t % self.period) 33 | x = F.pad(x, (0, n_pad), "reflect") 34 | t = t + n_pad 35 | x = x.view(b, c, t // self.period, self.period) 36 | 37 | for l in self.convs: 38 | x = l(x) 39 | x = F.leaky_relu(x, self.LRELU_SLOPE) 40 | fmap.append(x) 41 | x = self.conv_post(x) 42 | fmap.append(x) 43 | x = torch.flatten(x, 1, -1) 44 | 45 | return fmap, x 46 | 47 | 48 | class MultiPeriodDiscriminator(nn.Module): 49 | def __init__(self, hp): 50 | super(MultiPeriodDiscriminator, self).__init__() 51 | 52 | self.discriminators = nn.ModuleList( 53 | [DiscriminatorP(hp, period) for period in hp.mpd.periods] 54 | ) 55 | 56 | def forward(self, x): 57 | ret = list() 58 | for disc in self.discriminators: 59 | ret.append(disc(x)) 60 | 61 | return ret # [(feat, score), (feat, score), (feat, score), (feat, score), (feat, score)] 62 | -------------------------------------------------------------------------------- /model/mrd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | 6 | class DiscriminatorR(torch.nn.Module): 7 | def __init__(self, hp, resolution): 8 | super(DiscriminatorR, self).__init__() 9 | 10 | self.resolution = resolution 11 | self.LRELU_SLOPE = hp.mpd.lReLU_slope 12 | 13 | norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm 14 | 15 | self.convs = nn.ModuleList([ 16 | norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), 17 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 18 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 19 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 20 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 21 | ]) 22 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 23 | 24 | def forward(self, x): 25 | fmap = [] 26 | 27 | x = self.spectrogram(x) 28 | x = x.unsqueeze(1) 29 | for l in self.convs: 30 | x = l(x) 31 | x = F.leaky_relu(x, self.LRELU_SLOPE) 32 | fmap.append(x) 33 | x = self.conv_post(x) 34 | fmap.append(x) 35 | x = torch.flatten(x, 1, -1) 36 | 37 | return fmap, x 38 | 39 | def spectrogram(self, x): 40 | n_fft, hop_length, win_length = self.resolution 41 | x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') 42 | x = x.squeeze(1) 43 | x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False) #[B, F, TT, 2] 44 | mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] 45 | 46 | return mag 47 | 48 | 49 | class MultiResolutionDiscriminator(torch.nn.Module): 50 | def __init__(self, hp): 51 | super(MultiResolutionDiscriminator, self).__init__() 52 | self.resolutions = eval(hp.mrd.resolutions) 53 | self.discriminators = nn.ModuleList( 54 | [DiscriminatorR(hp, resolution) for resolution in self.resolutions] 55 | ) 56 | 57 | def forward(self, x): 58 | ret = list() 59 | for disc in self.discriminators: 60 | ret.append(disc(x)) 61 | 62 | return ret # [(feat, score), (feat, score), (feat, score)] 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.8.1 2 | matplotlib==3.1.3 3 | numpy==1.17.4 4 | scipy==1.5.4 5 | torch==1.6.0 6 | tensorboard==2.3.0 7 | tqdm==4.61.2 8 | omegaconf==2.0.6 9 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | import torch 6 | import torch.multiprocessing as mp 7 | from omegaconf import OmegaConf 8 | 9 | from utils.train import train 10 | 11 | torch.backends.cudnn.benchmark = True 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-c', '--config', type=str, required=True, 17 | help="yaml file for configuration") 18 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 19 | help="path of checkpoint pt file to resume training") 20 | parser.add_argument('-n', '--name', type=str, required=True, 21 | help="name of the model for logging, saving checkpoint") 22 | args = parser.parse_args() 23 | 24 | hp = OmegaConf.load(args.config) 25 | with open(args.config, 'r') as f: 26 | hp_str = ''.join(f.readlines()) 27 | 28 | assert hp.audio.hop_length == 256, \ 29 | 'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length 30 | 31 | args.num_gpus = 0 32 | torch.manual_seed(hp.train.seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed(hp.train.seed) 35 | args.num_gpus = torch.cuda.device_count() 36 | print('Batch size per GPU :', hp.train.batch_size) 37 | else: 38 | pass 39 | 40 | if args.num_gpus > 1: 41 | mp.spawn(train, nprocs=args.num_gpus, 42 | args=(args, args.checkpoint_path, hp, hp_str,)) 43 | else: 44 | train(0, args, args.checkpoint_path, hp, hp_str) 45 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | data = np.transpose(data, (2, 0, 1)) 12 | return data 13 | 14 | 15 | def plot_waveform_to_numpy(waveform): 16 | fig, ax = plt.subplots(figsize=(12, 4)) 17 | ax.plot() 18 | ax.plot(range(len(waveform)), waveform, 19 | linewidth=0.1, alpha=0.7, color='blue') 20 | 21 | plt.xlabel("Samples") 22 | plt.ylabel("Amplitude") 23 | plt.ylim(-1, 1) 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | 30 | return data 31 | 32 | 33 | def plot_spectrogram_to_numpy(spectrogram): 34 | fig, ax = plt.subplots(figsize=(12, 4)) 35 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 36 | interpolation='none') 37 | plt.colorbar(im, ax=ax) 38 | plt.xlabel("Frames") 39 | plt.ylabel("Channels") 40 | plt.tight_layout() 41 | 42 | fig.canvas.draw() 43 | data = save_figure_to_numpy(fig) 44 | plt.close() 45 | return data 46 | -------------------------------------------------------------------------------- /utils/stft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jungil Kong 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import os 25 | import random 26 | import torch 27 | import torch.utils.data 28 | import numpy as np 29 | from librosa.util import normalize 30 | from scipy.io.wavfile import read 31 | from librosa.filters import mel as librosa_mel_fn 32 | 33 | 34 | class TacotronSTFT(torch.nn.Module): 35 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 36 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 37 | mel_fmax=None, center=False, device='cpu'): 38 | super(TacotronSTFT, self).__init__() 39 | self.n_mel_channels = n_mel_channels 40 | self.sampling_rate = sampling_rate 41 | self.n_fft = filter_length 42 | self.hop_size = hop_length 43 | self.win_size = win_length 44 | self.fmin = mel_fmin 45 | self.fmax = mel_fmax 46 | self.center = center 47 | 48 | mel = librosa_mel_fn( 49 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 50 | 51 | mel_basis = torch.from_numpy(mel).float().to(device) 52 | hann_window = torch.hann_window(win_length).to(device) 53 | 54 | self.register_buffer('mel_basis', mel_basis) 55 | self.register_buffer('hann_window', hann_window) 56 | 57 | def linear_spectrogram(self, y): 58 | assert (torch.min(y.data) >= -1) 59 | assert (torch.max(y.data) <= 1) 60 | 61 | y = torch.nn.functional.pad(y.unsqueeze(1), 62 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)), 63 | mode='reflect') 64 | y = y.squeeze(1) 65 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 66 | center=self.center, pad_mode='reflect', normalized=False, onesided=True) 67 | spec = torch.norm(spec, p=2, dim=-1) 68 | 69 | return spec 70 | 71 | def mel_spectrogram(self, y): 72 | """Computes mel-spectrograms from a batch of waves 73 | PARAMS 74 | ------ 75 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 76 | 77 | RETURNS 78 | ------- 79 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 80 | """ 81 | assert(torch.min(y.data) >= -1) 82 | assert(torch.max(y.data) <= 1) 83 | 84 | y = torch.nn.functional.pad(y.unsqueeze(1), 85 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)), 86 | mode='reflect') 87 | y = y.squeeze(1) 88 | 89 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 90 | center=self.center, pad_mode='reflect', normalized=False, onesided=True) 91 | 92 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 93 | 94 | spec = torch.matmul(self.mel_basis, spec) 95 | spec = self.spectral_normalize_torch(spec) 96 | 97 | return spec 98 | 99 | def spectral_normalize_torch(self, magnitudes): 100 | output = self.dynamic_range_compression_torch(magnitudes) 101 | return output 102 | 103 | def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5): 104 | return torch.log(torch.clamp(x, min=clip_val) * C) 105 | -------------------------------------------------------------------------------- /utils/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | Returns: 21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 22 | """ 23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 24 | real = x_stft[..., 0] 25 | imag = x_stft[..., 1] 26 | 27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 29 | 30 | 31 | class SpectralConvergengeLoss(torch.nn.Module): 32 | """Spectral convergence loss module.""" 33 | 34 | def __init__(self): 35 | """Initilize spectral convergence loss module.""" 36 | super(SpectralConvergengeLoss, self).__init__() 37 | 38 | def forward(self, x_mag, y_mag): 39 | """Calculate forward propagation. 40 | Args: 41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 43 | Returns: 44 | Tensor: Spectral convergence loss value. 45 | """ 46 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 47 | 48 | 49 | class LogSTFTMagnitudeLoss(torch.nn.Module): 50 | """Log STFT magnitude loss module.""" 51 | 52 | def __init__(self): 53 | """Initilize los STFT magnitude loss module.""" 54 | super(LogSTFTMagnitudeLoss, self).__init__() 55 | 56 | def forward(self, x_mag, y_mag): 57 | """Calculate forward propagation. 58 | Args: 59 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 60 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 61 | Returns: 62 | Tensor: Log STFT magnitude loss value. 63 | """ 64 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 65 | 66 | 67 | class STFTLoss(torch.nn.Module): 68 | """STFT loss module.""" 69 | 70 | def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 71 | """Initialize STFT loss module.""" 72 | super(STFTLoss, self).__init__() 73 | self.fft_size = fft_size 74 | self.shift_size = shift_size 75 | self.win_length = win_length 76 | self.window = getattr(torch, window)(win_length).to(device) 77 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 78 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 79 | 80 | def forward(self, x, y): 81 | """Calculate forward propagation. 82 | Args: 83 | x (Tensor): Predicted signal (B, T). 84 | y (Tensor): Groundtruth signal (B, T). 85 | Returns: 86 | Tensor: Spectral convergence loss value. 87 | Tensor: Log STFT magnitude loss value. 88 | """ 89 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 90 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 91 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 92 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 93 | 94 | return sc_loss, mag_loss 95 | 96 | 97 | class MultiResolutionSTFTLoss(torch.nn.Module): 98 | """Multi resolution STFT loss module.""" 99 | 100 | def __init__(self, 101 | device, 102 | resolutions, 103 | window="hann_window"): 104 | """Initialize Multi resolution STFT loss module. 105 | Args: 106 | resolutions (list): List of (FFT size, hop size, window length). 107 | window (str): Window function type. 108 | """ 109 | super(MultiResolutionSTFTLoss, self).__init__() 110 | self.stft_losses = torch.nn.ModuleList() 111 | for fs, ss, wl in resolutions: 112 | self.stft_losses += [STFTLoss(device, fs, ss, wl, window)] 113 | 114 | def forward(self, x, y): 115 | """Calculate forward propagation. 116 | Args: 117 | x (Tensor): Predicted signal (B, T). 118 | y (Tensor): Groundtruth signal (B, T). 119 | Returns: 120 | Tensor: Multi resolution spectral convergence loss value. 121 | Tensor: Multi resolution log STFT magnitude loss value. 122 | """ 123 | sc_loss = 0.0 124 | mag_loss = 0.0 125 | for f in self.stft_losses: 126 | sc_l, mag_l = f(x, y) 127 | sc_loss += sc_l 128 | mag_loss += mag_l 129 | 130 | sc_loss /= len(self.stft_losses) 131 | mag_loss /= len(self.stft_losses) 132 | 133 | return sc_loss, mag_loss 134 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import math 5 | import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributed import init_process_group 10 | from torch.nn.parallel import DistributedDataParallel 11 | import itertools 12 | import traceback 13 | 14 | from datasets.dataloader import create_dataloader 15 | from utils.writer import MyWriter 16 | from utils.stft import TacotronSTFT 17 | from utils.stft_loss import MultiResolutionSTFTLoss 18 | from model.generator import Generator 19 | from model.discriminator import Discriminator 20 | from .utils import get_commit_hash 21 | from .validation import validate 22 | 23 | 24 | def train(rank, args, chkpt_path, hp, hp_str): 25 | 26 | if args.num_gpus > 1: 27 | init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url, 28 | world_size=hp.dist_config.world_size * args.num_gpus, rank=rank) 29 | 30 | torch.cuda.manual_seed(hp.train.seed) 31 | device = torch.device('cuda:{:d}'.format(rank)) 32 | 33 | model_g = Generator(hp).to(device) 34 | model_d = Discriminator(hp).to(device) 35 | 36 | optim_g = torch.optim.AdamW(model_g.parameters(), 37 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) 38 | optim_d = torch.optim.AdamW(model_d.parameters(), 39 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) 40 | 41 | githash = get_commit_hash() 42 | 43 | init_epoch = -1 44 | step = 0 45 | 46 | # define logger, writer, valloader, stft at rank_zero 47 | if rank == 0: 48 | pt_dir = os.path.join(hp.log.chkpt_dir, args.name) 49 | log_dir = os.path.join(hp.log.log_dir, args.name) 50 | os.makedirs(pt_dir, exist_ok=True) 51 | os.makedirs(log_dir, exist_ok=True) 52 | 53 | logging.basicConfig( 54 | level=logging.INFO, 55 | format='%(asctime)s - %(levelname)s - %(message)s', 56 | handlers=[ 57 | logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))), 58 | logging.StreamHandler() 59 | ] 60 | ) 61 | logger = logging.getLogger() 62 | writer = MyWriter(hp, log_dir) 63 | valloader = create_dataloader(hp, args, False, device='cpu') 64 | stft = TacotronSTFT(filter_length=hp.audio.filter_length, 65 | hop_length=hp.audio.hop_length, 66 | win_length=hp.audio.win_length, 67 | n_mel_channels=hp.audio.n_mel_channels, 68 | sampling_rate=hp.audio.sampling_rate, 69 | mel_fmin=hp.audio.mel_fmin, 70 | mel_fmax=hp.audio.mel_fmax, 71 | center=False, 72 | device=device) 73 | 74 | if chkpt_path is not None: 75 | if rank == 0: 76 | logger.info("Resuming from checkpoint: %s" % chkpt_path) 77 | checkpoint = torch.load(chkpt_path) 78 | model_g.load_state_dict(checkpoint['model_g']) 79 | model_d.load_state_dict(checkpoint['model_d']) 80 | optim_g.load_state_dict(checkpoint['optim_g']) 81 | optim_d.load_state_dict(checkpoint['optim_d']) 82 | step = checkpoint['step'] 83 | init_epoch = checkpoint['epoch'] 84 | 85 | if rank == 0: 86 | if hp_str != checkpoint['hp_str']: 87 | logger.warning("New hparams is different from checkpoint. Will use new.") 88 | 89 | if githash != checkpoint['githash']: 90 | logger.warning("Code might be different: git hash is different.") 91 | logger.warning("%s -> %s" % (checkpoint['githash'], githash)) 92 | 93 | else: 94 | if rank == 0: 95 | logger.info("Starting new training run.") 96 | 97 | if args.num_gpus > 1: 98 | model_g = DistributedDataParallel(model_g, device_ids=[rank]).to(device) 99 | model_d = DistributedDataParallel(model_d, device_ids=[rank]).to(device) 100 | 101 | # this accelerates training when the size of minibatch is always consistent. 102 | # if not consistent, it'll horribly slow down. 103 | torch.backends.cudnn.benchmark = True 104 | 105 | trainloader = create_dataloader(hp, args, True, device='cpu') 106 | 107 | model_g.train() 108 | model_d.train() 109 | 110 | resolutions = eval(hp.mrd.resolutions) 111 | stft_criterion = MultiResolutionSTFTLoss(device, resolutions) 112 | 113 | for epoch in itertools.count(init_epoch+1): 114 | 115 | if rank == 0 and epoch % hp.log.validation_interval == 0: 116 | with torch.no_grad(): 117 | validate(hp, args, model_g, model_d, valloader, stft, writer, step, device) 118 | 119 | trainloader.dataset.shuffle_mapping() 120 | if rank == 0: 121 | loader = tqdm.tqdm(trainloader, desc='Loading train data') 122 | else: 123 | loader = trainloader 124 | 125 | for mel, audio in loader: 126 | 127 | mel = mel.to(device) 128 | audio = audio.to(device) 129 | noise = torch.randn(hp.train.batch_size, hp.gen.noise_dim, mel.size(2)).to(device) 130 | 131 | # generator 132 | optim_g.zero_grad() 133 | fake_audio = model_g(mel, noise) 134 | 135 | # Multi-Resolution STFT Loss 136 | sc_loss, mag_loss = stft_criterion(fake_audio.squeeze(1), audio.squeeze(1)) 137 | stft_loss = (sc_loss + mag_loss) * hp.train.stft_lamb 138 | 139 | res_fake, period_fake = model_d(fake_audio) 140 | 141 | score_loss = 0.0 142 | 143 | for (_, score_fake) in res_fake + period_fake: 144 | score_loss += torch.mean(torch.pow(score_fake - 1.0, 2)) 145 | 146 | score_loss = score_loss / len(res_fake + period_fake) 147 | 148 | loss_g = score_loss + stft_loss 149 | 150 | loss_g.backward() 151 | optim_g.step() 152 | 153 | # discriminator 154 | 155 | optim_d.zero_grad() 156 | res_fake, period_fake = model_d(fake_audio.detach()) 157 | res_real, period_real = model_d(audio) 158 | 159 | loss_d = 0.0 160 | for (_, score_fake), (_, score_real) in zip(res_fake + period_fake, res_real + period_real): 161 | loss_d += torch.mean(torch.pow(score_real - 1.0, 2)) 162 | loss_d += torch.mean(torch.pow(score_fake, 2)) 163 | 164 | loss_d = loss_d / len(res_fake + period_fake) 165 | 166 | loss_d.backward() 167 | optim_d.step() 168 | 169 | step += 1 170 | # logging 171 | loss_g = loss_g.item() 172 | loss_d = loss_d.item() 173 | 174 | if rank == 0 and step % hp.log.summary_interval == 0: 175 | writer.log_training(loss_g, loss_d, stft_loss.item(), score_loss.item(), step) 176 | loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d, step)) 177 | 178 | if rank == 0 and epoch % hp.log.save_interval == 0: 179 | save_path = os.path.join(pt_dir, '%s_%04d.pt' 180 | % (args.name, epoch)) 181 | torch.save({ 182 | 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(), 183 | 'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(), 184 | 'optim_g': optim_g.state_dict(), 185 | 'optim_d': optim_d.state_dict(), 186 | 'step': step, 187 | 'epoch': epoch, 188 | 'hp_str': hp_str, 189 | 'githash': githash, 190 | }, save_path) 191 | logger.info("Saved checkpoint to: %s" % save_path) 192 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import subprocess 3 | import numpy as np 4 | from scipy.io.wavfile import read 5 | 6 | 7 | def get_commit_hash(): 8 | message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) 9 | return message.strip().decode('utf-8') 10 | 11 | def read_wav_np(path): 12 | sr, wav = read(path) 13 | 14 | if len(wav.shape) == 2: 15 | wav = wav[:, 0] 16 | 17 | if wav.dtype == np.int16: 18 | wav = wav / 32768.0 19 | elif wav.dtype == np.int32: 20 | wav = wav / 2147483648.0 21 | elif wav.dtype == np.uint8: 22 | wav = (wav - 128) / 128.0 23 | 24 | wav = wav.astype(np.float32) 25 | 26 | return sr, wav 27 | -------------------------------------------------------------------------------- /utils/validation.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def validate(hp, args, generator, discriminator, valloader, stft, writer, step, device): 7 | generator.eval() 8 | discriminator.eval() 9 | torch.backends.cudnn.benchmark = False 10 | 11 | loader = tqdm.tqdm(valloader, desc='Validation loop') 12 | mel_loss = 0.0 13 | for idx, (mel, audio) in enumerate(loader): 14 | mel = mel.to(device) 15 | audio = audio.to(device) 16 | noise = torch.randn(1, hp.gen.noise_dim, mel.size(2)).to(device) 17 | 18 | fake_audio = generator(mel, noise)[:,:,:audio.size(2)] 19 | 20 | mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1)) 21 | mel_real = stft.mel_spectrogram(audio.squeeze(1)) 22 | 23 | mel_loss += F.l1_loss(mel_fake, mel_real).item() 24 | 25 | if idx < hp.log.num_audio: 26 | spec_fake = stft.linear_spectrogram(fake_audio.squeeze(1)) 27 | spec_real = stft.linear_spectrogram(audio.squeeze(1)) 28 | 29 | audio = audio[0][0].cpu().detach().numpy() 30 | fake_audio = fake_audio[0][0].cpu().detach().numpy() 31 | spec_fake = spec_fake[0].cpu().detach().numpy() 32 | spec_real = spec_real[0].cpu().detach().numpy() 33 | writer.log_fig_audio(audio, fake_audio, spec_fake, spec_real, idx, step) 34 | 35 | mel_loss = mel_loss / len(valloader.dataset) 36 | 37 | writer.log_validation(mel_loss, generator, discriminator, step) 38 | 39 | torch.backends.cudnn.benchmark = True 40 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import numpy as np 3 | import librosa 4 | 5 | from .plotting import plot_waveform_to_numpy, plot_spectrogram_to_numpy 6 | 7 | class MyWriter(SummaryWriter): 8 | def __init__(self, hp, logdir): 9 | super(MyWriter, self).__init__(logdir) 10 | self.sample_rate = hp.audio.sampling_rate 11 | self.is_first = True 12 | 13 | def log_training(self, g_loss, d_loss, stft_loss, score_loss, step): 14 | self.add_scalar('train/g_loss', g_loss, step) 15 | self.add_scalar('train/d_loss', d_loss, step) 16 | 17 | self.add_scalar('train/score_loss', score_loss, step) 18 | self.add_scalar('train/stft_loss', stft_loss, step) 19 | 20 | def log_validation(self, mel_loss, generator, discriminator, step): 21 | self.add_scalar('validation/mel_loss', mel_loss, step) 22 | 23 | self.log_histogram(generator, step) 24 | self.log_histogram(discriminator, step) 25 | if self.is_first: 26 | self.is_first = False 27 | 28 | def log_fig_audio(self, target, prediction, spec_fake, spec_real, idx, step): 29 | if idx == 0: 30 | spec_fake = librosa.amplitude_to_db(spec_fake, ref=np.max,top_db=80.) 31 | spec_real = librosa.amplitude_to_db(spec_real, ref=np.max,top_db=80.) 32 | self.add_image('spec/predicted', plot_spectrogram_to_numpy(spec_fake), step) 33 | self.add_image('spec/error', plot_spectrogram_to_numpy(np.power(spec_real - spec_fake, 2)), step) 34 | self.add_image('waveform/predicted', plot_waveform_to_numpy(prediction), step) 35 | if self.is_first: 36 | self.add_image('spec/target', plot_spectrogram_to_numpy(spec_real), step) 37 | self.add_image('waveform/target', plot_waveform_to_numpy(target), step) 38 | 39 | self.add_audio('predicted/raw_audio_%d' % idx, prediction, step, self.sample_rate) 40 | if self.is_first: 41 | self.add_audio('target/raw_audio_%d' % idx, target, step, self.sample_rate) 42 | 43 | 44 | def log_histogram(self, model, step): 45 | for tag, value in model.named_parameters(): 46 | self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step) 47 | --------------------------------------------------------------------------------