├── .gitignore ├── LICENSE ├── README.md ├── convert.py ├── encode.py ├── hubconf.py ├── requirements.txt ├── resample.py ├── segment.py ├── train_rhythm_model.py ├── train_segmenter.py ├── train_vocoder.py ├── urhythmic ├── __init__.py ├── dataset.py ├── model.py ├── rhythm.py ├── segmenter.py ├── stretcher.py ├── utils.py └── vocoder.py └── urhythmic_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # VSCode project settings 136 | .vscode 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Benjamin van Niekerk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Urhythmic: Rhythm Modeling for Voice Conversion 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2307.06040) 4 | [![demo](https://img.shields.io/static/v1?message=Audio%20Samples&logo=Github&labelColor=grey&color=blue&logoColor=white&label=%20&style=flat)](https://ubisoft-laforge.github.io/speech/urhythmic/) 5 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bshall/urhythmic/blob/main/urhythmic_demo.ipynb) 6 | 7 | Official repository for [Rhythm Modeling for Voice Conversion](https://arxiv.org/abs/2307.06040). 8 | Audio samples can be found [here](https://ubisoft-laforge.github.io/speech/urhythmic/). 9 | Colab demo can be found [here](https://colab.research.google.com/github/bshall/urhythmic/blob/main/urhythmic_demo.ipynb). 10 | 11 | **Abstract**: Voice conversion aims to transform source speech into a different target voice. However, typical voice conversion systems do not account for rhythm, which is an important factor in the perception of speaker identity. To bridge this gap, we introduce Urhythmic - an unsupervised method for rhythm conversion that does not require parallel data or text transcriptions. Using self-supervised representations, we first divide source audio into segments approximating sonorants, obstruents, and silences. Then we model rhythm by estimating speaking rate or the duration distribution of each segment type. Finally, we match the target speaking rate or rhythm by time-stretching the speech segments.Experiments show that Urhythmic outperforms existing unsupervised methods in terms of quality and prosody. 12 | 13 | Note: Urhythmic builds on soft speech units from our paper [A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion](https://github.com/bshall/soft-vc/). 14 | 15 | ## Example Usage 16 | 17 | ### Programmatic Usage 18 | 19 | ```python 20 | import torch, torchaudio 21 | 22 | # Load the HubertSoft content encoder (see https://github.com/bshall/hubert/) 23 | hubert = torch.hub.load("bshall/hubert:main", "hubert_soft").cuda() 24 | 25 | # Select the source and target speakers 26 | # Pretrained models are available for: 27 | # VCTK: p228, p268, p225, p232, p257, p231. 28 | # and LJSpeech. 29 | source, target = "p231", "p225" 30 | 31 | # Load Urhythmic (either urhythmic_fine or urhythmic_global) 32 | urhythmic, encode = torch.hub.load( 33 | "bshall/urhythmic:main", 34 | "urhythmic_fine", 35 | source_speaker=source, 36 | target_speaker=target, 37 | ) 38 | urhythmic.cuda() 39 | 40 | # Load the source audio 41 | wav, sr = torchaudio.load("path/to/wav") 42 | assert sr == 16000 43 | wav = wav.unsqueeze(0).cuda() 44 | 45 | # Convert to the target speaker 46 | with torch.inference_mode(): 47 | # Extract speech units and log probabilities 48 | units, log_probs = encode(hubert, wav) 49 | # Convert to the target speaker 50 | wav_ = urhythmic(units, log_probs) 51 | ``` 52 | 53 | ### Script-Based Usage 54 | 55 | ``` 56 | usage: convert.py [-h] [--extension EXTENSION] 57 | {urhythmic_fine,urhythmic_global} source-speaker target-speaker in-dir out-dir 58 | 59 | Convert audio samples using Urhythmic. 60 | 61 | positional arguments: 62 | {urhythmic_fine,urhythmic_global} 63 | available models (Urhythmic-Fine or Urhythmic-Global). 64 | source-speaker the source speaker: p228, p268, p225, p232, p257, p231, LJSpeech 65 | target-speaker the target speaker: p228, p268, p225, p232, p257, p231, LJSpeech 66 | in-dir path to the dataset directory. 67 | out-dir path to the output directory. 68 | 69 | options: 70 | -h, --help show this help message and exit 71 | --extension EXTENSION 72 | extension of the audio files (defaults to .wav). 73 | ``` 74 | 75 | ## Training 76 | 77 | Here we outline the training steps for [VCTK](https://datashare.ed.ac.uk/handle/10283/3443). However, it should be straight forward to extend the recipe to other datasets. 78 | 79 | 1. [Prepare the Dataset](#step-1-prepare-the-dataset) 80 | 2. [Extract Soft Speech Units and Log Probabilities](#step-2-extract-soft-speech-units-and-log-probabilities) 81 | 3. [Train the Segmenter](#step-3-train-the-segmenter) 82 | 4. [Segmentation and Clustering](#step-4-segmentation-and-clustering) 83 | 5. [Train the Rhythm Model](#step-5-train-the-rhythm-model) 84 | 6. [Train or Finetune the Vocoder](#step-6-train-or-finetune-the-vocoder) 85 | 86 | To apply `Urhythmic` to your own data, you can skip step 3. (i.e., the `Segmenter` doesn't need to be re-trained). 87 | 88 | ### Step 1: Prepare the Dataset 89 | 90 | Download and extract [VCTK](https://datashare.ed.ac.uk/handle/10283/3443). Split the data into `dev`, `test`, and `train` sets for a given speaker (e.g. `p225`). The resulting directory should have the following structure: 91 | 92 | ``` 93 | p225 94 | ├── dev 95 | │   ├── wavs 96 | │   │   ├── p225_025.wav 97 | │   │   ├── ... 98 | │   │   ├── p225_045.wav 99 | ├── test 100 | │   ├── wavs 101 | │   │   ├── p225_001.wav 102 | │   │   ├── ... 103 | │   │   ├── p225_024.wav 104 | ├── train 105 | │   ├── wavs 106 | │   │   ├── p225_046.wav 107 | │   │   ├── ... 108 | │   │   ├── p225_366.wav 109 | 110 | ``` 111 | 112 | Note that for VCTK we take the first 24 parallel utterances as the `test` set. 113 | 114 | Next, resample the audio to 16kHz using the `resample.py` script. 115 | The script will replace each file with a 16kHz version so remember to copy your data if you want to keep the originals. 116 | 117 | ``` 118 | usage: resample.py [-h] [--sample_rate SAMPLE_RATE] in-dir 119 | 120 | Resample an audio dataset. 121 | 122 | positional arguments: 123 | in-dir path to dataset directory. 124 | 125 | options: 126 | -h, --help show this help message and exit 127 | --sample_rate SAMPLE_RATE 128 | target sample rate (defaults to 16000). 129 | ``` 130 | 131 | For example: 132 | 133 | ``` 134 | python resample.py path/to/p225 135 | ``` 136 | 137 | ### Step 2: Extract Soft Speech Units and Log Probabilities 138 | 139 | Encode the `dev`, `test`, and `train` sets using HuBERT-Soft and the `encode.py` script: 140 | 141 | ``` 142 | usage: encode.py [-h] [--extension EXTENSION] in-dir out-dir 143 | 144 | Encode an audio dataset into soft speech units and the log probabilities of the associated discrete units. 145 | 146 | positional arguments: 147 | in-dir path to the dataset directory. 148 | out-dir path to the output directory. 149 | 150 | options: 151 | -h, --help show this help message and exit 152 | --extension EXTENSION 153 | extension of the audio files (defaults to .wav). 154 | ``` 155 | 156 | for example: 157 | 158 | ``` 159 | python encode.py path/to/p225/dev/wavs path/to/p225/dev 160 | ``` 161 | 162 | At this point the directory tree should look as follows: 163 | 164 | ``` 165 | p225 166 | ├── dev 167 | │   ├── wavs 168 | │   ├── soft 169 | │   ├── logprobs 170 | ├── test 171 | │   ├── wavs 172 | │   ├── soft 173 | │   ├── logprobs 174 | ├── train 175 | │   ├── wavs 176 | │   ├── soft 177 | │   ├── logprobs 178 | ``` 179 | 180 | ### Step 3: Train the Segmenter 181 | 182 | Cluster the discrete speech units and identify the cluster id corresponding to sonorants, obstruents, and silences using the `train_segmenter.py` script: 183 | 184 | ``` 185 | usage: train_segmenter.py [-h] dataset-dir checkpoint-path 186 | 187 | Cluster the codebook of discrete speech units and identify the cluster id 188 | corresponding to sonorants, obstruents, and silences. 189 | 190 | positional arguments: 191 | dataset-dir path to the directory of segmented speech. 192 | checkpoint-path path to save checkpoint. 193 | 194 | options: 195 | -h, --help show this help message and exit 196 | ``` 197 | 198 | for example: 199 | 200 | ``` 201 | python train_segmenter.py path/to/p225/dev/ path/to/checkpoints/segmenter.pt 202 | ``` 203 | 204 | ### Step 4: Segmentation and Clustering 205 | 206 | Segment the `dev`, `test`, and `train` sets using the `segment.py` script. 207 | Note, this script uses the [segmenter checkpoint](https://github.com/bshall/urhythmic/releases/tag/v0.1). 208 | You'll need to adapt the script to use your own checkpoint if required. 209 | 210 | ``` 211 | usage: segment.py [-h] in-dir out-dir 212 | 213 | Segment an audio dataset. 214 | 215 | positional arguments: 216 | in-dir path to the log probability directory. 217 | out-dir path to the output directory. 218 | 219 | options: 220 | -h, --help show this help message and exit 221 | ``` 222 | 223 | for example: 224 | 225 | ``` 226 | python segment.py path/to/p225/dev/logprobs path/to/p225/dev/segments/ 227 | ``` 228 | 229 | At this point the directory tree should look as follows: 230 | 231 | ``` 232 | p225 233 | ├── dev 234 | │   ├── wavs 235 | │   ├── soft 236 | │   ├── logprobs 237 | │   ├── segments 238 | ├── test 239 | │   ├── wavs 240 | │   ├── soft 241 | │   ├── logprobs 242 | │   ├── segments 243 | ├── train 244 | │   ├── wavs 245 | │   ├── soft 246 | │   ├── logprobs 247 | │   ├── segments 248 | ``` 249 | 250 | ### Step 5: Train the Rhythm Model 251 | 252 | Train the fine-grained or global rhythm model using the `train_rhythm_model.py` script: 253 | 254 | ``` 255 | usage: train_rhythm_model.py [-h] {fine,global} dataset-dir checkpoint-path 256 | 257 | Train the FineGrained or Global rhythm model. 258 | 259 | positional arguments: 260 | {fine,global} type of rhythm model (fine-grained or global). 261 | dataset-dir path to the directory of segmented speech. 262 | checkpoint-path path to save checkpoint. 263 | 264 | options: 265 | -h, --help show this help message and exit 266 | ``` 267 | 268 | for example: 269 | 270 | ``` 271 | python train_rhythm_model.py fine path/to/p225/train/segments path/to/checkpoints/rhythm-fine-p225.pt 272 | ``` 273 | 274 | ### Step 6: Train or Finetune the Vocoder 275 | 276 | Train or finetune the HiFiGAN vocoder. We recommend finetuning from the [LJSpeech checkpoint](https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-LJSpeech-ceb1368d.pt). 277 | 278 | 279 | ``` 280 | usage: train_vocoder.py [-h] [--resume RESUME] [--finetune | --no-finetune] dataset-dir checkpoint-dir 281 | 282 | Train or finetune the HiFiGAN vocoder. 283 | 284 | positional arguments: 285 | dataset-dir path to the preprocessed data directory 286 | checkpoint-dir path to the checkpoint directory 287 | 288 | options: 289 | -h, --help show this help message and exit 290 | --resume RESUME path to the checkpoint to resume from 291 | --finetune, --no-finetune 292 | whether to finetune 293 | 294 | ``` 295 | 296 | For example, to train from scratch: 297 | 298 | ``` 299 | python train_vocoder.py /path/to/p225 /path/to/checkpoints 300 | ``` 301 | 302 | To finetune, download the [LJSpeech checkpoint](https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-LJSpeech-ceb1368d.pt) and run: 303 | 304 | ``` 305 | python train_vocoder.py /path/to/p225 /path/to/checkpoints --resume hifigan-LJSpeech-ceb1368d.pt --finetune 306 | ``` 307 | 308 | ## Citation 309 | 310 | If you found this work helpful please consider citing our paper. 311 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torchaudio 9 | import torchaudio.functional as AF 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | SPEAKERS = ["p228", "p268", "p225", "p232", "p257", "p231", "LJSpeech"] 15 | 16 | 17 | def convert(args): 18 | logging.info("Loading HuBERT-Soft checkpoint") 19 | hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True).cuda() 20 | 21 | logging.info("Loading Urhythmic checkpoint") 22 | urhythmic, encode = torch.hub.load( 23 | "bshall/urhythmic:main", 24 | args.model, 25 | source_speaker=args.source, 26 | target_speaker=args.target, 27 | trust_repo=True, 28 | ) 29 | urhythmic.cuda() 30 | 31 | logging.info(f"Coverting {args.in_dir} to {args.target}") 32 | for in_path in tqdm(list(args.in_dir.rglob(f"*{args.extension}"))): 33 | wav, sr = torchaudio.load(in_path) 34 | wav = AF.resample(wav, sr, 16000) 35 | wav = wav.unsqueeze(0).cuda() 36 | 37 | with torch.inference_mode(): 38 | units, log_probs = encode(hubert, wav) 39 | wav = urhythmic(units, log_probs) 40 | 41 | out_path = args.out_dir / in_path.relative_to(args.in_dir) 42 | out_path.parent.mkdir(parents=True, exist_ok=True) 43 | torchaudio.save( 44 | out_path.with_suffix(args.extension), wav.squeeze(0).cpu(), 16000 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser( 50 | description="Convert audio samples using Urhythmic." 51 | ) 52 | parser.add_argument( 53 | "model", 54 | help="available models (Urhythmic-Fine or Urhythmic-Global).", 55 | choices=["urhythmic_fine", "urhythmic_global"], 56 | ) 57 | parser.add_argument( 58 | "source", 59 | metavar="source-speaker", 60 | help=f"the source speaker: {', '.join(SPEAKERS)}", 61 | choices=SPEAKERS, 62 | ) 63 | parser.add_argument( 64 | "target", 65 | metavar="target-speaker", 66 | help=f"the target speaker: {', '.join(SPEAKERS)}", 67 | choices=SPEAKERS, 68 | ) 69 | parser.add_argument( 70 | "in_dir", 71 | metavar="in-dir", 72 | help="path to the dataset directory.", 73 | type=Path, 74 | ) 75 | parser.add_argument( 76 | "out_dir", 77 | metavar="out-dir", 78 | help="path to the output directory.", 79 | type=Path, 80 | ) 81 | parser.add_argument( 82 | "--extension", 83 | help="extension of the audio files (defaults to .wav).", 84 | default=".wav", 85 | type=str, 86 | ) 87 | args = parser.parse_args() 88 | convert(args) 89 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torchaudio 10 | import torchaudio.functional as AF 11 | 12 | from urhythmic.model import encode 13 | 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def encode_dataset(args): 20 | logging.info("Loading hubert checkpoint") 21 | hubert = torch.hub.load("bshall/hubert:main", "hubert_soft").cuda() 22 | 23 | logging.info(f"Encoding dataset at {args.in_dir}") 24 | for in_path in tqdm(list(args.in_dir.rglob(f"*{args.extension}"))): 25 | wav, sr = torchaudio.load(in_path) 26 | if sr != 16000: 27 | raise ValueError(f"Sample rate: {sr} should be 16kHz.") 28 | wav = wav.unsqueeze(0).cuda() 29 | 30 | with torch.inference_mode(): 31 | units, log_probs = encode(hubert, wav) 32 | 33 | units_out_path = args.out_dir / "soft" / in_path.relative_to(args.in_dir) 34 | units_out_path.parent.mkdir(parents=True, exist_ok=True) 35 | np.save(units_out_path.with_suffix(".npy"), units.squeeze().cpu().numpy()) 36 | 37 | probs_out_path = args.out_dir / "logprobs" / in_path.relative_to(args.in_dir) 38 | probs_out_path.parent.mkdir(parents=True, exist_ok=True) 39 | np.save(probs_out_path.with_suffix(".npy"), log_probs.squeeze().cpu().numpy()) 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser( 44 | description="Encode an audio dataset into soft speech units and the log probabilities of the associated discrete units." 45 | ) 46 | parser.add_argument( 47 | "in_dir", 48 | metavar="in-dir", 49 | help="path to the dataset directory.", 50 | type=Path, 51 | ) 52 | parser.add_argument( 53 | "out_dir", 54 | metavar="out-dir", 55 | help="path to the output directory.", 56 | type=Path, 57 | ) 58 | parser.add_argument( 59 | "--extension", 60 | help="extension of the audio files (defaults to .wav).", 61 | default=".wav", 62 | type=str, 63 | ) 64 | args = parser.parse_args() 65 | encode_dataset(args) 66 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "torchaudio", "numpy", "scipy", "numba", "sklearn"] 2 | 3 | URLS = { 4 | "segmenter-3": "https://github.com/bshall/urhythmic/releases/download/v0.1/segmenter-3-61beaeac.pt", 5 | "segmenter-8": "https://github.com/bshall/urhythmic/releases/download/v0.1/segmenter-8-b3d14f93.pt", 6 | "rhythm-model-fine-grained": "https://github.com/bshall/urhythmic/releases/download/v0.1/rhythm-fine-143621e1.pt", 7 | "rhythm-model-global": "https://github.com/bshall/urhythmic/releases/download/v0.1/rhythm-global-745d52d8.pt", 8 | "hifigan-p228": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p228-4ab1748f.pt", 9 | "hifigan-p268": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p268-36a1d51a.pt", 10 | "hifigan-p225": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p225-cc447edc.pt", 11 | "hifigan-p232": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p232-e0efc4c3.pt", 12 | "hifigan-p257": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p257-06fd495b.pt", 13 | "hifigan-p231": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p231-250198a1.pt", 14 | "hifigan-LJSpeech": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-LJSpeech-ceb1368d.pt", 15 | } 16 | 17 | SPEAKERS = {"p228", "p268", "p225", "p232", "p257", "p231", "LJSpeech"} 18 | 19 | from typing import Tuple, Callable 20 | 21 | import torch 22 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 23 | 24 | from urhythmic.model import UrhythmicFine, UrhythmicGlobal, encode 25 | from urhythmic.segmenter import Segmenter 26 | from urhythmic.rhythm import RhythmModelFineGrained, RhythmModelGlobal 27 | from urhythmic.stretcher import TimeStretcherFineGrained, TimeStretcherGlobal 28 | from urhythmic.vocoder import HifiganGenerator, HifiganDiscriminator 29 | 30 | 31 | def segmenter( 32 | num_clusters: int, 33 | gamma: float = 2, 34 | pretrained: bool = True, 35 | progress=True, 36 | ) -> Segmenter: 37 | """Segmentation and clustering block. Groups similar speech units into short segments. 38 | The segments are then combined into coarser groups approximating sonorants, obstruents, and silences. 39 | 40 | Args: 41 | num_clusters (int): number of clusters used for agglomerative clustering. 42 | gamma (float): regularizer weight encouraging longer segments 43 | pretrained (bool): load pretrained weights into the model. 44 | progress (bool): show progress bar when downloading model. 45 | 46 | Returns: 47 | Segmenter: the segmentation and clustering block (optionally pretrained). 48 | """ 49 | segmenter = Segmenter(num_clusters=num_clusters, gamma=gamma) 50 | if pretrained: 51 | checkpoint = torch.hub.load_state_dict_from_url( 52 | URLS[f"segmenter-{num_clusters}"], 53 | progress=progress, 54 | ) 55 | segmenter.load_state_dict(checkpoint) 56 | return segmenter 57 | 58 | 59 | def rhythm_model_fine_grained( 60 | source_speaker: None | str, 61 | target_speaker: None | str, 62 | pretrained: bool = True, 63 | progress=True, 64 | ) -> RhythmModelFineGrained: 65 | """Rhythm modeling block (Fine-Grained). Estimates the duration distribution of each sound type. 66 | 67 | Available speakers: 68 | VCTK: p228, p268, p225, p232, p257, p231. 69 | LJSpeech. 70 | 71 | Args: 72 | source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers. 73 | target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers. 74 | pretrained (bool): load pretrained weights into the model. 75 | progress (bool): show progress bar when downloading model. 76 | 77 | Returns: 78 | RhythmModelFineGrained: the fine-grained rhythm modeling block (optionally preloaded with source and target duration models). 79 | """ 80 | if source_speaker is not None and source_speaker not in SPEAKERS: 81 | raise ValueError(f"source speaker is not in available set: {SPEAKERS}") 82 | if target_speaker is not None and target_speaker not in SPEAKERS: 83 | raise ValueError(f"target speaker is not in available set: {SPEAKERS}") 84 | 85 | rhythm_model = RhythmModelFineGrained() 86 | if pretrained: 87 | checkpoint = torch.hub.load_state_dict_from_url( 88 | URLS["rhythm-model-fine-grained"], 89 | progress=progress, 90 | ) 91 | state_dict = {} 92 | if target_speaker: 93 | state_dict["target"] = checkpoint[target_speaker] 94 | if source_speaker: 95 | state_dict["source"] = checkpoint[source_speaker] 96 | rhythm_model.load_state_dict(state_dict) 97 | return rhythm_model 98 | 99 | 100 | def rhythm_model_global( 101 | source_speaker: None | str, 102 | target_speaker: None | str, 103 | pretrained: bool = True, 104 | progress=True, 105 | ) -> RhythmModelGlobal: 106 | """Rhythm modeling block (Global). Estimates speaking rate. 107 | 108 | Available speakers: 109 | VCTK: p228, p268, p225, p232, p257, p231. 110 | LJSpeech. 111 | 112 | Args: 113 | source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers. 114 | target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers. 115 | pretrained (bool): load pretrained weights into the model. 116 | progress (bool): show progress bar when downloading model. 117 | 118 | Returns: 119 | RhythmModelGlobal: the global rhythm modeling block (optionally preloaded with source and target speaking rates). 120 | """ 121 | if source_speaker is not None and source_speaker not in SPEAKERS: 122 | raise ValueError(f"source speaker is not in available set: {SPEAKERS}") 123 | if target_speaker is not None and target_speaker not in SPEAKERS: 124 | raise ValueError(f"target speaker is not in available set: {SPEAKERS}") 125 | 126 | rhythm_model = RhythmModelGlobal() 127 | if pretrained: 128 | checkpoint = torch.hub.load_state_dict_from_url( 129 | URLS["rhythm-model-global"], 130 | progress=progress, 131 | ) 132 | state_dict = {} 133 | if target_speaker: 134 | state_dict["target_rate"] = checkpoint[target_speaker] 135 | if source_speaker: 136 | state_dict["source_rate"] = checkpoint[source_speaker] 137 | rhythm_model.load_state_dict(state_dict) 138 | return rhythm_model 139 | 140 | 141 | def hifigan_generator( 142 | speaker: None | str, 143 | pretrained: bool = True, 144 | progress: bool = True, 145 | map_location=None, 146 | ) -> HifiganGenerator: 147 | """HifiGAN Generator. 148 | 149 | Available speakers: 150 | VCTK: p228, p268, p225, p232, p257, p231. 151 | LJSpeech. 152 | 153 | Args: 154 | speaker (None | str): the target speaker. None to fit your own speaker or a selection from the available speakers. 155 | pretrained (bool): load pretrained weights into the model. 156 | progress (bool): show progress bar when downloading model. 157 | map_location: function or a dict specifying how to remap storage locations (see torch.load) 158 | 159 | Returns: 160 | HifiganGenerator: the HifiGAN Generator (pretrained on LJSpeech or one of the VCTK speakers). 161 | """ 162 | if speaker is not None and speaker not in SPEAKERS: 163 | raise ValueError(f"target speaker is not in available set: {SPEAKERS}") 164 | 165 | hifigan = HifiganGenerator() 166 | if pretrained: 167 | checkpoint = torch.hub.load_state_dict_from_url( 168 | URLS[f"hifigan-{speaker}"], map_location=map_location, progress=progress 169 | ) 170 | consume_prefix_in_state_dict_if_present( 171 | checkpoint["generator"]["model"], "module." 172 | ) 173 | hifigan.load_state_dict(checkpoint["generator"]["model"]) 174 | hifigan.eval() 175 | hifigan.remove_weight_norm() 176 | return hifigan 177 | 178 | 179 | def hifigan_discriminator( 180 | pretrained: bool = True, progress: bool = True, map_location=None 181 | ) -> HifiganDiscriminator: 182 | """HifiGAN Discriminator. 183 | 184 | Args: 185 | pretrained (bool): load pretrained weights into the model. 186 | progress (bool): show progress bar when downloading model. 187 | map_location: function or a dict specifying how to remap storage locations (see torch.load) 188 | 189 | Returns: 190 | HifiganDiscriminator: the HifiGAN Discriminator (pretrained on LJSpeech). 191 | """ 192 | discriminator = HifiganDiscriminator() 193 | if pretrained: 194 | checkpoint = torch.hub.load_state_dict_from_url( 195 | URLS["hifigan-LJSpeech"], map_location=map_location, progress=progress 196 | ) 197 | consume_prefix_in_state_dict_if_present( 198 | checkpoint["discriminator"]["model"], "module." 199 | ) 200 | discriminator.load_state_dict(checkpoint["discriminator"]["model"]) 201 | discriminator.eval() 202 | return discriminator 203 | 204 | 205 | def urhythmic_fine( 206 | source_speaker: str | None, 207 | target_speaker: str | None, 208 | pretrained: bool = True, 209 | progress: bool = True, 210 | map_location=None, 211 | ) -> Tuple[UrhythmicFine, Callable]: 212 | """Urhythmic (Fine-Grained), a voice and rhythm conversion system that does not require text or parallel data. 213 | 214 | Available speakers: 215 | VCTK: p228, p268, p225, p232, p257, p231. 216 | LJSpeech. 217 | 218 | Args: 219 | source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers. 220 | target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers. 221 | pretrained (bool): load pretrained weights into the model. 222 | progress (bool): show progress bar when downloading model. 223 | map_location: function or a dict specifying how to remap storage locations (see torch.load) 224 | 225 | Returns: 226 | UrhythmicFine: the Fine-Grained Urhythmic model. 227 | Callable: the encode function to extract soft speech units and log probabilies using HubertSoft. 228 | """ 229 | seg = segmenter(num_clusters=3, gamma=2, pretrained=pretrained, progress=progress) 230 | rhythm_model = rhythm_model_fine_grained( 231 | source_speaker=source_speaker, 232 | target_speaker=target_speaker, 233 | pretrained=pretrained, 234 | progress=progress, 235 | ) 236 | time_stretcher = TimeStretcherFineGrained() 237 | vocoder = hifigan_generator( 238 | speaker=target_speaker, 239 | pretrained=pretrained, 240 | progress=progress, 241 | map_location=map_location, 242 | ) 243 | return UrhythmicFine(seg, rhythm_model, time_stretcher, vocoder), encode 244 | 245 | 246 | def urhythmic_global( 247 | source_speaker: str | None, 248 | target_speaker: str | None, 249 | pretrained: bool = True, 250 | progress: bool = True, 251 | map_location=None, 252 | ) -> Tuple[UrhythmicGlobal, Callable]: 253 | """Urhythmic (Global), a voice and rhythm conversion system that does not require text or parallel data. 254 | 255 | Available speakers: 256 | VCTK: p228, p268, p225, p232, p257, p231. 257 | LJSpeech. 258 | 259 | Args: 260 | source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers. 261 | target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers. 262 | pretrained (bool): load pretrained weights into the model. 263 | progress (bool): show progress bar when downloading model. 264 | map_location: function or a dict specifying how to remap storage locations (see torch.load) 265 | 266 | Returns: 267 | UrhythmicFine: the Fine-Grained Urhythmic model. 268 | Callable: the encode function to extract soft speech units and log probabilies using HubertSoft. 269 | """ 270 | seg = segmenter(num_clusters=3, gamma=2, pretrained=pretrained, progress=progress) 271 | rhythm_model = rhythm_model_global( 272 | source_speaker=source_speaker, 273 | target_speaker=target_speaker, 274 | pretrained=pretrained, 275 | progress=progress, 276 | ) 277 | time_stretcher = TimeStretcherGlobal() 278 | vocoder = hifigan_generator( 279 | speaker=target_speaker, 280 | pretrained=pretrained, 281 | progress=progress, 282 | map_location=map_location, 283 | ) 284 | return UrhythmicGlobal(seg, rhythm_model, time_stretcher, vocoder), encode 285 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchaudio==2.0.2 3 | numpy==1.24.3 4 | scipy==1.10.1 5 | numba==0.57.0 6 | scikit-learn==1.2.2 7 | tqdm==4.65.0 8 | librosa==0.10.0.post2 9 | webrtcvad==2.0.10 10 | tensorboard==2.13.0 11 | matplotlib==3.7.1 -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | from concurrent.futures import ProcessPoolExecutor 6 | from tqdm import tqdm 7 | 8 | import torchaudio 9 | import torchaudio.functional as AF 10 | import numpy as np 11 | import itertools 12 | 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def resample_file(path, sample_rate): 19 | wav, sr = torchaudio.load(path) 20 | wav = AF.resample(wav, sr, sample_rate) 21 | torchaudio.save(path, wav, sample_rate) 22 | return wav.size(-1) / sample_rate 23 | 24 | 25 | def resample_dataset(args): 26 | logger.info(f"Resampling dataset at {args.in_dir}") 27 | paths = list(args.in_dir.rglob("*.wav")) 28 | with ProcessPoolExecutor(max_workers=4) as executor: 29 | results = list( 30 | tqdm( 31 | executor.map(resample_file, paths, itertools.repeat(args.sample_rate)), 32 | total=len(paths), 33 | ) 34 | ) 35 | logger.info(f"Processed {np.sum(results) / 60 / 60:4f} hours of audio.") 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="Resample an audio dataset.") 40 | parser.add_argument( 41 | "in_dir", 42 | metavar="in-dir", 43 | type=Path, 44 | help="path to dataset directory.", 45 | ) 46 | parser.add_argument( 47 | "--sample_rate", 48 | help="target sample rate (defaults to 16000).", 49 | type=int, 50 | default=16000, 51 | ) 52 | args = parser.parse_args() 53 | resample_dataset(args) 54 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | from concurrent.futures import ProcessPoolExecutor 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import numpy as np 10 | import itertools 11 | 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def segment_file(segmenter, in_path, out_path): 18 | log_probs = np.load(in_path) 19 | segments, boundaries = segmenter(log_probs) 20 | np.savez(out_path.with_suffix(".npz"), segments=segments, boundaries=boundaries) 21 | return log_probs.shape[0], np.mean(np.diff(boundaries)) 22 | 23 | 24 | def segment_dataset(args): 25 | logging.info("Loading segmenter checkpoint") 26 | segmenter = torch.hub.load("bshall/urhythmic:main", "segmenter", num_clusters=3) 27 | 28 | in_paths = list(args.in_dir.rglob("*.npy")) 29 | out_paths = [args.out_dir / path.relative_to(args.in_dir) for path in in_paths] 30 | 31 | logger.info("Setting up folder structure") 32 | for path in tqdm(out_paths): 33 | path.parent.mkdir(exist_ok=True, parents=True) 34 | 35 | logger.info("Segmenting dataset") 36 | with ProcessPoolExecutor(max_workers=4) as executor: 37 | results = list( 38 | tqdm( 39 | executor.map( 40 | segment_file, 41 | itertools.repeat(segmenter), 42 | in_paths, 43 | out_paths, 44 | ), 45 | total=len(in_paths), 46 | ) 47 | ) 48 | 49 | frames, boundary_length = zip(*results) 50 | logger.info(f"Segmented {sum(frames) * 0.02 / 60 / 60:.2f} hours of audio") 51 | logger.info( 52 | f"Average segment length: {np.mean(boundary_length) * 0.02:.4f} seconds" 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser(description="Segment an audio dataset.") 58 | parser.add_argument( 59 | "in_dir", 60 | metavar="in-dir", 61 | type=Path, 62 | help="path to the log probability directory.", 63 | ) 64 | parser.add_argument( 65 | "out_dir", metavar="out-dir", type=Path, help="path to the output directory." 66 | ) 67 | args = parser.parse_args() 68 | segment_dataset(args) 69 | -------------------------------------------------------------------------------- /train_rhythm_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from urhythmic.rhythm import RhythmModelFineGrained, RhythmModelGlobal 10 | 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | HOP_LENGTH = 320 16 | SAMPLE_RATE = 16000 17 | 18 | 19 | def train_rhythm_model(args): 20 | logger.info(f"Training {args.model} rhythm model on {args.dataset_dir}") 21 | 22 | model_type = RhythmModelFineGrained if args.model == "fine" else RhythmModelGlobal 23 | rhythm_model = model_type(hop_length=HOP_LENGTH, sample_rate=SAMPLE_RATE) 24 | 25 | utterances = [] 26 | for path in tqdm(list(args.dataset_dir.rglob("*.npz"))): 27 | file = np.load(path, allow_pickle=True) 28 | segments = list(file["segments"]) 29 | boundaries = list(file["boundaries"]) 30 | utterances.append((segments, boundaries)) 31 | 32 | dists = rhythm_model._fit(utterances) 33 | 34 | logger.info(f"Saving checkpoint to {args.checkpoint_path}") 35 | 36 | torch.save(dists, args.checkpoint_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser( 41 | description="Train the FineGrained or Global rhythm model." 42 | ) 43 | parser.add_argument( 44 | "model", 45 | help="type of rhythm model (fine-grained or global).", 46 | type=str, 47 | choices=["fine", "global"], 48 | ) 49 | parser.add_argument( 50 | "dataset_dir", 51 | metavar="dataset-dir", 52 | help="path to the directory of segmented speech.", 53 | type=Path, 54 | ) 55 | parser.add_argument( 56 | "checkpoint_path", 57 | metavar="checkpoint-path", 58 | help="path to save checkpoint.", 59 | type=Path, 60 | ) 61 | args = parser.parse_args() 62 | train_rhythm_model(args) 63 | -------------------------------------------------------------------------------- /train_segmenter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import webrtcvad 6 | import struct 7 | import numpy as np 8 | import librosa 9 | from tqdm import tqdm 10 | 11 | import torchaudio 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from urhythmic.segmenter import Segmenter 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | INT16_MAX = (2**15) - 1 21 | 22 | 23 | def mark_silences( 24 | vad: webrtcvad.Vad, 25 | wav: torch.Tensor, 26 | hop_length: int = 320, 27 | sample_rate: int = 16000, 28 | pad: int = 40, 29 | ): 30 | """Marks silent frames using webrtcvad. 31 | 32 | Args: 33 | vad (webrtcvad.Vad): instance of the webrtcvad.Vad class. 34 | wav (Tensor): audio waveform of shape (1, T) where T is the number of samples. 35 | hop_length (int): the hop length measured in number of frames (defaults to 320). 36 | sample_rate (int): the sample rate (defaults to 16kHz). 37 | pad (int): padding (defaults to 40) 38 | 39 | Returns: 40 | NDArray: array of booleans indicating whether each frame is silent. 41 | """ 42 | win_length = hop_length 43 | 44 | wav = F.pad(wav, (pad, pad)) # add padding to match HuBERT 45 | wav = wav[:, : wav.size(-1) - (wav.size(-1) % win_length)] 46 | 47 | pcm = struct.pack( 48 | "%dh" % wav.size(-1), 49 | *(np.round(wav.squeeze().numpy() * INT16_MAX)).astype(np.int16), 50 | ) 51 | 52 | flags = [] 53 | for window_start in range(0, wav.size(-1), hop_length): 54 | window_end = window_start + win_length 55 | flag = vad.is_speech(pcm[window_start * 2 : window_end * 2], sample_rate) 56 | flags.append(flag) 57 | return ~np.array(flags) 58 | 59 | 60 | def mark_voiced( 61 | wav: torch.Tensor, 62 | hop_length: int = 320, 63 | win_length: int = 1024, 64 | sample_rate: int = 16000, 65 | ): 66 | _, voiced_flags, _ = librosa.pyin( 67 | wav.squeeze().numpy(), 68 | fmin=librosa.note_to_hz("C2"), 69 | fmax=librosa.note_to_hz("C5"), 70 | sr=sample_rate, 71 | hop_length=hop_length, 72 | win_length=win_length, 73 | ) 74 | return voiced_flags 75 | 76 | 77 | def train_segmenter(args): 78 | logger.info(f"Training Segmenter on {args.dataset_dir}") 79 | 80 | segmenter = Segmenter(num_clusters=3) 81 | checkpoints = torch.hub.load_state_dict_from_url( 82 | "https://github.com/bshall/hubert/releases/download/v0.2/kmeans100-50f36a95.pt" 83 | ) 84 | codebook = checkpoints["cluster_centers_"].numpy() 85 | segmenter.cluster(codebook) 86 | 87 | vad = webrtcvad.Vad(2) 88 | 89 | wavs_dir = args.dataset_dir / "wavs" 90 | logprobs_dir = args.dataset_dir / "logprobs" 91 | 92 | logger.info("Extracting VAD and voicing flags") 93 | 94 | utterances = [] 95 | for wav_path in tqdm(list(wavs_dir.rglob("*.wav"))): 96 | log_prob_path = logprobs_dir / wav_path.relative_to(wavs_dir) 97 | 98 | wav, _ = torchaudio.load(wav_path) 99 | log_probs = np.load(log_prob_path.with_suffix(".npy")) 100 | 101 | segments, boundaries = segmenter._segment(log_probs) 102 | silences = mark_silences(vad, wav) 103 | voiced_flags = mark_voiced(wav) 104 | 105 | utterances.append((segments, boundaries, silences, voiced_flags)) 106 | 107 | logger.info("Identifying the cluster id corresponding to each sound type") 108 | sound_types = segmenter.identify(utterances) 109 | 110 | logger.info(f"cluster 0 - {sound_types[0]}") 111 | logger.info(f"cluster 1 - {sound_types[1]}") 112 | logger.info(f"cluster 2 - {sound_types[2]}") 113 | 114 | logger.info(f"Saving checkpoint to {args.checkpoint_path}") 115 | args.checkpoint_path.parent.mkdir(exist_ok=True, parents=True) 116 | torch.save(segmenter.state_dict(), args.checkpoint_path) 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser( 121 | description="""Cluster the codebook of discrete speech units 122 | and identify the cluster id corresponding to sonorants, obstruents, and silences. 123 | """ 124 | ) 125 | parser.add_argument( 126 | "dataset_dir", 127 | metavar="dataset-dir", 128 | help="path to the directory of segmented speech.", 129 | type=Path, 130 | ) 131 | parser.add_argument( 132 | "checkpoint_path", 133 | metavar="checkpoint-path", 134 | help="path to save checkpoint.", 135 | type=Path, 136 | ) 137 | args = parser.parse_args() 138 | train_segmenter(args) 139 | -------------------------------------------------------------------------------- /train_vocoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.distributed as dist 11 | from torch.utils.data.distributed import DistributedSampler 12 | import torch.multiprocessing as mp 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | from urhythmic.vocoder import ( 16 | HifiganGenerator, 17 | HifiganDiscriminator, 18 | feature_loss, 19 | discriminator_loss, 20 | generator_loss, 21 | ) 22 | from urhythmic.dataset import MelDataset, LogMelSpectrogram 23 | from urhythmic.utils import Metric, load_checkpoint, save_checkpoint 24 | 25 | import matplotlib 26 | 27 | matplotlib.use("Agg") 28 | import matplotlib.pylab as plt 29 | 30 | 31 | def plot_spectrogram(spectrogram): 32 | fig, ax = plt.subplots(figsize=(10, 2)) 33 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 34 | plt.colorbar(im, ax=ax) 35 | 36 | fig.canvas.draw() 37 | plt.close() 38 | 39 | return fig 40 | 41 | 42 | logging.basicConfig(level=logging.DEBUG) 43 | logger = logging.getLogger(__name__) 44 | 45 | 46 | BATCH_SIZE = 8 47 | SEGMENT_LENGTH = 8320 48 | HOP_LENGTH = 320 49 | SAMPLE_RATE = 16000 50 | BASE_LEARNING_RATE = 2e-4 51 | FINETUNE_LEARNING_RATE = 5e-5 52 | BETAS = (0.8, 0.99) 53 | LEARNING_RATE_DECAY = 0.999 54 | WEIGHT_DECAY = 1e-2 55 | STEPS = 3000000 56 | FINETUNE_STEPS = 50000 57 | LOG_INTERVAL = 25 58 | VALIDATION_INTERVAL = 1000 59 | NUM_GENERATED_EXAMPLES = 40 60 | CHECKPOINT_INTERVAL = 10000 61 | 62 | 63 | def train_model(rank, world_size, args): 64 | dist.init_process_group( 65 | "nccl", 66 | rank=rank, 67 | world_size=world_size, 68 | init_method="tcp://localhost:54321", 69 | ) 70 | 71 | log_dir = args.checkpoint_dir / "logs" 72 | log_dir.mkdir(exist_ok=True, parents=True) 73 | 74 | if rank == 0: 75 | logger.setLevel(logging.DEBUG) 76 | handler = logging.FileHandler(log_dir / f"{args.checkpoint_dir.stem}.log") 77 | handler.setLevel(logging.DEBUG) 78 | formatter = logging.Formatter( 79 | "%(asctime)s [%(levelname)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S" 80 | ) 81 | handler.setFormatter(formatter) 82 | logger.addHandler(handler) 83 | else: 84 | logger.setLevel(logging.ERROR) 85 | 86 | writer = SummaryWriter(log_dir) if rank == 0 else None 87 | 88 | generator = HifiganGenerator().to(rank) 89 | discriminator = HifiganDiscriminator().to(rank) 90 | 91 | generator = DDP(generator, device_ids=[rank]) 92 | discriminator = DDP(discriminator, device_ids=[rank]) 93 | 94 | optimizer_generator = optim.AdamW( 95 | generator.parameters(), 96 | lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, 97 | betas=BETAS, 98 | weight_decay=WEIGHT_DECAY, 99 | ) 100 | optimizer_discriminator = optim.AdamW( 101 | discriminator.parameters(), 102 | lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, 103 | betas=BETAS, 104 | weight_decay=WEIGHT_DECAY, 105 | ) 106 | 107 | scheduler_generator = optim.lr_scheduler.ExponentialLR( 108 | optimizer_generator, gamma=LEARNING_RATE_DECAY 109 | ) 110 | scheduler_discriminator = optim.lr_scheduler.ExponentialLR( 111 | optimizer_discriminator, gamma=LEARNING_RATE_DECAY 112 | ) 113 | 114 | train_dataset = MelDataset( 115 | root=args.dataset_dir, 116 | segment_length=SEGMENT_LENGTH, 117 | sample_rate=SAMPLE_RATE, 118 | hop_length=HOP_LENGTH, 119 | train=True, 120 | ) 121 | train_sampler = DistributedSampler(train_dataset, drop_last=True) 122 | train_loader = DataLoader( 123 | train_dataset, 124 | batch_size=BATCH_SIZE, 125 | sampler=train_sampler, 126 | num_workers=8, 127 | pin_memory=True, 128 | shuffle=False, 129 | drop_last=True, 130 | ) 131 | 132 | validation_dataset = MelDataset( 133 | root=args.dataset_dir, 134 | segment_length=SEGMENT_LENGTH, 135 | sample_rate=SAMPLE_RATE, 136 | hop_length=HOP_LENGTH, 137 | train=False, 138 | ) 139 | validation_loader = DataLoader( 140 | validation_dataset, 141 | batch_size=1, 142 | shuffle=False, 143 | num_workers=8, 144 | pin_memory=True, 145 | ) 146 | 147 | melspectrogram = LogMelSpectrogram().to(rank) 148 | 149 | if args.resume is not None: 150 | global_step, best_loss = load_checkpoint( 151 | load_path=args.resume, 152 | generator=generator, 153 | discriminator=discriminator, 154 | optimizer_generator=optimizer_generator, 155 | optimizer_discriminator=optimizer_discriminator, 156 | scheduler_generator=scheduler_generator, 157 | scheduler_discriminator=scheduler_discriminator, 158 | rank=rank, 159 | logger=logger, 160 | finetune=args.finetune, 161 | ) 162 | else: 163 | global_step, best_loss = 0, float("inf") 164 | 165 | n_epochs = (STEPS if not args.finetune else FINETUNE_STEPS) // len(train_loader) + 1 166 | start_epoch = global_step // len(train_loader) + 1 167 | 168 | logger.info("**" * 40) 169 | logger.info(f"batch size: {BATCH_SIZE}") 170 | logger.info(f"iterations per epoch: {len(train_loader)}") 171 | logger.info(f"total number of epochs: {n_epochs}") 172 | logger.info(f"started at epoch: {start_epoch}") 173 | logger.info("**" * 40 + "\n") 174 | 175 | for epoch in range(start_epoch, n_epochs + 1): 176 | train_sampler.set_epoch(epoch) 177 | 178 | generator.train() 179 | discriminator.train() 180 | average_loss_mel = Metric() 181 | average_loss_discriminator = Metric() 182 | average_loss_generator = Metric() 183 | average_validation_loss = Metric() 184 | for i, (wavs, units, tgts) in enumerate(train_loader, 1): 185 | wavs, units, tgts = (wavs.to(rank), units.to(rank), tgts.to(rank)) 186 | 187 | # Discriminator 188 | optimizer_discriminator.zero_grad() 189 | 190 | wavs_ = generator(units) 191 | mels_ = melspectrogram(wavs_) 192 | 193 | scores, _ = discriminator(wavs) 194 | scores_, _ = discriminator(wavs_.detach()) 195 | 196 | loss_discriminator, _, _ = discriminator_loss(scores, scores_) 197 | 198 | loss_discriminator.backward() 199 | optimizer_discriminator.step() 200 | 201 | # Generator 202 | optimizer_generator.zero_grad() 203 | 204 | scores, features = discriminator(wavs) 205 | scores_, features_ = discriminator(wavs_) 206 | 207 | loss_mel = F.l1_loss(mels_, tgts) 208 | loss_features = feature_loss(features, features_) 209 | loss_generator_adversarial, _ = generator_loss(scores_) 210 | loss_generator = ( 211 | 45 * loss_mel + 2 * loss_features + loss_generator_adversarial 212 | ) 213 | 214 | loss_generator.backward() 215 | optimizer_generator.step() 216 | 217 | global_step += 1 218 | 219 | average_loss_mel.update(loss_mel.item()) 220 | average_loss_discriminator.update(loss_discriminator.item()) 221 | average_loss_generator.update(loss_generator.item()) 222 | 223 | if rank == 0: 224 | if global_step % LOG_INTERVAL == 0: 225 | writer.add_scalar( 226 | "train/loss_mel", 227 | average_loss_mel.value, 228 | global_step, 229 | ) 230 | writer.add_scalar( 231 | "train/loss_discriminator", 232 | average_loss_discriminator.value, 233 | global_step, 234 | ) 235 | writer.add_scalar( 236 | "train/loss_generator", 237 | average_loss_generator.value, 238 | global_step, 239 | ) 240 | average_loss_mel.reset() 241 | average_loss_discriminator.reset() 242 | average_loss_generator.reset() 243 | 244 | if global_step % VALIDATION_INTERVAL == 0: 245 | generator.eval() 246 | 247 | average_validation_loss.reset() 248 | for j, (wavs, units, tgts) in enumerate(validation_loader, 1): 249 | wavs, units, tgts = (wavs.to(rank), units.to(rank), tgts.to(rank)) 250 | 251 | with torch.no_grad(): 252 | wavs_ = generator(units) 253 | mels_ = melspectrogram(wavs_) 254 | 255 | length = min(mels_.size(-1), tgts.size(-1)) 256 | 257 | loss_mel = F.l1_loss(mels_[..., :length], tgts[..., :length]) 258 | 259 | average_validation_loss.update(loss_mel.item()) 260 | 261 | if rank == 0: 262 | if j <= NUM_GENERATED_EXAMPLES: 263 | writer.add_audio( 264 | f"generated/wav_{j}", 265 | wavs_.squeeze(0), 266 | global_step, 267 | sample_rate=16000, 268 | ) 269 | writer.add_figure( 270 | f"generated/mel_{j}", 271 | plot_spectrogram(mels_.squeeze().cpu().numpy()), 272 | global_step, 273 | ) 274 | 275 | generator.train() 276 | discriminator.train() 277 | 278 | if rank == 0: 279 | writer.add_scalar( 280 | "validation/mel_loss", 281 | average_validation_loss.value, 282 | global_step, 283 | ) 284 | logger.info( 285 | f"valid -- epoch: {epoch}, mel loss: {average_validation_loss.value:.4f}" 286 | ) 287 | 288 | new_best = best_loss > average_validation_loss.value 289 | if new_best or global_step % CHECKPOINT_INTERVAL == 0: 290 | if new_best: 291 | logger.info("-------- new best model found!") 292 | best_loss = average_validation_loss.value 293 | 294 | if rank == 0: 295 | save_checkpoint( 296 | checkpoint_dir=args.checkpoint_dir, 297 | generator=generator, 298 | discriminator=discriminator, 299 | optimizer_generator=optimizer_generator, 300 | optimizer_discriminator=optimizer_discriminator, 301 | scheduler_generator=scheduler_generator, 302 | scheduler_discriminator=scheduler_discriminator, 303 | step=global_step, 304 | loss=average_validation_loss.value, 305 | best=new_best, 306 | logger=logger, 307 | ) 308 | 309 | scheduler_discriminator.step() 310 | scheduler_generator.step() 311 | 312 | logger.info(f"train -- epoch: {epoch}") 313 | 314 | dist.destroy_process_group() 315 | 316 | 317 | if __name__ == "__main__": 318 | parser = argparse.ArgumentParser( 319 | description="Train or finetune the HiFiGAN vocoder." 320 | ) 321 | parser.add_argument( 322 | "dataset_dir", 323 | metavar="dataset-dir", 324 | help="path to the preprocessed data directory", 325 | type=Path, 326 | ) 327 | parser.add_argument( 328 | "checkpoint_dir", 329 | metavar="checkpoint-dir", 330 | help="path to the checkpoint directory", 331 | type=Path, 332 | ) 333 | parser.add_argument( 334 | "--resume", 335 | help="path to the checkpoint to resume from", 336 | type=Path, 337 | ) 338 | parser.add_argument( 339 | "--finetune", 340 | help="whether to finetune", 341 | action=argparse.BooleanOptionalAction, 342 | ) 343 | args = parser.parse_args() 344 | 345 | # display training setup info 346 | logger.info(f"PyTorch version: {torch.__version__}") 347 | logger.info(f"CUDA version: {torch.version.cuda}") 348 | logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") 349 | logger.info(f"CUDNN enabled: {torch.backends.cudnn.enabled}") 350 | logger.info(f"CUDNN deterministic: {torch.backends.cudnn.deterministic}") 351 | logger.info(f"CUDNN benchmark: {torch.backends.cudnn.benchmark}") 352 | logger.info(f"# of GPUS: {torch.cuda.device_count()}") 353 | 354 | # clear handlers 355 | logger.handlers.clear() 356 | 357 | world_size = torch.cuda.device_count() 358 | mp.spawn( 359 | train_model, 360 | args=(world_size, args), 361 | nprocs=world_size, 362 | join=True, 363 | ) 364 | -------------------------------------------------------------------------------- /urhythmic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/urhythmic/976ba9a7e4ff56abeb1617a6324196920bc32661/urhythmic/__init__.py -------------------------------------------------------------------------------- /urhythmic/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | 9 | import torchaudio 10 | import torchaudio.transforms as transforms 11 | 12 | 13 | class LogMelSpectrogram(torch.nn.Module): 14 | def __init__( 15 | self, 16 | sample_rate: int = 16000, 17 | n_fft: int = 1024, 18 | win_length: int = 1024, 19 | hop_length: int = 320, 20 | n_mels: int = 80, 21 | ): 22 | super().__init__() 23 | self.melspctrogram = transforms.MelSpectrogram( 24 | sample_rate=sample_rate, 25 | n_fft=n_fft, 26 | win_length=win_length, 27 | hop_length=hop_length, 28 | center=False, 29 | power=1.0, 30 | norm="slaney", 31 | n_mels=n_mels, 32 | mel_scale="slaney", 33 | ) 34 | self.pad = (win_length - hop_length) // 2 35 | 36 | def forward(self, wav: torch.Tensor) -> torch.Tensor: 37 | wav = F.pad(wav, (self.pad, self.pad), "reflect") 38 | mel = self.melspctrogram(wav) 39 | logmel = torch.log(torch.clamp(mel, min=1e-5)) 40 | return logmel 41 | 42 | 43 | class MelDataset(Dataset): 44 | def __init__( 45 | self, 46 | root: Path, 47 | segment_length: int, 48 | sample_rate: int, 49 | hop_length: int, 50 | train: bool = True, 51 | ): 52 | split = "train" if train else "dev" 53 | self.wavs_dir = root / split / "wavs" 54 | self.units_dir = root / split / "soft" 55 | 56 | self.segment_length = segment_length 57 | self.sample_rate = sample_rate 58 | self.hop_length = hop_length 59 | self.train = train 60 | 61 | self.metadata = [ 62 | path.relative_to(self.wavs_dir).with_suffix("") 63 | for path in self.wavs_dir.rglob("*.wav") 64 | ] 65 | 66 | self.logmel = LogMelSpectrogram() 67 | 68 | def __len__(self) -> int: 69 | return len(self.metadata) 70 | 71 | def __getitem__( 72 | self, index: int 73 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 74 | path = self.metadata[index] 75 | wav_path = self.wavs_dir / path 76 | units_path = self.units_dir / path 77 | 78 | info = torchaudio.info(wav_path.with_suffix(".wav")) 79 | if info.sample_rate != self.sample_rate: 80 | raise ValueError( 81 | f"Sample rate {info.sample_rate} doesn't match target of {self.sample_rate}" 82 | ) 83 | 84 | units = torch.from_numpy(np.load(units_path.with_suffix(".npy"))) 85 | units = units.transpose(0, 1) 86 | 87 | units_frames_per_segment = math.floor(self.segment_length / self.hop_length) 88 | units_diff = units.size(0) - units_frames_per_segment if self.train else 0 89 | units_offset = random.randint(0, max(units_diff, 0)) 90 | 91 | frame_offset = self.hop_length * units_offset 92 | 93 | wav, _ = torchaudio.load( 94 | filepath=wav_path.with_suffix(".wav"), 95 | frame_offset=frame_offset if self.train else 0, 96 | num_frames=self.segment_length if self.train else -1, 97 | ) 98 | 99 | if wav.size(-1) < self.segment_length: 100 | wav = F.pad(wav, (0, self.segment_length - wav.size(-1))) 101 | 102 | tgt_logmel = self.logmel(wav.unsqueeze(0)).squeeze(0) 103 | 104 | if self.train: 105 | units = units[units_offset : units_offset + units_frames_per_segment, :] 106 | 107 | if units.size(0) < units_frames_per_segment: 108 | diff = units_frames_per_segment - units.size(0) 109 | units = F.pad(units, (0, 0, 0, diff), "constant", units.mean()) 110 | 111 | return wav, units.transpose(0, 1), tgt_logmel 112 | -------------------------------------------------------------------------------- /urhythmic/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from urhythmic.vocoder import HifiganGenerator 6 | from urhythmic.segmenter import Segmenter 7 | from urhythmic.stretcher import TimeStretcherFineGrained, TimeStretcherGlobal 8 | from urhythmic.rhythm import RhythmModelFineGrained, RhythmModelGlobal 9 | 10 | 11 | @torch.inference_mode() 12 | def encode(hubert, wav): 13 | r"""Encode an audio waveform into soft speech units and the log probabilities of the associated discrete units. 14 | 15 | Args: 16 | hubert (HubertSoft): the HubertSoft content encoder. 17 | wav (Tensor): an audio waveform of shape (B, 1, T) where B is the batch size and T is the number of samples. 18 | 19 | Returns: 20 | Tensor: soft speech units of shape (B, D, N) where N is the number of frames, and D is the unit dimensions. 21 | Tensor: the predicted log probabilities over the discrete units of shape (B, N, K) where K is the number of discrete units. 22 | """ 23 | units = hubert.units(wav) 24 | logits = hubert.logits(units) 25 | log_probs = F.log_softmax(logits, dim=-1) 26 | return units.transpose(1, 2), log_probs 27 | 28 | 29 | class UrhythmicFine(nn.Module): 30 | """Urhythmic (Fine-Grained), a voice and rhythm conversion system that does not require text or parallel data.""" 31 | 32 | def __init__( 33 | self, 34 | segmenter: Segmenter, 35 | rhythm_model: RhythmModelFineGrained, 36 | time_stretcher: TimeStretcherFineGrained, 37 | vocoder: HifiganGenerator, 38 | ): 39 | """ 40 | Args: 41 | segmenter (Segmenter): the segmentation and clustering block groups similar units into short segments. 42 | The segments are then combined into coarser groups approximating sonorants, obstruents, and silences. 43 | rhythm_model (RhythmModelFineGrained): the rhythm modeling block estimates the duration distribution of each group. 44 | time_stretcher (TimeStretcherFineGrained): the time-stretching block down/up-samples the speech units to match the target rhythm. 45 | vocoder (HifiganGenerator): the vocoder converts the speech units into an audio waveform. 46 | """ 47 | super().__init__() 48 | self.segmenter = segmenter 49 | self.rhythm_model = rhythm_model 50 | self.time_stretcher = time_stretcher 51 | self.vocoder = vocoder 52 | 53 | @torch.inference_mode() 54 | def forward(self, units: torch.Tensor, log_probs: torch.Tensor) -> torch.Tensor: 55 | """Convert the to the target speaker's voice and rhythm 56 | 57 | Args: 58 | units (Tensor): soft speech units of shape (1, D, N) where D is the unit dimensions and N is the number of frames. 59 | log_probs (Tensor): the predicted log probabilities over the discrete units of shape (1, N, K) where K is the number of discrete units. 60 | 61 | Returns: 62 | Tensor: the converted waveform of shape (1, 1, T) where T is the number of samples. 63 | """ 64 | clusters, boundaries = self.segmenter(log_probs.squeeze().cpu().numpy()) 65 | tgt_durations = self.rhythm_model(clusters, boundaries) 66 | units = self.time_stretcher(units, clusters, boundaries, tgt_durations) 67 | wav = self.vocoder(units) 68 | return wav 69 | 70 | 71 | class UrhythmicGlobal(nn.Module): 72 | """Urhythmic (Global), a voice and rhythm conversion system that does not require text or parallel data. 73 | 74 | Args: 75 | segmenter (Segmenter): the segmentation and clustering block groups similar units into short segments. 76 | The segments are then combined into coarser groups approximating sonorants, obstruents, and silences. 77 | rhythm_model (RhythmModelGlobal): the rhythm modeling block estimates speaking rate. 78 | time_stretcher (TimeStretcherGlobal): the time-stretching block down/up-samples the speech units to match the target speaking rate. 79 | vocoder (HifiganGenerator): the vocoder converts the speech units into an audio waveform. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | segmenter: Segmenter, 85 | rhythm_model: RhythmModelGlobal, 86 | time_stretcher: TimeStretcherGlobal, 87 | vocoder: HifiganGenerator, 88 | ): 89 | super().__init__() 90 | self.segmenter = segmenter 91 | self.rhythm_model = rhythm_model 92 | self.time_stretcher = time_stretcher 93 | self.vocoder = vocoder 94 | 95 | @torch.inference_mode() 96 | def forward(self, units: torch.Tensor, log_probs: torch.Tensor): 97 | """Convert the to the target speaker's voice and rhythm 98 | 99 | Args: 100 | units (Tensor): soft speech units of shape (1, D, N) where D is the unit dimensions and N is the number of frames. 101 | log_probs (Tensor): the predicted log probabilities over the discrete units of shape (1, N, K) where K is the number of discrete units. 102 | 103 | Returns: 104 | Tensor: the converted waveform of shape (1, 1, T) where T is the number of samples. 105 | """ 106 | ratio = self.rhythm_model() 107 | units = self.time_stretcher(units, ratio) 108 | wav = self.vocoder(units) 109 | return wav 110 | -------------------------------------------------------------------------------- /urhythmic/rhythm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Mapping, Any 2 | 3 | import numpy as np 4 | import itertools 5 | import scipy.stats as stats 6 | 7 | from urhythmic.utils import SONORANT, SILENCE, SoundType 8 | 9 | 10 | def transform( 11 | source: stats.rv_continuous, target: stats.rv_continuous, sample: float 12 | ) -> float: 13 | return target.ppf(source.cdf(sample)) 14 | 15 | 16 | def segment_rate( 17 | codes: List[SoundType], 18 | boundaries: List[int], 19 | sonorant: SoundType = SONORANT, 20 | silence: SoundType = SILENCE, 21 | unit_rate=0.02, 22 | ) -> float: 23 | times = np.round(np.array(boundaries) * unit_rate, 2) 24 | segments = [ 25 | (code, t0, tn) 26 | for code, (t0, tn) in zip(codes, itertools.pairwise(times)) 27 | if code not in silence 28 | ] 29 | return len([code for code, _, _ in segments if code in sonorant]) / sum( 30 | [tn - t0 for _, t0, tn in segments] 31 | ) 32 | 33 | 34 | class RhythmModelFineGrained: 35 | """Rhythm modeling block (Fine-Grained). Estimates the duration distribution of each sound type.""" 36 | 37 | def __init__(self, hop_length: int = 320, sample_rate: int = 16000): 38 | """ 39 | Args: 40 | hop_length (int): hop length between the frames of speech units. 41 | sample_rate (int): the sample rate of the audio waveforms. 42 | """ 43 | self.hop_rate = hop_length / sample_rate 44 | self.source = None 45 | self.target = None 46 | 47 | def _tally_durations( 48 | self, utterances: List[Tuple[List[SoundType], List[int]]] 49 | ) -> Dict[SoundType, np.ndarray]: 50 | durations_dict = {} 51 | for clusters, boundaries in utterances: 52 | durations = np.diff(boundaries) 53 | for cluster, duration in zip(clusters, durations): 54 | if ( 55 | cluster in SILENCE and duration <= 3 56 | ): # ignore silences that are too short 57 | continue 58 | durations_dict.setdefault(cluster, []).append(self.hop_rate * duration) 59 | return { 60 | cluster: np.array(durations) 61 | for cluster, durations in durations_dict.items() 62 | } 63 | 64 | def state_dict(self) -> Mapping[str, Mapping[SoundType, Tuple[float, ...]]]: 65 | state_dict = {} 66 | if self.source: 67 | state_dict["source"] = { 68 | cluster: (dist.args[0], dist.kwds["scale"]) 69 | for cluster, dist in self.source.items() 70 | } 71 | if self.target: 72 | state_dict["target"] = { 73 | cluster: (dist.args[0], dist.kwds["scale"]) 74 | for cluster, dist in self.target.items() 75 | } 76 | return state_dict 77 | 78 | def load_state_dict( 79 | self, state_dict: Mapping[str, Mapping[SoundType, Tuple[float, ...]]] 80 | ): 81 | if "source" in state_dict: 82 | self.source = { 83 | cluster: stats.gamma(a, scale=scale) 84 | for cluster, (a, _, scale) in state_dict["source"].items() 85 | } 86 | if "target" in state_dict: 87 | self.target = { 88 | cluster: stats.gamma(a, scale=scale) 89 | for cluster, (a, _, scale) in state_dict["target"].items() 90 | } 91 | 92 | def _fit( 93 | self, utterances: List[Tuple[List[SoundType], List[int]]] 94 | ) -> Mapping[SoundType, Tuple[float, ...]]: 95 | duration_tally = self._tally_durations(utterances) 96 | dists = { 97 | cluster: stats.gamma.fit(durations, floc=0) 98 | for cluster, durations in duration_tally.items() 99 | } 100 | return dists 101 | 102 | def fit_source(self, utterances: List[Tuple[List[SoundType], List[int]]]): 103 | """Fit the duration model for the source speaker. 104 | 105 | Args: 106 | utterances (List[Tuple[List[SoundType], List[int]]]): list of segemented utterances. 107 | """ 108 | source = self._fit(utterances) 109 | self.source = { 110 | cluster: stats.gamma(a, scale=scale) 111 | for cluster, (a, _, scale) in source.items() 112 | } 113 | 114 | def fit_target(self, utterances: List[Tuple[List[SoundType], List[int]]]): 115 | """Fit the duration model for the target speaker. 116 | 117 | Args: 118 | utterances (List[Tuple[List[SoundType], List[int]]]): list of segemented utterances. 119 | """ 120 | 121 | target = self._fit(utterances) 122 | self.target = { 123 | cluster: stats.gamma(a, scale=scale) 124 | for cluster, (a, _, scale) in target.items() 125 | } 126 | 127 | def __call__(self, clusters: List[SoundType], boundaries: List[int]) -> List[int]: 128 | """Transforms the source durations to match the target rhythm. 129 | 130 | Args: 131 | clusters (List[SoundType]): list of segmented sound types of shape (N,). 132 | boundaries (List[int]): list of segment boundaries of shape (N+1,). 133 | 134 | Returns: 135 | List[int]: list of target durations of shape (N,) 136 | """ 137 | durations = self.hop_rate * np.diff(boundaries) 138 | durations = [ 139 | transform(self.source[cluster], self.target[cluster], duration) 140 | for cluster, duration in zip(clusters, durations) 141 | if cluster not in SILENCE 142 | or duration > 3 * self.hop_rate # ignore silences that are too short 143 | ] 144 | durations = [round(duration / self.hop_rate) for duration in durations] 145 | return durations 146 | 147 | 148 | class RhythmModelGlobal: 149 | """Rhythm modeling block (Global). Estimates speaking rate.""" 150 | 151 | def __init__(self, hop_length: int = 320, sample_rate: int = 16000): 152 | """ 153 | Args: 154 | hop_length (int): hop length between the frames of speech units. 155 | sample_rate (int): the sample rate of the audio waveforms. 156 | """ 157 | self.hop_rate = hop_length / sample_rate 158 | self.source_rate = None 159 | self.target_rate = None 160 | 161 | def state_dict(self) -> Mapping[str, Any]: 162 | state_dict = {} 163 | if self.source_rate: 164 | state_dict["source_rate"] = self.source_rate 165 | if self.target_rate: 166 | state_dict["target_rate"] = self.target_rate 167 | return state_dict 168 | 169 | def load_state_dict(self, state_dict: Mapping[str, Any]): 170 | if "source_rate" in state_dict: 171 | self.source_rate = state_dict["source_rate"] 172 | if "target_rate" in state_dict: 173 | self.target_rate = state_dict["target_rate"] 174 | 175 | def _fit(self, utterances: List[Tuple[List[SoundType], List[int]]]) -> float: 176 | return np.mean( 177 | [ 178 | segment_rate(clusters, boundaries, SONORANT, SILENCE, self.hop_rate) 179 | for clusters, boundaries in utterances 180 | ] 181 | ) 182 | 183 | def fit_source(self, utterances: List[Tuple[List[SoundType], List[int]]]): 184 | """Estimate the speaking rate of the source speaker. 185 | 186 | Args: 187 | utterances (List[Tuple[List[SoundType], List[int]]]): list of segemented utterances. 188 | """ 189 | self.source_rate = self._fit(utterances) 190 | 191 | def fit_target(self, utterances: List[Tuple[List[SoundType], List[int]]]): 192 | """Estimate the speaking rate of the target speaker. 193 | 194 | Args: 195 | utterances (List[Tuple[List[SoundType], List[int]]]): list of segemented utterances. 196 | """ 197 | self.target_rate = self._fit(utterances) 198 | 199 | def __call__(self) -> float: 200 | """ 201 | Returns: 202 | float: ratio between the source and target speaking rates. 203 | """ 204 | return self.source_rate / self.target_rate 205 | -------------------------------------------------------------------------------- /urhythmic/segmenter.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any, Tuple, List 2 | from collections import Counter 3 | import itertools 4 | 5 | import torch 6 | 7 | import numpy as np 8 | from sklearn.cluster import AgglomerativeClustering 9 | import numba 10 | 11 | from urhythmic.utils import SoundType, SILENCE, SONORANT, OBSTRUENT 12 | 13 | 14 | class Segmenter: 15 | """Segmentation and clustering block. Groups similar speech units into short segments. 16 | The segments are then combined into coarser groups approximating sonorants, obstruents, and silences. 17 | """ 18 | 19 | def __init__(self, num_clusters: int = 3, gamma: float = 2): 20 | """ 21 | Args: 22 | num_clusters (int): number of clusters used for agglomerative clustering. 23 | gamma (float): regularizer weight encouraging longer segments 24 | """ 25 | self.gamma = gamma 26 | self.clustering = AgglomerativeClustering(n_clusters=num_clusters) 27 | self.sound_types = dict() 28 | 29 | def state_dict(self) -> Mapping[str, Any]: 30 | return { 31 | "n_clusters_": self.clustering.n_clusters_, 32 | "labels_": torch.from_numpy(self.clustering.labels_), 33 | "n_leaves_": self.clustering.n_leaves_, 34 | "n_features_in_": self.clustering.n_features_in_, 35 | "children_": torch.from_numpy(self.clustering.children_), 36 | "sound_types": self.sound_types, 37 | } 38 | 39 | def load_state_dict(self, state_dict: Mapping[str, Any]): 40 | if self.clustering.n_clusters != state_dict["n_clusters_"]: 41 | raise RuntimeError( 42 | "Error in loading state_dict for {}", self.__class__.__name__ 43 | ) 44 | self.clustering.labels_ = state_dict["labels_"].numpy() 45 | self.clustering.n_leaves_ = state_dict["n_leaves_"] 46 | self.clustering.n_features_in_ = state_dict["n_features_in_"] 47 | self.clustering.children_ = state_dict["children_"].numpy() 48 | self.sound_types = state_dict["sound_types"] 49 | 50 | def cluster(self, codebook: np.ndarray): 51 | """Fit the hierarchical clustering from the codebook of discrete units. 52 | 53 | Args: 54 | codebook (NDArray): codebook of discrete units of shape (K, D) 55 | where K is the number of units and D is the unit dimension. 56 | """ 57 | self.clustering.fit(codebook) 58 | 59 | def identify( 60 | self, 61 | utterances: List[Tuple[np.ndarray, ...]], 62 | ) -> Mapping[int, SoundType]: 63 | """Identify which clusters correspond to sonorants, obstruents, and silences. 64 | Only implemented for num_clusters = 3. 65 | 66 | Args: 67 | utterances (List[Tuple[np.ndarray, ...]]): list of segmented utterances along with marked silences and voiced frames. 68 | 69 | Returns: 70 | Mapping[int, SoundType]: mapping of cluster id to sonorant, obstruent, or silence. 71 | """ 72 | if self.clustering.n_clusters_ != 3: 73 | raise ValueError( 74 | "Cluster identification is only implemented for num_clusters = 3." 75 | ) 76 | 77 | silence_overlap = Counter() 78 | voiced_overlap = Counter() 79 | total = Counter() 80 | 81 | for segments, boundaries, silences, voiced_flags in utterances: 82 | for code, (a, b) in zip(segments, itertools.pairwise(boundaries)): 83 | silence_overlap[code] += np.count_nonzero(silences[a : b + 1]) 84 | voiced_overlap[code] += np.count_nonzero(voiced_flags[a : b + 1]) 85 | total[code] += b - a + 1 86 | 87 | clusters = {0, 1, 2} 88 | 89 | silence, _ = max( 90 | [(k, v / total[k]) for k, v in silence_overlap.items()], key=lambda x: x[1] 91 | ) 92 | clusters.remove(silence) 93 | 94 | sonorant, _ = max( 95 | [(k, v / total[k]) for k, v in voiced_overlap.items() if k in clusters], 96 | key=lambda x: x[1], 97 | ) 98 | clusters.remove(sonorant) 99 | 100 | obstruent = clusters.pop() 101 | 102 | self.sound_types = { 103 | silence: SILENCE, 104 | sonorant: SONORANT, 105 | obstruent: OBSTRUENT, 106 | } 107 | return self.sound_types 108 | 109 | def _segment(self, log_probs: np.ndarray) -> Tuple[List[int], List[int]]: 110 | codes, boundaries = segment(log_probs, self.gamma) 111 | segments = codes[boundaries[:-1]] 112 | segments, boundaries = cluster_merge(self.clustering, segments, boundaries) 113 | return list(segments), list(boundaries) 114 | 115 | def __call__(self, log_probs: np.ndarray) -> Tuple[List[SoundType], List[int]]: 116 | """Segment the soft speech units into groups approximating the different sound types. 117 | 118 | Args: 119 | log_probs (NDArray): log probabilities of each discrete unit of shape (T, K) where T is the number of frames and K is the number of discrete units 120 | 121 | Returns: 122 | List[SoundType]: list of segmented sound types of shape (N,). 123 | List[int]: list of segment boundaries of shape (N+1,). 124 | """ 125 | segments, boundaries = self._segment(log_probs) 126 | segments = [self.sound_types[cluster] for cluster in segments] 127 | return segments, boundaries 128 | 129 | 130 | def segment(log_probs: np.ndarray, gamma: float) -> Tuple[np.ndarray, np.ndarray]: 131 | alpha, P = _segment(log_probs, gamma) 132 | return _backtrack(alpha, P) 133 | 134 | 135 | @numba.njit() 136 | def _backtrack(alpha, P): 137 | rhs = len(alpha) - 1 138 | segments = np.zeros(len(alpha) - 1, dtype=np.int32) 139 | boundaries = [rhs] 140 | while rhs != 0: 141 | lhs, code = P[rhs, :] 142 | boundaries.append(lhs) 143 | segments[lhs:rhs] = code 144 | rhs = lhs 145 | boundaries.reverse() 146 | return segments, np.array(boundaries) 147 | 148 | 149 | @numba.njit() 150 | def _segment(log_probs, gamma): 151 | T, K = log_probs.shape 152 | 153 | alpha = np.zeros(T + 1, dtype=np.float32) 154 | P = np.zeros((T + 1, 2), dtype=np.int32) 155 | D = np.zeros((T, T, K), dtype=np.float32) 156 | 157 | for t in range(T): 158 | for k in range(K): 159 | D[t, t, k] = log_probs[t, k] 160 | for t in range(T): 161 | for s in range(t + 1, T): 162 | D[t, s, :] = D[t, s - 1, :] + log_probs[s, :] 163 | 164 | for t in range(T): 165 | alpha[t + 1] = -np.inf 166 | for s in range(t + 1): 167 | k = np.argmax(D[t - s, t, :]) 168 | alpha_max = alpha[t - s] + D[t - s, t, k] + gamma * s 169 | if alpha_max > alpha[t + 1]: 170 | P[t + 1, :] = t - s, k 171 | alpha[t + 1] = alpha_max 172 | return alpha, P 173 | 174 | 175 | def cluster_merge( 176 | clustering: AgglomerativeClustering, segments: np.ndarray, boundaries: np.ndarray 177 | ) -> Tuple[np.ndarray, np.ndarray]: 178 | clusters = clustering.labels_[segments] 179 | cluster_switches = np.diff(clusters, prepend=-1, append=-1) 180 | (cluster_boundaries,) = np.nonzero(cluster_switches) 181 | clusters = clusters[cluster_boundaries[:-1]] 182 | cluster_boundaries = boundaries[cluster_boundaries] 183 | return clusters, cluster_boundaries 184 | -------------------------------------------------------------------------------- /urhythmic/stretcher.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import itertools 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from urhythmic.utils import SoundType, SILENCE 9 | 10 | 11 | class TimeStretcherFineGrained: 12 | """Time stretching block (Fine-Grained). Up/down samples the speech units to match the target rhythm.""" 13 | 14 | def __call__( 15 | self, 16 | units: torch.Tensor, 17 | clusters: List[SoundType], 18 | boundaries: List[int], 19 | tgt_duartations: List[int], 20 | ) -> torch.Tensor: 21 | """ 22 | Args: 23 | units (Tensor): soft speech units of shape (1, D, T) 24 | where D is the dimension of the units and T is the number of frames. 25 | clusters (List[SoundType]): list of sound types for each segment of shape (N,) 26 | where N is the number of segments. 27 | boundaries (List[int]): list of segment bounaries of shape (N+1,). 28 | tgt_durations (List[int]): list of target durations of shape (N,). 29 | Returns: 30 | Tensor: up/down sampled soft speech units. 31 | """ 32 | units = [ 33 | units[..., t0:tn] 34 | for cluster, (t0, tn) in zip(clusters, itertools.pairwise(boundaries)) 35 | if cluster not in SILENCE or tn - t0 > 3 36 | ] 37 | units = [ 38 | F.interpolate(segment, mode="linear", size=duration) 39 | for segment, duration in zip(units, tgt_duartations) 40 | ] 41 | units = torch.cat(units, dim=-1) 42 | return units 43 | 44 | 45 | class TimeStretcherGlobal: 46 | """Time stretching block (Global). Up/down samples the speech units to match the target speaking rate.""" 47 | 48 | def __call__(self, units: torch.Tensor, ratio: float) -> torch.Tensor: 49 | """ 50 | Args: 51 | units (Tensor): soft speech units of shape (1, D, T) 52 | where D is the dimension of the units and T is the number of frames. 53 | ratio (float): ratio between the source and target speaking rates. 54 | Returns: 55 | Tensor: up/down sampled soft speech units. 56 | """ 57 | units = F.interpolate(units, scale_factor=ratio, mode="linear") 58 | return units 59 | -------------------------------------------------------------------------------- /urhythmic/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Flag, auto 3 | 4 | 5 | class SoundType(Flag): 6 | VOWEL = auto() 7 | APPROXIMANT = auto() 8 | NASAL = auto() 9 | FRICATIVE = auto() 10 | STOP = auto() 11 | SILENCE = auto() 12 | 13 | 14 | SONORANT = SoundType.VOWEL | SoundType.APPROXIMANT | SoundType.NASAL 15 | OBSTRUENT = SoundType.FRICATIVE | SoundType.STOP 16 | SILENCE = SoundType.SILENCE 17 | 18 | 19 | def get_padding(k, d): 20 | return int((k * d - d) / 2) 21 | 22 | 23 | class Metric: 24 | def __init__(self): 25 | self.steps = 0 26 | self.value = 0 27 | 28 | def update(self, value): 29 | self.steps += 1 30 | self.value += (value - self.value) / self.steps 31 | return self.value 32 | 33 | def reset(self): 34 | self.steps = 0 35 | self.value = 0 36 | 37 | 38 | def save_checkpoint( 39 | checkpoint_dir, 40 | generator, 41 | discriminator, 42 | optimizer_generator, 43 | optimizer_discriminator, 44 | scheduler_generator, 45 | scheduler_discriminator, 46 | step, 47 | loss, 48 | best, 49 | logger, 50 | ): 51 | state = { 52 | "generator": { 53 | "model": generator.state_dict(), 54 | "optimizer": optimizer_generator.state_dict(), 55 | "scheduler": scheduler_generator.state_dict(), 56 | }, 57 | "discriminator": { 58 | "model": discriminator.state_dict(), 59 | "optimizer": optimizer_discriminator.state_dict(), 60 | "scheduler": scheduler_discriminator.state_dict(), 61 | }, 62 | "step": step, 63 | "loss": loss, 64 | } 65 | checkpoint_dir.mkdir(exist_ok=True, parents=True) 66 | checkpoint_path = checkpoint_dir / f"model-{step}.pt" 67 | torch.save(state, checkpoint_path) 68 | if best: 69 | best_path = checkpoint_dir / "model-best.pt" 70 | torch.save(state, best_path) 71 | logger.info(f"Saved checkpoint: {checkpoint_path.stem}") 72 | 73 | 74 | def load_checkpoint( 75 | load_path, 76 | generator, 77 | discriminator, 78 | optimizer_generator, 79 | optimizer_discriminator, 80 | scheduler_generator, 81 | scheduler_discriminator, 82 | rank, 83 | logger, 84 | finetune=False, 85 | ): 86 | verb = "Resuming" if not finetune else "Finetuning" 87 | logger.info(f"{verb} checkpoint from {load_path}") 88 | checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) 89 | generator.load_state_dict(checkpoint["generator"]["model"]) 90 | discriminator.load_state_dict(checkpoint["discriminator"]["model"]) 91 | if not finetune: 92 | optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"]) 93 | scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"]) 94 | optimizer_discriminator.load_state_dict( 95 | checkpoint["discriminator"]["optimizer"] 96 | ) 97 | scheduler_discriminator.load_state_dict( 98 | checkpoint["discriminator"]["scheduler"] 99 | ) 100 | return checkpoint["step"], checkpoint["loss"] 101 | else: 102 | return 0, float("inf") 103 | -------------------------------------------------------------------------------- /urhythmic/vocoder.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/jik876/hifi-gan/blob/master/models.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils import remove_weight_norm, weight_norm 6 | from typing import Tuple, List 7 | 8 | from urhythmic.utils import get_padding 9 | 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class HifiganGenerator(torch.nn.Module): 15 | """HiFiGAN Generator. Converts speech units into an audio waveform.""" 16 | 17 | def __init__( 18 | self, 19 | in_channels: int = 256, 20 | resblock_dilation_sizes: Tuple[Tuple[int, ...], ...] = ( 21 | (1, 3, 5), 22 | (1, 3, 5), 23 | (1, 3, 5), 24 | ), 25 | resblock_kernel_sizes: Tuple[int, ...] = (3, 7, 11), 26 | upsample_kernel_sizes: Tuple[int, ...] = (20, 16, 4, 4), 27 | upsample_channels: int = 512, 28 | upsample_factors: Tuple[int, ...] = (10, 8, 2, 2), 29 | sample_rate: int = 16000, 30 | ): 31 | """ 32 | Args: 33 | in_channels (int): number of input channels. 34 | resblock_dilation_sizes (Tuple[Tuple[int, ...], ...]): list of dilation values in each layer of a `ResBlock`. 35 | resblock_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each `ResBlock`. 36 | upsample_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each transposed convolution. 37 | upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 38 | for each consecutive upsampling layer. 39 | upsample_factors (Tuple[int, ...]): upsampling factors (stride) for each upsampling layer. 40 | sample_rate (int): sample rate of the generated audio. 41 | """ 42 | super().__init__() 43 | self.num_kernels = len(resblock_kernel_sizes) 44 | self.num_upsamples = len(upsample_factors) 45 | self.sample_rate = sample_rate 46 | 47 | self.conv_pre = weight_norm( 48 | nn.Conv1d(in_channels, upsample_channels, 5, 1, padding=2), 49 | ) 50 | 51 | # upsampling layers 52 | self.ups = nn.ModuleList() 53 | for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): 54 | self.ups.append( 55 | weight_norm( 56 | nn.ConvTranspose1d( 57 | upsample_channels // (2**i), 58 | upsample_channels // (2 ** (i + 1)), 59 | k, 60 | u, 61 | padding=(k - u) // 2, 62 | ) 63 | ) 64 | ) 65 | 66 | # MRF blocks 67 | self.resblocks = nn.ModuleList() 68 | for i in range(len(self.ups)): 69 | ch = upsample_channels // (2 ** (i + 1)) 70 | for _, (k, d) in enumerate( 71 | zip(resblock_kernel_sizes, resblock_dilation_sizes) 72 | ): 73 | self.resblocks.append(ResBlock(ch, k, d)) 74 | 75 | # post convolution layer 76 | self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) 77 | 78 | def forward(self, x: torch.Tensor) -> torch.Tensor: 79 | """ 80 | Args: 81 | x (Tensor): soft speech units of shape (B, D, N) where B is the batch size, D is the unit dimensions, and N is the number of frames. 82 | """ 83 | output = self.conv_pre(x) 84 | for i in range(self.num_upsamples): 85 | output = F.leaky_relu(output, LRELU_SLOPE) 86 | output = self.ups[i](output) 87 | z_sum = None 88 | for j in range(self.num_kernels): 89 | if z_sum is None: 90 | z_sum = self.resblocks[i * self.num_kernels + j](output) 91 | else: 92 | z_sum += self.resblocks[i * self.num_kernels + j](output) 93 | output = z_sum / self.num_kernels 94 | output = F.leaky_relu(output) 95 | output = self.conv_post(output) 96 | output = torch.tanh(output) 97 | return output 98 | 99 | def remove_weight_norm(self): 100 | for l in self.ups: 101 | remove_weight_norm(l) 102 | for l in self.resblocks: 103 | l.remove_weight_norm() 104 | remove_weight_norm(self.conv_pre) 105 | remove_weight_norm(self.conv_post) 106 | 107 | 108 | class ResBlock(torch.nn.Module): 109 | def __init__( 110 | self, 111 | channels: int, 112 | kernel_size: int = 3, 113 | dilation: Tuple[int, ...] = (1, 3, 5), 114 | ) -> None: 115 | super().__init__() 116 | self.convs1 = nn.ModuleList( 117 | [ 118 | weight_norm( 119 | nn.Conv1d( 120 | channels, 121 | channels, 122 | kernel_size, 123 | 1, 124 | dilation=dilation[0], 125 | padding=get_padding(kernel_size, dilation[0]), 126 | ) 127 | ), 128 | weight_norm( 129 | nn.Conv1d( 130 | channels, 131 | channels, 132 | kernel_size, 133 | 1, 134 | dilation=dilation[1], 135 | padding=get_padding(kernel_size, dilation[1]), 136 | ) 137 | ), 138 | weight_norm( 139 | nn.Conv1d( 140 | channels, 141 | channels, 142 | kernel_size, 143 | 1, 144 | dilation=dilation[2], 145 | padding=get_padding(kernel_size, dilation[2]), 146 | ) 147 | ), 148 | ] 149 | ) 150 | 151 | self.convs2 = nn.ModuleList( 152 | [ 153 | weight_norm( 154 | nn.Conv1d( 155 | channels, 156 | channels, 157 | kernel_size, 158 | 1, 159 | dilation=1, 160 | padding=get_padding(kernel_size, 1), 161 | ) 162 | ), 163 | weight_norm( 164 | nn.Conv1d( 165 | channels, 166 | channels, 167 | kernel_size, 168 | 1, 169 | dilation=1, 170 | padding=get_padding(kernel_size, 1), 171 | ) 172 | ), 173 | weight_norm( 174 | nn.Conv1d( 175 | channels, 176 | channels, 177 | kernel_size, 178 | 1, 179 | dilation=1, 180 | padding=get_padding(kernel_size, 1), 181 | ) 182 | ), 183 | ] 184 | ) 185 | 186 | def forward(self, x: torch.Tensor) -> torch.Tensor: 187 | for c1, c2 in zip(self.convs1, self.convs2): 188 | xt = F.leaky_relu(x, LRELU_SLOPE) 189 | xt = c1(xt) 190 | xt = F.leaky_relu(xt, LRELU_SLOPE) 191 | xt = c2(xt) 192 | x = xt + x 193 | return x 194 | 195 | def remove_weight_norm(self): 196 | for l in self.convs1: 197 | remove_weight_norm(l) 198 | for l in self.convs2: 199 | remove_weight_norm(l) 200 | 201 | 202 | class PeriodDiscriminator(torch.nn.Module): 203 | """HiFiGAN Period Discriminator""" 204 | 205 | def __init__( 206 | self, 207 | period: int, 208 | kernel_size: int = 5, 209 | stride: int = 3, 210 | use_spectral_norm: bool = False, 211 | ) -> None: 212 | super().__init__() 213 | self.period = period 214 | norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm 215 | self.convs = nn.ModuleList( 216 | [ 217 | norm_f( 218 | nn.Conv2d( 219 | 1, 220 | 32, 221 | (kernel_size, 1), 222 | (stride, 1), 223 | padding=(get_padding(5, 1), 0), 224 | ) 225 | ), 226 | norm_f( 227 | nn.Conv2d( 228 | 32, 229 | 128, 230 | (kernel_size, 1), 231 | (stride, 1), 232 | padding=(get_padding(5, 1), 0), 233 | ) 234 | ), 235 | norm_f( 236 | nn.Conv2d( 237 | 128, 238 | 512, 239 | (kernel_size, 1), 240 | (stride, 1), 241 | padding=(get_padding(5, 1), 0), 242 | ) 243 | ), 244 | norm_f( 245 | nn.Conv2d( 246 | 512, 247 | 1024, 248 | (kernel_size, 1), 249 | (stride, 1), 250 | padding=(get_padding(5, 1), 0), 251 | ) 252 | ), 253 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 254 | ] 255 | ) 256 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 257 | 258 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 259 | """ 260 | Args: 261 | x (Tensor): input waveform. 262 | Returns: 263 | [Tensor]: discriminator scores per sample in the batch. 264 | [List[Tensor]]: list of features from each convolutional layer. 265 | """ 266 | feat = [] 267 | 268 | # 1d to 2d 269 | b, c, t = x.shape 270 | if t % self.period != 0: # pad first 271 | n_pad = self.period - (t % self.period) 272 | x = F.pad(x, (0, n_pad), "reflect") 273 | t = t + n_pad 274 | x = x.view(b, c, t // self.period, self.period) 275 | 276 | for l in self.convs: 277 | x = l(x) 278 | x = F.leaky_relu(x, LRELU_SLOPE) 279 | feat.append(x) 280 | x = self.conv_post(x) 281 | feat.append(x) 282 | x = torch.flatten(x, 1, -1) 283 | 284 | return x, feat 285 | 286 | 287 | class MultiPeriodDiscriminator(torch.nn.Module): 288 | """HiFiGAN Multi-Period Discriminator (MPD)""" 289 | 290 | def __init__(self): 291 | super().__init__() 292 | self.discriminators = nn.ModuleList( 293 | [ 294 | PeriodDiscriminator(2), 295 | PeriodDiscriminator(3), 296 | PeriodDiscriminator(5), 297 | PeriodDiscriminator(7), 298 | PeriodDiscriminator(11), 299 | ] 300 | ) 301 | 302 | def forward( 303 | self, x: torch.Tensor 304 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 305 | """ 306 | Args: 307 | x (Tensor): input waveform. 308 | Returns: 309 | [List[Tensor]]: list of scores from each discriminator. 310 | [List[List[Tensor]]]: list of features from each discriminator's convolutional layers. 311 | """ 312 | scores = [] 313 | feats = [] 314 | for _, d in enumerate(self.discriminators): 315 | score, feat = d(x) 316 | scores.append(score) 317 | feats.append(feat) 318 | return scores, feats 319 | 320 | 321 | class ScaleDiscriminator(torch.nn.Module): 322 | """HiFiGAN Scale Discriminator.""" 323 | 324 | def __init__(self, use_spectral_norm: bool = False) -> None: 325 | super().__init__() 326 | norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm 327 | self.convs = nn.ModuleList( 328 | [ 329 | norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), 330 | norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), 331 | norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), 332 | norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), 333 | norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 334 | norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 335 | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), 336 | ] 337 | ) 338 | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) 339 | 340 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 341 | """ 342 | Args: 343 | x (Tensor): input waveform. 344 | Returns: 345 | Tensor: discriminator scores. 346 | List[Tensor]: list of features from the convolutional layers. 347 | """ 348 | feat = [] 349 | for l in self.convs: 350 | x = l(x) 351 | x = F.leaky_relu(x, LRELU_SLOPE) 352 | feat.append(x) 353 | x = self.conv_post(x) 354 | feat.append(x) 355 | x = torch.flatten(x, 1, -1) 356 | return x, feat 357 | 358 | 359 | class MultiScaleDiscriminator(torch.nn.Module): 360 | """HiFiGAN Multi-Scale Discriminator.""" 361 | 362 | def __init__(self): 363 | super().__init__() 364 | self.discriminators = nn.ModuleList( 365 | [ 366 | ScaleDiscriminator(use_spectral_norm=True), 367 | ScaleDiscriminator(), 368 | ScaleDiscriminator(), 369 | ] 370 | ) 371 | self.meanpools = nn.ModuleList( 372 | [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)] 373 | ) 374 | 375 | def forward( 376 | self, x: torch.Tensor 377 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 378 | """ 379 | Args: 380 | x (Tensor): input waveform. 381 | Returns: 382 | List[Tensor]: discriminator scores. 383 | List[List[Tensor]]: list of features from each discriminator's convolutional layers. 384 | """ 385 | scores = [] 386 | feats = [] 387 | for i, d in enumerate(self.discriminators): 388 | if i != 0: 389 | x = self.meanpools[i - 1](x) 390 | score, feat = d(x) 391 | scores.append(score) 392 | feats.append(feat) 393 | return scores, feats 394 | 395 | 396 | class HifiganDiscriminator(nn.Module): 397 | """HiFiGAN discriminator""" 398 | 399 | def __init__(self): 400 | super().__init__() 401 | self.mpd = MultiPeriodDiscriminator() 402 | self.msd = MultiScaleDiscriminator() 403 | 404 | def forward( 405 | self, x: torch.Tensor 406 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 407 | """ 408 | Args: 409 | x (Tensor): input waveform. 410 | Returns: 411 | List[Tensor]: discriminator scores. 412 | List[List[Tensor]]: list of features from from each discriminator's convolutional layers. 413 | """ 414 | scores, feats = self.mpd(x) 415 | scores_, feats_ = self.msd(x) 416 | return scores + scores_, feats + feats_ 417 | 418 | 419 | def feature_loss( 420 | features_real: List[List[torch.Tensor]], 421 | features_generated: List[List[torch.Tensor]], 422 | ) -> float: 423 | loss = 0 424 | for r, g in zip(features_real, features_generated): 425 | for rl, gl in zip(r, g): 426 | loss += torch.mean(torch.abs(rl - gl)) 427 | return loss 428 | 429 | 430 | def discriminator_loss( 431 | real: List[torch.Tensor], generated: List[torch.Tensor] 432 | ) -> Tuple[torch.Tensor, List[float], List[float]]: 433 | loss = 0 434 | real_losses = [] 435 | generated_losses = [] 436 | for r, g in zip(real, generated): 437 | r_loss = torch.mean((1 - r) ** 2) 438 | g_loss = torch.mean(g**2) 439 | loss += r_loss + g_loss 440 | real_losses.append(r_loss.item()) 441 | generated_losses.append(g_loss.item()) 442 | 443 | return loss, real_losses, generated_losses 444 | 445 | 446 | def generator_loss( 447 | discriminator_outputs: List[torch.Tensor], 448 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 449 | loss = 0 450 | generator_losses = [] 451 | for x in discriminator_outputs: 452 | l = torch.mean((1 - x) ** 2) 453 | generator_losses.append(l) 454 | loss += l 455 | 456 | return loss, generator_losses 457 | -------------------------------------------------------------------------------- /urhythmic_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "I0SMrN7KVoVO" 17 | }, 18 | "source": [ 19 | "#Urhythmic: Rhythm Modeling for Voice Conversion\n", 20 | "\n", 21 | "Demo for the paper: [Rhythm Modeling for Voice Conversion]().\n", 22 | "\n", 23 | "* [Companion webpage](https://ubisoft-laforge.github.io/speech/urhythmic/)\n", 24 | "* [Code repository](https://github.com/bshall/urhythmic)\n", 25 | "* [HuBERT content encoder](https://github.com/bshall/hubert)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import torch, torchaudio\n", 35 | "import requests\n", 36 | "import IPython.display as display" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "Y5GUYFKHWRfs" 43 | }, 44 | "source": [ 45 | "Download the `HubertSoft` content encoder (see https://github.com/bshall/hubert for details):" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "hubert = torch.hub.load(\"bshall/hubert:main\", \"hubert_soft\", trust_repo=True).cuda()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "m2FvNyEZXSl4" 61 | }, 62 | "source": [ 63 | " Select the source and target speakers. Pretrained models are available for:\n", 64 | "1. VCTK: p228, p268, p225, p232, p257, p231\n", 65 | "2. and LJSpeech.\n", 66 | "\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "source, target = \"p228\", \"p232\"" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "QTnaU8u5W-2L" 82 | }, 83 | "source": [ 84 | "Download the `Urhythmic` voice conversion mode (either urhythmic_fine or urhythmic_global):" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "urhythmic, encode = torch.hub.load(\"bshall/urhythmic:main\", \"urhythmic_fine\", source_speaker=source, target_speaker=target, trust_repo=True)\n", 94 | "urhythmic = urhythmic.cuda()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "ElGokuPViBng" 101 | }, 102 | "source": [ 103 | "Download an example utterance:" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "with open(\"p228_003.wav\", \"wb\") as file:\n", 113 | " response = requests.get(\"https://github.com/bshall/urhythmic/raw/gh-pages/samples/urhythmic-fine/source/p228_003.wav\")\n", 114 | " file.write(response.content)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "dANIeGxH4JRv" 121 | }, 122 | "source": [ 123 | "Load the audio file:" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "wav, sr = torchaudio.load(\"p228_003.wav\")\n", 133 | "wav = wav.unsqueeze(0).cuda()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": { 139 | "id": "Pw7W1kad4Nbm" 140 | }, 141 | "source": [ 142 | "Extract the soft speech units and log probabilies:" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "units, log_probs = encode(hubert, wav)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "Wxxf-5Vt4VfT" 158 | }, 159 | "source": [ 160 | "Convert to the target speaker:" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "wav_ = urhythmic(units, log_probs)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "gY7tKZkP4btg" 176 | }, 177 | "source": [ 178 | "Listen to the result!" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "display.Audio(wav.squeeze().cpu().numpy(), rate=16000) # source" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "3-p5dmxXmOdc" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "display.Audio(wav_.squeeze().cpu().numpy(), rate=16000) # converted" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "accelerator": "GPU", 204 | "colab": { 205 | "authorship_tag": "ABX9TyOTVEzGSobOMIxcq+ibHeEp", 206 | "gpuType": "T4", 207 | "include_colab_link": true, 208 | "provenance": [] 209 | }, 210 | "kernelspec": { 211 | "display_name": "Python 3", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "name": "python" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 0 220 | } 221 | --------------------------------------------------------------------------------