├── .env.example ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── codecov.yml ├── dependabot.yml └── release-drafter.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── .pylintrc ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ ├── hi-fi_en-US_female.yaml │ ├── ljspeech.yaml │ └── vctk.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── hifi_dataset_piper_phonemizer.yaml │ ├── ljspeech.yaml │ ├── ljspeech_from_durations.yaml │ ├── ljspeech_min_memory.yaml │ └── multispeaker.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── aim.yaml │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── cfm │ │ └── default.yaml │ ├── decoder │ │ └── default.yaml │ ├── encoder │ │ └── default.yaml │ ├── matcha.yaml │ └── optimizer │ │ └── adam.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── matcha ├── VERSION ├── __init__.py ├── app.py ├── cli.py ├── data │ ├── __init__.py │ ├── components │ │ └── __init__.py │ └── text_mel_datamodule.py ├── hifigan │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── denoiser.py │ ├── env.py │ ├── meldataset.py │ ├── models.py │ └── xutils.py ├── models │ ├── __init__.py │ ├── baselightningmodule.py │ ├── components │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── flow_matching.py │ │ ├── text_encoder.py │ │ └── transformer.py │ └── matcha_tts.py ├── onnx │ ├── __init__.py │ ├── export.py │ └── infer.py ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── numbers.py │ └── symbols.py ├── train.py └── utils │ ├── __init__.py │ ├── audio.py │ ├── data │ ├── __init__.py │ ├── hificaptain.py │ ├── ljspeech.py │ └── utils.py │ ├── generate_data_statistics.py │ ├── get_durations_from_trained_model.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── model.py │ ├── monotonic_align │ ├── __init__.py │ ├── core.pyx │ └── setup.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py ├── notebooks └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── scripts └── schedule.sh ├── setup.py └── synthesis.ipynb /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # measures overall project coverage 4 | project: 5 | default: 6 | threshold: 100% # how much decrease in coverage is needed to not consider success 7 | 8 | # measures PR or single commit coverage 9 | patch: 10 | default: 11 | threshold: 100% # how much decrease in coverage is needed to not consider success 12 | 13 | 14 | # project: off 15 | # patch: off 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | target-branch: "dev" 11 | schedule: 12 | interval: "daily" 13 | ignore: 14 | - dependency-name: "pytorch-lightning" 15 | update-types: ["version-update:semver-patch"] 16 | - dependency-name: "torchmetrics" 17 | update-types: ["version-update:semver-patch"] 18 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | 4 | categories: 5 | - title: "🚀 Features" 6 | labels: 7 | - "feature" 8 | - "enhancement" 9 | - title: "🐛 Bug Fixes" 10 | labels: 11 | - "fix" 12 | - "bugfix" 13 | - "bug" 14 | - title: "🧹 Maintenance" 15 | labels: 16 | - "maintenance" 17 | - "dependencies" 18 | - "refactoring" 19 | - "cosmetic" 20 | - "chore" 21 | - title: "📝️ Documentation" 22 | labels: 23 | - "documentation" 24 | - "docs" 25 | 26 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 27 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 28 | 29 | version-resolver: 30 | major: 31 | labels: 32 | - "major" 33 | minor: 34 | labels: 35 | - "minor" 36 | patch: 37 | labels: 38 | - "patch" 39 | default: patch 40 | 41 | template: | 42 | ## Changes 43 | 44 | $CHANGES 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | # Cython complied files 157 | matcha/utils/monotonic_align/core.c 158 | 159 | # Ignoring hifigan checkpoint 160 | generator_v1 161 | g_02500000 162 | gradio_cached_examples/ 163 | synth_output/ 164 | /data 165 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.5.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | # - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | 19 | # python code formatting 20 | - repo: https://github.com/psf/black 21 | rev: 23.12.1 22 | hooks: 23 | - id: black 24 | args: [--line-length, "120"] 25 | 26 | # python import sorting 27 | - repo: https://github.com/PyCQA/isort 28 | rev: 5.13.2 29 | hooks: 30 | - id: isort 31 | args: ["--profile", "black", "--filter-files"] 32 | 33 | # python upgrading syntax to newer version 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.15.0 36 | hooks: 37 | - id: pyupgrade 38 | args: [--py38-plus] 39 | 40 | # python check (PEP8), programming errors and code complexity 41 | - repo: https://github.com/PyCQA/flake8 42 | rev: 7.0.0 43 | hooks: 44 | - id: flake8 45 | args: 46 | [ 47 | "--max-line-length", "120", 48 | "--extend-ignore", 49 | "E203,E402,E501,F401,F841,RST2,RST301", 50 | "--exclude", 51 | "logs/*,data/*,matcha/hifigan/*", 52 | ] 53 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 54 | 55 | # pylint 56 | - repo: https://github.com/pycqa/pylint 57 | rev: v3.0.3 58 | hooks: 59 | - id: pylint 60 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shivam Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include *.cff 5 | include requirements.txt 6 | include matcha/VERSION 7 | recursive-include matcha *.json 8 | recursive-include matcha *.html 9 | recursive-include matcha *.png 10 | recursive-include matcha *.md 11 | recursive-include matcha *.py 12 | recursive-include matcha *.pyx 13 | recursive-exclude tests * 14 | prune tests* 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | create-package: ## Create wheel and tar gz 17 | rm -rf dist/ 18 | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 19 | python setup.py sdist 20 | python -m twine upload dist/* --verbose --skip-existing 21 | 22 | format: ## Run pre-commit hooks 23 | pre-commit run -a 24 | 25 | sync: ## Merge changes from main branch to your current branch 26 | git pull 27 | git pull origin main 28 | 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | test-full: ## Run all tests 33 | pytest 34 | 35 | train-ljspeech: ## Train the model 36 | python matcha/train.py experiment=ljspeech 37 | 38 | train-ljspeech-min: ## Train the model with minimum memory 39 | python matcha/train.py experiment=ljspeech_min_memory 40 | 41 | start_app: ## Start the app 42 | python matcha/app.py 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching 4 | 5 | ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) 6 | 7 | [![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/) 8 | [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 9 | [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) 10 | [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) 11 | [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) 12 | [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 13 | 14 |

15 | 16 |

17 | 18 |
19 | 20 | > This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024]. 21 | 22 | We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method: 23 | 24 | - Is probabilistic 25 | - Has compact memory footprint 26 | - Sounds highly natural 27 | - Is very fast to synthesise from 28 | 29 | Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details. 30 | 31 | [Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface. 32 | 33 | You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS). 34 | 35 | ## Teaser video 36 | 37 | [![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0) 38 | 39 | ## Installation 40 | 41 | 1. Create an environment (suggested but optional) 42 | 43 | ``` 44 | conda create -n matcha-tts python=3.10 -y 45 | conda activate matcha-tts 46 | ``` 47 | 48 | 2. Install Matcha TTS using pip or from source 49 | 50 | ```bash 51 | pip install matcha-tts 52 | ``` 53 | 54 | from source 55 | 56 | ```bash 57 | pip install git+https://github.com/shivammehta25/Matcha-TTS.git 58 | cd Matcha-TTS 59 | pip install -e . 60 | ``` 61 | 62 | 3. Run CLI / gradio app / jupyter notebook 63 | 64 | ```bash 65 | # This will download the required models 66 | matcha-tts --text "" 67 | ``` 68 | 69 | or 70 | 71 | ```bash 72 | matcha-tts-app 73 | ``` 74 | 75 | or open `synthesis.ipynb` on jupyter notebook 76 | 77 | ### CLI Arguments 78 | 79 | - To synthesise from given text, run: 80 | 81 | ```bash 82 | matcha-tts --text "" 83 | ``` 84 | 85 | - To synthesise from a file, run: 86 | 87 | ```bash 88 | matcha-tts --file 89 | ``` 90 | 91 | - To batch synthesise from a file, run: 92 | 93 | ```bash 94 | matcha-tts --file --batched 95 | ``` 96 | 97 | Additional arguments 98 | 99 | - Speaking rate 100 | 101 | ```bash 102 | matcha-tts --text "" --speaking_rate 1.0 103 | ``` 104 | 105 | - Sampling temperature 106 | 107 | ```bash 108 | matcha-tts --text "" --temperature 0.667 109 | ``` 110 | 111 | - Euler ODE solver steps 112 | 113 | ```bash 114 | matcha-tts --text "" --steps 10 115 | ``` 116 | 117 | ## Train with your own dataset 118 | 119 | Let's assume we are training with LJ Speech 120 | 121 | 1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup). 122 | 123 | 2. Clone and enter the Matcha-TTS repository 124 | 125 | ```bash 126 | git clone https://github.com/shivammehta25/Matcha-TTS.git 127 | cd Matcha-TTS 128 | ``` 129 | 130 | 3. Install the package from source 131 | 132 | ```bash 133 | pip install -e . 134 | ``` 135 | 136 | 4. Go to `configs/data/ljspeech.yaml` and change 137 | 138 | ```yaml 139 | train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt 140 | valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt 141 | ``` 142 | 143 | 5. Generate normalisation statistics with the yaml file of dataset configuration 144 | 145 | ```bash 146 | matcha-data-stats -i ljspeech.yaml 147 | # Output: 148 | #{'mel_mean': -5.53662231756592, 'mel_std': 2.1161014277038574} 149 | ``` 150 | 151 | Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key. 152 | 153 | ```bash 154 | data_statistics: # Computed for ljspeech dataset 155 | mel_mean: -5.536622 156 | mel_std: 2.116101 157 | ``` 158 | 159 | to the paths of your train and validation filelists. 160 | 161 | 6. Run the training script 162 | 163 | ```bash 164 | make train-ljspeech 165 | ``` 166 | 167 | or 168 | 169 | ```bash 170 | python matcha/train.py experiment=ljspeech 171 | ``` 172 | 173 | - for a minimum memory run 174 | 175 | ```bash 176 | python matcha/train.py experiment=ljspeech_min_memory 177 | ``` 178 | 179 | - for multi-gpu training, run 180 | 181 | ```bash 182 | python matcha/train.py experiment=ljspeech trainer.devices=[0,1] 183 | ``` 184 | 185 | 7. Synthesise from the custom trained model 186 | 187 | ```bash 188 | matcha-tts --text "" --checkpoint_path 189 | ``` 190 | 191 | ## ONNX support 192 | 193 | > Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support. 194 | 195 | It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. 196 | 197 | ### ONNX export 198 | 199 | To export a checkpoint to ONNX, first install ONNX with 200 | 201 | ```bash 202 | pip install onnx 203 | ``` 204 | 205 | then run the following: 206 | 207 | ```bash 208 | python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 209 | ``` 210 | 211 | Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). 212 | 213 | **Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. 214 | 215 | **Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. 216 | 217 | ### ONNX Inference 218 | 219 | To run inference on the exported model, first install `onnxruntime` using 220 | 221 | ```bash 222 | pip install onnxruntime 223 | pip install onnxruntime-gpu # for GPU inference 224 | ``` 225 | 226 | then use the following: 227 | 228 | ```bash 229 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs 230 | ``` 231 | 232 | You can also control synthesis parameters: 233 | 234 | ```bash 235 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 236 | ``` 237 | 238 | To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: 239 | 240 | ```bash 241 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu 242 | ``` 243 | 244 | If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. 245 | If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. 246 | 247 | If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: 248 | 249 | ```bash 250 | python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx 251 | ``` 252 | 253 | This will write `.wav` audio files to the output directory. 254 | 255 | ## Extract phoneme alignments from Matcha-TTS 256 | 257 | If the dataset is structured as 258 | 259 | ```bash 260 | data/ 261 | └── LJSpeech-1.1 262 | ├── metadata.csv 263 | ├── README 264 | ├── test.txt 265 | ├── train.txt 266 | ├── val.txt 267 | └── wavs 268 | ``` 269 | Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using: 270 | ```bash 271 | python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c 272 | ``` 273 | Example: 274 | ```bash 275 | python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt 276 | ``` 277 | or simply: 278 | ```bash 279 | matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt 280 | ``` 281 | --- 282 | ## Train using extracted alignments 283 | 284 | In the datasetconfig turn on load duration. 285 | Example: `ljspeech.yaml` 286 | ``` 287 | load_durations: True 288 | ``` 289 | or see an examples in configs/experiment/ljspeech_from_durations.yaml 290 | 291 | 292 | ## Citation information 293 | 294 | If you use our code or otherwise find this work useful, please cite our paper: 295 | 296 | ```text 297 | @inproceedings{mehta2024matcha, 298 | title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching}, 299 | author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, 300 | booktitle={Proc. ICASSP}, 301 | year={2024} 302 | } 303 | ``` 304 | 305 | ## Acknowledgements 306 | 307 | Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. 308 | 309 | Other source code we would like to acknowledge: 310 | 311 | - [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement 312 | - [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components 313 | - [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code 314 | - [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development 315 | - [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation 316 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/hi-fi_en-US_female.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | # Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ 6 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 7 | name: hi-fi_en-US_female 8 | train_filelist_path: data/hi-fi_en-US_female/train.txt 9 | valid_filelist_path: data/hi-fi_en-US_female/val.txt 10 | batch_size: 32 11 | cleaners: [english_cleaners_piper] 12 | data_statistics: # Computed for this dataset 13 | mel_mean: -6.38385 14 | mel_std: 2.541796 15 | -------------------------------------------------------------------------------- /configs/data/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 2 | name: ljspeech 3 | train_filelist_path: data/LJSpeech-1.1/train.txt 4 | valid_filelist_path: data/LJSpeech-1.1/val.txt 5 | batch_size: 32 6 | num_workers: 20 7 | pin_memory: True 8 | cleaners: [english_cleaners2] 9 | add_blank: True 10 | n_spks: 1 11 | n_fft: 1024 12 | n_feats: 80 13 | sample_rate: 22050 14 | hop_length: 256 15 | win_length: 1024 16 | f_min: 0 17 | f_max: 8000 18 | data_statistics: # Computed for ljspeech dataset 19 | mel_mean: -5.536622 20 | mel_std: 2.116101 21 | seed: ${seed} 22 | load_durations: false 23 | -------------------------------------------------------------------------------- /configs/data/vctk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 6 | name: vctk 7 | train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt 8 | valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt 9 | batch_size: 32 10 | add_blank: True 11 | n_spks: 109 12 | data_statistics: # Computed for vctk dataset 13 | mel_mean: -6.630575 14 | mel_std: 2.482914 15 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/experiment/hifi_dataset_piper_phonemizer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: hi-fi_en-US_female.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] 13 | 14 | run_name: hi-fi_en-US_female_piper_phonemizer 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech_from_durations.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | 16 | 17 | data: 18 | load_durations: True 19 | batch_size: 64 20 | -------------------------------------------------------------------------------- /configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/cfm/default.yaml: -------------------------------------------------------------------------------- 1 | name: CFM 2 | solver: euler 3 | sigma_min: 1e-4 4 | -------------------------------------------------------------------------------- /configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | channels: [256, 256] 2 | dropout: 0.05 3 | attention_head_dim: 64 4 | n_blocks: 1 5 | num_mid_blocks: 2 6 | num_heads: 2 7 | act_fn: snakebeta 8 | -------------------------------------------------------------------------------- /configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: RoPE Encoder 2 | encoder_params: 3 | n_feats: ${model.n_feats} 4 | n_channels: 192 5 | filter_channels: 768 6 | filter_channels_dp: 256 7 | n_heads: 2 8 | n_layers: 6 9 | kernel_size: 3 10 | p_dropout: 0.1 11 | spk_emb_dim: 64 12 | n_spks: 1 13 | prenet: true 14 | 15 | duration_predictor_params: 16 | filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} 17 | kernel_size: 3 18 | p_dropout: ${model.encoder.encoder_params.p_dropout} 19 | -------------------------------------------------------------------------------- /configs/model/matcha.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - decoder: default.yaml 5 | - cfm: default.yaml 6 | - optimizer: adam.yaml 7 | 8 | _target_: matcha.models.matcha_tts.MatchaTTS 9 | n_vocab: 178 10 | n_spks: ${data.n_spks} 11 | spk_emb_dim: 64 12 | n_feats: 80 13 | data_statistics: ${data.data_statistics} 14 | out_size: null # Must be divisible by 4 15 | prior_loss: true 16 | use_precomputed_durations: ${data.load_durations} 17 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: matcha 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 1234 52 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /matcha/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.7.2 2 | -------------------------------------------------------------------------------- /matcha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/__init__.py -------------------------------------------------------------------------------- /matcha/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/data/__init__.py -------------------------------------------------------------------------------- /matcha/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/data/components/__init__.py -------------------------------------------------------------------------------- /matcha/data/text_mel_datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Any, Dict, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio as ta 8 | from lightning import LightningDataModule 9 | from torch.utils.data.dataloader import DataLoader 10 | 11 | from matcha.text import text_to_sequence 12 | from matcha.utils.audio import mel_spectrogram 13 | from matcha.utils.model import fix_len_compatibility, normalize 14 | from matcha.utils.utils import intersperse 15 | 16 | 17 | def parse_filelist(filelist_path, split_char="|"): 18 | with open(filelist_path, encoding="utf-8") as f: 19 | filepaths_and_text = [line.strip().split(split_char) for line in f] 20 | return filepaths_and_text 21 | 22 | 23 | class TextMelDataModule(LightningDataModule): 24 | def __init__( # pylint: disable=unused-argument 25 | self, 26 | name, 27 | train_filelist_path, 28 | valid_filelist_path, 29 | batch_size, 30 | num_workers, 31 | pin_memory, 32 | cleaners, 33 | add_blank, 34 | n_spks, 35 | n_fft, 36 | n_feats, 37 | sample_rate, 38 | hop_length, 39 | win_length, 40 | f_min, 41 | f_max, 42 | data_statistics, 43 | seed, 44 | load_durations, 45 | ): 46 | super().__init__() 47 | 48 | # this line allows to access init params with 'self.hparams' attribute 49 | # also ensures init params will be stored in ckpt 50 | self.save_hyperparameters(logger=False) 51 | 52 | def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument 53 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 54 | 55 | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be 56 | careful not to execute things like random split twice! 57 | """ 58 | # load and split datasets only if not loaded already 59 | 60 | self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 61 | self.hparams.train_filelist_path, 62 | self.hparams.n_spks, 63 | self.hparams.cleaners, 64 | self.hparams.add_blank, 65 | self.hparams.n_fft, 66 | self.hparams.n_feats, 67 | self.hparams.sample_rate, 68 | self.hparams.hop_length, 69 | self.hparams.win_length, 70 | self.hparams.f_min, 71 | self.hparams.f_max, 72 | self.hparams.data_statistics, 73 | self.hparams.seed, 74 | self.hparams.load_durations, 75 | ) 76 | self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init 77 | self.hparams.valid_filelist_path, 78 | self.hparams.n_spks, 79 | self.hparams.cleaners, 80 | self.hparams.add_blank, 81 | self.hparams.n_fft, 82 | self.hparams.n_feats, 83 | self.hparams.sample_rate, 84 | self.hparams.hop_length, 85 | self.hparams.win_length, 86 | self.hparams.f_min, 87 | self.hparams.f_max, 88 | self.hparams.data_statistics, 89 | self.hparams.seed, 90 | self.hparams.load_durations, 91 | ) 92 | 93 | def train_dataloader(self): 94 | return DataLoader( 95 | dataset=self.trainset, 96 | batch_size=self.hparams.batch_size, 97 | num_workers=self.hparams.num_workers, 98 | pin_memory=self.hparams.pin_memory, 99 | shuffle=True, 100 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 101 | ) 102 | 103 | def val_dataloader(self): 104 | return DataLoader( 105 | dataset=self.validset, 106 | batch_size=self.hparams.batch_size, 107 | num_workers=self.hparams.num_workers, 108 | pin_memory=self.hparams.pin_memory, 109 | shuffle=False, 110 | collate_fn=TextMelBatchCollate(self.hparams.n_spks), 111 | ) 112 | 113 | def teardown(self, stage: Optional[str] = None): 114 | """Clean up after fit or test.""" 115 | pass # pylint: disable=unnecessary-pass 116 | 117 | def state_dict(self): 118 | """Extra things to save to checkpoint.""" 119 | return {} 120 | 121 | def load_state_dict(self, state_dict: Dict[str, Any]): 122 | """Things to do when loading checkpoint.""" 123 | pass # pylint: disable=unnecessary-pass 124 | 125 | 126 | class TextMelDataset(torch.utils.data.Dataset): 127 | def __init__( 128 | self, 129 | filelist_path, 130 | n_spks, 131 | cleaners, 132 | add_blank=True, 133 | n_fft=1024, 134 | n_mels=80, 135 | sample_rate=22050, 136 | hop_length=256, 137 | win_length=1024, 138 | f_min=0.0, 139 | f_max=8000, 140 | data_parameters=None, 141 | seed=None, 142 | load_durations=False, 143 | ): 144 | self.filepaths_and_text = parse_filelist(filelist_path) 145 | self.n_spks = n_spks 146 | self.cleaners = cleaners 147 | self.add_blank = add_blank 148 | self.n_fft = n_fft 149 | self.n_mels = n_mels 150 | self.sample_rate = sample_rate 151 | self.hop_length = hop_length 152 | self.win_length = win_length 153 | self.f_min = f_min 154 | self.f_max = f_max 155 | self.load_durations = load_durations 156 | 157 | if data_parameters is not None: 158 | self.data_parameters = data_parameters 159 | else: 160 | self.data_parameters = {"mel_mean": 0, "mel_std": 1} 161 | random.seed(seed) 162 | random.shuffle(self.filepaths_and_text) 163 | 164 | def get_datapoint(self, filepath_and_text): 165 | if self.n_spks > 1: 166 | filepath, spk, text = ( 167 | filepath_and_text[0], 168 | int(filepath_and_text[1]), 169 | filepath_and_text[2], 170 | ) 171 | else: 172 | filepath, text = filepath_and_text[0], filepath_and_text[1] 173 | spk = None 174 | 175 | text, cleaned_text = self.get_text(text, add_blank=self.add_blank) 176 | mel = self.get_mel(filepath) 177 | 178 | durations = self.get_durations(filepath, text) if self.load_durations else None 179 | 180 | return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} 181 | 182 | def get_durations(self, filepath, text): 183 | filepath = Path(filepath) 184 | data_dir, name = filepath.parent.parent, filepath.stem 185 | 186 | try: 187 | dur_loc = data_dir / "durations" / f"{name}.npy" 188 | durs = torch.from_numpy(np.load(dur_loc).astype(int)) 189 | 190 | except FileNotFoundError as e: 191 | raise FileNotFoundError( 192 | f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" 193 | ) from e 194 | 195 | assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" 196 | 197 | return durs 198 | 199 | def get_mel(self, filepath): 200 | audio, sr = ta.load(filepath) 201 | assert sr == self.sample_rate 202 | mel = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.n_mels, 206 | self.sample_rate, 207 | self.hop_length, 208 | self.win_length, 209 | self.f_min, 210 | self.f_max, 211 | center=False, 212 | ).squeeze() 213 | mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) 214 | return mel 215 | 216 | def get_text(self, text, add_blank=True): 217 | text_norm, cleaned_text = text_to_sequence(text, self.cleaners) 218 | if self.add_blank: 219 | text_norm = intersperse(text_norm, 0) 220 | text_norm = torch.IntTensor(text_norm) 221 | return text_norm, cleaned_text 222 | 223 | def __getitem__(self, index): 224 | datapoint = self.get_datapoint(self.filepaths_and_text[index]) 225 | return datapoint 226 | 227 | def __len__(self): 228 | return len(self.filepaths_and_text) 229 | 230 | 231 | class TextMelBatchCollate: 232 | def __init__(self, n_spks): 233 | self.n_spks = n_spks 234 | 235 | def __call__(self, batch): 236 | B = len(batch) 237 | y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator 238 | y_max_length = fix_len_compatibility(y_max_length) 239 | x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator 240 | n_feats = batch[0]["y"].shape[-2] 241 | 242 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) 243 | x = torch.zeros((B, x_max_length), dtype=torch.long) 244 | durations = torch.zeros((B, x_max_length), dtype=torch.long) 245 | 246 | y_lengths, x_lengths = [], [] 247 | spks = [] 248 | filepaths, x_texts = [], [] 249 | for i, item in enumerate(batch): 250 | y_, x_ = item["y"], item["x"] 251 | y_lengths.append(y_.shape[-1]) 252 | x_lengths.append(x_.shape[-1]) 253 | y[i, :, : y_.shape[-1]] = y_ 254 | x[i, : x_.shape[-1]] = x_ 255 | spks.append(item["spk"]) 256 | filepaths.append(item["filepath"]) 257 | x_texts.append(item["x_text"]) 258 | if item["durations"] is not None: 259 | durations[i, : item["durations"].shape[-1]] = item["durations"] 260 | 261 | y_lengths = torch.tensor(y_lengths, dtype=torch.long) 262 | x_lengths = torch.tensor(x_lengths, dtype=torch.long) 263 | spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None 264 | 265 | return { 266 | "x": x, 267 | "x_lengths": x_lengths, 268 | "y": y, 269 | "y_lengths": y_lengths, 270 | "spks": spks, 271 | "filepaths": filepaths, 272 | "x_texts": x_texts, 273 | "durations": durations if not torch.eq(durations, 0).all() else None, 274 | } 275 | -------------------------------------------------------------------------------- /matcha/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /matcha/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /matcha/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/hifigan/__init__.py -------------------------------------------------------------------------------- /matcha/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /matcha/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class ModeException(Exception): 8 | pass 9 | 10 | 11 | class Denoiser(torch.nn.Module): 12 | """Removes model bias from audio produced with waveglow""" 13 | 14 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 15 | super().__init__() 16 | self.filter_length = filter_length 17 | self.hop_length = int(filter_length / n_overlap) 18 | self.win_length = win_length 19 | 20 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 21 | self.device = device 22 | if mode == "zeros": 23 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 24 | elif mode == "normal": 25 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 26 | else: 27 | raise ModeException(f"Mode {mode} if not supported") 28 | 29 | def stft_fn(audio, n_fft, hop_length, win_length, window): 30 | spec = torch.stft( 31 | audio, 32 | n_fft=n_fft, 33 | hop_length=hop_length, 34 | win_length=win_length, 35 | window=window, 36 | return_complex=True, 37 | ) 38 | spec = torch.view_as_real(spec) 39 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 40 | 41 | self.stft = lambda x: stft_fn( 42 | audio=x, 43 | n_fft=self.filter_length, 44 | hop_length=self.hop_length, 45 | win_length=self.win_length, 46 | window=torch.hann_window(self.win_length, device=device), 47 | ) 48 | self.istft = lambda x, y: torch.istft( 49 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 50 | n_fft=self.filter_length, 51 | hop_length=self.hop_length, 52 | win_length=self.win_length, 53 | window=torch.hann_window(self.win_length, device=device), 54 | ) 55 | 56 | with torch.no_grad(): 57 | bias_audio = vocoder(mel_input).float().squeeze(0) 58 | bias_spec, _ = self.stft(bias_audio) 59 | 60 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 61 | 62 | @torch.inference_mode() 63 | def forward(self, audio, strength=0.0005): 64 | audio_spec, audio_angles = self.stft(audio) 65 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 66 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 67 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 68 | return audio_denoised 69 | -------------------------------------------------------------------------------- /matcha/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /matcha/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /matcha/hifigan/models.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import torch 4 | import torch.nn as nn # pylint: disable=consider-using-from-import 5 | import torch.nn.functional as F 6 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 7 | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 8 | 9 | from .xutils import get_padding, init_weights 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class ResBlock1(torch.nn.Module): 15 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 16 | super().__init__() 17 | self.h = h 18 | self.convs1 = nn.ModuleList( 19 | [ 20 | weight_norm( 21 | Conv1d( 22 | channels, 23 | channels, 24 | kernel_size, 25 | 1, 26 | dilation=dilation[0], 27 | padding=get_padding(kernel_size, dilation[0]), 28 | ) 29 | ), 30 | weight_norm( 31 | Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | 1, 36 | dilation=dilation[1], 37 | padding=get_padding(kernel_size, dilation[1]), 38 | ) 39 | ), 40 | weight_norm( 41 | Conv1d( 42 | channels, 43 | channels, 44 | kernel_size, 45 | 1, 46 | dilation=dilation[2], 47 | padding=get_padding(kernel_size, dilation[2]), 48 | ) 49 | ), 50 | ] 51 | ) 52 | self.convs1.apply(init_weights) 53 | 54 | self.convs2 = nn.ModuleList( 55 | [ 56 | weight_norm( 57 | Conv1d( 58 | channels, 59 | channels, 60 | kernel_size, 61 | 1, 62 | dilation=1, 63 | padding=get_padding(kernel_size, 1), 64 | ) 65 | ), 66 | weight_norm( 67 | Conv1d( 68 | channels, 69 | channels, 70 | kernel_size, 71 | 1, 72 | dilation=1, 73 | padding=get_padding(kernel_size, 1), 74 | ) 75 | ), 76 | weight_norm( 77 | Conv1d( 78 | channels, 79 | channels, 80 | kernel_size, 81 | 1, 82 | dilation=1, 83 | padding=get_padding(kernel_size, 1), 84 | ) 85 | ), 86 | ] 87 | ) 88 | self.convs2.apply(init_weights) 89 | 90 | def forward(self, x): 91 | for c1, c2 in zip(self.convs1, self.convs2): 92 | xt = F.leaky_relu(x, LRELU_SLOPE) 93 | xt = c1(xt) 94 | xt = F.leaky_relu(xt, LRELU_SLOPE) 95 | xt = c2(xt) 96 | x = xt + x 97 | return x 98 | 99 | def remove_weight_norm(self): 100 | for l in self.convs1: 101 | remove_weight_norm(l) 102 | for l in self.convs2: 103 | remove_weight_norm(l) 104 | 105 | 106 | class ResBlock2(torch.nn.Module): 107 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 108 | super().__init__() 109 | self.h = h 110 | self.convs = nn.ModuleList( 111 | [ 112 | weight_norm( 113 | Conv1d( 114 | channels, 115 | channels, 116 | kernel_size, 117 | 1, 118 | dilation=dilation[0], 119 | padding=get_padding(kernel_size, dilation[0]), 120 | ) 121 | ), 122 | weight_norm( 123 | Conv1d( 124 | channels, 125 | channels, 126 | kernel_size, 127 | 1, 128 | dilation=dilation[1], 129 | padding=get_padding(kernel_size, dilation[1]), 130 | ) 131 | ), 132 | ] 133 | ) 134 | self.convs.apply(init_weights) 135 | 136 | def forward(self, x): 137 | for c in self.convs: 138 | xt = F.leaky_relu(x, LRELU_SLOPE) 139 | xt = c(xt) 140 | x = xt + x 141 | return x 142 | 143 | def remove_weight_norm(self): 144 | for l in self.convs: 145 | remove_weight_norm(l) 146 | 147 | 148 | class Generator(torch.nn.Module): 149 | def __init__(self, h): 150 | super().__init__() 151 | self.h = h 152 | self.num_kernels = len(h.resblock_kernel_sizes) 153 | self.num_upsamples = len(h.upsample_rates) 154 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 155 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 156 | 157 | self.ups = nn.ModuleList() 158 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 159 | self.ups.append( 160 | weight_norm( 161 | ConvTranspose1d( 162 | h.upsample_initial_channel // (2**i), 163 | h.upsample_initial_channel // (2 ** (i + 1)), 164 | k, 165 | u, 166 | padding=(k - u) // 2, 167 | ) 168 | ) 169 | ) 170 | 171 | self.resblocks = nn.ModuleList() 172 | for i in range(len(self.ups)): 173 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 174 | for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 175 | self.resblocks.append(resblock(h, ch, k, d)) 176 | 177 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 178 | self.ups.apply(init_weights) 179 | self.conv_post.apply(init_weights) 180 | 181 | def forward(self, x): 182 | x = self.conv_pre(x) 183 | for i in range(self.num_upsamples): 184 | x = F.leaky_relu(x, LRELU_SLOPE) 185 | x = self.ups[i](x) 186 | xs = None 187 | for j in range(self.num_kernels): 188 | if xs is None: 189 | xs = self.resblocks[i * self.num_kernels + j](x) 190 | else: 191 | xs += self.resblocks[i * self.num_kernels + j](x) 192 | x = xs / self.num_kernels 193 | x = F.leaky_relu(x) 194 | x = self.conv_post(x) 195 | x = torch.tanh(x) 196 | 197 | return x 198 | 199 | def remove_weight_norm(self): 200 | print("Removing weight norm...") 201 | for l in self.ups: 202 | remove_weight_norm(l) 203 | for l in self.resblocks: 204 | l.remove_weight_norm() 205 | remove_weight_norm(self.conv_pre) 206 | remove_weight_norm(self.conv_post) 207 | 208 | 209 | class DiscriminatorP(torch.nn.Module): 210 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 211 | super().__init__() 212 | self.period = period 213 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 214 | self.convs = nn.ModuleList( 215 | [ 216 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 217 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 218 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 219 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 220 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 221 | ] 222 | ) 223 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 224 | 225 | def forward(self, x): 226 | fmap = [] 227 | 228 | # 1d to 2d 229 | b, c, t = x.shape 230 | if t % self.period != 0: # pad first 231 | n_pad = self.period - (t % self.period) 232 | x = F.pad(x, (0, n_pad), "reflect") 233 | t = t + n_pad 234 | x = x.view(b, c, t // self.period, self.period) 235 | 236 | for l in self.convs: 237 | x = l(x) 238 | x = F.leaky_relu(x, LRELU_SLOPE) 239 | fmap.append(x) 240 | x = self.conv_post(x) 241 | fmap.append(x) 242 | x = torch.flatten(x, 1, -1) 243 | 244 | return x, fmap 245 | 246 | 247 | class MultiPeriodDiscriminator(torch.nn.Module): 248 | def __init__(self): 249 | super().__init__() 250 | self.discriminators = nn.ModuleList( 251 | [ 252 | DiscriminatorP(2), 253 | DiscriminatorP(3), 254 | DiscriminatorP(5), 255 | DiscriminatorP(7), 256 | DiscriminatorP(11), 257 | ] 258 | ) 259 | 260 | def forward(self, y, y_hat): 261 | y_d_rs = [] 262 | y_d_gs = [] 263 | fmap_rs = [] 264 | fmap_gs = [] 265 | for _, d in enumerate(self.discriminators): 266 | y_d_r, fmap_r = d(y) 267 | y_d_g, fmap_g = d(y_hat) 268 | y_d_rs.append(y_d_r) 269 | fmap_rs.append(fmap_r) 270 | y_d_gs.append(y_d_g) 271 | fmap_gs.append(fmap_g) 272 | 273 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 274 | 275 | 276 | class DiscriminatorS(torch.nn.Module): 277 | def __init__(self, use_spectral_norm=False): 278 | super().__init__() 279 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 280 | self.convs = nn.ModuleList( 281 | [ 282 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 283 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 284 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 285 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 286 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 287 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 288 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 289 | ] 290 | ) 291 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 292 | 293 | def forward(self, x): 294 | fmap = [] 295 | for l in self.convs: 296 | x = l(x) 297 | x = F.leaky_relu(x, LRELU_SLOPE) 298 | fmap.append(x) 299 | x = self.conv_post(x) 300 | fmap.append(x) 301 | x = torch.flatten(x, 1, -1) 302 | 303 | return x, fmap 304 | 305 | 306 | class MultiScaleDiscriminator(torch.nn.Module): 307 | def __init__(self): 308 | super().__init__() 309 | self.discriminators = nn.ModuleList( 310 | [ 311 | DiscriminatorS(use_spectral_norm=True), 312 | DiscriminatorS(), 313 | DiscriminatorS(), 314 | ] 315 | ) 316 | self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) 317 | 318 | def forward(self, y, y_hat): 319 | y_d_rs = [] 320 | y_d_gs = [] 321 | fmap_rs = [] 322 | fmap_gs = [] 323 | for i, d in enumerate(self.discriminators): 324 | if i != 0: 325 | y = self.meanpools[i - 1](y) 326 | y_hat = self.meanpools[i - 1](y_hat) 327 | y_d_r, fmap_r = d(y) 328 | y_d_g, fmap_g = d(y_hat) 329 | y_d_rs.append(y_d_r) 330 | fmap_rs.append(fmap_r) 331 | y_d_gs.append(y_d_g) 332 | fmap_gs.append(fmap_g) 333 | 334 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 335 | 336 | 337 | def feature_loss(fmap_r, fmap_g): 338 | loss = 0 339 | for dr, dg in zip(fmap_r, fmap_g): 340 | for rl, gl in zip(dr, dg): 341 | loss += torch.mean(torch.abs(rl - gl)) 342 | 343 | return loss * 2 344 | 345 | 346 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 347 | loss = 0 348 | r_losses = [] 349 | g_losses = [] 350 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 351 | r_loss = torch.mean((1 - dr) ** 2) 352 | g_loss = torch.mean(dg**2) 353 | loss += r_loss + g_loss 354 | r_losses.append(r_loss.item()) 355 | g_losses.append(g_loss.item()) 356 | 357 | return loss, r_losses, g_losses 358 | 359 | 360 | def generator_loss(disc_outputs): 361 | loss = 0 362 | gen_losses = [] 363 | for dg in disc_outputs: 364 | l = torch.mean((1 - dg) ** 2) 365 | gen_losses.append(l) 366 | loss += l 367 | 368 | return loss, gen_losses 369 | -------------------------------------------------------------------------------- /matcha/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /matcha/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/models/__init__.py -------------------------------------------------------------------------------- /matcha/models/baselightningmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a base lightning module that can be used to train a model. 3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. 4 | """ 5 | import inspect 6 | from abc import ABC 7 | from typing import Any, Dict 8 | 9 | import torch 10 | from lightning import LightningModule 11 | from lightning.pytorch.utilities import grad_norm 12 | 13 | from matcha import utils 14 | from matcha.utils.utils import plot_tensor 15 | 16 | log = utils.get_pylogger(__name__) 17 | 18 | 19 | class BaseLightningClass(LightningModule, ABC): 20 | def update_data_statistics(self, data_statistics): 21 | if data_statistics is None: 22 | data_statistics = { 23 | "mel_mean": 0.0, 24 | "mel_std": 1.0, 25 | } 26 | 27 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) 28 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) 29 | 30 | def configure_optimizers(self) -> Any: 31 | optimizer = self.hparams.optimizer(params=self.parameters()) 32 | if self.hparams.scheduler not in (None, {}): 33 | scheduler_args = {} 34 | # Manage last epoch for exponential schedulers 35 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: 36 | if hasattr(self, "ckpt_loaded_epoch"): 37 | current_epoch = self.ckpt_loaded_epoch - 1 38 | else: 39 | current_epoch = -1 40 | 41 | scheduler_args.update({"optimizer": optimizer}) 42 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args) 43 | scheduler.last_epoch = current_epoch 44 | return { 45 | "optimizer": optimizer, 46 | "lr_scheduler": { 47 | "scheduler": scheduler, 48 | "interval": self.hparams.scheduler.lightning_args.interval, 49 | "frequency": self.hparams.scheduler.lightning_args.frequency, 50 | "name": "learning_rate", 51 | }, 52 | } 53 | 54 | return {"optimizer": optimizer} 55 | 56 | def get_losses(self, batch): 57 | x, x_lengths = batch["x"], batch["x_lengths"] 58 | y, y_lengths = batch["y"], batch["y_lengths"] 59 | spks = batch["spks"] 60 | 61 | dur_loss, prior_loss, diff_loss, *_ = self( 62 | x=x, 63 | x_lengths=x_lengths, 64 | y=y, 65 | y_lengths=y_lengths, 66 | spks=spks, 67 | out_size=self.out_size, 68 | durations=batch["durations"], 69 | ) 70 | return { 71 | "dur_loss": dur_loss, 72 | "prior_loss": prior_loss, 73 | "diff_loss": diff_loss, 74 | } 75 | 76 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 77 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init 78 | 79 | def training_step(self, batch: Any, batch_idx: int): 80 | loss_dict = self.get_losses(batch) 81 | self.log( 82 | "step", 83 | float(self.global_step), 84 | on_step=True, 85 | prog_bar=True, 86 | logger=True, 87 | sync_dist=True, 88 | ) 89 | 90 | self.log( 91 | "sub_loss/train_dur_loss", 92 | loss_dict["dur_loss"], 93 | on_step=True, 94 | on_epoch=True, 95 | logger=True, 96 | sync_dist=True, 97 | ) 98 | self.log( 99 | "sub_loss/train_prior_loss", 100 | loss_dict["prior_loss"], 101 | on_step=True, 102 | on_epoch=True, 103 | logger=True, 104 | sync_dist=True, 105 | ) 106 | self.log( 107 | "sub_loss/train_diff_loss", 108 | loss_dict["diff_loss"], 109 | on_step=True, 110 | on_epoch=True, 111 | logger=True, 112 | sync_dist=True, 113 | ) 114 | 115 | total_loss = sum(loss_dict.values()) 116 | self.log( 117 | "loss/train", 118 | total_loss, 119 | on_step=True, 120 | on_epoch=True, 121 | logger=True, 122 | prog_bar=True, 123 | sync_dist=True, 124 | ) 125 | 126 | return {"loss": total_loss, "log": loss_dict} 127 | 128 | def validation_step(self, batch: Any, batch_idx: int): 129 | loss_dict = self.get_losses(batch) 130 | self.log( 131 | "sub_loss/val_dur_loss", 132 | loss_dict["dur_loss"], 133 | on_step=True, 134 | on_epoch=True, 135 | logger=True, 136 | sync_dist=True, 137 | ) 138 | self.log( 139 | "sub_loss/val_prior_loss", 140 | loss_dict["prior_loss"], 141 | on_step=True, 142 | on_epoch=True, 143 | logger=True, 144 | sync_dist=True, 145 | ) 146 | self.log( 147 | "sub_loss/val_diff_loss", 148 | loss_dict["diff_loss"], 149 | on_step=True, 150 | on_epoch=True, 151 | logger=True, 152 | sync_dist=True, 153 | ) 154 | 155 | total_loss = sum(loss_dict.values()) 156 | self.log( 157 | "loss/val", 158 | total_loss, 159 | on_step=True, 160 | on_epoch=True, 161 | logger=True, 162 | prog_bar=True, 163 | sync_dist=True, 164 | ) 165 | 166 | return total_loss 167 | 168 | def on_validation_end(self) -> None: 169 | if self.trainer.is_global_zero: 170 | one_batch = next(iter(self.trainer.val_dataloaders)) 171 | if self.current_epoch == 0: 172 | log.debug("Plotting original samples") 173 | for i in range(2): 174 | y = one_batch["y"][i].unsqueeze(0).to(self.device) 175 | self.logger.experiment.add_image( 176 | f"original/{i}", 177 | plot_tensor(y.squeeze().cpu()), 178 | self.current_epoch, 179 | dataformats="HWC", 180 | ) 181 | 182 | log.debug("Synthesising...") 183 | for i in range(2): 184 | x = one_batch["x"][i].unsqueeze(0).to(self.device) 185 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) 186 | spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None 187 | output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) 188 | y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] 189 | attn = output["attn"] 190 | self.logger.experiment.add_image( 191 | f"generated_enc/{i}", 192 | plot_tensor(y_enc.squeeze().cpu()), 193 | self.current_epoch, 194 | dataformats="HWC", 195 | ) 196 | self.logger.experiment.add_image( 197 | f"generated_dec/{i}", 198 | plot_tensor(y_dec.squeeze().cpu()), 199 | self.current_epoch, 200 | dataformats="HWC", 201 | ) 202 | self.logger.experiment.add_image( 203 | f"alignment/{i}", 204 | plot_tensor(attn.squeeze().cpu()), 205 | self.current_epoch, 206 | dataformats="HWC", 207 | ) 208 | 209 | def on_before_optimizer_step(self, optimizer): 210 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) 211 | -------------------------------------------------------------------------------- /matcha/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/models/components/__init__.py -------------------------------------------------------------------------------- /matcha/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from matcha.models.components.decoder import Decoder 7 | from matcha.utils.pylogger import get_pylogger 8 | 9 | log = get_pylogger(__name__) 10 | 11 | 12 | class BASECFM(torch.nn.Module, ABC): 13 | def __init__( 14 | self, 15 | n_feats, 16 | cfm_params, 17 | n_spks=1, 18 | spk_emb_dim=128, 19 | ): 20 | super().__init__() 21 | self.n_feats = n_feats 22 | self.n_spks = n_spks 23 | self.spk_emb_dim = spk_emb_dim 24 | self.solver = cfm_params.solver 25 | if hasattr(cfm_params, "sigma_min"): 26 | self.sigma_min = cfm_params.sigma_min 27 | else: 28 | self.sigma_min = 1e-4 29 | 30 | self.estimator = None 31 | 32 | @torch.inference_mode() 33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 34 | """Forward diffusion 35 | 36 | Args: 37 | mu (torch.Tensor): output of encoder 38 | shape: (batch_size, n_feats, mel_timesteps) 39 | mask (torch.Tensor): output_mask 40 | shape: (batch_size, 1, mel_timesteps) 41 | n_timesteps (int): number of diffusion steps 42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 43 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 44 | shape: (batch_size, spk_emb_dim) 45 | cond: Not used but kept for future purposes 46 | 47 | Returns: 48 | sample: generated mel-spectrogram 49 | shape: (batch_size, n_feats, mel_timesteps) 50 | """ 51 | z = torch.randn_like(mu) * temperature 52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 54 | 55 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 56 | """ 57 | Fixed euler solver for ODEs. 58 | Args: 59 | x (torch.Tensor): random noise 60 | t_span (torch.Tensor): n_timesteps interpolated 61 | shape: (n_timesteps + 1,) 62 | mu (torch.Tensor): output of encoder 63 | shape: (batch_size, n_feats, mel_timesteps) 64 | mask (torch.Tensor): output_mask 65 | shape: (batch_size, 1, mel_timesteps) 66 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 67 | shape: (batch_size, spk_emb_dim) 68 | cond: Not used but kept for future purposes 69 | """ 70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 71 | 72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 73 | # Or in future might add like a return_all_steps flag 74 | sol = [] 75 | 76 | for step in range(1, len(t_span)): 77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 78 | 79 | x = x + dt * dphi_dt 80 | t = t + dt 81 | sol.append(x) 82 | if step < len(t_span) - 1: 83 | dt = t_span[step + 1] - t 84 | 85 | return sol[-1] 86 | 87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 88 | """Computes diffusion loss 89 | 90 | Args: 91 | x1 (torch.Tensor): Target 92 | shape: (batch_size, n_feats, mel_timesteps) 93 | mask (torch.Tensor): target mask 94 | shape: (batch_size, 1, mel_timesteps) 95 | mu (torch.Tensor): output of encoder 96 | shape: (batch_size, n_feats, mel_timesteps) 97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 98 | shape: (batch_size, spk_emb_dim) 99 | 100 | Returns: 101 | loss: conditional flow matching loss 102 | y: conditional flow 103 | shape: (batch_size, n_feats, mel_timesteps) 104 | """ 105 | b, _, t = mu.shape 106 | 107 | # random timestep 108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 109 | # sample noise p(x_0) 110 | z = torch.randn_like(x1) 111 | 112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 113 | u = x1 - (1 - self.sigma_min) * z 114 | 115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 116 | torch.sum(mask) * u.shape[1] 117 | ) 118 | return loss, y 119 | 120 | 121 | class CFM(BASECFM): 122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 123 | super().__init__( 124 | n_feats=in_channels, 125 | cfm_params=cfm_params, 126 | n_spks=n_spks, 127 | spk_emb_dim=spk_emb_dim, 128 | ) 129 | 130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 131 | # Just change the architecture of the estimator here 132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 133 | -------------------------------------------------------------------------------- /matcha/models/components/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn # pylint: disable=consider-using-from-import 5 | from diffusers.models.attention import ( 6 | GEGLU, 7 | GELU, 8 | AdaLayerNorm, 9 | AdaLayerNormZero, 10 | ApproximateGELU, 11 | ) 12 | from diffusers.models.attention_processor import Attention 13 | from diffusers.models.lora import LoRACompatibleLinear 14 | from diffusers.utils.torch_utils import maybe_allow_in_graph 15 | 16 | 17 | class SnakeBeta(nn.Module): 18 | """ 19 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 20 | Shape: 21 | - Input: (B, C, T) 22 | - Output: (B, C, T), same shape as the input 23 | Parameters: 24 | - alpha - trainable parameter that controls frequency 25 | - beta - trainable parameter that controls magnitude 26 | References: 27 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 28 | https://arxiv.org/abs/2006.08195 29 | Examples: 30 | >>> a1 = snakebeta(256) 31 | >>> x = torch.randn(256) 32 | >>> x = a1(x) 33 | """ 34 | 35 | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): 36 | """ 37 | Initialization. 38 | INPUT: 39 | - in_features: shape of the input 40 | - alpha - trainable parameter that controls frequency 41 | - beta - trainable parameter that controls magnitude 42 | alpha is initialized to 1 by default, higher values = higher-frequency. 43 | beta is initialized to 1 by default, higher values = higher-magnitude. 44 | alpha will be trained along with the rest of your model. 45 | """ 46 | super().__init__() 47 | self.in_features = out_features if isinstance(out_features, list) else [out_features] 48 | self.proj = LoRACompatibleLinear(in_features, out_features) 49 | 50 | # initialize alpha 51 | self.alpha_logscale = alpha_logscale 52 | if self.alpha_logscale: # log scale alphas initialized to zeros 53 | self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) 54 | self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) 55 | else: # linear scale alphas initialized to ones 56 | self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) 57 | self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) 58 | 59 | self.alpha.requires_grad = alpha_trainable 60 | self.beta.requires_grad = alpha_trainable 61 | 62 | self.no_div_by_zero = 0.000000001 63 | 64 | def forward(self, x): 65 | """ 66 | Forward pass of the function. 67 | Applies the function to the input elementwise. 68 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 69 | """ 70 | x = self.proj(x) 71 | if self.alpha_logscale: 72 | alpha = torch.exp(self.alpha) 73 | beta = torch.exp(self.beta) 74 | else: 75 | alpha = self.alpha 76 | beta = self.beta 77 | 78 | x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) 79 | 80 | return x 81 | 82 | 83 | class FeedForward(nn.Module): 84 | r""" 85 | A feed-forward layer. 86 | 87 | Parameters: 88 | dim (`int`): The number of channels in the input. 89 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 90 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 91 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 92 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 93 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | dim: int, 99 | dim_out: Optional[int] = None, 100 | mult: int = 4, 101 | dropout: float = 0.0, 102 | activation_fn: str = "geglu", 103 | final_dropout: bool = False, 104 | ): 105 | super().__init__() 106 | inner_dim = int(dim * mult) 107 | dim_out = dim_out if dim_out is not None else dim 108 | 109 | if activation_fn == "gelu": 110 | act_fn = GELU(dim, inner_dim) 111 | if activation_fn == "gelu-approximate": 112 | act_fn = GELU(dim, inner_dim, approximate="tanh") 113 | elif activation_fn == "geglu": 114 | act_fn = GEGLU(dim, inner_dim) 115 | elif activation_fn == "geglu-approximate": 116 | act_fn = ApproximateGELU(dim, inner_dim) 117 | elif activation_fn == "snakebeta": 118 | act_fn = SnakeBeta(dim, inner_dim) 119 | 120 | self.net = nn.ModuleList([]) 121 | # project in 122 | self.net.append(act_fn) 123 | # project dropout 124 | self.net.append(nn.Dropout(dropout)) 125 | # project out 126 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 127 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 128 | if final_dropout: 129 | self.net.append(nn.Dropout(dropout)) 130 | 131 | def forward(self, hidden_states): 132 | for module in self.net: 133 | hidden_states = module(hidden_states) 134 | return hidden_states 135 | 136 | 137 | @maybe_allow_in_graph 138 | class BasicTransformerBlock(nn.Module): 139 | r""" 140 | A basic Transformer block. 141 | 142 | Parameters: 143 | dim (`int`): The number of channels in the input and output. 144 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 145 | attention_head_dim (`int`): The number of channels in each head. 146 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 147 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 148 | only_cross_attention (`bool`, *optional*): 149 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 150 | double_self_attention (`bool`, *optional*): 151 | Whether to use two self-attention layers. In this case no cross attention layers are used. 152 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 153 | num_embeds_ada_norm (: 154 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 155 | attention_bias (: 156 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 157 | """ 158 | 159 | def __init__( 160 | self, 161 | dim: int, 162 | num_attention_heads: int, 163 | attention_head_dim: int, 164 | dropout=0.0, 165 | cross_attention_dim: Optional[int] = None, 166 | activation_fn: str = "geglu", 167 | num_embeds_ada_norm: Optional[int] = None, 168 | attention_bias: bool = False, 169 | only_cross_attention: bool = False, 170 | double_self_attention: bool = False, 171 | upcast_attention: bool = False, 172 | norm_elementwise_affine: bool = True, 173 | norm_type: str = "layer_norm", 174 | final_dropout: bool = False, 175 | ): 176 | super().__init__() 177 | self.only_cross_attention = only_cross_attention 178 | 179 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 180 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 181 | 182 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 183 | raise ValueError( 184 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 185 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 186 | ) 187 | 188 | # Define 3 blocks. Each block has its own normalization layer. 189 | # 1. Self-Attn 190 | if self.use_ada_layer_norm: 191 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 192 | elif self.use_ada_layer_norm_zero: 193 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 194 | else: 195 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 196 | self.attn1 = Attention( 197 | query_dim=dim, 198 | heads=num_attention_heads, 199 | dim_head=attention_head_dim, 200 | dropout=dropout, 201 | bias=attention_bias, 202 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 203 | upcast_attention=upcast_attention, 204 | ) 205 | 206 | # 2. Cross-Attn 207 | if cross_attention_dim is not None or double_self_attention: 208 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 209 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 210 | # the second cross attention block. 211 | self.norm2 = ( 212 | AdaLayerNorm(dim, num_embeds_ada_norm) 213 | if self.use_ada_layer_norm 214 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 215 | ) 216 | self.attn2 = Attention( 217 | query_dim=dim, 218 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 219 | heads=num_attention_heads, 220 | dim_head=attention_head_dim, 221 | dropout=dropout, 222 | bias=attention_bias, 223 | upcast_attention=upcast_attention, 224 | # scale_qk=False, # uncomment this to not to use flash attention 225 | ) # is self-attn if encoder_hidden_states is none 226 | else: 227 | self.norm2 = None 228 | self.attn2 = None 229 | 230 | # 3. Feed-forward 231 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 232 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 233 | 234 | # let chunk size default to None 235 | self._chunk_size = None 236 | self._chunk_dim = 0 237 | 238 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 239 | # Sets chunk feed-forward 240 | self._chunk_size = chunk_size 241 | self._chunk_dim = dim 242 | 243 | def forward( 244 | self, 245 | hidden_states: torch.FloatTensor, 246 | attention_mask: Optional[torch.FloatTensor] = None, 247 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 248 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 249 | timestep: Optional[torch.LongTensor] = None, 250 | cross_attention_kwargs: Dict[str, Any] = None, 251 | class_labels: Optional[torch.LongTensor] = None, 252 | ): 253 | # Notice that normalization is always applied before the real computation in the following blocks. 254 | # 1. Self-Attention 255 | if self.use_ada_layer_norm: 256 | norm_hidden_states = self.norm1(hidden_states, timestep) 257 | elif self.use_ada_layer_norm_zero: 258 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 259 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 260 | ) 261 | else: 262 | norm_hidden_states = self.norm1(hidden_states) 263 | 264 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 265 | 266 | attn_output = self.attn1( 267 | norm_hidden_states, 268 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 269 | attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, 270 | **cross_attention_kwargs, 271 | ) 272 | if self.use_ada_layer_norm_zero: 273 | attn_output = gate_msa.unsqueeze(1) * attn_output 274 | hidden_states = attn_output + hidden_states 275 | 276 | # 2. Cross-Attention 277 | if self.attn2 is not None: 278 | norm_hidden_states = ( 279 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 280 | ) 281 | 282 | attn_output = self.attn2( 283 | norm_hidden_states, 284 | encoder_hidden_states=encoder_hidden_states, 285 | attention_mask=encoder_attention_mask, 286 | **cross_attention_kwargs, 287 | ) 288 | hidden_states = attn_output + hidden_states 289 | 290 | # 3. Feed-forward 291 | norm_hidden_states = self.norm3(hidden_states) 292 | 293 | if self.use_ada_layer_norm_zero: 294 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 295 | 296 | if self._chunk_size is not None: 297 | # "feed_forward_chunk_size" can be used to save memory 298 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 299 | raise ValueError( 300 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 301 | ) 302 | 303 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 304 | ff_output = torch.cat( 305 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], 306 | dim=self._chunk_dim, 307 | ) 308 | else: 309 | ff_output = self.ff(norm_hidden_states) 310 | 311 | if self.use_ada_layer_norm_zero: 312 | ff_output = gate_mlp.unsqueeze(1) * ff_output 313 | 314 | hidden_states = ff_output + hidden_states 315 | 316 | return hidden_states 317 | -------------------------------------------------------------------------------- /matcha/models/matcha_tts.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import math 3 | import random 4 | 5 | import torch 6 | 7 | import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import 8 | from matcha import utils 9 | from matcha.models.baselightningmodule import BaseLightningClass 10 | from matcha.models.components.flow_matching import CFM 11 | from matcha.models.components.text_encoder import TextEncoder 12 | from matcha.utils.model import ( 13 | denormalize, 14 | duration_loss, 15 | fix_len_compatibility, 16 | generate_path, 17 | sequence_mask, 18 | ) 19 | 20 | log = utils.get_pylogger(__name__) 21 | 22 | 23 | class MatchaTTS(BaseLightningClass): # 🍵 24 | def __init__( 25 | self, 26 | n_vocab, 27 | n_spks, 28 | spk_emb_dim, 29 | n_feats, 30 | encoder, 31 | decoder, 32 | cfm, 33 | data_statistics, 34 | out_size, 35 | optimizer=None, 36 | scheduler=None, 37 | prior_loss=True, 38 | use_precomputed_durations=False, 39 | ): 40 | super().__init__() 41 | 42 | self.save_hyperparameters(logger=False) 43 | 44 | self.n_vocab = n_vocab 45 | self.n_spks = n_spks 46 | self.spk_emb_dim = spk_emb_dim 47 | self.n_feats = n_feats 48 | self.out_size = out_size 49 | self.prior_loss = prior_loss 50 | self.use_precomputed_durations = use_precomputed_durations 51 | 52 | if n_spks > 1: 53 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 54 | 55 | self.encoder = TextEncoder( 56 | encoder.encoder_type, 57 | encoder.encoder_params, 58 | encoder.duration_predictor_params, 59 | n_vocab, 60 | n_spks, 61 | spk_emb_dim, 62 | ) 63 | 64 | self.decoder = CFM( 65 | in_channels=2 * encoder.encoder_params.n_feats, 66 | out_channel=encoder.encoder_params.n_feats, 67 | cfm_params=cfm, 68 | decoder_params=decoder, 69 | n_spks=n_spks, 70 | spk_emb_dim=spk_emb_dim, 71 | ) 72 | 73 | self.update_data_statistics(data_statistics) 74 | 75 | @torch.inference_mode() 76 | def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): 77 | """ 78 | Generates mel-spectrogram from text. Returns: 79 | 1. encoder outputs 80 | 2. decoder outputs 81 | 3. generated alignment 82 | 83 | Args: 84 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 85 | shape: (batch_size, max_text_length) 86 | x_lengths (torch.Tensor): lengths of texts in batch. 87 | shape: (batch_size,) 88 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 89 | temperature (float, optional): controls variance of terminal distribution. 90 | spks (bool, optional): speaker ids. 91 | shape: (batch_size,) 92 | length_scale (float, optional): controls speech pace. 93 | Increase value to slow down generated speech and vice versa. 94 | 95 | Returns: 96 | dict: { 97 | "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 98 | # Average mel spectrogram generated by the encoder 99 | "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 100 | # Refined mel spectrogram improved by the CFM 101 | "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), 102 | # Alignment map between text and mel spectrogram 103 | "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 104 | # Denormalized mel spectrogram 105 | "mel_lengths": torch.Tensor, shape: (batch_size,), 106 | # Lengths of mel spectrograms 107 | "rtf": float, 108 | # Real-time factor 109 | } 110 | """ 111 | # For RTF computation 112 | t = dt.datetime.now() 113 | 114 | if self.n_spks > 1: 115 | # Get speaker embedding 116 | spks = self.spk_emb(spks.long()) 117 | 118 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 119 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) 120 | 121 | w = torch.exp(logw) * x_mask 122 | w_ceil = torch.ceil(w) * length_scale 123 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 124 | y_max_length = y_lengths.max() 125 | y_max_length_ = fix_len_compatibility(y_max_length) 126 | 127 | # Using obtained durations `w` construct alignment map `attn` 128 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 129 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 130 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 131 | 132 | # Align encoded text and get mu_y 133 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 134 | mu_y = mu_y.transpose(1, 2) 135 | encoder_outputs = mu_y[:, :, :y_max_length] 136 | 137 | # Generate sample tracing the probability flow 138 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) 139 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 140 | 141 | t = (dt.datetime.now() - t).total_seconds() 142 | rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) 143 | 144 | return { 145 | "encoder_outputs": encoder_outputs, 146 | "decoder_outputs": decoder_outputs, 147 | "attn": attn[:, :, :y_max_length], 148 | "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), 149 | "mel_lengths": y_lengths, 150 | "rtf": rtf, 151 | } 152 | 153 | def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): 154 | """ 155 | Computes 3 losses: 156 | 1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS). 157 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 158 | 3. flow matching loss: loss between mel-spectrogram and decoder outputs. 159 | 160 | Args: 161 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 162 | shape: (batch_size, max_text_length) 163 | x_lengths (torch.Tensor): lengths of texts in batch. 164 | shape: (batch_size,) 165 | y (torch.Tensor): batch of corresponding mel-spectrograms. 166 | shape: (batch_size, n_feats, max_mel_length) 167 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 168 | shape: (batch_size,) 169 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. 170 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. 171 | spks (torch.Tensor, optional): speaker ids. 172 | shape: (batch_size,) 173 | """ 174 | if self.n_spks > 1: 175 | # Get speaker embedding 176 | spks = self.spk_emb(spks) 177 | 178 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 179 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) 180 | y_max_length = y.shape[-1] 181 | 182 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 183 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 184 | 185 | if self.use_precomputed_durations: 186 | attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) 187 | else: 188 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 189 | with torch.no_grad(): 190 | const = -0.5 * math.log(2 * math.pi) * self.n_feats 191 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) 192 | y_square = torch.matmul(factor.transpose(1, 2), y**2) 193 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) 194 | mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) 195 | log_prior = y_square - y_mu_double + mu_square + const 196 | 197 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) 198 | attn = attn.detach() # b, t_text, T_mel 199 | 200 | # Compute loss between predicted log-scaled durations and those obtained from MAS 201 | # refered to as prior loss in the paper 202 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 203 | dur_loss = duration_loss(logw, logw_, x_lengths) 204 | 205 | # Cut a small segment of mel-spectrogram in order to increase batch size 206 | # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it 207 | # - Do not need this hack for Matcha-TTS, but it works with it as well 208 | if not isinstance(out_size, type(None)): 209 | max_offset = (y_lengths - out_size).clamp(0) 210 | offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) 211 | out_offset = torch.LongTensor( 212 | [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] 213 | ).to(y_lengths) 214 | attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) 215 | y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) 216 | 217 | y_cut_lengths = [] 218 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): 219 | y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) 220 | y_cut_lengths.append(y_cut_length) 221 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length 222 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] 223 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] 224 | 225 | y_cut_lengths = torch.LongTensor(y_cut_lengths) 226 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) 227 | 228 | attn = attn_cut 229 | y = y_cut 230 | y_mask = y_cut_mask 231 | 232 | # Align encoded text with mel-spectrogram and get mu_y segment 233 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 234 | mu_y = mu_y.transpose(1, 2) 235 | 236 | # Compute loss of the decoder 237 | diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) 238 | 239 | if self.prior_loss: 240 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) 241 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 242 | else: 243 | prior_loss = 0 244 | 245 | return dur_loss, prior_loss, diff_loss, attn 246 | -------------------------------------------------------------------------------- /matcha/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/onnx/__init__.py -------------------------------------------------------------------------------- /matcha/onnx/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from lightning import LightningModule 8 | 9 | from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder 10 | 11 | DEFAULT_OPSET = 15 12 | 13 | SEED = 1234 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | torch.cuda.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class MatchaWithVocoder(LightningModule): 23 | def __init__(self, matcha, vocoder): 24 | super().__init__() 25 | self.matcha = matcha 26 | self.vocoder = vocoder 27 | 28 | def forward(self, x, x_lengths, scales, spks=None): 29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) 30 | wavs = self.vocoder(mel).clamp(-1, 1) 31 | lengths = mel_lengths * 256 32 | return wavs.squeeze(1), lengths 33 | 34 | 35 | def get_exportable_module(matcha, vocoder, n_timesteps): 36 | """ 37 | Return an appropriate `LighteningModule` and output-node names 38 | based on whether the vocoder is embedded in the final graph 39 | """ 40 | 41 | def onnx_forward_func(x, x_lengths, scales, spks=None): 42 | """ 43 | Custom forward function for accepting 44 | scaler parameters as tensors 45 | """ 46 | # Extract scaler parameters from tensors 47 | temperature = scales[0] 48 | length_scale = scales[1] 49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) 50 | return output["mel"], output["mel_lengths"] 51 | 52 | # Monkey-patch Matcha's forward function 53 | matcha.forward = onnx_forward_func 54 | 55 | if vocoder is None: 56 | model, output_names = matcha, ["mel", "mel_lengths"] 57 | else: 58 | model = MatchaWithVocoder(matcha, vocoder) 59 | output_names = ["wav", "wav_lengths"] 60 | return model, output_names 61 | 62 | 63 | def get_inputs(is_multi_speaker): 64 | """ 65 | Create dummy inputs for tracing 66 | """ 67 | dummy_input_length = 50 68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 69 | x_lengths = torch.LongTensor([dummy_input_length]) 70 | 71 | # Scales 72 | temperature = 0.667 73 | length_scale = 1.0 74 | scales = torch.Tensor([temperature, length_scale]) 75 | 76 | model_inputs = [x, x_lengths, scales] 77 | input_names = [ 78 | "x", 79 | "x_lengths", 80 | "scales", 81 | ] 82 | 83 | if is_multi_speaker: 84 | spks = torch.LongTensor([1]) 85 | model_inputs.append(spks) 86 | input_names.append("spks") 87 | 88 | return tuple(model_inputs), input_names 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") 93 | 94 | parser.add_argument( 95 | "checkpoint_path", 96 | type=str, 97 | help="Path to the model checkpoint", 98 | ) 99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file") 100 | parser.add_argument( 101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 102 | ) 103 | parser.add_argument( 104 | "--vocoder-name", 105 | type=str, 106 | choices=list(VOCODER_URLS.keys()), 107 | default=None, 108 | help="Name of the vocoder to embed in the ONNX graph", 109 | ) 110 | parser.add_argument( 111 | "--vocoder-checkpoint-path", 112 | type=str, 113 | default=None, 114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 115 | ) 116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 117 | 118 | args = parser.parse_args() 119 | 120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") 121 | print(f"Setting n_timesteps to {args.n_timesteps}") 122 | 123 | checkpoint_path = Path(args.checkpoint_path) 124 | matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") 125 | 126 | if args.vocoder_name or args.vocoder_checkpoint_path: 127 | assert ( 128 | args.vocoder_name and args.vocoder_checkpoint_path 129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 131 | else: 132 | vocoder = None 133 | 134 | is_multi_speaker = matcha.n_spks > 1 135 | 136 | dummy_input, input_names = get_inputs(is_multi_speaker) 137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) 138 | 139 | # Set dynamic shape for inputs/outputs 140 | dynamic_axes = { 141 | "x": {0: "batch_size", 1: "time"}, 142 | "x_lengths": {0: "batch_size"}, 143 | } 144 | 145 | if vocoder is None: 146 | dynamic_axes.update( 147 | { 148 | "mel": {0: "batch_size", 2: "time"}, 149 | "mel_lengths": {0: "batch_size"}, 150 | } 151 | ) 152 | else: 153 | print("Embedding the vocoder in the ONNX graph") 154 | dynamic_axes.update( 155 | { 156 | "wav": {0: "batch_size", 1: "time"}, 157 | "wav_lengths": {0: "batch_size"}, 158 | } 159 | ) 160 | 161 | if is_multi_speaker: 162 | dynamic_axes["spks"] = {0: "batch_size"} 163 | 164 | # Create the output directory (if not exists) 165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 166 | 167 | model.to_onnx( 168 | args.output, 169 | dummy_input, 170 | input_names=input_names, 171 | output_names=output_names, 172 | dynamic_axes=dynamic_axes, 173 | opset_version=args.opset, 174 | export_params=True, 175 | do_constant_folding=True, 176 | ) 177 | print(f"[🍵] ONNX model exported to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /matcha/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | from time import perf_counter 6 | 7 | import numpy as np 8 | import onnxruntime as ort 9 | import soundfile as sf 10 | import torch 11 | 12 | from matcha.cli import plot_spectrogram_to_numpy, process_text 13 | 14 | 15 | def validate_args(args): 16 | assert ( 17 | args.text or args.file 18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." 19 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 21 | return args 22 | 23 | 24 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 25 | if external_vocoder is None: 26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 27 | t0 = perf_counter() 28 | wavs, wav_lengths = model.run(None, inputs) 29 | infer_secs = perf_counter() - t0 30 | mel_infer_secs = vocoder_infer_secs = None 31 | else: 32 | print("[🍵] Generating mel using Matcha") 33 | mel_t0 = perf_counter() 34 | mels, mel_lengths = model.run(None, inputs) 35 | mel_infer_secs = perf_counter() - mel_t0 36 | print("Generating waveform from mel using external vocoder") 37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 38 | vocoder_t0 = perf_counter() 39 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 40 | vocoder_infer_secs = perf_counter() - vocoder_t0 41 | wavs = wavs.squeeze(1) 42 | wav_lengths = mel_lengths * 256 43 | infer_secs = mel_infer_secs + vocoder_infer_secs 44 | 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(parents=True, exist_ok=True) 47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 49 | audio = wav[:wav_length] 50 | print(f"Writing audio to {output_filename}") 51 | sf.write(output_filename, audio, 22050, "PCM_24") 52 | 53 | wav_secs = wav_lengths.sum() / 22050 54 | print(f"Inference seconds: {infer_secs}") 55 | print(f"Generated wav seconds: {wav_secs}") 56 | rtf = infer_secs / wav_secs 57 | if mel_infer_secs is not None: 58 | mel_rtf = mel_infer_secs / wav_secs 59 | print(f"Matcha RTF: {mel_rtf}") 60 | if vocoder_infer_secs is not None: 61 | vocoder_rtf = vocoder_infer_secs / wav_secs 62 | print(f"Vocoder RTF: {vocoder_rtf}") 63 | print(f"Overall RTF: {rtf}") 64 | 65 | 66 | def write_mels(model, inputs, output_dir): 67 | t0 = perf_counter() 68 | mels, mel_lengths = model.run(None, inputs) 69 | infer_secs = perf_counter() - t0 70 | 71 | output_dir = Path(output_dir) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | for i, mel in enumerate(mels): 74 | output_stem = output_dir.joinpath(f"output_{i + 1}") 75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 76 | np.save(output_stem.with_suffix(".numpy"), mel) 77 | 78 | wav_secs = (mel_lengths * 256).sum() / 22050 79 | print(f"Inference seconds: {infer_secs}") 80 | print(f"Generated wav seconds: {wav_secs}") 81 | rtf = infer_secs / wav_secs 82 | print(f"RTF: {rtf}") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser( 87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" 88 | ) 89 | parser.add_argument( 90 | "model", 91 | type=str, 92 | help="ONNX model to use", 93 | ) 94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID") 98 | parser.add_argument( 99 | "--temperature", 100 | type=float, 101 | default=0.667, 102 | help="Variance of the x0 noise (default: 0.667)", 103 | ) 104 | parser.add_argument( 105 | "--speaking-rate", 106 | type=float, 107 | default=1.0, 108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 109 | ) 110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 111 | parser.add_argument( 112 | "--output-dir", 113 | type=str, 114 | default=os.getcwd(), 115 | help="Output folder to save results (default: current dir)", 116 | ) 117 | 118 | args = parser.parse_args() 119 | args = validate_args(args) 120 | 121 | if args.gpu: 122 | providers = ["GPUExecutionProvider"] 123 | else: 124 | providers = ["CPUExecutionProvider"] 125 | model = ort.InferenceSession(args.model, providers=providers) 126 | 127 | model_inputs = model.get_inputs() 128 | model_outputs = list(model.get_outputs()) 129 | 130 | if args.text: 131 | text_lines = args.text.splitlines() 132 | else: 133 | with open(args.file, encoding="utf-8") as file: 134 | text_lines = file.read().splitlines() 135 | 136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 137 | x = [line["x"].squeeze() for line in processed_lines] 138 | # Pad 139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 140 | x = x.detach().cpu().numpy() 141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 142 | inputs = { 143 | "x": x, 144 | "x_lengths": x_lengths, 145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 146 | } 147 | is_multi_speaker = len(model_inputs) == 4 148 | if is_multi_speaker: 149 | if args.spk is None: 150 | args.spk = 0 151 | warn = "[!] Speaker ID not provided! Using speaker ID 0" 152 | warnings.warn(warn, UserWarning) 153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) 154 | 155 | has_vocoder_embedded = model_outputs[0].name == "wav" 156 | if has_vocoder_embedded: 157 | write_wavs(model, inputs, args.output_dir) 158 | elif args.vocoder: 159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 161 | else: 162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 163 | warnings.warn(warn, UserWarning) 164 | write_mels(model, inputs, args.output_dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /matcha/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from matcha.text import cleaners 3 | from matcha.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | class UnknownCleanerException(Exception): 11 | pass 12 | 13 | 14 | def text_to_sequence(text, cleaner_names): 15 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 16 | Args: 17 | text: string to convert to a sequence 18 | cleaner_names: names of the cleaner functions to run the text through 19 | Returns: 20 | List of integers corresponding to the symbols in the text 21 | """ 22 | sequence = [] 23 | 24 | clean_text = _clean_text(text, cleaner_names) 25 | for symbol in clean_text: 26 | symbol_id = _symbol_to_id[symbol] 27 | sequence += [symbol_id] 28 | return sequence, clean_text 29 | 30 | 31 | def cleaned_text_to_sequence(cleaned_text): 32 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 33 | Args: 34 | text: string to convert to a sequence 35 | Returns: 36 | List of integers corresponding to the symbols in the text 37 | """ 38 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 39 | return sequence 40 | 41 | 42 | def sequence_to_text(sequence): 43 | """Converts a sequence of IDs back to a string""" 44 | result = "" 45 | for symbol_id in sequence: 46 | s = _id_to_symbol[symbol_id] 47 | result += s 48 | return result 49 | 50 | 51 | def _clean_text(text, cleaner_names): 52 | for name in cleaner_names: 53 | cleaner = getattr(cleaners, name) 54 | if not cleaner: 55 | raise UnknownCleanerException(f"Unknown cleaner: {name}") 56 | text = cleaner(text) 57 | return text 58 | -------------------------------------------------------------------------------- /matcha/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | from unidecode import unidecode 19 | 20 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 21 | critical_logger = logging.getLogger("phonemizer") 22 | critical_logger.setLevel(logging.CRITICAL) 23 | 24 | # Intializing the phonemizer globally significantly reduces the speed 25 | # now the phonemizer is not initialising at every call 26 | # Might be less flexible, but it is much-much faster 27 | global_phonemizer = phonemizer.backend.EspeakBackend( 28 | language="en-us", 29 | preserve_punctuation=True, 30 | with_stress=True, 31 | language_switch="remove-flags", 32 | logger=critical_logger, 33 | ) 34 | 35 | 36 | # Regular expression matching whitespace: 37 | _whitespace_re = re.compile(r"\s+") 38 | 39 | # Remove brackets 40 | _brackets_re = re.compile(r"[\[\]\(\)\{\}]") 41 | 42 | # List of (regular expression, replacement) pairs for abbreviations: 43 | _abbreviations = [ 44 | (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1]) 45 | for x in [ 46 | ("mrs", "misess"), 47 | ("mr", "mister"), 48 | ("dr", "doctor"), 49 | ("st", "saint"), 50 | ("co", "company"), 51 | ("jr", "junior"), 52 | ("maj", "major"), 53 | ("gen", "general"), 54 | ("drs", "doctors"), 55 | ("rev", "reverend"), 56 | ("lt", "lieutenant"), 57 | ("hon", "honorable"), 58 | ("sgt", "sergeant"), 59 | ("capt", "captain"), 60 | ("esq", "esquire"), 61 | ("ltd", "limited"), 62 | ("col", "colonel"), 63 | ("ft", "fort"), 64 | ] 65 | ] 66 | 67 | 68 | def expand_abbreviations(text): 69 | for regex, replacement in _abbreviations: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def lowercase(text): 75 | return text.lower() 76 | 77 | 78 | def remove_brackets(text): 79 | return re.sub(_brackets_re, "", text) 80 | 81 | 82 | def collapse_whitespace(text): 83 | return re.sub(_whitespace_re, " ", text) 84 | 85 | 86 | def convert_to_ascii(text): 87 | return unidecode(text) 88 | 89 | 90 | def basic_cleaners(text): 91 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 92 | text = lowercase(text) 93 | text = collapse_whitespace(text) 94 | return text 95 | 96 | 97 | def transliteration_cleaners(text): 98 | """Pipeline for non-English text that transliterates to ASCII.""" 99 | text = convert_to_ascii(text) 100 | text = lowercase(text) 101 | text = collapse_whitespace(text) 102 | return text 103 | 104 | 105 | def english_cleaners2(text): 106 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 107 | text = convert_to_ascii(text) 108 | text = lowercase(text) 109 | text = expand_abbreviations(text) 110 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 111 | # Added in some cases espeak is not removing brackets 112 | phonemes = remove_brackets(phonemes) 113 | phonemes = collapse_whitespace(phonemes) 114 | return phonemes 115 | 116 | 117 | def ipa_simplifier(text): 118 | replacements = [ 119 | ("ɐ", "ə"), 120 | ("ˈə", "ə"), 121 | ("ʤ", "dʒ"), 122 | ("ʧ", "tʃ"), 123 | ("ᵻ", "ɪ"), 124 | ] 125 | for replacement in replacements: 126 | text = text.replace(replacement[0], replacement[1]) 127 | phonemes = collapse_whitespace(text) 128 | return phonemes 129 | 130 | 131 | # I am removing this due to incompatibility with several version of python 132 | # However, if you want to use it, you can uncomment it 133 | # and install piper-phonemize with the following command: 134 | # pip install piper-phonemize 135 | 136 | # import piper_phonemize 137 | # def english_cleaners_piper(text): 138 | # """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 139 | # text = convert_to_ascii(text) 140 | # text = lowercase(text) 141 | # text = expand_abbreviations(text) 142 | # phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) 143 | # phonemes = collapse_whitespace(phonemes) 144 | # return phonemes 145 | -------------------------------------------------------------------------------- /matcha/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /matcha/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /matcha/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from matcha import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /matcha/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from matcha.utils.logging_utils import log_hyperparameters 3 | from matcha.utils.pylogger import get_pylogger 4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree 5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /matcha/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /matcha/utils/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/matcha/utils/data/__init__.py -------------------------------------------------------------------------------- /matcha/utils/data/hificaptain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import sys 5 | import tempfile 6 | from pathlib import Path 7 | 8 | import torchaudio 9 | from torch.hub import download_url_to_file 10 | from tqdm import tqdm 11 | 12 | from matcha.utils.data.utils import _extract_zip 13 | 14 | URLS = { 15 | "en-US": { 16 | "female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip", 17 | "male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip", 18 | }, 19 | "ja-JP": { 20 | "female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip", 21 | "male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip", 22 | }, 23 | } 24 | 25 | INFO_PAGE = "https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/" 26 | 27 | # On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN", 28 | # but they use this very-much-not-open-source licence. 29 | # Dunno if this is open washing or stupidity. 30 | LICENCE = "CC BY-NC-SA 4.0" 31 | 32 | # I'd normally put the citation here. It's on their website. 33 | # Boo to non-open-source stuff. 34 | 35 | 36 | def get_args(): 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files") 40 | parser.add_argument( 41 | "-r", 42 | "--skip-resampling", 43 | action="store_true", 44 | default=False, 45 | help="Skip resampling the data (from 48 to 22.05)", 46 | ) 47 | parser.add_argument( 48 | "-l", "--language", type=str, choices=["en-US", "ja-JP"], default="en-US", help="The language to download" 49 | ) 50 | parser.add_argument( 51 | "-g", 52 | "--gender", 53 | type=str, 54 | choices=["male", "female"], 55 | default="female", 56 | help="The gender of the speaker to download", 57 | ) 58 | parser.add_argument( 59 | "-o", 60 | "--output_dir", 61 | type=str, 62 | default="data", 63 | help="Place to store the converted data. Top-level only, the subdirectory will be created", 64 | ) 65 | 66 | return parser.parse_args() 67 | 68 | 69 | def process_text(infile, outpath: Path): 70 | outmode = "w" 71 | if infile.endswith("dev.txt"): 72 | outfile = outpath / "valid.txt" 73 | elif infile.endswith("eval.txt"): 74 | outfile = outpath / "test.txt" 75 | else: 76 | outfile = outpath / "train.txt" 77 | if outfile.exists(): 78 | outmode = "a" 79 | with ( 80 | open(infile, encoding="utf-8") as inf, 81 | open(outfile, outmode, encoding="utf-8") as of, 82 | ): 83 | for line in inf.readlines(): 84 | line = line.strip() 85 | fileid, rest = line.split(" ", maxsplit=1) 86 | outfile = str(outpath / f"{fileid}.wav") 87 | of.write(f"{outfile}|{rest}\n") 88 | 89 | 90 | def process_files(zipfile, outpath, resample=True): 91 | with tempfile.TemporaryDirectory() as tmpdirname: 92 | for filename in tqdm(_extract_zip(zipfile, tmpdirname)): 93 | if not filename.startswith(tmpdirname): 94 | filename = os.path.join(tmpdirname, filename) 95 | if filename.endswith(".txt"): 96 | process_text(filename, outpath) 97 | elif filename.endswith(".wav"): 98 | filepart = filename.rsplit("/", maxsplit=1)[-1] 99 | outfile = str(outpath / filepart) 100 | arr, sr = torchaudio.load(filename) 101 | if resample: 102 | arr = torchaudio.functional.resample(arr, orig_freq=sr, new_freq=22050) 103 | torchaudio.save(outfile, arr, 22050) 104 | else: 105 | continue 106 | 107 | 108 | def main(): 109 | args = get_args() 110 | 111 | save_dir = None 112 | if args.save_dir: 113 | save_dir = Path(args.save_dir) 114 | if not save_dir.is_dir(): 115 | save_dir.mkdir() 116 | 117 | if not args.output_dir: 118 | print("output directory not specified, exiting") 119 | sys.exit(1) 120 | 121 | URL = URLS[args.language][args.gender] 122 | dirname = f"hi-fi_{args.language}_{args.gender}" 123 | 124 | outbasepath = Path(args.output_dir) 125 | if not outbasepath.is_dir(): 126 | outbasepath.mkdir() 127 | outpath = outbasepath / dirname 128 | if not outpath.is_dir(): 129 | outpath.mkdir() 130 | 131 | resample = True 132 | if args.skip_resampling: 133 | resample = False 134 | 135 | if save_dir: 136 | zipname = URL.rsplit("/", maxsplit=1)[-1] 137 | zipfile = save_dir / zipname 138 | if not zipfile.exists(): 139 | download_url_to_file(URL, zipfile, progress=True) 140 | process_files(zipfile, outpath, resample) 141 | else: 142 | with tempfile.NamedTemporaryFile(suffix=".zip", delete=True) as zf: 143 | download_url_to_file(URL, zf.name, progress=True) 144 | process_files(zf.name, outpath, resample) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /matcha/utils/data/ljspeech.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import random 4 | import tempfile 5 | from pathlib import Path 6 | 7 | from torch.hub import download_url_to_file 8 | 9 | from matcha.utils.data.utils import _extract_tar 10 | 11 | URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" 12 | 13 | INFO_PAGE = "https://keithito.com/LJ-Speech-Dataset/" 14 | 15 | LICENCE = "Public domain (LibriVox copyright disclaimer)" 16 | 17 | CITATION = """ 18 | @misc{ljspeech17, 19 | author = {Keith Ito and Linda Johnson}, 20 | title = {The LJ Speech Dataset}, 21 | howpublished = {\\url{https://keithito.com/LJ-Speech-Dataset/}}, 22 | year = 2017 23 | } 24 | """ 25 | 26 | 27 | def decision(): 28 | return random.random() < 0.98 29 | 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files") 35 | parser.add_argument( 36 | "output_dir", 37 | type=str, 38 | nargs="?", 39 | default="data", 40 | help="Place to store the converted data (subdirectory LJSpeech-1.1 will be created)", 41 | ) 42 | 43 | return parser.parse_args() 44 | 45 | 46 | def process_csv(ljpath: Path): 47 | if (ljpath / "metadata.csv").exists(): 48 | basepath = ljpath 49 | elif (ljpath / "LJSpeech-1.1" / "metadata.csv").exists(): 50 | basepath = ljpath / "LJSpeech-1.1" 51 | csvpath = basepath / "metadata.csv" 52 | wavpath = basepath / "wavs" 53 | 54 | with ( 55 | open(csvpath, encoding="utf-8") as csvf, 56 | open(basepath / "train.txt", "w", encoding="utf-8") as tf, 57 | open(basepath / "val.txt", "w", encoding="utf-8") as vf, 58 | ): 59 | for line in csvf.readlines(): 60 | line = line.strip() 61 | parts = line.split("|") 62 | wavfile = str(wavpath / f"{parts[0]}.wav") 63 | if decision(): 64 | tf.write(f"{wavfile}|{parts[1]}\n") 65 | else: 66 | vf.write(f"{wavfile}|{parts[1]}\n") 67 | 68 | 69 | def main(): 70 | args = get_args() 71 | 72 | save_dir = None 73 | if args.save_dir: 74 | save_dir = Path(args.save_dir) 75 | if not save_dir.is_dir(): 76 | save_dir.mkdir() 77 | 78 | outpath = Path(args.output_dir) 79 | if not outpath.is_dir(): 80 | outpath.mkdir() 81 | 82 | if save_dir: 83 | tarname = URL.rsplit("/", maxsplit=1)[-1] 84 | tarfile = save_dir / tarname 85 | if not tarfile.exists(): 86 | download_url_to_file(URL, str(tarfile), progress=True) 87 | _extract_tar(tarfile, outpath) 88 | process_csv(outpath) 89 | else: 90 | with tempfile.NamedTemporaryFile(suffix=".tar.bz2", delete=True) as zf: 91 | download_url_to_file(URL, zf.name, progress=True) 92 | _extract_tar(zf.name, outpath) 93 | process_csv(outpath) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /matcha/utils/data/utils.py: -------------------------------------------------------------------------------- 1 | # taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py 2 | # Copyright (c) 2017 Facebook Inc. (Soumith Chintala) 3 | # Licence: BSD 2-Clause 4 | # pylint: disable=C0123 5 | 6 | import logging 7 | import os 8 | import tarfile 9 | import zipfile 10 | from pathlib import Path 11 | from typing import Any, List, Optional, Union 12 | 13 | _LG = logging.getLogger(__name__) 14 | 15 | 16 | def _extract_tar(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: 17 | if type(from_path) is Path: 18 | from_path = str(Path) 19 | 20 | if to_path is None: 21 | to_path = os.path.dirname(from_path) 22 | 23 | with tarfile.open(from_path, "r") as tar: 24 | files = [] 25 | for file_ in tar: # type: Any 26 | file_path = os.path.join(to_path, file_.name) 27 | if file_.isfile(): 28 | files.append(file_path) 29 | if os.path.exists(file_path): 30 | _LG.info("%s already extracted.", file_path) 31 | if not overwrite: 32 | continue 33 | tar.extract(file_, to_path) 34 | return files 35 | 36 | 37 | def _extract_zip(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: 38 | if type(from_path) is Path: 39 | from_path = str(Path) 40 | 41 | if to_path is None: 42 | to_path = os.path.dirname(from_path) 43 | 44 | with zipfile.ZipFile(from_path, "r") as zfile: 45 | files = zfile.namelist() 46 | for file_ in files: 47 | file_path = os.path.join(to_path, file_) 48 | if os.path.exists(file_path): 49 | _LG.info("%s already extracted.", file_path) 50 | if not overwrite: 51 | continue 52 | zfile.extract(file_, to_path) 53 | return files 54 | -------------------------------------------------------------------------------- /matcha/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import rootutils 14 | import torch 15 | from hydra import compose, initialize 16 | from omegaconf import open_dict 17 | from tqdm.auto import tqdm 18 | 19 | from matcha.data.text_mel_datamodule import TextMelDataModule 20 | from matcha.utils.logging_utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 26 | """Generate data mean and standard deviation helpful in data normalisation 27 | 28 | Args: 29 | data_loader (torch.utils.data.Dataloader): _description_ 30 | out_channels (int): mel spectrogram channels 31 | """ 32 | total_mel_sum = 0 33 | total_mel_sq_sum = 0 34 | total_mel_len = 0 35 | 36 | for batch in tqdm(data_loader, leave=False): 37 | mels = batch["y"] 38 | mel_lengths = batch["y_lengths"] 39 | 40 | total_mel_len += torch.sum(mel_lengths) 41 | total_mel_sum += torch.sum(mels) 42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 43 | 44 | data_mean = total_mel_sum / (total_mel_len * out_channels) 45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 46 | 47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument( 54 | "-i", 55 | "--input-config", 56 | type=str, 57 | default="vctk.yaml", 58 | help="The name of the yaml config file under configs/data", 59 | ) 60 | 61 | parser.add_argument( 62 | "-b", 63 | "--batch-size", 64 | type=int, 65 | default="256", 66 | help="Can have increased batch size for faster computation", 67 | ) 68 | 69 | parser.add_argument( 70 | "-f", 71 | "--force", 72 | action="store_true", 73 | default=False, 74 | required=False, 75 | help="force overwrite the file", 76 | ) 77 | args = parser.parse_args() 78 | output_file = Path(args.input_config).with_suffix(".json") 79 | 80 | if os.path.exists(output_file) and not args.force: 81 | print("File already exists. Use -f to force overwrite") 82 | sys.exit(1) 83 | 84 | with initialize(version_base="1.3", config_path="../../configs/data"): 85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 86 | 87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 88 | 89 | with open_dict(cfg): 90 | del cfg["hydra"] 91 | del cfg["_target_"] 92 | cfg["data_statistics"] = None 93 | cfg["seed"] = 1234 94 | cfg["batch_size"] = args.batch_size 95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 97 | cfg["load_durations"] = False 98 | 99 | text_mel_datamodule = TextMelDataModule(**cfg) 100 | text_mel_datamodule.setup() 101 | data_loader = text_mel_datamodule.train_dataloader() 102 | log.info("Dataloader loaded! Now computing stats...") 103 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 104 | print(params) 105 | with open(output_file, "w", encoding="utf-8") as dumpfile: 106 | json.dump(params, dumpfile) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /matcha/utils/get_durations_from_trained_model.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import lightning 14 | import numpy as np 15 | import rootutils 16 | import torch 17 | from hydra import compose, initialize 18 | from omegaconf import open_dict 19 | from torch import nn 20 | from tqdm.auto import tqdm 21 | 22 | from matcha.cli import get_device 23 | from matcha.data.text_mel_datamodule import TextMelDataModule 24 | from matcha.models.matcha_tts import MatchaTTS 25 | from matcha.utils.logging_utils import pylogger 26 | from matcha.utils.utils import get_phoneme_durations 27 | 28 | log = pylogger.get_pylogger(__name__) 29 | 30 | 31 | def save_durations_to_folder( 32 | attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str 33 | ): 34 | durations = attn.squeeze().sum(1)[:x_length].numpy() 35 | durations_json = get_phoneme_durations(durations, text) 36 | output = output_folder / Path(filepath).name.replace(".wav", ".npy") 37 | with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: 38 | json.dump(durations_json, f, indent=4, ensure_ascii=False) 39 | 40 | np.save(output, durations) 41 | 42 | 43 | @torch.inference_mode() 44 | def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): 45 | """Generate durations from the model for each datapoint and save it in a folder 46 | 47 | Args: 48 | data_loader (torch.utils.data.DataLoader): Dataloader 49 | model (nn.Module): MatchaTTS model 50 | device (torch.device): GPU or CPU 51 | """ 52 | 53 | for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): 54 | x, x_lengths = batch["x"], batch["x_lengths"] 55 | y, y_lengths = batch["y"], batch["y_lengths"] 56 | spks = batch["spks"] 57 | x = x.to(device) 58 | y = y.to(device) 59 | x_lengths = x_lengths.to(device) 60 | y_lengths = y_lengths.to(device) 61 | spks = spks.to(device) if spks is not None else None 62 | 63 | _, _, _, attn = model( 64 | x=x, 65 | x_lengths=x_lengths, 66 | y=y, 67 | y_lengths=y_lengths, 68 | spks=spks, 69 | ) 70 | attn = attn.cpu() 71 | for i in range(attn.shape[0]): 72 | save_durations_to_folder( 73 | attn[i], 74 | x_lengths[i].item(), 75 | y_lengths[i].item(), 76 | batch["filepaths"][i], 77 | output_folder, 78 | batch["x_texts"][i], 79 | ) 80 | 81 | 82 | def main(): 83 | parser = argparse.ArgumentParser() 84 | 85 | parser.add_argument( 86 | "-i", 87 | "--input-config", 88 | type=str, 89 | default="ljspeech.yaml", 90 | help="The name of the yaml config file under configs/data", 91 | ) 92 | 93 | parser.add_argument( 94 | "-b", 95 | "--batch-size", 96 | type=int, 97 | default="32", 98 | help="Can have increased batch size for faster computation", 99 | ) 100 | 101 | parser.add_argument( 102 | "-f", 103 | "--force", 104 | action="store_true", 105 | default=False, 106 | required=False, 107 | help="force overwrite the file", 108 | ) 109 | parser.add_argument( 110 | "-c", 111 | "--checkpoint_path", 112 | type=str, 113 | required=True, 114 | help="Path to the checkpoint file to load the model from", 115 | ) 116 | 117 | parser.add_argument( 118 | "-o", 119 | "--output-folder", 120 | type=str, 121 | default=None, 122 | help="Output folder to save the data statistics", 123 | ) 124 | 125 | parser.add_argument( 126 | "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" 127 | ) 128 | 129 | args = parser.parse_args() 130 | 131 | with initialize(version_base="1.3", config_path="../../configs/data"): 132 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 133 | 134 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 135 | 136 | with open_dict(cfg): 137 | del cfg["hydra"] 138 | del cfg["_target_"] 139 | cfg["seed"] = 1234 140 | cfg["batch_size"] = args.batch_size 141 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 142 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 143 | cfg["load_durations"] = False 144 | 145 | if args.output_folder is not None: 146 | output_folder = Path(args.output_folder) 147 | else: 148 | output_folder = Path(cfg["train_filelist_path"]).parent / "durations" 149 | 150 | print(f"Output folder set to: {output_folder}") 151 | 152 | if os.path.exists(output_folder) and not args.force: 153 | print("Folder already exists. Use -f to force overwrite") 154 | sys.exit(1) 155 | 156 | output_folder.mkdir(parents=True, exist_ok=True) 157 | 158 | print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") 159 | print("Loading model...") 160 | device = get_device(args) 161 | model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) 162 | 163 | text_mel_datamodule = TextMelDataModule(**cfg) 164 | text_mel_datamodule.setup() 165 | try: 166 | print("Computing stats for training set if exists...") 167 | train_dataloader = text_mel_datamodule.train_dataloader() 168 | compute_durations(train_dataloader, model, device, output_folder) 169 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 170 | print("No training set found") 171 | 172 | try: 173 | print("Computing stats for validation set if exists...") 174 | val_dataloader = text_mel_datamodule.val_dataloader() 175 | compute_durations(val_dataloader, model, device, output_folder) 176 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 177 | print("No validation set found") 178 | 179 | try: 180 | print("Computing stats for test set if exists...") 181 | test_dataloader = text_mel_datamodule.test_dataloader() 182 | compute_durations(test_dataloader, model, device, output_folder) 183 | except lightning.fabric.utilities.exceptions.MisconfigurationException: 184 | print("No test set found") 185 | 186 | print(f"[+] Done! Data statistics saved to: {output_folder}") 187 | 188 | 189 | if __name__ == "__main__": 190 | # Helps with generating durations for the dataset to train other architectures 191 | # that cannot learn to align due to limited size of dataset 192 | # Example usage: 193 | # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model 194 | # This will create a folder in data/processed_data/durations/ljspeech with the durations 195 | main() 196 | -------------------------------------------------------------------------------- /matcha/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from matcha.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /matcha/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from matcha.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /matcha/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths): 45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 46 | return loss 47 | 48 | 49 | def normalize(data, mu, std): 50 | if not isinstance(mu, (float, int)): 51 | if isinstance(mu, list): 52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 53 | elif isinstance(mu, torch.Tensor): 54 | mu = mu.to(data.device) 55 | elif isinstance(mu, np.ndarray): 56 | mu = torch.from_numpy(mu).to(data.device) 57 | mu = mu.unsqueeze(-1) 58 | 59 | if not isinstance(std, (float, int)): 60 | if isinstance(std, list): 61 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 62 | elif isinstance(std, torch.Tensor): 63 | std = std.to(data.device) 64 | elif isinstance(std, np.ndarray): 65 | std = torch.from_numpy(std).to(data.device) 66 | std = std.unsqueeze(-1) 67 | 68 | return (data - mu) / std 69 | 70 | 71 | def denormalize(data, mu, std): 72 | if not isinstance(mu, float): 73 | if isinstance(mu, list): 74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 75 | elif isinstance(mu, torch.Tensor): 76 | mu = mu.to(data.device) 77 | elif isinstance(mu, np.ndarray): 78 | mu = torch.from_numpy(mu).to(data.device) 79 | mu = mu.unsqueeze(-1) 80 | 81 | if not isinstance(std, float): 82 | if isinstance(std, list): 83 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 84 | elif isinstance(std, torch.Tensor): 85 | std = std.to(data.device) 86 | elif isinstance(std, np.ndarray): 87 | std = torch.from_numpy(std).to(data.device) 88 | std = std.unsqueeze(-1) 89 | 90 | return data * std + mu 91 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matcha.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /matcha/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /matcha/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from matcha.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /matcha/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from importlib.util import find_spec 5 | from math import ceil 6 | from pathlib import Path 7 | from typing import Any, Callable, Dict, Tuple 8 | 9 | import gdown 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | import wget 14 | from omegaconf import DictConfig 15 | 16 | from matcha.utils import pylogger, rich_utils 17 | 18 | log = pylogger.get_pylogger(__name__) 19 | 20 | 21 | def extras(cfg: DictConfig) -> None: 22 | """Applies optional utilities before the task is started. 23 | 24 | Utilities: 25 | - Ignoring python warnings 26 | - Setting tags from command line 27 | - Rich config printing 28 | 29 | :param cfg: A DictConfig object containing the config tree. 30 | """ 31 | # return if no `extras` config 32 | if not cfg.get("extras"): 33 | log.warning("Extras config not found! ") 34 | return 35 | 36 | # disable python warnings 37 | if cfg.extras.get("ignore_warnings"): 38 | log.info("Disabling python warnings! ") 39 | warnings.filterwarnings("ignore") 40 | 41 | # prompt user to input tags from command line if none are provided in the config 42 | if cfg.extras.get("enforce_tags"): 43 | log.info("Enforcing tags! ") 44 | rich_utils.enforce_tags(cfg, save_to_file=True) 45 | 46 | # pretty print config tree using Rich library 47 | if cfg.extras.get("print_config"): 48 | log.info("Printing config tree with Rich! ") 49 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 50 | 51 | 52 | def task_wrapper(task_func: Callable) -> Callable: 53 | """Optional decorator that controls the failure behavior when executing the task function. 54 | 55 | This wrapper can be used to: 56 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 57 | - save the exception to a `.log` file 58 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 59 | - etc. (adjust depending on your needs) 60 | 61 | Example: 62 | ``` 63 | @utils.task_wrapper 64 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 65 | ... 66 | return metric_dict, object_dict 67 | ``` 68 | 69 | :param task_func: The task function to be wrapped. 70 | 71 | :return: The wrapped task function. 72 | """ 73 | 74 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 75 | # execute the task 76 | try: 77 | metric_dict, object_dict = task_func(cfg=cfg) 78 | 79 | # things to do if exception occurs 80 | except Exception as ex: 81 | # save exception to `.log` file 82 | log.exception("") 83 | 84 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 85 | # so when using hparam search plugins like Optuna, you might want to disable 86 | # raising the below exception to avoid multirun failure 87 | raise ex 88 | 89 | # things to always do after either success or exception 90 | finally: 91 | # display output dir path in terminal 92 | log.info(f"Output dir: {cfg.paths.output_dir}") 93 | 94 | # always close wandb run (even if exception occurs so multirun won't fail) 95 | if find_spec("wandb"): # check if wandb is installed 96 | import wandb 97 | 98 | if wandb.run: 99 | log.info("Closing wandb!") 100 | wandb.finish() 101 | 102 | return metric_dict, object_dict 103 | 104 | return wrap 105 | 106 | 107 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: 108 | """Safely retrieves value of the metric logged in LightningModule. 109 | 110 | :param metric_dict: A dict containing metric values. 111 | :param metric_name: The name of the metric to retrieve. 112 | :return: The value of the metric. 113 | """ 114 | if not metric_name: 115 | log.info("Metric name is None! Skipping metric value retrieval...") 116 | return None 117 | 118 | if metric_name not in metric_dict: 119 | raise ValueError( 120 | f"Metric value not found! \n" 121 | "Make sure metric name logged in LightningModule is correct!\n" 122 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 123 | ) 124 | 125 | metric_value = metric_dict[metric_name].item() 126 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 127 | 128 | return metric_value 129 | 130 | 131 | def intersperse(lst, item): 132 | # Adds blank symbol 133 | result = [item] * (len(lst) * 2 + 1) 134 | result[1::2] = lst 135 | return result 136 | 137 | 138 | def save_figure_to_numpy(fig): 139 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 140 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 141 | return data 142 | 143 | 144 | def plot_tensor(tensor): 145 | plt.style.use("default") 146 | fig, ax = plt.subplots(figsize=(12, 3)) 147 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 148 | plt.colorbar(im, ax=ax) 149 | plt.tight_layout() 150 | fig.canvas.draw() 151 | data = save_figure_to_numpy(fig) 152 | plt.close() 153 | return data 154 | 155 | 156 | def save_plot(tensor, savepath): 157 | plt.style.use("default") 158 | fig, ax = plt.subplots(figsize=(12, 3)) 159 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 160 | plt.colorbar(im, ax=ax) 161 | plt.tight_layout() 162 | fig.canvas.draw() 163 | plt.savefig(savepath) 164 | plt.close() 165 | 166 | 167 | def to_numpy(tensor): 168 | if isinstance(tensor, np.ndarray): 169 | return tensor 170 | elif isinstance(tensor, torch.Tensor): 171 | return tensor.detach().cpu().numpy() 172 | elif isinstance(tensor, list): 173 | return np.array(tensor) 174 | else: 175 | raise TypeError("Unsupported type for conversion to numpy array") 176 | 177 | 178 | def get_user_data_dir(appname="matcha_tts"): 179 | """ 180 | Args: 181 | appname (str): Name of application 182 | 183 | Returns: 184 | Path: path to user data directory 185 | """ 186 | 187 | MATCHA_HOME = os.environ.get("MATCHA_HOME") 188 | if MATCHA_HOME is not None: 189 | ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) 190 | elif sys.platform == "win32": 191 | import winreg # pylint: disable=import-outside-toplevel 192 | 193 | key = winreg.OpenKey( 194 | winreg.HKEY_CURRENT_USER, 195 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", 196 | ) 197 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 198 | ans = Path(dir_).resolve(strict=False) 199 | elif sys.platform == "darwin": 200 | ans = Path("~/Library/Application Support/").expanduser() 201 | else: 202 | ans = Path.home().joinpath(".local/share") 203 | 204 | final_path = ans.joinpath(appname) 205 | final_path.mkdir(parents=True, exist_ok=True) 206 | return final_path 207 | 208 | 209 | def assert_model_downloaded(checkpoint_path, url, use_wget=True): 210 | if Path(checkpoint_path).exists(): 211 | log.debug(f"[+] Model already present at {checkpoint_path}!") 212 | print(f"[+] Model already present at {checkpoint_path}!") 213 | return 214 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it") 215 | print(f"[-] Model not found at {checkpoint_path}! Will download it") 216 | checkpoint_path = str(checkpoint_path) 217 | if not use_wget: 218 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) 219 | else: 220 | wget.download(url=url, out=checkpoint_path) 221 | 222 | 223 | def get_phoneme_durations(durations, phones): 224 | prev = durations[0] 225 | merged_durations = [] 226 | # Convolve with stride 2 227 | for i in range(1, len(durations), 2): 228 | if i == len(durations) - 2: 229 | # if it is last take full value 230 | next_half = durations[i + 1] 231 | else: 232 | next_half = ceil(durations[i + 1] / 2) 233 | 234 | curr = prev + durations[i] + next_half 235 | prev = durations[i + 1] - next_half 236 | merged_durations.append(curr) 237 | 238 | assert len(phones) == len(merged_durations) 239 | assert len(merged_durations) == (len(durations) - 1) // 2 240 | 241 | merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) 242 | start = torch.tensor(0) 243 | duration_json = [] 244 | for i, duration in enumerate(merged_durations): 245 | duration_json.append( 246 | { 247 | phones[i]: { 248 | "starttime": start.item(), 249 | "endtime": duration.item(), 250 | "duration": duration.item() - start.item(), 251 | } 252 | } 253 | ) 254 | start = duration 255 | 256 | assert list(duration_json[-1].values())[0]["endtime"] == sum( 257 | durations 258 | ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" 259 | return duration_json 260 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shivammehta25/Matcha-TTS/108906c603fad5055f2649b3fd71d2bbdf222eac/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ['py310'] 7 | exclude = ''' 8 | 9 | ( 10 | /( 11 | \.eggs # exclude a few common directories in the 12 | | \.git # root of the project 13 | | \.hg 14 | | \.mypy_cache 15 | | \.tox 16 | | \.venv 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | | foo.py # also separately exclude a file named foo.py in 23 | # the root of the project 24 | ) 25 | ''' 26 | 27 | [tool.pytest.ini_options] 28 | addopts = [ 29 | "--color=yes", 30 | "--durations=0", 31 | "--strict-markers", 32 | "--doctest-modules", 33 | ] 34 | filterwarnings = [ 35 | "ignore::DeprecationWarning", 36 | "ignore::UserWarning", 37 | ] 38 | log_cli = "True" 39 | markers = [ 40 | "slow: slow tests", 41 | ] 42 | minversion = "6.0" 43 | testpaths = "tests/" 44 | 45 | [tool.coverage.report] 46 | exclude_lines = [ 47 | "pragma: nocover", 48 | "raise NotImplementedError", 49 | "raise NotImplementedError()", 50 | "if __name__ == .__main__.:", 51 | ] 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | phonemizer # phonemization of text 26 | tensorboard 27 | librosa 28 | Cython 29 | numpy 30 | einops 31 | inflect 32 | Unidecode 33 | scipy 34 | torchaudio 35 | matplotlib 36 | pandas 37 | conformer==0.3.2 38 | diffusers # developed using version ==0.25.0 39 | notebook 40 | ipywidgets 41 | gradio==3.43.2 42 | gdown 43 | wget 44 | seaborn 45 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | from Cython.Build import cythonize 6 | from setuptools import Extension, find_packages, setup 7 | 8 | exts = [ 9 | Extension( 10 | name="matcha.utils.monotonic_align.core", 11 | sources=["matcha/utils/monotonic_align/core.pyx"], 12 | ) 13 | ] 14 | 15 | with open("README.md", encoding="utf-8") as readme_file: 16 | README = readme_file.read() 17 | 18 | cwd = os.path.dirname(os.path.abspath(__file__)) 19 | with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin: 20 | version = fin.read().strip() 21 | 22 | 23 | def get_requires(): 24 | requirements = os.path.join(os.path.dirname(__file__), "requirements.txt") 25 | with open(requirements, encoding="utf-8") as reqfile: 26 | return [str(r).strip() for r in reqfile] 27 | 28 | 29 | setup( 30 | name="matcha-tts", 31 | version=version, 32 | description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", 33 | long_description=README, 34 | long_description_content_type="text/markdown", 35 | author="Shivam Mehta", 36 | author_email="shivam.mehta25@gmail.com", 37 | url="https://shivammehta25.github.io/Matcha-TTS", 38 | install_requires=get_requires(), 39 | include_dirs=[numpy.get_include()], 40 | include_package_data=True, 41 | packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), 42 | # use this to customize global commands available in the terminal after installing the package 43 | entry_points={ 44 | "console_scripts": [ 45 | "matcha-data-stats=matcha.utils.generate_data_statistics:main", 46 | "matcha-tts=matcha.cli:cli", 47 | "matcha-tts-app=matcha.app:main", 48 | "matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main", 49 | ] 50 | }, 51 | ext_modules=cythonize(exts, language_level=3), 52 | python_requires=">=3.9.0", 53 | ) 54 | --------------------------------------------------------------------------------