├── .flake8 ├── .github ├── FUNDING.yml ├── guide.jpeg └── sponsor.jpg ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README.zh-CN.md ├── docker └── Dockerfile ├── docs ├── images │ ├── Overview.jpg │ ├── train.png │ └── vallf.png └── requirements.txt ├── egs ├── aishell1 │ ├── README.md │ ├── bin │ ├── demos │ │ └── 0_demo.wav │ ├── prepare.sh │ ├── prompts │ │ ├── ch_24k.txt │ │ ├── ch_24k.wav │ │ └── ch_24k_loudness_normalized20.wav │ └── shared ├── libritts │ ├── README.md │ ├── bin │ ├── prepare.sh │ ├── prompts │ │ ├── 8455_210777_000067_000000.txt │ │ ├── 8455_210777_000067_000000.wav │ │ ├── 8463_294825_000043_000000.txt │ │ └── 8463_294825_000043_000000.wav │ └── shared │ │ └── parse_options.sh └── ljspeech │ ├── README.md │ ├── bin │ ├── demos │ ├── 0.wav │ └── 1.wav │ ├── prepare.sh │ ├── prompts │ ├── LJ049-0108_24K.txt │ ├── LJ049-0108_24K.wav │ ├── LJ049-0110_24K.txt │ ├── LJ049-0110_24K.wav │ ├── LJ049-0124.txt │ ├── LJ049-0124_24K.wav │ ├── LJ049-0185.txt │ └── LJ049-0185_24K.wav │ └── shared ├── examples ├── setup.py ├── test.sh └── valle ├── __init__.py ├── bin ├── __init__.py ├── display_manifest_statistics.py ├── infer.py ├── tokenizer.py └── trainer.py ├── data ├── __init__.py ├── collation.py ├── datamodule.py ├── dataset.py ├── fbank.py ├── input_strategies.py └── tokenizer.py ├── models ├── __init__.py ├── macros.py ├── transformer.py ├── valle.py └── visualizer.py ├── modules ├── __init__.py ├── activation.py ├── embedding.py ├── optim.py ├── scaling.py ├── scheduler.py └── transformer.py ├── tests ├── data │ └── tokenizer_test.py ├── scaling_test.py └── valle_test.py └── utils ├── __init__.py └── symbol_table.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | show-source=true 3 | statistics=true 4 | max-line-length = 80 5 | per-file-ignores = 6 | # line too long 7 | valle/bin/trainer.py: E501, E402, 8 | valle/bin/infer.py: E501, E402, 9 | valle/bin/tokenizer.py: E501, E402, 10 | valle/models/*.py: E501, E203, 11 | valle/models/__init__.py: F401, E501 12 | valle/utils/__init__.py: F401, 13 | valle/__init__.py: F401, 14 | valle/modules/__init__.py: F401, 15 | valle/modules/activation.py: E501, 16 | valle/modules/embedding.py: E203, 17 | valle/modules/scheduler.py: F401, 18 | valle/modules/transformer.py: E501, 19 | valle/modules/scaling.py: E501, F841, F401 20 | valle/data/fbank.py: E501, 21 | valle/data/input_strategies.py: E501, 22 | valle/data/datamodule.py: E501, 23 | valle/tests/*.py: F841, F401 24 | valle/tests/data/*.py: F841, F401 25 | 26 | exclude = 27 | .git, 28 | egs/**/data/**, 29 | setup.py, 30 | valle/utils/symbol_table.py, 31 | valle/modules/optim.py, 32 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: lifeiteng 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: ['https://github.com/lifeiteng/SoundStorm/blob/master/.github/sponsor.jpg'] 14 | -------------------------------------------------------------------------------- /.github/guide.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/.github/guide.jpeg -------------------------------------------------------------------------------- /.github/sponsor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/.github/sponsor.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | valle.egg-info/ 2 | __pycache__ 3 | path.sh 4 | exp 5 | exp*/ 6 | *.pt 7 | log 8 | .DS_Store 9 | .vscode 10 | egs/*/data 11 | egs/libritts/prompts/*.png 12 | valle/version.py 13 | infer/ 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 21.6b0 4 | hooks: 5 | - id: black 6 | args: [--line-length=80] 7 | additional_dependencies: ['click==8.0.1'] 8 | exclude: valle\/__init__\.py 9 | exclude: valle\/utils\/symbol_table\.py 10 | exclude: valle\/modules\/scaling\.py 11 | exclude: valle\/tests\/data\/tokenizer\_test\.py 12 | 13 | - repo: https://github.com/PyCQA/flake8 14 | rev: 3.9.2 15 | hooks: 16 | - id: flake8 17 | args: [--max-line-length=80] 18 | exclude: valle\/modules\/scaling\.py 19 | exclude: valle\/tests\/data\/tokenizer\_test\.py 20 | 21 | - repo: https://github.com/pycqa/isort 22 | rev: 5.12.0 23 | hooks: 24 | - id: isort 25 | args: [--profile=black, --line-length=80] 26 | exclude: valle\/utils\/symbol_table\.py 27 | exclude: valle\/modules\/scaling\.py 28 | 29 | - repo: https://github.com/pre-commit/pre-commit-hooks 30 | rev: v4.0.1 31 | hooks: 32 | - id: check-executables-have-shebangs 33 | - id: end-of-file-fixer 34 | - id: mixed-line-ending 35 | - id: trailing-whitespace 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Language : 🇺🇸 | [🇨🇳](./README.zh-CN.md) 2 | 3 | An unofficial PyTorch implementation of VALL-E([Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers](https://arxiv.org/abs/2301.02111)). 4 | 5 | We can train the VALL-E model on one GPU. 6 | 7 | ![model](./docs/images/Overview.jpg) 8 | 9 | ## Demo 10 | 11 | * [official demo](https://valle-demo.github.io/) 12 | * [reproduced demo](https://lifeiteng.github.io/valle/index.html) 13 | 14 | Buy Me A Coffee 15 | 16 | 17 | 18 | 19 | ## Broader impacts 20 | 21 | > Since VALL-E could synthesize speech that maintains speaker identity, it may carry potential risks in misuse of the model, such as spoofing voice identification or impersonating a specific speaker. 22 | 23 | To avoid abuse, Well-trained models and services will not be provided. 24 | 25 | ## Install Deps 26 | 27 | To get up and running quickly just follow the steps below: 28 | 29 | ``` 30 | # PyTorch 31 | pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 32 | pip install torchmetrics==0.11.1 33 | # fbank 34 | pip install librosa==0.8.1 35 | 36 | # phonemizer pypinyin 37 | apt-get install espeak-ng 38 | ## OSX: brew install espeak 39 | pip install phonemizer==3.2.1 pypinyin==0.48.0 40 | 41 | # lhotse update to newest version 42 | # https://github.com/lhotse-speech/lhotse/pull/956 43 | # https://github.com/lhotse-speech/lhotse/pull/960 44 | pip uninstall lhotse 45 | pip uninstall lhotse 46 | pip install git+https://github.com/lhotse-speech/lhotse 47 | 48 | # k2 49 | # find the right version in https://huggingface.co/csukuangfj/k2 50 | pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl 51 | 52 | # icefall 53 | git clone https://github.com/k2-fsa/icefall 54 | cd icefall 55 | pip install -r requirements.txt 56 | export PYTHONPATH=`pwd`/../icefall:$PYTHONPATH 57 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.zshrc 58 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.bashrc 59 | cd - 60 | source ~/.zshrc 61 | 62 | # valle 63 | git clone https://github.com/lifeiteng/valle.git 64 | cd valle 65 | pip install -e . 66 | ``` 67 | 68 | 69 | ## Training&Inference 70 | * #### English example [examples/libritts/README.md](egs/libritts/README.md) 71 | * #### Chinese example [examples/aishell1/README.md](egs/aishell1/README.md) 72 | * ### Prefix Mode 0 1 2 4 for NAR Decoder 73 | **Paper Chapter 5.1** "The average length of the waveform in LibriLight is 60 seconds. During 74 | training, we randomly crop the waveform to a random length between 10 seconds and 20 seconds. For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds from the same utterance." 75 | * **0**: no acoustic prompt tokens 76 | * **1**: random prefix of current batched utterances **(This is recommended)** 77 | * **2**: random segment of current batched utterances 78 | * **4**: same as the paper (As they randomly crop the long waveform to multiple utterances, so the same utterance means pre or post utterance in the same long waveform.) 79 | ``` 80 | # If train NAR Decoders with prefix_mode 4 81 | python3 bin/trainer.py --prefix_mode 4 --dataset libritts --input-strategy PromptedPrecomputedFeatures ... 82 | ``` 83 | 84 | #### [LibriTTS demo](https://lifeiteng.github.io/valle/index.html) Trained on one GPU with 24G memory 85 | 86 | ``` 87 | cd examples/libritts 88 | 89 | # step1 prepare dataset 90 | bash prepare.sh --stage -1 --stop-stage 3 91 | 92 | # step2 train the model on one GPU with 24GB memory 93 | exp_dir=exp/valle 94 | 95 | ## Train AR model 96 | python3 bin/trainer.py --max-duration 80 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ 97 | --num-buckets 6 --dtype "bfloat16" --save-every-n 10000 --valid-interval 20000 \ 98 | --model-name valle --share-embedding true --norm-first true --add-prenet false \ 99 | --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ 100 | --base-lr 0.05 --warmup-steps 200 --average-period 0 \ 101 | --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 4 \ 102 | --exp-dir ${exp_dir} 103 | 104 | ## Train NAR model 105 | cp ${exp_dir}/best-valid-loss.pt ${exp_dir}/epoch-2.pt # --start-epoch 3=2+1 106 | python3 bin/trainer.py --max-duration 40 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ 107 | --num-buckets 6 --dtype "float32" --save-every-n 10000 --valid-interval 20000 \ 108 | --model-name valle --share-embedding true --norm-first true --add-prenet false \ 109 | --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ 110 | --base-lr 0.05 --warmup-steps 200 --average-period 0 \ 111 | --num-epochs 40 --start-epoch 3 --start-batch 0 --accumulate-grad-steps 4 \ 112 | --exp-dir ${exp_dir} 113 | 114 | # step3 inference 115 | python3 bin/infer.py --output-dir infer/demos \ 116 | --checkpoint=${exp_dir}/best-valid-loss.pt \ 117 | --text-prompts "KNOT one point one five miles per hour." \ 118 | --audio-prompts ./prompts/8463_294825_000043_000000.wav \ 119 | --text "To get up and running quickly just follow the steps below." \ 120 | 121 | # Demo Inference 122 | https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/run.sh#L68 123 | ``` 124 | ![train](./docs/images/train.png) 125 | 126 | #### Troubleshooting 127 | 128 | * **SummaryWriter segmentation fault (core dumped)** 129 | * LINE `tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")` 130 | * FIX [https://github.com/tensorflow/tensorboard/pull/6135/files](https://github.com/tensorflow/tensorboard/pull/6135/files) 131 | ``` 132 | file=`python -c 'import site; print(f"{site.getsitepackages()[0]}/tensorboard/summary/writer/event_file_writer.py")'` 133 | sed -i 's/import tf/import tensorflow_stub as tf/g' $file 134 | ``` 135 | 136 | #### Training on a custom dataset? 137 | * prepare the dataset to `lhotse manifests` 138 | * There are plenty of references here [lhotse/recipes](https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes) 139 | * `python3 bin/tokenizer.py ...` 140 | * `python3 bin/trainer.py ...` 141 | 142 | ## Contributing 143 | 144 | * Parallelize bin/tokenizer.py on multi-GPUs 145 | * Buy Me A Coffee 146 | 147 | ## Citing 148 | 149 | To cite this repository: 150 | 151 | ```bibtex 152 | @misc{valle, 153 | author={Feiteng Li}, 154 | title={VALL-E: A neural codec language model}, 155 | year={2023}, 156 | url={http://github.com/lifeiteng/vall-e} 157 | } 158 | ``` 159 | 160 | ```bibtex 161 | @article{VALL-E, 162 | title = {Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers}, 163 | author = {Chengyi Wang, Sanyuan Chen, Yu Wu, 164 | Ziqiang Zhang, Long Zhou, Shujie Liu, 165 | Zhuo Chen, Yanqing Liu, Huaming Wang, 166 | Jinyu Li, Lei He, Sheng Zhao, Furu Wei}, 167 | year = {2023}, 168 | eprint = {2301.02111}, 169 | archivePrefix = {arXiv}, 170 | volume = {abs/2301.02111}, 171 | url = {http://arxiv.org/abs/2301.02111}, 172 | } 173 | ``` 174 | 175 | ## Star History 176 | 177 | [![Star History Chart](https://api.star-history.com/svg?repos=lifeiteng/vall-e&type=Date)](https://star-history.com/#lifeiteng/vall-e&Date) 178 | -------------------------------------------------------------------------------- /README.zh-CN.md: -------------------------------------------------------------------------------- 1 | 非官方 VALL-E([Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers](https://arxiv.org/abs/2301.02111))开源 PyTorch 实现。 2 | 3 | Buy Me A Coffee 4 | 5 | 未同步更新,移步英文版[🇺🇸](./README.md) 6 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel 2 | 3 | RUN apt-get update && \ 4 | apt-get upgrade -y 5 | RUN apt-get install -y vim wget git libsndfile1 espeak-ng 6 | 7 | RUN pip install torchmetrics==0.11.1 librosa==0.8.1 phonemizer==3.2.1 pypinyin==0.48.0 lhotse matplotlib h5py 8 | 9 | RUN pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl 10 | 11 | RUN pip install torchaudio==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 12 | 13 | WORKDIR /workspace 14 | RUN git clone https://github.com/k2-fsa/icefall && \ 15 | cd icefall && \ 16 | pip install -r requirements.txt 17 | ENV PYTHONPATH=/workspace/icefall:$PYTHONPATH 18 | 19 | RUN cp /opt/conda/lib/libpython3.10.so.1.0 /usr/lib/x86_64-linux-gnu/ 20 | 21 | WORKDIR /workspace 22 | RUN git clone https://github.com/lifeiteng/vall-e.git && \ 23 | cd vall-e && \ 24 | pip install -e . 25 | 26 | WORKDIR /workspace/vall-e 27 | -------------------------------------------------------------------------------- /docs/images/Overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/Overview.jpg -------------------------------------------------------------------------------- /docs/images/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/train.png -------------------------------------------------------------------------------- /docs/images/vallf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/vallf.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/requirements.txt -------------------------------------------------------------------------------- /egs/aishell1/README.md: -------------------------------------------------------------------------------- 1 | # aishell1 2 | `150` 小时的数据不足以训练出稳定的结果,尝试加到 `500`/`1000` 小时的数据去训练 3 | * 当前的实现,韵母和声调放在一起组成建模单元,也增加了对数据量的需求。 4 | 5 | ## Prepare Dataset 6 | ``` 7 | cd egs/aishell1 8 | 9 | # Those stages are very time-consuming 10 | bash prepare.sh --stage -1 --stop-stage 3 11 | 12 | ## train 13 | Cut statistics: 14 | ╒═══════════════════════════╤═══════════╕ 15 | │ Cuts count: │ 120098 │ 16 | ├───────────────────────────┼───────────┤ 17 | │ Total duration (hh:mm:ss) │ 150:51:08 │ 18 | ├───────────────────────────┼───────────┤ 19 | │ mean │ 4.5 │ 20 | ├───────────────────────────┼───────────┤ 21 | │ std │ 1.4 │ 22 | ├───────────────────────────┼───────────┤ 23 | │ min │ 1.2 │ 24 | ├───────────────────────────┼───────────┤ 25 | │ 25% │ 3.5 │ 26 | ├───────────────────────────┼───────────┤ 27 | │ 50% │ 4.3 │ 28 | ├───────────────────────────┼───────────┤ 29 | │ 75% │ 5.3 │ 30 | ├───────────────────────────┼───────────┤ 31 | │ 99% │ 8.5 │ 32 | ├───────────────────────────┼───────────┤ 33 | │ 99.5% │ 9.1 │ 34 | ├───────────────────────────┼───────────┤ 35 | │ 99.9% │ 10.5 │ 36 | ├───────────────────────────┼───────────┤ 37 | │ max │ 14.5 │ 38 | ├───────────────────────────┼───────────┤ 39 | │ Recordings available: │ 120098 │ 40 | ├───────────────────────────┼───────────┤ 41 | │ Features available: │ 120098 │ 42 | ├───────────────────────────┼───────────┤ 43 | │ Supervisions available: │ 120098 │ 44 | ╘═══════════════════════════╧═══════════╛ 45 | SUPERVISION custom fields: 46 | Speech duration statistics: 47 | ╒══════════════════════════════╤═══════════╤══════════════════════╕ 48 | │ Total speech duration │ 150:51:08 │ 100.00% of recording │ 49 | ├──────────────────────────────┼───────────┼──────────────────────┤ 50 | │ Total speaking time duration │ 150:51:08 │ 100.00% of recording │ 51 | ├──────────────────────────────┼───────────┼──────────────────────┤ 52 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │ 53 | ╘══════════════════════════════╧═══════════╧══════════════════════╛ 54 | 55 | 56 | ## dev 57 | Cut statistics: 58 | ╒═══════════════════════════╤══════════╕ 59 | │ Cuts count: │ 400 │ 60 | ├───────────────────────────┼──────────┤ 61 | │ Total duration (hh:mm:ss) │ 00:28:37 │ 62 | ├───────────────────────────┼──────────┤ 63 | │ mean │ 4.3 │ 64 | ├───────────────────────────┼──────────┤ 65 | │ std │ 1.1 │ 66 | ├───────────────────────────┼──────────┤ 67 | │ min │ 2.3 │ 68 | ├───────────────────────────┼──────────┤ 69 | │ 25% │ 3.5 │ 70 | ├───────────────────────────┼──────────┤ 71 | │ 50% │ 4.0 │ 72 | ├───────────────────────────┼──────────┤ 73 | │ 75% │ 5.0 │ 74 | ├───────────────────────────┼──────────┤ 75 | │ 99% │ 7.4 │ 76 | ├───────────────────────────┼──────────┤ 77 | │ 99.5% │ 7.5 │ 78 | ├───────────────────────────┼──────────┤ 79 | │ 99.9% │ 8.0 │ 80 | ├───────────────────────────┼──────────┤ 81 | │ max │ 8.0 │ 82 | ├───────────────────────────┼──────────┤ 83 | │ Recordings available: │ 400 │ 84 | ├───────────────────────────┼──────────┤ 85 | │ Features available: │ 400 │ 86 | ├───────────────────────────┼──────────┤ 87 | │ Supervisions available: │ 400 │ 88 | ╘═══════════════════════════╧══════════╛ 89 | SUPERVISION custom fields: 90 | Speech duration statistics: 91 | ╒══════════════════════════════╤══════════╤══════════════════════╕ 92 | │ Total speech duration │ 00:28:37 │ 100.00% of recording │ 93 | ├──────────────────────────────┼──────────┼──────────────────────┤ 94 | │ Total speaking time duration │ 00:28:37 │ 100.00% of recording │ 95 | ├──────────────────────────────┼──────────┼──────────────────────┤ 96 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │ 97 | ╘══════════════════════════════╧══════════╧══════════════════════╛ 98 | 99 | 100 | ## test 101 | Cut statistics: 102 | ╒═══════════════════════════╤══════════╕ 103 | │ Cuts count: │ 7176 │ 104 | ├───────────────────────────┼──────────┤ 105 | │ Total duration (hh:mm:ss) │ 10:01:49 │ 106 | ├───────────────────────────┼──────────┤ 107 | │ mean │ 5.0 │ 108 | ├───────────────────────────┼──────────┤ 109 | │ std │ 1.6 │ 110 | ├───────────────────────────┼──────────┤ 111 | │ min │ 1.9 │ 112 | ├───────────────────────────┼──────────┤ 113 | │ 25% │ 3.8 │ 114 | ├───────────────────────────┼──────────┤ 115 | │ 50% │ 4.7 │ 116 | ├───────────────────────────┼──────────┤ 117 | │ 75% │ 5.9 │ 118 | ├───────────────────────────┼──────────┤ 119 | │ 99% │ 9.9 │ 120 | ├───────────────────────────┼──────────┤ 121 | │ 99.5% │ 10.7 │ 122 | ├───────────────────────────┼──────────┤ 123 | │ 99.9% │ 11.9 │ 124 | ├───────────────────────────┼──────────┤ 125 | │ max │ 14.7 │ 126 | ├───────────────────────────┼──────────┤ 127 | │ Recordings available: │ 7176 │ 128 | ├───────────────────────────┼──────────┤ 129 | │ Features available: │ 7176 │ 130 | ├───────────────────────────┼──────────┤ 131 | │ Supervisions available: │ 7176 │ 132 | ╘═══════════════════════════╧══════════╛ 133 | SUPERVISION custom fields: 134 | Speech duration statistics: 135 | ╒══════════════════════════════╤══════════╤══════════════════════╕ 136 | │ Total speech duration │ 10:01:49 │ 100.00% of recording │ 137 | ├──────────────────────────────┼──────────┼──────────────────────┤ 138 | │ Total speaking time duration │ 10:01:49 │ 100.00% of recording │ 139 | ├──────────────────────────────┼──────────┼──────────────────────┤ 140 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │ 141 | ╘══════════════════════════════╧══════════╧══════════════════════╛ 142 | ``` 143 | 144 | ## Training & Inference 145 | refer to [Training](../../README.md##Training&Inference) 146 | -------------------------------------------------------------------------------- /egs/aishell1/bin: -------------------------------------------------------------------------------- 1 | ../../valle/bin -------------------------------------------------------------------------------- /egs/aishell1/demos/0_demo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/demos/0_demo.wav -------------------------------------------------------------------------------- /egs/aishell1/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eou pipefail 4 | 5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 7 | 8 | nj=16 9 | stage=-1 10 | stop_stage=4 11 | 12 | # We assume dl_dir (download dir) contains the following 13 | # directories and files. If not, they will be downloaded 14 | # by this script automatically. 15 | # 16 | # - $dl_dir/aishell 17 | # You can download aishell from https://www.openslr.org/33/ 18 | # 19 | 20 | dl_dir=$PWD/download 21 | 22 | dataset_parts="-p train -p dev -p test" # debug 23 | 24 | text_extractor="pypinyin_initials_finals" 25 | audio_extractor="Encodec" # or Fbank 26 | audio_feats_dir=data/tokenized 27 | 28 | . shared/parse_options.sh || exit 1 29 | 30 | 31 | # All files generated by this script are saved in "data". 32 | # You can safely remove "data" and rerun this script to regenerate it. 33 | mkdir -p data 34 | 35 | log() { 36 | # This function is from espnet 37 | local fname=${BASH_SOURCE[1]##*/} 38 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 39 | } 40 | 41 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 42 | log "dl_dir: $dl_dir" 43 | log "Stage 0: Download data" 44 | 45 | # If you have pre-downloaded it to /path/to/aishell, 46 | # you can create a symlink 47 | # 48 | # ln -sfv /path/to/aishell $dl_dir/aishell 49 | # 50 | if [ ! -d $dl_dir/aishell/dev ]; then 51 | lhotse download aishell $dl_dir 52 | fi 53 | fi 54 | 55 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 56 | log "Stage 1: Prepare aishell manifest" 57 | # We assume that you have downloaded the aishell corpus 58 | # to $dl_dir/aishell 59 | mkdir -p data/manifests 60 | if [ ! -e data/manifests/.aishell.done ]; then 61 | lhotse prepare aishell $dl_dir/aishell data/manifests 62 | touch data/manifests/.aishell.done 63 | fi 64 | fi 65 | 66 | 67 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 68 | log "Stage 2: Tokenize/Fbank aishell" 69 | mkdir -p ${audio_feats_dir} 70 | if [ ! -e ${audio_feats_dir}/.aishell.tokenize.done ]; then 71 | python3 bin/tokenizer.py --dataset-parts "${dataset_parts}" \ 72 | --text-extractor ${text_extractor} \ 73 | --audio-extractor ${audio_extractor} \ 74 | --batch-duration 400 \ 75 | --prefix "aishell" \ 76 | --src-dir "data/manifests" \ 77 | --output-dir "${audio_feats_dir}" 78 | fi 79 | touch ${audio_feats_dir}/.aishell.tokenize.done 80 | fi 81 | 82 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 83 | log "Stage 3: Prepare aishell train/dev/test" 84 | if [ ! -e ${audio_feats_dir}/.aishell.train.done ]; then 85 | # dev 14326 86 | lhotse subset --first 400 \ 87 | ${audio_feats_dir}/aishell_cuts_dev.jsonl.gz \ 88 | ${audio_feats_dir}/cuts_dev.jsonl.gz 89 | 90 | lhotse subset --last 13926 \ 91 | ${audio_feats_dir}/aishell_cuts_dev.jsonl.gz \ 92 | ${audio_feats_dir}/cuts_dev_others.jsonl.gz 93 | 94 | # train 95 | lhotse combine \ 96 | ${audio_feats_dir}/cuts_dev_others.jsonl.gz \ 97 | ${audio_feats_dir}/aishell_cuts_train.jsonl.gz \ 98 | ${audio_feats_dir}/cuts_train.jsonl.gz 99 | 100 | # test 101 | lhotse copy \ 102 | ${audio_feats_dir}/aishell_cuts_test.jsonl.gz \ 103 | ${audio_feats_dir}/cuts_test.jsonl.gz 104 | 105 | touch ${audio_feats_dir}/.aishell.train.done 106 | fi 107 | fi 108 | 109 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} 110 | -------------------------------------------------------------------------------- /egs/aishell1/prompts/ch_24k.txt: -------------------------------------------------------------------------------- 1 | 甚至 出现 交易 几乎 停滞 的 情况 2 | -------------------------------------------------------------------------------- /egs/aishell1/prompts/ch_24k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/prompts/ch_24k.wav -------------------------------------------------------------------------------- /egs/aishell1/prompts/ch_24k_loudness_normalized20.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/prompts/ch_24k_loudness_normalized20.wav -------------------------------------------------------------------------------- /egs/aishell1/shared: -------------------------------------------------------------------------------- 1 | ../libritts/shared/ -------------------------------------------------------------------------------- /egs/libritts/README.md: -------------------------------------------------------------------------------- 1 | # LibriTTS 2 | 3 | ## Install deps 4 | ``` 5 | pip install librosa==0.8.1 6 | ``` 7 | 8 | ## Prepare Dataset 9 | ``` 10 | cd egs/libritts 11 | 12 | # Those stages are very time-consuming 13 | bash prepare.sh --stage -1 --stop-stage 3 14 | ``` 15 | #### data 16 | 17 | ``` 18 | ## train 19 | Cut statistics: 20 | ╒═══════════════════════════╤═══════════╕ 21 | │ Cuts count: │ 354780 │ 22 | ├───────────────────────────┼───────────┤ 23 | │ Total duration (hh:mm:ss) │ 555:09:48 │ 24 | ├───────────────────────────┼───────────┤ 25 | │ mean │ 5.6 │ 26 | ├───────────────────────────┼───────────┤ 27 | │ std │ 4.5 │ 28 | ├───────────────────────────┼───────────┤ 29 | │ min │ 0.1 │ 30 | ├───────────────────────────┼───────────┤ 31 | │ 25% │ 2.3 │ 32 | ├───────────────────────────┼───────────┤ 33 | │ 50% │ 4.3 │ 34 | ├───────────────────────────┼───────────┤ 35 | │ 75% │ 7.6 │ 36 | ├───────────────────────────┼───────────┤ 37 | │ 80% │ 8.7 │ 38 | ├───────────────────────────┼───────────┤ 39 | │ 85% │ 10.0 │ 40 | ├───────────────────────────┼───────────┤ 41 | │ 90% │ 11.8 │ 42 | ├───────────────────────────┼───────────┤ 43 | │ 95% │ 14.8 │ 44 | ├───────────────────────────┼───────────┤ 45 | │ 99% │ 20.9 │ 46 | ├───────────────────────────┼───────────┤ 47 | │ 99.5% │ 23.1 │ 48 | ├───────────────────────────┼───────────┤ 49 | │ 99.9% │ 27.4 │ 50 | ├───────────────────────────┼───────────┤ 51 | │ max │ 43.9 │ 52 | ├───────────────────────────┼───────────┤ 53 | │ Recordings available: │ 354780 │ 54 | ├───────────────────────────┼───────────┤ 55 | │ Features available: │ 354780 │ 56 | ├───────────────────────────┼───────────┤ 57 | │ Supervisions available: │ 354780 │ 58 | ╘═══════════════════════════╧═══════════╛ 59 | SUPERVISION custom fields: 60 | Speech duration statistics: 61 | ╒══════════════════════════════╤═══════════╤══════════════════════╕ 62 | │ Total speech duration │ 555:09:48 │ 100.00% of recording │ 63 | ├──────────────────────────────┼───────────┼──────────────────────┤ 64 | │ Total speaking time duration │ 555:09:48 │ 100.00% of recording │ 65 | ├──────────────────────────────┼───────────┼──────────────────────┤ 66 | │ Total silence duration │ 00:00:01 │ 0.00% of recording │ 67 | ╘══════════════════════════════╧═══════════╧══════════════════════╛ 68 | ``` 69 | 70 | ## Training & Inference 71 | refer to [Training](../../README.md##Training&Inference) 72 | -------------------------------------------------------------------------------- /egs/libritts/bin: -------------------------------------------------------------------------------- 1 | ../../valle/bin -------------------------------------------------------------------------------- /egs/libritts/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eou pipefail 4 | 5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 7 | 8 | nj=16 9 | stage=-1 10 | stop_stage=3 11 | 12 | # We assume dl_dir (download dir) contains the following 13 | # directories and files. If not, they will be downloaded 14 | # by this script automatically. 15 | # 16 | # - $dl_dir/LibriTTS 17 | # You can download LibriTTS from https://www.openslr.org/60/ 18 | # After downloading tar.gz files, you should extract them into dl_dir/LibriTTS. 19 | # Ignoring *.tar.gz files, which you can download into anywhere, the structure of $dl_dir should look like below 20 | # 21 | # dl_dir 22 | # ├── dev-clean.tar.gz 23 | # ├── dev-other.tar.gz 24 | # ├── LibriTTS 25 | # │ ├── BOOKS.txt 26 | # │ ├── CHAPTERS.txt 27 | # │ ├── dev-clean 28 | # │ ├── dev-other 29 | # │ ├── eval_sentences10.tsv 30 | # │ ├── LICENSE.txt 31 | # │ ├── NOTE.txt 32 | # │ ├── reader_book.tsv 33 | # │ ├── README_librispeech.txt 34 | # │ ├── README_libritts.txt 35 | # │ ├── speakers.tsv 36 | # │ ├── SPEAKERS.txt 37 | # │ ├── test-clean 38 | # │ ├── test-other 39 | # │ ├── train-clean-100 40 | # │ ├── train-clean-360 41 | # │ └── train-other-500 42 | # ├── test-clean.tar.gz 43 | # ├── test-other.tar.gz 44 | # ├── train-clean-100.tar.gz 45 | # ├── train-clean-360.tar.gz 46 | # └── train-other-500.tar.gz 47 | 48 | echo "We will download the LibriTTS dataset by default. If the downloading fails or you want to download the dataset yourself, see the comments in this script for steps." 49 | 50 | dl_dir=$PWD/download 51 | 52 | # dataset_parts="-p dev-clean -p test-clean" # debug 53 | dataset_parts="--dataset-parts all" # all 54 | 55 | audio_extractor="Encodec" # or Fbank 56 | audio_feats_dir=data/tokenized 57 | 58 | . shared/parse_options.sh || exit 1 59 | 60 | 61 | # All files generated by this script are saved in "data". 62 | # You can safely remove "data" and rerun this script to regenerate it. 63 | mkdir -p data 64 | 65 | log() { 66 | # This function is from espnet 67 | local fname=${BASH_SOURCE[1]##*/} 68 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 69 | } 70 | 71 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 72 | log "dl_dir: $dl_dir" 73 | log "Stage 0: Download data" 74 | 75 | # If you have pre-downloaded it to /path/to/LibriTTS, 76 | # you can create a symlink 77 | # 78 | # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS 79 | # 80 | if [ ! -d $dl_dir/LibriTTS/dev-other ]; then 81 | # lhotse download libritts $dl_dir 82 | lhotse download libritts ${dataset_parts} $dl_dir 83 | fi 84 | fi 85 | 86 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 87 | log "Stage 1: Prepare LibriTTS manifest" 88 | # We assume that you have downloaded the LibriTTS corpus 89 | # to $dl_dir/LibriTTS 90 | mkdir -p data/manifests 91 | if [ ! -e data/manifests/.libritts.done ]; then 92 | lhotse prepare libritts ${dataset_parts} -j $nj $dl_dir/LibriTTS data/manifests 93 | touch data/manifests/.libritts.done 94 | fi 95 | fi 96 | 97 | 98 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 99 | log "Stage 2: Tokenize/Fbank LibriTTS" 100 | mkdir -p ${audio_feats_dir} 101 | if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then 102 | python3 bin/tokenizer.py --dataset-parts "${dataset_parts}" \ 103 | --audio-extractor ${audio_extractor} \ 104 | --batch-duration 400 \ 105 | --src-dir "data/manifests" \ 106 | --output-dir "${audio_feats_dir}" 107 | fi 108 | touch ${audio_feats_dir}/.libritts.tokenize.done 109 | fi 110 | 111 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 112 | log "Stage 3: Prepare LibriTTS train/dev/test" 113 | if [ ! -e ${audio_feats_dir}/.libritts.train.done ]; then 114 | if [ "${dataset_parts}" == "--dataset-parts all" ];then 115 | # train 116 | lhotse combine \ 117 | ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \ 118 | ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \ 119 | ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \ 120 | ${audio_feats_dir}/cuts_train.jsonl.gz 121 | 122 | # dev 123 | lhotse copy \ 124 | ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ 125 | ${audio_feats_dir}/cuts_dev.jsonl.gz 126 | else # debug 127 | # train 128 | lhotse copy \ 129 | ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ 130 | ${audio_feats_dir}/cuts_train.jsonl.gz 131 | # dev 132 | lhotse subset --first 400 \ 133 | ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ 134 | ${audio_feats_dir}/cuts_dev.jsonl.gz 135 | fi 136 | 137 | # test 138 | lhotse copy \ 139 | ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ 140 | ${audio_feats_dir}/cuts_test.jsonl.gz 141 | 142 | touch ${audio_feats_dir}/.libritts.train.done 143 | fi 144 | fi 145 | 146 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} 147 | -------------------------------------------------------------------------------- /egs/libritts/prompts/8455_210777_000067_000000.txt: -------------------------------------------------------------------------------- 1 | This I read with great attention, while they sat silent. 2 | -------------------------------------------------------------------------------- /egs/libritts/prompts/8455_210777_000067_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/libritts/prompts/8455_210777_000067_000000.wav -------------------------------------------------------------------------------- /egs/libritts/prompts/8463_294825_000043_000000.txt: -------------------------------------------------------------------------------- 1 | KNOT one point one five miles per hour 2 | -------------------------------------------------------------------------------- /egs/libritts/prompts/8463_294825_000043_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/libritts/prompts/8463_294825_000043_000000.wav -------------------------------------------------------------------------------- /egs/libritts/shared/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /egs/ljspeech/README.md: -------------------------------------------------------------------------------- 1 | # LJSpeech 2 | 3 | ## Install deps 4 | ``` 5 | pip install librosa==0.8.1 6 | 7 | # lhotse update LJSpeech 8 | # https://github.com/lhotse-speech/lhotse/pull/988 9 | ``` 10 | 11 | ## Prepare Dataset 12 | ``` 13 | cd egs/ljspeech 14 | 15 | bash prepare.sh --stage -1 --stop-stage 3 \ 16 | --audio_extractor "Encodec" \ 17 | --audio_feats_dir data/tokenized 18 | ``` 19 | 20 | ## Training & Inference 21 | **LJSpeech is used to debug, Please try LibriTTS** 22 | 23 | refer to [LibriTTS Training](../../README.md##Training&Inference) 24 | -------------------------------------------------------------------------------- /egs/ljspeech/bin: -------------------------------------------------------------------------------- 1 | ../../valle/bin -------------------------------------------------------------------------------- /egs/ljspeech/demos/0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/demos/0.wav -------------------------------------------------------------------------------- /egs/ljspeech/demos/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/demos/1.wav -------------------------------------------------------------------------------- /egs/ljspeech/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eou pipefail 4 | 5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 7 | 8 | nj=16 9 | stage=-1 10 | stop_stage=3 11 | 12 | # We assume dl_dir (download dir) contains the following 13 | # directories and files. If not, they will be downloaded 14 | # by this script automatically. 15 | # 16 | # - $dl_dir/LJSpeech-1.1 17 | 18 | dl_dir=$PWD/download 19 | 20 | audio_extractor="Encodec" # or Fbank 21 | audio_feats_dir=data/tokenized 22 | 23 | 24 | . shared/parse_options.sh || exit 1 25 | 26 | 27 | # All files generated by this script are saved in "data". 28 | # You can safely remove "data" and rerun this script to regenerate it. 29 | mkdir -p data 30 | 31 | log() { 32 | # This function is from espnet 33 | local fname=${BASH_SOURCE[1]##*/} 34 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 35 | } 36 | 37 | log "dl_dir: $dl_dir" 38 | 39 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 40 | log "Stage 0: Download data" 41 | 42 | # If you have pre-downloaded it to /path/to/LJSpeech, 43 | # you can create a symlink 44 | # 45 | # ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech 46 | # 47 | if [ ! -d $dl_dir/LJSpeech-1.1 ];then 48 | lhotse download ljspeech $dl_dir 49 | fi 50 | fi 51 | 52 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 53 | log "Stage 1: Prepare LJSpeech manifest" 54 | # We assume that you have downloaded the LJSpeech corpus 55 | # to $dl_dir/LJSpeech 56 | mkdir -p data/manifests 57 | if [ ! -e data/manifests/.ljspeech.done ]; then 58 | lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests 59 | touch data/manifests/.ljspeech.done 60 | fi 61 | fi 62 | 63 | 64 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 65 | log "Stage 2: Split LJSpeech" 66 | 67 | # 13100 = dev/test/train = 200/200/12500 68 | if [ ! -e data/manifests/ljspeech_recordings_test.jsonl.gz ]; then 69 | for manifest in "recordings" "supervisions";do 70 | lhotse subset --last 600 data/manifests/ljspeech_${manifest}_all.jsonl.gz \ 71 | data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz || exit 1 72 | lhotse subset --last 400 data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz \ 73 | data/manifests/ljspeech_${manifest}_test.jsonl.gz || exit 1 74 | lhotse subset --first 200 data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz \ 75 | data/manifests/ljspeech_${manifest}_dev.jsonl.gz || exit 1 76 | 77 | lhotse subset --first 12500 data/manifests/ljspeech_${manifest}_all.jsonl.gz \ 78 | data/manifests/ljspeech_${manifest}_train.jsonl.gz || exit 1 79 | 80 | rm -f data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz 81 | done 82 | fi 83 | fi 84 | 85 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 86 | log "Stage 3: ${audio_extractor} LJSpeech" 87 | 88 | mkdir -p ${audio_feats_dir} 89 | if [ ! -e ${audio_feats_dir}/.ljspeech.done ]; then 90 | python3 bin/tokenizer.py --dataset-parts "train test dev" --prefix "ljspeech" \ 91 | --audio-extractor ${audio_extractor} \ 92 | --batch-duration 400 \ 93 | --src-dir "data/manifests" \ 94 | --output-dir "${audio_feats_dir}" 95 | fi 96 | touch ${audio_feats_dir}/.ljspeech.done 97 | 98 | cd ${audio_feats_dir} 99 | ln -sf ljspeech_cuts_train.jsonl.gz cuts_train.jsonl.gz 100 | ln -sf ljspeech_cuts_dev.jsonl.gz cuts_dev.jsonl.gz 101 | ln -sf ljspeech_cuts_test.jsonl.gz cuts_test.jsonl.gz 102 | cd - 103 | fi 104 | 105 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} 106 | -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0108_24K.txt: -------------------------------------------------------------------------------- 1 | In addition, the restriction would probably eliminate a need for the requirement which has been urged as necessary for the exercise of Federal power, 2 | -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0108_24K.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0108_24K.wav -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0110_24K.txt: -------------------------------------------------------------------------------- 1 | The governmental consequences of assassination of one of the specified officials give the United States ample power to act for its own protection. 2 | -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0110_24K.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0110_24K.wav -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0124.txt: -------------------------------------------------------------------------------- 1 | In addition, the proposed legislation will insure. 2 | -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0124_24K.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0124_24K.wav -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0185.txt: -------------------------------------------------------------------------------- 1 | During the period the Commission was giving thought to this situation, 2 | -------------------------------------------------------------------------------- /egs/ljspeech/prompts/LJ049-0185_24K.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0185_24K.wav -------------------------------------------------------------------------------- /egs/ljspeech/shared: -------------------------------------------------------------------------------- 1 | ../libritts/shared -------------------------------------------------------------------------------- /examples: -------------------------------------------------------------------------------- 1 | egs/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import sys 4 | from pathlib import Path 5 | from subprocess import DEVNULL, PIPE, run 6 | 7 | from setuptools import find_packages, setup 8 | 9 | project_root = Path(__file__).parent 10 | 11 | # modified from https://github.com/lhotse-speech/lhotse/blob/master/setup.py 12 | 13 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # 14 | # NOTE: REMEMBER TO UPDATE THE FALLBACK VERSION IN valle/__init__.py WHEN RELEASING # 15 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # 16 | MAJOR_VERSION = 1 17 | MINOR_VERSION = 0 18 | PATCH_VERSION = 0 19 | IS_DEV_VERSION = True # False = public release, True = otherwise 20 | 21 | 22 | if sys.version_info < (3,): 23 | # fmt: off 24 | print( 25 | "Python 2 has reached end-of-life and is no longer supported by valle." 26 | ) 27 | # fmt: on 28 | sys.exit(-1) 29 | 30 | if sys.version_info < (3, 7): 31 | print( 32 | "Python 3.6 has reached end-of-life on December 31st, 2021 " 33 | "and is no longer supported by valle." 34 | ) 35 | sys.exit(-1) 36 | 37 | 38 | def discover_valle_version() -> str: 39 | """ 40 | Scans Valle source code to determine the current version. 41 | When development version is detected, it queries git for the commit hash 42 | to append it as a local version identifier. 43 | 44 | Ideally this function would have been imported from valle.version and 45 | re-used when valle is imported to set the version, but it introduces 46 | a circular dependency. To avoid this, we write the determined version 47 | into project_root / 'valle' / 'version.py' during setup and read it 48 | from there later. If it's not detected, the version will be 0.0.0.dev. 49 | """ 50 | 51 | version = f"{MAJOR_VERSION}.{MINOR_VERSION}.{PATCH_VERSION}" 52 | if not IS_DEV_VERSION: 53 | # This is a PyPI public release -- return a clean version string. 54 | return version 55 | 56 | version = version + ".dev" 57 | 58 | # This is not a PyPI release -- try to read the git commit 59 | try: 60 | git_commit = ( 61 | run( 62 | ["git", "rev-parse", "--short", "HEAD"], 63 | check=True, 64 | stdout=PIPE, 65 | stderr=DEVNULL, 66 | ) 67 | .stdout.decode() 68 | .rstrip("\n") 69 | .strip() 70 | ) 71 | dirty_commit = ( 72 | len( 73 | run( 74 | ["git", "diff", "--shortstat"], 75 | check=True, 76 | stdout=PIPE, 77 | stderr=DEVNULL, 78 | ) 79 | .stdout.decode() 80 | .rstrip("\n") 81 | .strip() 82 | ) 83 | > 0 84 | ) 85 | git_commit = ( 86 | git_commit + ".dirty" if dirty_commit else git_commit + ".clean" 87 | ) 88 | source_version = f"+git.{git_commit}" 89 | except Exception: 90 | source_version = ".unknownsource" 91 | # See the format: 92 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#local-version-identifiers 93 | version = version + source_version 94 | 95 | return version 96 | 97 | 98 | def mark_valle_version(version: str) -> None: 99 | (project_root / "valle" / "version.py").write_text( 100 | f'__version__ = "{version}"' 101 | ) 102 | 103 | 104 | VALLE_VERSION = discover_valle_version() 105 | mark_valle_version(VALLE_VERSION) 106 | 107 | 108 | install_requires = [ 109 | "encodec", 110 | "phonemizer", 111 | ] 112 | 113 | try: 114 | # If the user already installed PyTorch, make sure he has torchaudio too. 115 | # Otherwise, we'll just install the latest versions from PyPI for the user. 116 | import torch 117 | 118 | try: 119 | import torchaudio 120 | except ImportError: 121 | raise ValueError( 122 | "We detected that you have already installed PyTorch, but haven't installed torchaudio. " 123 | "Unfortunately we can't detect the compatible torchaudio version for you; " 124 | "you will have to install it manually. " 125 | "For instructions, please refer either to https://pytorch.org/get-started/locally/ " 126 | "or https://github.com/pytorch/audio#dependencies" 127 | ) 128 | except ImportError: 129 | install_requires.extend(["torch", "torchaudio"]) 130 | 131 | docs_require = ( 132 | (project_root / "docs" / "requirements.txt").read_text().splitlines() 133 | ) 134 | tests_require = [ 135 | # "pytest==7.1.3", 136 | # "pytest-forked==1.4.0", 137 | # "pytest-xdist==2.5.0", 138 | # "pytest-cov==4.0.0", 139 | ] 140 | workflow_requires = [""] 141 | dev_requires = sorted( 142 | docs_require 143 | + tests_require 144 | + workflow_requires 145 | + ["jupyterlab", "matplotlib"] 146 | ) 147 | all_requires = sorted(dev_requires) 148 | 149 | if os.environ.get("READTHEDOCS", False): 150 | # When building documentation, omit torchaudio installation and mock it instead. 151 | # This works around the inability to install libsoundfile1 in read-the-docs env, 152 | # which caused the documentation builds to silently crash. 153 | install_requires = [ 154 | req 155 | for req in install_requires 156 | if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"]) 157 | ] 158 | 159 | setup( 160 | name="valle", 161 | version=VALLE_VERSION, 162 | python_requires=">=3.7.0", 163 | description="Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers", 164 | author="The Valle Development Team", 165 | author_email="lifeiteng0422@163.com", 166 | long_description=(project_root / "README.md").read_text(encoding="utf-8"), 167 | long_description_content_type="text/markdown", 168 | license="Apache-2.0 License", 169 | packages=find_packages(exclude=["test", "test.*"]), 170 | include_package_data=True, 171 | entry_points={}, 172 | install_requires=install_requires, 173 | extras_require={ 174 | "docs": docs_require, 175 | "tests": tests_require, 176 | "dev": dev_requires, 177 | "all": all_requires, 178 | }, 179 | classifiers=[ 180 | "Development Status :: 1 - Beta", 181 | "Programming Language :: Python :: 3.7", 182 | "Programming Language :: Python :: 3.8", 183 | "Programming Language :: Python :: 3.9", 184 | "Programming Language :: Python :: 3.10", 185 | "Intended Audience :: Science/Research", 186 | "Operating System :: POSIX :: Linux", 187 | "Operating System :: MacOS :: MacOS X", 188 | "License :: OSI Approved :: Apache Software License", 189 | "Topic :: Multimedia :: Sound/Audio :: Speech", 190 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 191 | "Topic :: Software Development :: Libraries :: Python Modules", 192 | "Typing :: Typed", 193 | ], 194 | ) 195 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 4 | 5 | python3 valle/tests/valle_test.py 6 | python3 valle/tests/scaling_test.py 7 | python3 valle/tests/data/tokenizer_test.py 8 | -------------------------------------------------------------------------------- /valle/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, models, modules, utils 2 | -------------------------------------------------------------------------------- /valle/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/bin/__init__.py -------------------------------------------------------------------------------- /valle/bin/display_manifest_statistics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) 3 | # Copyright 2023 (authors: Feiteng Li) 4 | # 5 | # See ../../../../LICENSE for clarification regarding multiple authors 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """ 20 | This file displays duration statistics of utterances in the manifests. 21 | You can use the displayed value to choose minimum/maximum duration 22 | to remove short and long utterances during the training. 23 | """ 24 | 25 | import argparse 26 | from pathlib import Path 27 | 28 | from lhotse import load_manifest_lazy 29 | 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument( 34 | "--manifest-dir", 35 | type=Path, 36 | default=Path("data/tokenized"), 37 | help="Path to the tokenized manifests.", 38 | ) 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | args = get_args() 44 | manifest_dir = args.manifest_dir or Path("data/tokenized") 45 | for part in ["train", "dev", "test"]: 46 | print(f"## {part}") 47 | cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") 48 | cuts.describe() 49 | print("\n") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /valle/bin/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Phonemize Text and EnCodec Audio. 17 | 18 | Usage example: 19 | python3 bin/infer.py \ 20 | --decoder-dim 128 --nhead 4 --num-decoder-layers 4 --model-name valle \ 21 | --text-prompts "Go to her." \ 22 | --audio-prompts ./prompts/61_70970_000007_000001.wav \ 23 | --output-dir infer/demo_valle_epoch20 \ 24 | --checkpoint exp/valle_nano_v2/epoch-20.pt 25 | 26 | """ 27 | import argparse 28 | import logging 29 | import os 30 | from pathlib import Path 31 | 32 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 33 | 34 | import torch 35 | import torchaudio 36 | from icefall.utils import AttributeDict, str2bool 37 | 38 | from valle.data import ( 39 | AudioTokenizer, 40 | TextTokenizer, 41 | tokenize_audio, 42 | tokenize_text, 43 | ) 44 | from valle.data.collation import get_text_token_collater 45 | from valle.models import get_model 46 | 47 | 48 | def get_args(): 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument( 52 | "--text-prompts", 53 | type=str, 54 | default="", 55 | help="Text prompts which are separated by |.", 56 | ) 57 | 58 | parser.add_argument( 59 | "--audio-prompts", 60 | type=str, 61 | default="", 62 | help="Audio prompts which are separated by | and should be aligned with --text-prompts.", 63 | ) 64 | 65 | parser.add_argument( 66 | "--text", 67 | type=str, 68 | default="To get up and running quickly just follow the steps below.", 69 | help="Text to be synthesized.", 70 | ) 71 | 72 | # model 73 | # add_model_arguments(parser) 74 | # parser.add_argument( 75 | # "--text-tokens", 76 | # type=str, 77 | # default="data/tokenized/unique_text_tokens.k2symbols", 78 | # help="Path to the unique text tokens file.", 79 | # ) 80 | 81 | parser.add_argument( 82 | "--text-extractor", 83 | type=str, 84 | default="espeak", 85 | help="espeak or pypinyin or pypinyin_initials_finals", 86 | ) 87 | 88 | parser.add_argument( 89 | "--checkpoint", 90 | type=str, 91 | default="exp/vallf_nano_full/checkpoint-100000.pt", 92 | help="Path to the saved checkpoint.", 93 | ) 94 | 95 | parser.add_argument( 96 | "--output-dir", 97 | type=Path, 98 | default=Path("infer/demo"), 99 | help="Path to the tokenized files.", 100 | ) 101 | 102 | parser.add_argument( 103 | "--top-k", 104 | type=int, 105 | default=-100, 106 | help="Whether AR Decoder do top_k(if > 0) sampling.", 107 | ) 108 | 109 | parser.add_argument( 110 | "--temperature", 111 | type=float, 112 | default=1.0, 113 | help="The temperature of AR Decoder top_k sampling.", 114 | ) 115 | 116 | parser.add_argument( 117 | "--continual", 118 | type=str2bool, 119 | default=False, 120 | help="Do continual task.", 121 | ) 122 | 123 | return parser.parse_args() 124 | 125 | 126 | def load_model(checkpoint, device): 127 | if not checkpoint: 128 | return None 129 | 130 | checkpoint = torch.load(checkpoint, map_location=device) 131 | 132 | args = AttributeDict(checkpoint) 133 | model = get_model(args) 134 | 135 | missing_keys, unexpected_keys = model.load_state_dict( 136 | checkpoint["model"], strict=True 137 | ) 138 | assert not missing_keys 139 | model.to(device) 140 | model.eval() 141 | 142 | text_tokens = args.text_tokens 143 | 144 | return model, text_tokens 145 | 146 | 147 | @torch.no_grad() 148 | def main(): 149 | args = get_args() 150 | text_tokenizer = TextTokenizer(backend=args.text_extractor) 151 | 152 | device = torch.device("cpu") 153 | if torch.cuda.is_available(): 154 | device = torch.device("cuda", 0) 155 | model, text_tokens = load_model(args.checkpoint, device) 156 | text_collater = get_text_token_collater(text_tokens) 157 | 158 | audio_tokenizer = AudioTokenizer() 159 | 160 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 161 | 162 | text_prompts = " ".join(args.text_prompts.split("|")) 163 | 164 | audio_prompts = [] 165 | if args.audio_prompts: 166 | for n, audio_file in enumerate(args.audio_prompts.split("|")): 167 | encoded_frames = tokenize_audio(audio_tokenizer, audio_file) 168 | if False: 169 | samples = audio_tokenizer.decode(encoded_frames) 170 | torchaudio.save( 171 | f"{args.output_dir}/p{n}.wav", samples[0], 24000 172 | ) 173 | 174 | audio_prompts.append(encoded_frames[0][0]) 175 | 176 | assert len(args.text_prompts.split("|")) == len(audio_prompts) 177 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) 178 | audio_prompts = audio_prompts.to(device) 179 | 180 | if os.path.isfile(args.text): # for demos 181 | # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py 182 | with open(args.text) as f: 183 | for line in f: 184 | fields = line.strip().split("\t") 185 | assert len(fields) == 4 186 | prompt_text, prompt_audio, text, audio_path = fields 187 | logging.info(f"synthesize text: {text}") 188 | text_tokens, text_tokens_lens = text_collater( 189 | [ 190 | tokenize_text( 191 | text_tokenizer, text=f"{prompt_text} {text}".strip() 192 | ) 193 | ] 194 | ) 195 | _, enroll_x_lens = text_collater( 196 | [ 197 | tokenize_text( 198 | text_tokenizer, text=f"{prompt_text}".strip() 199 | ) 200 | ] 201 | ) 202 | 203 | audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) 204 | audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) 205 | 206 | # synthesis 207 | encoded_frames = model.inference( 208 | text_tokens.to(device), 209 | text_tokens_lens.to(device), 210 | audio_prompts, 211 | enroll_x_lens=enroll_x_lens, 212 | top_k=args.top_k, 213 | temperature=args.temperature, 214 | ) 215 | 216 | samples = audio_tokenizer.decode( 217 | [(encoded_frames.transpose(2, 1), None)] 218 | ) 219 | # store 220 | torchaudio.save(audio_path, samples[0].cpu(), 24000) 221 | return 222 | 223 | for n, text in enumerate(args.text.split("|")): 224 | logging.info(f"synthesize text: {text}") 225 | text_tokens, text_tokens_lens = text_collater( 226 | [ 227 | tokenize_text( 228 | text_tokenizer, text=f"{text_prompts} {text}".strip() 229 | ) 230 | ] 231 | ) 232 | 233 | # synthesis 234 | if args.continual: 235 | assert text == "" 236 | encoded_frames = model.continual( 237 | text_tokens.to(device), 238 | text_tokens_lens.to(device), 239 | audio_prompts, 240 | ) 241 | else: 242 | enroll_x_lens = None 243 | if text_prompts: 244 | _, enroll_x_lens = text_collater( 245 | [ 246 | tokenize_text( 247 | text_tokenizer, text=f"{text_prompts}".strip() 248 | ) 249 | ] 250 | ) 251 | encoded_frames = model.inference( 252 | text_tokens.to(device), 253 | text_tokens_lens.to(device), 254 | audio_prompts, 255 | enroll_x_lens=enroll_x_lens, 256 | top_k=args.top_k, 257 | temperature=args.temperature, 258 | ) 259 | 260 | if audio_prompts != []: 261 | samples = audio_tokenizer.decode( 262 | [(encoded_frames.transpose(2, 1), None)] 263 | ) 264 | # store 265 | torchaudio.save( 266 | f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000 267 | ) 268 | else: # Transformer 269 | pass 270 | 271 | 272 | torch.set_num_threads(1) 273 | torch.set_num_interop_threads(1) 274 | torch._C._jit_set_profiling_executor(False) 275 | torch._C._jit_set_profiling_mode(False) 276 | torch._C._set_graph_executor_optimize(False) 277 | if __name__ == "__main__": 278 | formatter = ( 279 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" 280 | ) 281 | logging.basicConfig(format=formatter, level=logging.INFO) 282 | main() 283 | -------------------------------------------------------------------------------- /valle/bin/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Phonemize Text and EnCodec Audio. 17 | 18 | Usage example: 19 | python3 bin/tokenizer.py \ 20 | --src_dir ./data/manifests --output_dir ./data/tokenized 21 | 22 | """ 23 | import argparse 24 | import logging 25 | import os 26 | from pathlib import Path 27 | 28 | import torch 29 | import torch.multiprocessing 30 | from icefall.utils import get_executor 31 | from lhotse import CutSet, NumpyHdf5Writer 32 | from lhotse.recipes.utils import read_manifests_if_cached 33 | from tqdm.auto import tqdm 34 | 35 | from valle.data import ( 36 | AudioTokenConfig, 37 | AudioTokenExtractor, 38 | TextTokenizer, 39 | tokenize_text, 40 | ) 41 | from valle.data.fbank import get_fbank_extractor 42 | from valle.utils import SymbolTable 43 | 44 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 45 | 46 | 47 | # Torch's multithreaded behavior needs to be disabled or 48 | # it wastes a lot of CPU and slow things down. 49 | # Do this outside of main() in case it needs to take effect 50 | # even when we are not invoking the main (e.g. when spawning subprocesses). 51 | torch.set_num_threads(1) 52 | torch.set_num_interop_threads(1) 53 | torch.multiprocessing.set_sharing_strategy("file_system") 54 | 55 | 56 | def get_args(): 57 | parser = argparse.ArgumentParser() 58 | 59 | parser.add_argument( 60 | "--src-dir", 61 | type=Path, 62 | default=Path("data/manifests"), 63 | help="Path to the manifest files", 64 | ) 65 | parser.add_argument( 66 | "--output-dir", 67 | type=Path, 68 | default=Path("data/tokenized"), 69 | help="Path to the tokenized files", 70 | ) 71 | parser.add_argument( 72 | "--text-extractor", 73 | type=str, 74 | default="espeak", 75 | help="espeak or pypinyin or pypinyin_initials_finals", 76 | ) 77 | parser.add_argument( 78 | "--audio-extractor", 79 | type=str, 80 | default="Encodec", 81 | help="Encodec or Fbank", 82 | ) 83 | parser.add_argument( 84 | "--dataset-parts", 85 | type=str, 86 | default="dev-clean test-clean", 87 | help="Space separated dataset parts", 88 | ) 89 | parser.add_argument( 90 | "--prefix", 91 | type=str, 92 | default="libritts", 93 | help="prefix of the manifest file", 94 | ) 95 | parser.add_argument( 96 | "--suffix", 97 | type=str, 98 | default="jsonl.gz", 99 | help="suffix of the manifest file", 100 | ) 101 | parser.add_argument( 102 | "--batch-duration", 103 | type=float, 104 | default=400.0, 105 | help="The maximum number of audio seconds in a batch." 106 | "Determines batch size dynamically.", 107 | ) 108 | 109 | return parser.parse_args() 110 | 111 | 112 | def main(): 113 | args = get_args() 114 | 115 | dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() 116 | if dataset_parts == "all": # LibriTTS 117 | dataset_parts = [ 118 | "dev-clean", 119 | "dev-other", 120 | "test-clean", 121 | "test-other", 122 | "train-clean-100", 123 | "train-clean-360", 124 | "train-other-500", 125 | ] 126 | else: 127 | dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") 128 | 129 | assert len(dataset_parts) >= 1 130 | 131 | manifests = read_manifests_if_cached( 132 | dataset_parts=dataset_parts, 133 | output_dir=args.src_dir, 134 | prefix=args.prefix, 135 | suffix=args.suffix, 136 | types=["recordings", "supervisions", "cuts"], 137 | ) 138 | 139 | text_tokenizer = None 140 | if args.text_extractor: 141 | text_tokenizer = TextTokenizer(backend=args.text_extractor) 142 | 143 | audio_extractor = None 144 | if args.audio_extractor: 145 | if args.audio_extractor == "Encodec": 146 | audio_extractor = AudioTokenExtractor(AudioTokenConfig()) 147 | else: 148 | assert args.audio_extractor == "Fbank" 149 | audio_extractor = get_fbank_extractor() 150 | 151 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 152 | unique_symbols = set() 153 | num_jobs = min(32, os.cpu_count()) 154 | logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") 155 | 156 | prefix = args.prefix 157 | if prefix and not prefix.endswith("_"): 158 | prefix = f"{prefix}_" 159 | with get_executor() as ex: 160 | for partition, m in manifests.items(): 161 | logging.info( 162 | f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" 163 | ) 164 | try: 165 | cut_set = CutSet.from_manifests( 166 | recordings=m["recordings"], 167 | supervisions=m["supervisions"], 168 | ) 169 | except Exception: 170 | cut_set = m["cuts"] 171 | 172 | # AudioTokenizer 173 | if args.audio_extractor: 174 | if args.audio_extractor == "Encodec": 175 | storage_path = ( 176 | f"{args.output_dir}/{args.prefix}_encodec_{partition}" 177 | ) 178 | else: 179 | storage_path = ( 180 | f"{args.output_dir}/{args.prefix}_fbank_{partition}" 181 | ) 182 | 183 | if args.prefix.lower() in ["ljspeech", "aishell", "baker"]: 184 | cut_set = cut_set.resample(24000) 185 | # https://github.com/lifeiteng/vall-e/issues/90 186 | # if args.prefix == "aishell": 187 | # # NOTE: the loudness of aishell audio files is around -33 188 | # # The best way is datamodule --on-the-fly-feats --enable-audio-aug 189 | # cut_set = cut_set.normalize_loudness( 190 | # target=-20.0, affix_id=True 191 | # ) 192 | 193 | with torch.no_grad(): 194 | if ( 195 | torch.cuda.is_available() 196 | and args.audio_extractor == "Encodec" 197 | ): 198 | cut_set = cut_set.compute_and_store_features_batch( 199 | extractor=audio_extractor, 200 | storage_path=storage_path, 201 | num_workers=num_jobs, 202 | batch_duration=args.batch_duration, 203 | collate=False, 204 | overwrite=True, 205 | storage_type=NumpyHdf5Writer, 206 | ) 207 | else: 208 | cut_set = cut_set.compute_and_store_features( 209 | extractor=audio_extractor, 210 | storage_path=storage_path, 211 | num_jobs=num_jobs if ex is None else 64, 212 | executor=ex, 213 | storage_type=NumpyHdf5Writer, 214 | ) 215 | 216 | # TextTokenizer 217 | if args.text_extractor: 218 | if ( 219 | args.prefix == "baker" 220 | and args.text_extractor == "labeled_pinyin" 221 | ): 222 | for c in tqdm(cut_set): 223 | phonemes = c.supervisions[0].custom["tokens"]["text"] 224 | unique_symbols.update(phonemes) 225 | else: 226 | for c in tqdm(cut_set): 227 | if args.prefix == "ljspeech": 228 | text = c.supervisions[0].custom["normalized_text"] 229 | text = text.replace("”", '"').replace("“", '"') 230 | phonemes = tokenize_text(text_tokenizer, text=text) 231 | elif args.prefix == "aishell": 232 | phonemes = tokenize_text( 233 | text_tokenizer, text=c.supervisions[0].text 234 | ) 235 | c.supervisions[0].custom = {} 236 | else: 237 | assert args.prefix == "libritts" 238 | phonemes = tokenize_text( 239 | text_tokenizer, text=c.supervisions[0].text 240 | ) 241 | c.supervisions[0].custom["tokens"] = {"text": phonemes} 242 | unique_symbols.update(phonemes) 243 | 244 | cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" 245 | cut_set.to_file(f"{args.output_dir}/{cuts_filename}") 246 | 247 | if args.text_extractor: 248 | unique_phonemes = SymbolTable() 249 | for s in sorted(list(unique_symbols)): 250 | unique_phonemes.add(s) 251 | logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") 252 | 253 | unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" 254 | unique_phonemes.to_file(unique_phonemes_file) 255 | 256 | 257 | if __name__ == "__main__": 258 | formatter = ( 259 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" 260 | ) 261 | logging.basicConfig(format=formatter, level=logging.INFO) 262 | main() 263 | -------------------------------------------------------------------------------- /valle/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datamodule import * 2 | from .tokenizer import * 3 | from .collation import * 4 | -------------------------------------------------------------------------------- /valle/data/collation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from valle.utils import SymbolTable 8 | 9 | 10 | class TextTokenCollater: 11 | """Collate list of text tokens 12 | 13 | Map sentences to integers. Sentences are padded to equal length. 14 | Beginning and end-of-sequence symbols can be added. 15 | 16 | Example: 17 | >>> token_collater = TextTokenCollater(text_tokens) 18 | >>> tokens_batch, tokens_lens = token_collater(text) 19 | 20 | Returns: 21 | tokens_batch: IntTensor of shape (B, L) 22 | B: batch dimension, number of input sentences 23 | L: length of the longest sentence 24 | tokens_lens: IntTensor of shape (B,) 25 | Length of each sentence after adding and 26 | but before padding. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | text_tokens: List[str], 32 | add_eos: bool = True, 33 | add_bos: bool = True, 34 | pad_symbol: str = "", 35 | bos_symbol: str = "", 36 | eos_symbol: str = "", 37 | ): 38 | self.pad_symbol = pad_symbol 39 | 40 | self.add_eos = add_eos 41 | self.add_bos = add_bos 42 | 43 | self.bos_symbol = bos_symbol 44 | self.eos_symbol = eos_symbol 45 | 46 | unique_tokens = ( 47 | [pad_symbol] 48 | + ([bos_symbol] if add_bos else []) 49 | + ([eos_symbol] if add_eos else []) 50 | + sorted(text_tokens) 51 | ) 52 | 53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} 54 | self.idx2token = [token for token in unique_tokens] 55 | 56 | def index( 57 | self, tokens_list: List[str] 58 | ) -> Tuple[torch.Tensor, torch.Tensor]: 59 | seqs, seq_lens = [], [] 60 | for tokens in tokens_list: 61 | assert ( 62 | all([True if s in self.token2idx else False for s in tokens]) 63 | is True 64 | ) 65 | seq = ( 66 | ([self.bos_symbol] if self.add_bos else []) 67 | + list(tokens) 68 | + ([self.eos_symbol] if self.add_eos else []) 69 | ) 70 | seqs.append(seq) 71 | seq_lens.append(len(seq)) 72 | 73 | max_len = max(seq_lens) 74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): 75 | seq.extend([self.pad_symbol] * (max_len - seq_len)) 76 | 77 | tokens = torch.from_numpy( 78 | np.array( 79 | [[self.token2idx[token] for token in seq] for seq in seqs], 80 | dtype=np.int64, 81 | ) 82 | ) 83 | tokens_lens = torch.IntTensor(seq_lens) 84 | 85 | return tokens, tokens_lens 86 | 87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 88 | tokens_seqs = [[p for p in text] for text in texts] 89 | max_len = len(max(tokens_seqs, key=len)) 90 | 91 | seqs = [ 92 | ([self.bos_symbol] if self.add_bos else []) 93 | + list(seq) 94 | + ([self.eos_symbol] if self.add_eos else []) 95 | + [self.pad_symbol] * (max_len - len(seq)) 96 | for seq in tokens_seqs 97 | ] 98 | 99 | tokens_batch = torch.from_numpy( 100 | np.array( 101 | [[self.token2idx[token] for token in seq] for seq in seqs], 102 | dtype=np.int64, 103 | ) 104 | ) 105 | 106 | tokens_lens = torch.IntTensor( 107 | [ 108 | len(seq) + int(self.add_eos) + int(self.add_bos) 109 | for seq in tokens_seqs 110 | ] 111 | ) 112 | 113 | return tokens_batch, tokens_lens 114 | 115 | 116 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: 117 | text_tokens_path = Path(text_tokens_file) 118 | unique_tokens = SymbolTable.from_file(text_tokens_path) 119 | collater = TextTokenCollater( 120 | unique_tokens.symbols, add_bos=True, add_eos=True 121 | ) 122 | return collater 123 | -------------------------------------------------------------------------------- /valle/data/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import argparse 19 | import inspect 20 | import logging 21 | from functools import lru_cache 22 | from pathlib import Path 23 | from typing import Any, Dict, Optional 24 | 25 | import torch 26 | from icefall.utils import str2bool 27 | from lhotse import CutSet, load_manifest_lazy 28 | from lhotse.dataset import ( 29 | CutConcatenate, 30 | DynamicBucketingSampler, 31 | PrecomputedFeatures, 32 | SimpleCutSampler, 33 | SpecAugment, 34 | ) 35 | from lhotse.dataset.input_strategies import OnTheFlyFeatures 36 | from lhotse.utils import fix_random_seed 37 | from torch.utils.data import DataLoader 38 | 39 | from valle.data.collation import get_text_token_collater 40 | from valle.data.dataset import SpeechSynthesisDataset 41 | from valle.data.fbank import get_fbank_extractor 42 | from valle.data.input_strategies import PromptedPrecomputedFeatures 43 | 44 | PrecomputedFeatures = PrecomputedFeatures 45 | 46 | 47 | class _SeedWorkers: 48 | def __init__(self, seed: int): 49 | self.seed = seed 50 | 51 | def __call__(self, worker_id: int): 52 | fix_random_seed(self.seed + worker_id) 53 | 54 | 55 | def _get_input_strategy(input_strategy, dataset, cuts): 56 | if input_strategy == "PromptedPrecomputedFeatures": 57 | return PromptedPrecomputedFeatures(dataset, cuts) 58 | 59 | return eval(input_strategy)() 60 | 61 | 62 | class TtsDataModule: 63 | """ 64 | DataModule for VALL-E TTS experiments. 65 | It assumes there is always one train and valid dataloader. 66 | 67 | It contains all the common data pipeline modules used in TTS 68 | experiments, e.g.: 69 | - dynamic batch size, 70 | - bucketing samplers, 71 | - cut concatenation[not used & tested yet], 72 | - augmentation[not used & tested yet], 73 | - on-the-fly feature extraction[not used & tested yet] 74 | 75 | This class should be derived for specific corpora used in TTS tasks. 76 | """ 77 | 78 | def __init__(self, args: argparse.Namespace): 79 | self.args = args 80 | 81 | @classmethod 82 | def add_arguments(cls, parser: argparse.ArgumentParser): 83 | group = parser.add_argument_group( 84 | title="TTS data related options", 85 | description="These options are used for the preparation of " 86 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the " 87 | "effective batch sizes, sampling strategies, applied data " 88 | "augmentations, etc.", 89 | ) 90 | group.add_argument( 91 | "--manifest-dir", 92 | type=Path, 93 | default=Path("data/tokenized"), 94 | help="Path to directory with train/valid/test cuts.", 95 | ) 96 | group.add_argument( 97 | "--max-duration", 98 | type=int, 99 | default=40.0, 100 | help="Maximum pooled recordings duration (seconds) in a " 101 | "single batch. You can reduce it if it causes CUDA OOM.", 102 | ) 103 | group.add_argument( 104 | "--bucketing-sampler", 105 | type=str2bool, 106 | default=True, 107 | help="When enabled, the batches will come from buckets of " 108 | "similar duration (saves padding frames).", 109 | ) 110 | group.add_argument( 111 | "--num-buckets", 112 | type=int, 113 | default=10, 114 | help="The number of buckets for the DynamicBucketingSampler" 115 | "(you might want to increase it for larger datasets).", 116 | ) 117 | group.add_argument( 118 | "--concatenate-cuts", 119 | type=str2bool, 120 | default=False, 121 | help="When enabled, utterances (cuts) will be concatenated " 122 | "to minimize the amount of padding.", 123 | ) 124 | group.add_argument( 125 | "--duration-factor", 126 | type=float, 127 | default=1.0, 128 | help="Determines the maximum duration of a concatenated cut " 129 | "relative to the duration of the longest cut in a batch.", 130 | ) 131 | group.add_argument( 132 | "--gap", 133 | type=float, 134 | default=0.1, 135 | help="The amount of padding (in seconds) inserted between " 136 | "concatenated cuts. This padding is filled with noise when " 137 | "noise augmentation is used.", 138 | ) 139 | group.add_argument( 140 | "--on-the-fly-feats", 141 | type=str2bool, 142 | default=False, 143 | help="When enabled, use on-the-fly cut mixing and feature " 144 | "extraction. Will drop existing precomputed feature manifests " 145 | "if available.", 146 | ) 147 | group.add_argument( 148 | "--shuffle", 149 | type=str2bool, 150 | default=True, 151 | help="When enabled (=default), the examples will be " 152 | "shuffled for each epoch.", 153 | ) 154 | group.add_argument( 155 | "--buffer-size", 156 | type=int, 157 | default=40000, 158 | help="How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets." 159 | "Increasing ``max_duration`` (batch_size) or ``num_buckets`` might require increasing this number." 160 | "It will result in larger memory usage.", 161 | ) 162 | group.add_argument( 163 | "--shuffle-buffer-size", 164 | type=int, 165 | default=100000, 166 | help="How many cuts (or cut pairs, triplets) are being held in memory" 167 | "a buffer used for streaming shuffling. Larger number means better randomness at the cost" 168 | "of higher memory usage.", 169 | ) 170 | group.add_argument( 171 | "--drop-last", 172 | type=str2bool, 173 | default=False, 174 | help="Whether to drop last batch. Used by sampler.", 175 | ) 176 | group.add_argument( 177 | "--return-cuts", 178 | type=str2bool, 179 | default=True, 180 | help="When enabled, each batch will have the " 181 | "field: batch['supervisions']['cut'] with the cuts that " 182 | "were used to construct it.", 183 | ) 184 | 185 | group.add_argument( 186 | "--num-workers", 187 | type=int, 188 | default=8, 189 | help="The number of training dataloader workers that " 190 | "collect the batches.", 191 | ) 192 | 193 | group.add_argument( 194 | "--enable-spec-aug", 195 | type=str2bool, 196 | default=False, 197 | help="When enabled, use SpecAugment for training dataset.", 198 | ) 199 | 200 | group.add_argument( 201 | "--spec-aug-time-warp-factor", 202 | type=int, 203 | default=80, 204 | help="Used only when --enable-spec-aug is True. " 205 | "It specifies the factor for time warping in SpecAugment. " 206 | "Larger values mean more warping. " 207 | "A value less than 1 means to disable time warp.", 208 | ) 209 | 210 | group.add_argument( 211 | "--input-strategy", 212 | type=str, 213 | default="PrecomputedFeatures", 214 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures", 215 | ) 216 | 217 | group.add_argument( 218 | "--dataset", 219 | type=str, 220 | default="libritts", 221 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", 222 | ) 223 | 224 | parser.add_argument( 225 | "--text-tokens", 226 | type=str, 227 | default="data/tokenized/unique_text_tokens.k2symbols", 228 | help="Path to the unique text tokens file", 229 | ) 230 | 231 | parser.add_argument( 232 | "--sampling-rate", 233 | type=int, 234 | default=24000, 235 | help="""Audio sampling rate.""", 236 | ) 237 | 238 | def train_dataloaders( 239 | self, 240 | cuts_train: CutSet, 241 | sampler_state_dict: Optional[Dict[str, Any]] = None, 242 | ) -> DataLoader: 243 | """ 244 | Args: 245 | cuts_train: 246 | CutSet for training. 247 | sampler_state_dict: 248 | The state dict for the training sampler. 249 | """ 250 | transforms = [] 251 | 252 | if self.args.concatenate_cuts: 253 | logging.info( 254 | f"Using cut concatenation with duration factor " 255 | f"{self.args.duration_factor} and gap {self.args.gap}." 256 | ) 257 | # Cut concatenation should be the first transform in the list, 258 | # so that if we e.g. mix noise in, it will fill the gaps between 259 | # different utterances. 260 | transforms = [ 261 | CutConcatenate( 262 | duration_factor=self.args.duration_factor, gap=self.args.gap 263 | ) 264 | ] + transforms 265 | 266 | input_transforms = [] 267 | if self.args.enable_spec_aug: 268 | logging.info("Enable SpecAugment") 269 | logging.info( 270 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}" 271 | ) 272 | # Set the value of num_frame_masks according to Lhotse's version. 273 | # In different Lhotse's versions, the default of num_frame_masks is 274 | # different. 275 | num_frame_masks = 10 276 | num_frame_masks_parameter = inspect.signature( 277 | SpecAugment.__init__ 278 | ).parameters["num_frame_masks"] 279 | if num_frame_masks_parameter.default == 1: 280 | num_frame_masks = 2 281 | logging.info(f"Num frame mask: {num_frame_masks}") 282 | input_transforms.append( 283 | SpecAugment( 284 | time_warp_factor=self.args.spec_aug_time_warp_factor, 285 | num_frame_masks=num_frame_masks, 286 | features_mask_size=27, 287 | num_feature_masks=2, 288 | frames_mask_size=100, 289 | ) 290 | ) 291 | else: 292 | logging.info("Disable SpecAugment") 293 | 294 | logging.info("About to create train dataset") 295 | if self.args.on_the_fly_feats: 296 | # NOTE: the PerturbSpeed transform should be added only if we 297 | # remove it from data prep stage. 298 | # Add on-the-fly speed perturbation; since originally it would 299 | # have increased epoch size by 3, we will apply prob 2/3 and use 300 | # 3x more epochs. 301 | # Speed perturbation probably should come first before 302 | # concatenation, but in principle the transforms order doesn't have 303 | # to be strict (e.g. could be randomized) 304 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa 305 | # Drop feats to be on the safe side. 306 | train = SpeechSynthesisDataset( 307 | get_text_token_collater(self.args.text_tokens), 308 | cut_transforms=transforms, 309 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 310 | feature_transforms=input_transforms, 311 | ) 312 | else: 313 | train = SpeechSynthesisDataset( 314 | get_text_token_collater(self.args.text_tokens), 315 | feature_input_strategy=_get_input_strategy( 316 | self.args.input_strategy, self.args.dataset, cuts_train 317 | ), 318 | cut_transforms=transforms, 319 | feature_transforms=input_transforms, 320 | ) 321 | 322 | if self.args.bucketing_sampler: 323 | logging.info("Using DynamicBucketingSampler") 324 | train_sampler = DynamicBucketingSampler( 325 | cuts_train, 326 | max_duration=self.args.max_duration, 327 | shuffle=self.args.shuffle, 328 | buffer_size=self.args.buffer_size, 329 | shuffle_buffer_size=self.args.shuffle_buffer_size, 330 | quadratic_duration=10, 331 | num_cuts_for_bins_estimate=10000, 332 | drop_last=True, 333 | ) 334 | else: 335 | logging.info( 336 | "Using SimpleCutSampler and sort by duraton(ascending=True)." 337 | ) 338 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True) 339 | train_sampler = SimpleCutSampler( 340 | cuts_train, 341 | max_duration=self.args.max_duration, 342 | shuffle=self.args.shuffle, 343 | ) 344 | logging.info("About to create train dataloader") 345 | 346 | if sampler_state_dict is not None: 347 | logging.info("Loading sampler state dict") 348 | train_sampler.load_state_dict(sampler_state_dict) 349 | 350 | # 'seed' is derived from the current random state, which will have 351 | # previously been set in the main process. 352 | seed = torch.randint(0, 100000, ()).item() 353 | worker_init_fn = _SeedWorkers(seed) 354 | 355 | train_dl = DataLoader( 356 | train, 357 | sampler=train_sampler, 358 | batch_size=None, 359 | num_workers=self.args.num_workers, 360 | persistent_workers=False, 361 | worker_init_fn=worker_init_fn, 362 | ) 363 | 364 | return train_dl 365 | 366 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: 367 | logging.info("About to create dev dataset") 368 | if self.args.on_the_fly_feats: 369 | validate = SpeechSynthesisDataset( 370 | get_text_token_collater(self.args.text_tokens), 371 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 372 | cut_transforms=[], 373 | ) 374 | else: 375 | validate = SpeechSynthesisDataset( 376 | get_text_token_collater(self.args.text_tokens), 377 | feature_input_strategy=_get_input_strategy( 378 | self.args.input_strategy, self.args.dataset, cuts_valid 379 | ), 380 | cut_transforms=[], 381 | ) 382 | valid_sampler = DynamicBucketingSampler( 383 | cuts_valid, 384 | max_duration=self.args.max_duration, 385 | shuffle=False, 386 | drop_last=True, 387 | ) 388 | logging.info("About to create dev dataloader") 389 | valid_dl = DataLoader( 390 | validate, 391 | sampler=valid_sampler, 392 | batch_size=None, 393 | num_workers=4, 394 | persistent_workers=False, 395 | ) 396 | 397 | return valid_dl 398 | 399 | def test_dataloaders(self, cuts: CutSet) -> DataLoader: 400 | logging.debug("About to create test dataset") 401 | test = SpeechSynthesisDataset( 402 | get_text_token_collater(self.args.text_tokens), 403 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()) 404 | if self.args.on_the_fly_feats 405 | else _get_input_strategy( 406 | self.args.input_strategy, self.args.dataset, cuts 407 | ), 408 | cut_transforms=[], 409 | ) 410 | sampler = DynamicBucketingSampler( 411 | cuts, 412 | max_duration=self.args.max_duration, 413 | shuffle=False, 414 | drop_last=True, 415 | ) 416 | logging.debug("About to create test dataloader") 417 | test_dl = DataLoader( 418 | test, 419 | batch_size=None, 420 | sampler=sampler, 421 | num_workers=self.args.num_workers, 422 | ) 423 | return test_dl 424 | 425 | @lru_cache() 426 | def train_cuts(self) -> CutSet: 427 | logging.info("About to get train cuts") 428 | return load_manifest_lazy( 429 | self.args.manifest_dir / "cuts_train.jsonl.gz" 430 | ) 431 | 432 | @lru_cache() 433 | def dev_cuts(self) -> CutSet: 434 | logging.info("About to get dev cuts") 435 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") 436 | 437 | @lru_cache() 438 | def test_cuts(self) -> CutSet: 439 | logging.info("About to get test cuts") 440 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") 441 | -------------------------------------------------------------------------------- /valle/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | modified from lhoste.dataset.speech_synthesis.py 19 | """ 20 | 21 | from typing import Callable, Dict, List, Sequence, Union 22 | 23 | import torch 24 | from lhotse import validate 25 | from lhotse.cut import CutSet 26 | from lhotse.dataset.collation import collate_audio 27 | from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures 28 | from lhotse.utils import ifnone 29 | 30 | from valle.data.collation import TextTokenCollater 31 | 32 | 33 | class SpeechSynthesisDataset(torch.utils.data.Dataset): 34 | """ 35 | The PyTorch Dataset for the speech synthesis(e.g. TTS) task. 36 | Each item in this dataset is a dict of: 37 | 38 | .. code-block:: 39 | 40 | { 41 | 'audio': (B x NumSamples) float tensor 42 | 'audio_lens': (B, ) int tensor 43 | 'text': str 44 | 'audio_features': (B x NumFrames x NumFeatures) float tensor 45 | 'audio_features_lens': (B, ) int tensor 46 | 'text_tokens': (B x NumTextTokens) long tensor 47 | 'text_tokens_lens': (B, ) int tensor 48 | } 49 | """ 50 | 51 | def __init__( 52 | self, 53 | text_token_collater: TextTokenCollater, 54 | cut_transforms: List[Callable[[CutSet], CutSet]] = None, 55 | feature_input_strategy: BatchIO = PrecomputedFeatures(), 56 | feature_transforms: Union[Sequence[Callable], Callable] = None, 57 | ) -> None: 58 | super().__init__() 59 | 60 | self.text_token_collater = text_token_collater 61 | self.cut_transforms = ifnone(cut_transforms, []) 62 | self.feature_input_strategy = feature_input_strategy 63 | 64 | if feature_transforms is None: 65 | feature_transforms = [] 66 | elif not isinstance(feature_transforms, Sequence): 67 | feature_transforms = [feature_transforms] 68 | 69 | assert all( 70 | isinstance(transform, Callable) for transform in feature_transforms 71 | ), "Feature transforms must be Callable" 72 | self.feature_transforms = feature_transforms 73 | 74 | def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: 75 | validate_for_tts(cuts) 76 | 77 | for transform in self.cut_transforms: 78 | cuts = transform(cuts) 79 | 80 | if False: # not used 81 | audio, audio_lens = collate_audio(cuts) 82 | else: # for sharing tokenized features in different machines 83 | audio, audio_lens = None, None 84 | 85 | audio_features, audio_features_lens = self.feature_input_strategy(cuts) 86 | 87 | for transform in self.feature_transforms: 88 | audio_features = transform(audio_features) 89 | 90 | text_tokens, text_tokens_lens = self.text_token_collater( 91 | [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] 92 | ) 93 | 94 | return { 95 | "utt_id": [cut.id for cut in cuts], 96 | "text": [cut.supervisions[0].text for cut in cuts], 97 | "audio": audio, 98 | "audio_lens": audio_lens, 99 | "audio_features": audio_features, 100 | "audio_features_lens": audio_features_lens, 101 | "text_tokens": text_tokens, 102 | "text_tokens_lens": text_tokens_lens, 103 | } 104 | 105 | 106 | def validate_for_tts(cuts: CutSet) -> None: 107 | validate(cuts) 108 | for cut in cuts: 109 | assert ( 110 | len(cut.supervisions) == 1 111 | ), "Only the Cuts with single supervision are supported." 112 | -------------------------------------------------------------------------------- /valle/data/fbank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from dataclasses import asdict, dataclass 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from lhotse.features.base import FeatureExtractor 24 | from lhotse.utils import EPSILON, Seconds, compute_num_frames 25 | from librosa.filters import mel as librosa_mel_fn 26 | 27 | 28 | @dataclass 29 | class BigVGANFbankConfig: 30 | # Spectogram-related part 31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them 32 | frame_length: Seconds = 1024 / 24000.0 33 | frame_shift: Seconds = 256 / 24000.0 34 | remove_dc_offset: bool = True 35 | round_to_power_of_two: bool = True 36 | 37 | # Fbank-related part 38 | low_freq: float = 0.0 39 | high_freq: float = 12000.0 40 | num_mel_bins: int = 100 41 | use_energy: bool = False 42 | 43 | def to_dict(self) -> Dict[str, Any]: 44 | return asdict(self) 45 | 46 | @staticmethod 47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig": 48 | return BigVGANFbankConfig(**data) 49 | 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | 55 | def spectral_normalize_torch(magnitudes): 56 | output = dynamic_range_compression_torch(magnitudes) 57 | return output 58 | 59 | 60 | # https://github.com/NVIDIA/BigVGAN 61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz 62 | class BigVGANFbank(FeatureExtractor): 63 | name = "fbank" 64 | config_type = BigVGANFbankConfig 65 | 66 | def __init__(self, config: Optional[Any] = None): 67 | super(BigVGANFbank, self).__init__(config) 68 | sampling_rate = 24000 69 | self.mel_basis = torch.from_numpy( 70 | librosa_mel_fn( 71 | sampling_rate, 72 | 1024, 73 | self.config.num_mel_bins, 74 | self.config.low_freq, 75 | self.config.high_freq, 76 | ).astype(np.float32) 77 | ) 78 | self.hann_window = torch.hann_window(1024) 79 | 80 | def _feature_fn(self, samples, **kwargs): 81 | win_length, n_fft = 1024, 1024 82 | hop_size = 256 83 | if True: 84 | sampling_rate = 24000 85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 86 | expected_num_frames = compute_num_frames( 87 | duration=duration, 88 | frame_shift=self.frame_shift, 89 | sampling_rate=sampling_rate, 90 | ) 91 | pad_size = ( 92 | (expected_num_frames - 1) * hop_size 93 | + win_length 94 | - samples.shape[-1] 95 | ) 96 | assert pad_size >= 0 97 | 98 | y = torch.nn.functional.pad( 99 | samples, 100 | (0, pad_size), 101 | mode="constant", 102 | ) 103 | else: 104 | y = torch.nn.functional.pad( 105 | samples, 106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 107 | mode="reflect", 108 | ) 109 | 110 | y = y.squeeze(1) 111 | 112 | # complex tensor as default, then use view_as_real for future pytorch compatibility 113 | spec = torch.stft( 114 | y, 115 | n_fft, 116 | hop_length=hop_size, 117 | win_length=win_length, 118 | window=self.hann_window, 119 | center=False, 120 | pad_mode="reflect", 121 | normalized=False, 122 | onesided=True, 123 | return_complex=True, 124 | ) 125 | spec = torch.view_as_real(spec) 126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 127 | 128 | spec = torch.matmul(self.mel_basis, spec) 129 | spec = spectral_normalize_torch(spec) 130 | 131 | return spec.transpose(2, 1).squeeze(0) 132 | 133 | def extract( 134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 135 | ) -> np.ndarray: 136 | assert sampling_rate == 24000 137 | params = asdict(self.config) 138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False}) 139 | params["frame_shift"] *= 1000.0 140 | params["frame_length"] *= 1000.0 141 | if not isinstance(samples, torch.Tensor): 142 | samples = torch.from_numpy(samples) 143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first. 144 | if len(samples.shape) == 1: 145 | samples = samples.unsqueeze(0) 146 | features = self._feature_fn(samples, **params).to(torch.float32) 147 | return features.numpy() 148 | 149 | @property 150 | def frame_shift(self) -> Seconds: 151 | return self.config.frame_shift 152 | 153 | def feature_dim(self, sampling_rate: int) -> int: 154 | return self.config.num_mel_bins 155 | 156 | @staticmethod 157 | def mix( 158 | features_a: np.ndarray, 159 | features_b: np.ndarray, 160 | energy_scaling_factor_b: float, 161 | ) -> np.ndarray: 162 | return np.log( 163 | np.maximum( 164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0) 165 | EPSILON, 166 | np.exp(features_a) 167 | + energy_scaling_factor_b * np.exp(features_b), 168 | ) 169 | ) 170 | 171 | @staticmethod 172 | def compute_energy(features: np.ndarray) -> float: 173 | return float(np.sum(np.exp(features))) 174 | 175 | 176 | def get_fbank_extractor() -> BigVGANFbank: 177 | return BigVGANFbank(BigVGANFbankConfig()) 178 | 179 | 180 | if __name__ == "__main__": 181 | extractor = BigVGANFbank(BigVGANFbankConfig()) 182 | 183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32)) 184 | samples = torch.clip(samples, -1.0, 1.0) 185 | fbank = extractor.extract(samples, 24000.0) 186 | print(f"fbank {fbank.shape}") 187 | 188 | from scipy.io.wavfile import read 189 | 190 | MAX_WAV_VALUE = 32768.0 191 | 192 | sampling_rate, samples = read( 193 | "egs/libritts/prompts/5639_40744_000000_000002.wav" 194 | ) 195 | print(f"samples: [{samples.min()}, {samples.max()}]") 196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000) 197 | print(f"fbank {fbank.shape}") 198 | 199 | import matplotlib.pyplot as plt 200 | 201 | _ = plt.figure(figsize=(18, 10)) 202 | plt.imshow( 203 | X=fbank.transpose(1, 0), 204 | cmap=plt.get_cmap("jet"), 205 | aspect="auto", 206 | interpolation="nearest", 207 | ) 208 | plt.gca().invert_yaxis() 209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png") 210 | plt.close() 211 | 212 | print("fbank test PASS!") 213 | -------------------------------------------------------------------------------- /valle/data/input_strategies.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Tuple, Type 5 | 6 | from lhotse import CutSet 7 | from lhotse.dataset.collation import collate_features 8 | from lhotse.dataset.input_strategies import ( 9 | ExecutorType, 10 | PrecomputedFeatures, 11 | _get_executor, 12 | ) 13 | from lhotse.utils import fastcopy 14 | 15 | 16 | class PromptedFeatures: 17 | def __init__(self, prompts, features): 18 | self.prompts = prompts 19 | self.features = features 20 | 21 | def to(self, device): 22 | return PromptedFeatures( 23 | self.prompts.to(device), self.features.to(device) 24 | ) 25 | 26 | def sum(self): 27 | return self.features.sum() 28 | 29 | @property 30 | def ndim(self): 31 | return self.features.ndim 32 | 33 | @property 34 | def data(self): 35 | return (self.prompts, self.features) 36 | 37 | 38 | class PromptedPrecomputedFeatures(PrecomputedFeatures): 39 | """ 40 | :class:`InputStrategy` that reads pre-computed features, whose manifests 41 | are attached to cuts, from disk. 42 | 43 | It automatically pads the feature matrices with pre or post feature. 44 | 45 | .. automethod:: __call__ 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dataset: str, 51 | cuts: CutSet, 52 | num_workers: int = 0, 53 | executor_type: Type[ExecutorType] = ThreadPoolExecutor, 54 | ) -> None: 55 | super(PromptedPrecomputedFeatures, self).__init__( 56 | num_workers, executor_type 57 | ) 58 | 59 | self.utt2neighbors = defaultdict(lambda: []) 60 | 61 | if dataset.lower() == "libritts": 62 | # 909_131041_000013_000002 63 | # 909_131041_000013_000003 64 | speaker2utts = defaultdict(lambda: []) 65 | 66 | utt2cut = {} 67 | for cut in cuts: 68 | speaker = cut.supervisions[0].speaker 69 | speaker2utts[speaker].append(cut.id) 70 | utt2cut[cut.id] = cut 71 | 72 | for spk in speaker2utts: 73 | uttids = sorted(speaker2utts[spk]) 74 | # Using the property of sorted keys to find previous utterance 75 | # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001 76 | if len(uttids) == 1: 77 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 78 | continue 79 | 80 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 81 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 82 | 83 | for utt in utt2prevutt: 84 | self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) 85 | 86 | for utt in utt2postutt: 87 | self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) 88 | elif dataset.lower() == "ljspeech": 89 | utt2cut = {} 90 | uttids = [] 91 | for cut in cuts: 92 | uttids.append(cut.id) 93 | utt2cut[cut.id] = cut 94 | 95 | if len(uttids) == 1: 96 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 97 | else: 98 | # Using the property of sorted keys to find previous utterance 99 | # The keys has structure: LJ001-0010 100 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 101 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 102 | 103 | for utt in utt2postutt: 104 | postutt = utt2postutt[utt] 105 | if utt[:5] == postutt[:5]: 106 | self.utt2neighbors[utt].append(utt2cut[postutt]) 107 | 108 | for utt in utt2prevutt: 109 | prevutt = utt2prevutt[utt] 110 | if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]: 111 | self.utt2neighbors[utt].append(utt2cut[prevutt]) 112 | else: 113 | raise ValueError 114 | 115 | def __call__( 116 | self, cuts: CutSet 117 | ) -> Tuple[PromptedFeatures, PromptedFeatures]: 118 | """ 119 | Reads the pre-computed features from disk/other storage. 120 | The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``. 121 | 122 | :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding. 123 | """ 124 | features, features_lens = collate_features( 125 | cuts, 126 | executor=_get_executor( 127 | self.num_workers, executor_type=self._executor_type 128 | ), 129 | ) 130 | 131 | prompts_cuts = [] 132 | for k, cut in enumerate(cuts): 133 | prompts_cut = random.choice(self.utt2neighbors[cut.id]) 134 | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) 135 | 136 | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) 137 | # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate( 138 | # max_duration=mini_duration, 139 | # offset_type="random", 140 | # preserve_id=True, 141 | # ) 142 | prompts_cuts = CutSet( 143 | cuts={k: cut for k, cut in enumerate(prompts_cuts)} 144 | ).truncate( 145 | max_duration=mini_duration, 146 | offset_type="random", 147 | preserve_id=False, 148 | ) 149 | 150 | prompts, prompts_lens = collate_features( 151 | prompts_cuts, 152 | executor=_get_executor( 153 | self.num_workers, executor_type=self._executor_type 154 | ), 155 | ) 156 | 157 | return PromptedFeatures(prompts, features), PromptedFeatures( 158 | prompts_lens, features_lens 159 | ) 160 | -------------------------------------------------------------------------------- /valle/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from dataclasses import asdict, dataclass 18 | from typing import Any, Dict, List, Optional, Pattern, Union 19 | 20 | import numpy as np 21 | import torch 22 | import torchaudio 23 | from encodec import EncodecModel 24 | from encodec.utils import convert_audio 25 | from lhotse.features import FeatureExtractor 26 | from lhotse.utils import Seconds, compute_num_frames 27 | from phonemizer.backend import EspeakBackend 28 | from phonemizer.backend.espeak.language_switch import LanguageSwitch 29 | from phonemizer.backend.espeak.words_mismatch import WordMismatch 30 | from phonemizer.punctuation import Punctuation 31 | from phonemizer.separator import Separator 32 | 33 | try: 34 | from pypinyin import Style, pinyin 35 | from pypinyin.style._utils import get_finals, get_initials 36 | except Exception: 37 | pass 38 | 39 | 40 | class PypinyinBackend: 41 | """PypinyinBackend for Chinese. Most codes is referenced from espnet. 42 | There are two types pinyin or initials_finals, one is 43 | just like "ni1 hao3", the other is like "n i1 h ao3". 44 | """ 45 | 46 | def __init__( 47 | self, 48 | backend="initials_finals", 49 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 50 | ) -> None: 51 | self.backend = backend 52 | self.punctuation_marks = punctuation_marks 53 | 54 | def phonemize( 55 | self, text: List[str], separator: Separator, strip=True, njobs=1 56 | ) -> List[str]: 57 | assert isinstance(text, List) 58 | phonemized = [] 59 | for _text in text: 60 | _text = re.sub(" +", " ", _text.strip()) 61 | _text = _text.replace(" ", separator.word) 62 | phones = [] 63 | if self.backend == "pypinyin": 64 | for n, py in enumerate( 65 | pinyin( 66 | _text, style=Style.TONE3, neutral_tone_with_five=True 67 | ) 68 | ): 69 | if all([c in self.punctuation_marks for c in py[0]]): 70 | if len(phones): 71 | assert phones[-1] == separator.syllable 72 | phones.pop(-1) 73 | 74 | phones.extend(list(py[0])) 75 | else: 76 | phones.extend([py[0], separator.syllable]) 77 | elif self.backend == "pypinyin_initials_finals": 78 | for n, py in enumerate( 79 | pinyin( 80 | _text, style=Style.TONE3, neutral_tone_with_five=True 81 | ) 82 | ): 83 | if all([c in self.punctuation_marks for c in py[0]]): 84 | if len(phones): 85 | assert phones[-1] == separator.syllable 86 | phones.pop(-1) 87 | phones.extend(list(py[0])) 88 | else: 89 | if py[0][-1].isalnum(): 90 | initial = get_initials(py[0], strict=False) 91 | if py[0][-1].isdigit(): 92 | final = ( 93 | get_finals(py[0][:-1], strict=False) 94 | + py[0][-1] 95 | ) 96 | else: 97 | final = get_finals(py[0], strict=False) 98 | phones.extend( 99 | [ 100 | initial, 101 | separator.phone, 102 | final, 103 | separator.syllable, 104 | ] 105 | ) 106 | else: 107 | assert ValueError 108 | else: 109 | raise NotImplementedError 110 | phonemized.append( 111 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}") 112 | ) 113 | return phonemized 114 | 115 | 116 | class TextTokenizer: 117 | """Phonemize Text.""" 118 | 119 | def __init__( 120 | self, 121 | language="en-us", 122 | backend="espeak", 123 | separator=Separator(word="_", syllable="-", phone="|"), 124 | preserve_punctuation=True, 125 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 126 | with_stress: bool = False, 127 | tie: Union[bool, str] = False, 128 | language_switch: LanguageSwitch = "keep-flags", 129 | words_mismatch: WordMismatch = "ignore", 130 | ) -> None: 131 | if backend == "espeak": 132 | phonemizer = EspeakBackend( 133 | language, 134 | punctuation_marks=punctuation_marks, 135 | preserve_punctuation=preserve_punctuation, 136 | with_stress=with_stress, 137 | tie=tie, 138 | language_switch=language_switch, 139 | words_mismatch=words_mismatch, 140 | ) 141 | elif backend in ["pypinyin", "pypinyin_initials_finals"]: 142 | phonemizer = PypinyinBackend( 143 | backend=backend, 144 | punctuation_marks=punctuation_marks + separator.word, 145 | ) 146 | else: 147 | raise NotImplementedError(f"{backend}") 148 | 149 | self.backend = phonemizer 150 | self.separator = separator 151 | 152 | def to_list(self, phonemized: str) -> List[str]: 153 | fields = [] 154 | for word in phonemized.split(self.separator.word): 155 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. 156 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) 157 | fields.extend( 158 | [p for p in pp if p != self.separator.phone] 159 | + [self.separator.word] 160 | ) 161 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( 162 | self.separator.phone 163 | ) 164 | return fields[:-1] 165 | 166 | def __call__(self, text, strip=True) -> List[List[str]]: 167 | if isinstance(text, str): 168 | text = [text] 169 | 170 | phonemized = self.backend.phonemize( 171 | text, separator=self.separator, strip=strip, njobs=1 172 | ) 173 | return [self.to_list(p) for p in phonemized] 174 | 175 | 176 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: 177 | phonemes = tokenizer([text.strip()]) 178 | return phonemes[0] # k2symbols 179 | 180 | 181 | def remove_encodec_weight_norm(model): 182 | from encodec.modules import SConv1d 183 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock 184 | from torch.nn.utils import remove_weight_norm 185 | 186 | encoder = model.encoder.model 187 | for key in encoder._modules: 188 | if isinstance(encoder._modules[key], SEANetResnetBlock): 189 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv) 190 | block_modules = encoder._modules[key].block._modules 191 | for skey in block_modules: 192 | if isinstance(block_modules[skey], SConv1d): 193 | remove_weight_norm(block_modules[skey].conv.conv) 194 | elif isinstance(encoder._modules[key], SConv1d): 195 | remove_weight_norm(encoder._modules[key].conv.conv) 196 | 197 | decoder = model.decoder.model 198 | for key in decoder._modules: 199 | if isinstance(decoder._modules[key], SEANetResnetBlock): 200 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv) 201 | block_modules = decoder._modules[key].block._modules 202 | for skey in block_modules: 203 | if isinstance(block_modules[skey], SConv1d): 204 | remove_weight_norm(block_modules[skey].conv.conv) 205 | elif isinstance(decoder._modules[key], SConvTranspose1d): 206 | remove_weight_norm(decoder._modules[key].convtr.convtr) 207 | elif isinstance(decoder._modules[key], SConv1d): 208 | remove_weight_norm(decoder._modules[key].conv.conv) 209 | 210 | 211 | class AudioTokenizer: 212 | """EnCodec audio.""" 213 | 214 | def __init__( 215 | self, 216 | device: Any = None, 217 | ) -> None: 218 | # Instantiate a pretrained EnCodec model 219 | model = EncodecModel.encodec_model_24khz() 220 | model.set_target_bandwidth(6.0) 221 | remove_encodec_weight_norm(model) 222 | 223 | if not device: 224 | device = torch.device("cpu") 225 | if torch.cuda.is_available(): 226 | device = torch.device("cuda:0") 227 | 228 | self._device = device 229 | 230 | self.codec = model.to(device) 231 | self.sample_rate = model.sample_rate 232 | self.channels = model.channels 233 | 234 | @property 235 | def device(self): 236 | return self._device 237 | 238 | def encode(self, wav: torch.Tensor) -> torch.Tensor: 239 | return self.codec.encode(wav.to(self.device)) 240 | 241 | def decode(self, frames: torch.Tensor) -> torch.Tensor: 242 | return self.codec.decode(frames) 243 | 244 | 245 | def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): 246 | # Load and pre-process the audio waveform 247 | wav, sr = torchaudio.load(audio_path) 248 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 249 | wav = wav.unsqueeze(0) 250 | 251 | # Extract discrete codes from EnCodec 252 | with torch.no_grad(): 253 | encoded_frames = tokenizer.encode(wav) 254 | return encoded_frames 255 | 256 | 257 | @dataclass 258 | class AudioTokenConfig: 259 | frame_shift: Seconds = 320.0 / 24000 260 | num_quantizers: int = 8 261 | 262 | def to_dict(self) -> Dict[str, Any]: 263 | return asdict(self) 264 | 265 | @staticmethod 266 | def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": 267 | return AudioTokenConfig(**data) 268 | 269 | 270 | class AudioTokenExtractor(FeatureExtractor): 271 | name = "encodec" 272 | config_type = AudioTokenConfig 273 | 274 | def __init__(self, config: Optional[Any] = None): 275 | super(AudioTokenExtractor, self).__init__(config) 276 | self.tokenizer = AudioTokenizer() 277 | 278 | def extract( 279 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 280 | ) -> np.ndarray: 281 | if not isinstance(samples, torch.Tensor): 282 | samples = torch.from_numpy(samples) 283 | if sampling_rate != self.tokenizer.sample_rate: 284 | samples = convert_audio( 285 | samples, 286 | sampling_rate, 287 | self.tokenizer.sample_rate, 288 | self.tokenizer.channels, 289 | ) 290 | if len(samples.shape) == 2: 291 | samples = samples.unsqueeze(0) 292 | else: 293 | raise ValueError() 294 | 295 | device = self.tokenizer.device 296 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 297 | codes = encoded_frames[0][0] # [B, n_q, T] 298 | if True: 299 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 300 | expected_num_frames = compute_num_frames( 301 | duration=duration, 302 | frame_shift=self.frame_shift, 303 | sampling_rate=sampling_rate, 304 | ) 305 | assert abs(codes.shape[-1] - expected_num_frames) <= 1 306 | codes = codes[..., :expected_num_frames] 307 | return codes.cpu().squeeze(0).permute(1, 0).numpy() 308 | 309 | @property 310 | def frame_shift(self) -> Seconds: 311 | return self.config.frame_shift 312 | 313 | def feature_dim(self, sampling_rate: int) -> int: 314 | return self.config.num_quantizers 315 | 316 | def pad_tensor_list(self, tensor_list, device, padding_value=0): 317 | # 计算每个张量的长度 318 | lengths = [tensor.shape[0] for tensor in tensor_list] 319 | # 使用pad_sequence函数进行填充 320 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] 321 | padded_tensor = torch.nn.utils.rnn.pad_sequence( 322 | tensor_list, batch_first=True, padding_value=padding_value 323 | ) 324 | return padded_tensor, lengths 325 | 326 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: 327 | samples = [wav.squeeze() for wav in samples] 328 | device = self.tokenizer.device 329 | samples, lengths = self.pad_tensor_list(samples, device) 330 | samples = samples.unsqueeze(1) 331 | 332 | if not isinstance(samples, torch.Tensor): 333 | samples = torch.from_numpy(samples) 334 | if len(samples.shape) != 3: 335 | raise ValueError() 336 | if sampling_rate != self.tokenizer.sample_rate: 337 | samples = [ 338 | convert_audio( 339 | wav, 340 | sampling_rate, 341 | self.tokenizer.sample_rate, 342 | self.tokenizer.channels, 343 | ) 344 | for wav in samples 345 | ] 346 | samples = torch.stack(samples, 0) # convert samples from list to tensor 347 | # Extract discrete codes from EnCodec 348 | with torch.no_grad(): 349 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 350 | encoded_frames = encoded_frames[0][0] # [B, n_q, T] 351 | batch_codes = [] 352 | for b, length in enumerate(lengths): 353 | codes = encoded_frames[b] 354 | duration = round(length / sampling_rate, ndigits=12) 355 | expected_num_frames = compute_num_frames( 356 | duration=duration, 357 | frame_shift=self.frame_shift, 358 | sampling_rate=sampling_rate, 359 | ) 360 | batch_codes.append(codes[..., :expected_num_frames]) 361 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] 362 | 363 | 364 | if __name__ == "__main__": 365 | model = EncodecModel.encodec_model_24khz() 366 | model.set_target_bandwidth(6.0) 367 | 368 | samples = torch.from_numpy(np.random.random([4, 1, 1600])).type( 369 | torch.float32 370 | ) 371 | codes_raw = model.encode(samples) 372 | 373 | remove_encodec_weight_norm(model) 374 | codes_norm = model.encode(samples) 375 | 376 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) 377 | -------------------------------------------------------------------------------- /valle/models/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.nn as nn 4 | from icefall.utils import AttributeDict, str2bool 5 | 6 | from .macros import ( 7 | NUM_AUDIO_TOKENS, 8 | NUM_MEL_BINS, 9 | NUM_SPEAKER_CLASSES, 10 | NUM_TEXT_TOKENS, 11 | SPEAKER_EMBEDDING_DIM, 12 | ) 13 | from .transformer import Transformer 14 | from .valle import VALLE, VALLF 15 | from .visualizer import visualize 16 | 17 | 18 | def add_model_arguments(parser: argparse.ArgumentParser): 19 | parser.add_argument( 20 | "--model-name", 21 | type=str, 22 | default="VALL-E", 23 | help="VALL-E, VALL-F, Transformer.", 24 | ) 25 | parser.add_argument( 26 | "--decoder-dim", 27 | type=int, 28 | default=1024, 29 | help="Embedding dimension in the decoder model.", 30 | ) 31 | parser.add_argument( 32 | "--nhead", 33 | type=int, 34 | default=16, 35 | help="Number of attention heads in the Decoder layers.", 36 | ) 37 | parser.add_argument( 38 | "--num-decoder-layers", 39 | type=int, 40 | default=12, 41 | help="Number of Decoder layers.", 42 | ) 43 | parser.add_argument( 44 | "--scale-factor", 45 | type=float, 46 | default=1.0, 47 | help="Model scale factor which will be assigned different meanings in different models.", 48 | ) 49 | parser.add_argument( 50 | "--norm-first", 51 | type=str2bool, 52 | default=True, 53 | help="Pre or Post Normalization.", 54 | ) 55 | parser.add_argument( 56 | "--add-prenet", 57 | type=str2bool, 58 | default=False, 59 | help="Whether add PreNet after Inputs.", 60 | ) 61 | 62 | # VALL-E & F 63 | parser.add_argument( 64 | "--prefix-mode", 65 | type=int, 66 | default=0, 67 | help="The mode for how to prefix VALL-E NAR Decoder, " 68 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", 69 | ) 70 | parser.add_argument( 71 | "--share-embedding", 72 | type=str2bool, 73 | default=True, 74 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", 75 | ) 76 | parser.add_argument( 77 | "--prepend-bos", 78 | type=str2bool, 79 | default=False, 80 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", 81 | ) 82 | parser.add_argument( 83 | "--num-quantizers", 84 | type=int, 85 | default=8, 86 | help="Number of Audio/Semantic quantization layers.", 87 | ) 88 | 89 | # Transformer 90 | parser.add_argument( 91 | "--scaling-xformers", 92 | type=str2bool, 93 | default=False, 94 | help="Apply Reworked Conformer scaling on Transformers.", 95 | ) 96 | 97 | 98 | def get_model(params: AttributeDict) -> nn.Module: 99 | if params.model_name.lower() in ["vall-f", "vallf"]: 100 | model = VALLF( 101 | params.decoder_dim, 102 | params.nhead, 103 | params.num_decoder_layers, 104 | norm_first=params.norm_first, 105 | add_prenet=params.add_prenet, 106 | prefix_mode=params.prefix_mode, 107 | share_embedding=params.share_embedding, 108 | nar_scale_factor=params.scale_factor, 109 | prepend_bos=params.prepend_bos, 110 | num_quantizers=params.num_quantizers, 111 | ) 112 | elif params.model_name.lower() in ["vall-e", "valle"]: 113 | model = VALLE( 114 | params.decoder_dim, 115 | params.nhead, 116 | params.num_decoder_layers, 117 | norm_first=params.norm_first, 118 | add_prenet=params.add_prenet, 119 | prefix_mode=params.prefix_mode, 120 | share_embedding=params.share_embedding, 121 | nar_scale_factor=params.scale_factor, 122 | prepend_bos=params.prepend_bos, 123 | num_quantizers=params.num_quantizers, 124 | ) 125 | else: 126 | assert params.model_name in ["Transformer"] 127 | model = Transformer( 128 | params.decoder_dim, 129 | params.nhead, 130 | params.num_decoder_layers, 131 | norm_first=params.norm_first, 132 | add_prenet=params.add_prenet, 133 | scaling_xformers=params.scaling_xformers, 134 | ) 135 | 136 | return model 137 | -------------------------------------------------------------------------------- /valle/models/macros.py: -------------------------------------------------------------------------------- 1 | # Text 2 | NUM_TEXT_TOKENS = 512 3 | 4 | # Audio 5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins 6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band 7 | 8 | 9 | # Speaker 10 | NUM_SPEAKER_CLASSES = 4096 11 | SPEAKER_EMBEDDING_DIM = 64 12 | -------------------------------------------------------------------------------- /valle/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | from typing import Any, Dict, List, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from icefall.utils import make_pad_mask 22 | from torchmetrics.classification import BinaryAccuracy 23 | 24 | from valle.models.valle import Transpose 25 | from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding 26 | from valle.modules.scaling import BalancedDoubleSwish, ScaledLinear 27 | from valle.modules.transformer import ( 28 | BalancedBasicNorm, 29 | IdentityNorm, 30 | TransformerDecoderLayer, 31 | TransformerEncoder, 32 | TransformerEncoderLayer, 33 | ) 34 | 35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS 36 | from .visualizer import visualize 37 | 38 | IdentityNorm = IdentityNorm 39 | 40 | 41 | class Transformer(nn.Module): 42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) 43 | Neural Speech Synthesis with Transformer Network 44 | https://arxiv.org/abs/1809.08895 45 | """ 46 | 47 | def __init__( 48 | self, 49 | d_model: int, 50 | nhead: int, 51 | num_layers: int, 52 | norm_first: bool = True, 53 | add_prenet: bool = False, 54 | scaling_xformers: bool = False, 55 | ): 56 | """ 57 | Args: 58 | d_model: 59 | The number of expected features in the input (required). 60 | nhead: 61 | The number of heads in the multiheadattention models (required). 62 | num_layers: 63 | The number of sub-decoder-layers in the decoder (required). 64 | """ 65 | super().__init__() 66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x 67 | 68 | if add_prenet: 69 | self.encoder_prenet = nn.Sequential( 70 | Transpose(), 71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 72 | nn.BatchNorm1d(d_model), 73 | nn.ReLU(), 74 | nn.Dropout(0.5), 75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 76 | nn.BatchNorm1d(d_model), 77 | nn.ReLU(), 78 | nn.Dropout(0.5), 79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 80 | nn.BatchNorm1d(d_model), 81 | nn.ReLU(), 82 | nn.Dropout(0.5), 83 | Transpose(), 84 | nn.Linear(d_model, d_model), 85 | ) 86 | 87 | self.decoder_prenet = nn.Sequential( 88 | nn.Linear(NUM_MEL_BINS, 256), 89 | nn.ReLU(), 90 | nn.Dropout(0.5), 91 | nn.Linear(256, 256), 92 | nn.ReLU(), 93 | nn.Dropout(0.5), 94 | nn.Linear(256, d_model), 95 | ) 96 | 97 | assert scaling_xformers is False # TODO: update this block 98 | else: 99 | self.encoder_prenet = nn.Identity() 100 | if scaling_xformers: 101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) 102 | else: 103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) 104 | 105 | self.encoder_position = SinePositionalEmbedding( 106 | d_model, 107 | dropout=0.1, 108 | scale=False, 109 | ) 110 | self.decoder_position = SinePositionalEmbedding( 111 | d_model, dropout=0.1, scale=False 112 | ) 113 | 114 | if scaling_xformers: 115 | self.encoder = TransformerEncoder( 116 | TransformerEncoderLayer( 117 | d_model, 118 | nhead, 119 | dim_feedforward=d_model * 4, 120 | dropout=0.1, 121 | batch_first=True, 122 | norm_first=norm_first, 123 | linear1_self_attention_cls=ScaledLinear, 124 | linear2_self_attention_cls=partial( 125 | ScaledLinear, initial_scale=0.01 126 | ), 127 | linear1_feedforward_cls=ScaledLinear, 128 | linear2_feedforward_cls=partial( 129 | ScaledLinear, initial_scale=0.01 130 | ), 131 | activation=partial( 132 | BalancedDoubleSwish, 133 | channel_dim=-1, 134 | max_abs=10.0, 135 | min_prob=0.25, 136 | ), 137 | layer_norm_cls=IdentityNorm, 138 | ), 139 | num_layers=num_layers, 140 | norm=BalancedBasicNorm(d_model) if norm_first else None, 141 | ) 142 | 143 | self.decoder = nn.TransformerDecoder( 144 | TransformerDecoderLayer( 145 | d_model, 146 | nhead, 147 | dim_feedforward=d_model * 4, 148 | dropout=0.1, 149 | batch_first=True, 150 | norm_first=norm_first, 151 | linear1_self_attention_cls=ScaledLinear, 152 | linear2_self_attention_cls=partial( 153 | ScaledLinear, initial_scale=0.01 154 | ), 155 | linear1_feedforward_cls=ScaledLinear, 156 | linear2_feedforward_cls=partial( 157 | ScaledLinear, initial_scale=0.01 158 | ), 159 | activation=partial( 160 | BalancedDoubleSwish, 161 | channel_dim=-1, 162 | max_abs=10.0, 163 | min_prob=0.25, 164 | ), 165 | layer_norm_cls=IdentityNorm, 166 | ), 167 | num_layers=num_layers, 168 | norm=BalancedBasicNorm(d_model) if norm_first else None, 169 | ) 170 | 171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) 172 | self.stop_layer = nn.Linear(d_model, 1) 173 | else: 174 | self.encoder = nn.TransformerEncoder( 175 | nn.TransformerEncoderLayer( 176 | d_model, 177 | nhead, 178 | dim_feedforward=d_model * 4, 179 | activation=F.relu, 180 | dropout=0.1, 181 | batch_first=True, 182 | norm_first=norm_first, 183 | ), 184 | num_layers=num_layers, 185 | norm=nn.LayerNorm(d_model) if norm_first else None, 186 | ) 187 | 188 | self.decoder = nn.TransformerDecoder( 189 | nn.TransformerDecoderLayer( 190 | d_model, 191 | nhead, 192 | dim_feedforward=d_model * 4, 193 | activation=F.relu, 194 | dropout=0.1, 195 | batch_first=True, 196 | norm_first=norm_first, 197 | ), 198 | num_layers=num_layers, 199 | norm=nn.LayerNorm(d_model) if norm_first else None, 200 | ) 201 | 202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) 203 | self.stop_layer = nn.Linear(d_model, 1) 204 | 205 | self.stop_accuracy_metric = BinaryAccuracy( 206 | threshold=0.5, multidim_average="global" 207 | ) 208 | 209 | # self.apply(self._init_weights) 210 | 211 | # def _init_weights(self, module): 212 | # if isinstance(module, (nn.Linear)): 213 | # module.weight.data.normal_(mean=0.0, std=0.02) 214 | # if isinstance(module, nn.Linear) and module.bias is not None: 215 | # module.bias.data.zero_() 216 | # elif isinstance(module, nn.LayerNorm): 217 | # module.bias.data.zero_() 218 | # module.weight.data.fill_(1.0) 219 | # elif isinstance(module, nn.Embedding): 220 | # module.weight.data.normal_(mean=0.0, std=0.02) 221 | 222 | def forward( 223 | self, 224 | x: torch.Tensor, 225 | x_lens: torch.Tensor, 226 | y: torch.Tensor, 227 | y_lens: torch.Tensor, 228 | reduction: str = "sum", 229 | train_stage: int = 0, 230 | **kwargs, 231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 232 | """ 233 | Args: 234 | x: 235 | A 2-D tensor of shape (N, S). 236 | x_lens: 237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 238 | before padding. 239 | y: 240 | A 3-D tensor of shape (N, T, 8). 241 | y_lens: 242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 243 | before padding. 244 | train_stage: 245 | Not used in this model. 246 | Returns: 247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. 248 | """ 249 | del train_stage 250 | 251 | assert x.ndim == 2, x.shape 252 | assert x_lens.ndim == 1, x_lens.shape 253 | assert y.ndim == 3, y.shape 254 | assert y_lens.ndim == 1, y_lens.shape 255 | 256 | assert torch.all(x_lens > 0) 257 | 258 | # NOTE: x has been padded in TextTokenCollater 259 | x_mask = make_pad_mask(x_lens).to(x.device) 260 | 261 | x = self.text_embedding(x) 262 | x = self.encoder_prenet(x) 263 | x = self.encoder_position(x) 264 | x = self.encoder(x, src_key_padding_mask=x_mask) 265 | 266 | total_loss, metrics = 0.0, {} 267 | 268 | y_mask = make_pad_mask(y_lens).to(y.device) 269 | y_mask_float = y_mask.type(torch.float32) 270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1) 271 | 272 | # Training 273 | # AR Decoder 274 | def pad_y(y): 275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() 276 | # inputs, targets 277 | return y[:, :-1], y[:, 1:] 278 | 279 | y, targets = pad_y(y * data_mask) # mask padding as zeros 280 | 281 | y_emb = self.decoder_prenet(y) 282 | y_pos = self.decoder_position(y_emb) 283 | 284 | y_len = y_lens.max() 285 | tgt_mask = torch.triu( 286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), 287 | diagonal=1, 288 | ) 289 | y_dec = self.decoder( 290 | y_pos, 291 | x, 292 | tgt_mask=tgt_mask, 293 | memory_key_padding_mask=x_mask, 294 | ) 295 | 296 | predict = self.predict_layer(y_dec) 297 | # loss 298 | total_loss = F.mse_loss(predict, targets, reduction=reduction) 299 | 300 | logits = self.stop_layer(y_dec).squeeze(-1) 301 | stop_loss = F.binary_cross_entropy_with_logits( 302 | logits, 303 | y_mask_float.detach(), 304 | weight=1.0 + y_mask_float.detach() * 4.0, 305 | reduction=reduction, 306 | ) 307 | metrics["stop_loss"] = stop_loss.detach() 308 | 309 | stop_accuracy = self.stop_accuracy_metric( 310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64), 311 | y_mask.type(torch.int64), 312 | ) 313 | # icefall MetricsTracker.norm_items() 314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( 315 | torch.float32 316 | ) 317 | 318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics) 319 | 320 | def inference( 321 | self, 322 | x: torch.Tensor, 323 | x_lens: torch.Tensor, 324 | y: Any = None, 325 | **kwargs, 326 | ) -> torch.Tensor: 327 | """ 328 | Args: 329 | x: 330 | A 2-D tensor of shape (1, S). 331 | x_lens: 332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x` 333 | before padding. 334 | Returns: 335 | Return the predicted audio code matrix and cross-entropy loss. 336 | """ 337 | assert x.ndim == 2, x.shape 338 | assert x_lens.ndim == 1, x_lens.shape 339 | 340 | assert torch.all(x_lens > 0) 341 | 342 | x_mask = make_pad_mask(x_lens).to(x.device) 343 | 344 | x = self.text_embedding(x) 345 | x = self.encoder_prenet(x) 346 | x = self.encoder_position(x) 347 | x = self.encoder(x, src_key_padding_mask=x_mask) 348 | 349 | x_mask = make_pad_mask(x_lens).to(x.device) 350 | 351 | # AR Decoder 352 | # TODO: Managing decoder steps avoid repetitive computation 353 | y = torch.zeros( 354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device 355 | ) 356 | while True: 357 | y_emb = self.decoder_prenet(y) 358 | y_pos = self.decoder_position(y_emb) 359 | 360 | tgt_mask = torch.triu( 361 | torch.ones( 362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool 363 | ), 364 | diagonal=1, 365 | ) 366 | 367 | y_dec = self.decoder( 368 | y_pos, 369 | x, 370 | tgt_mask=tgt_mask, 371 | memory_mask=None, 372 | memory_key_padding_mask=x_mask, 373 | ) 374 | predict = self.predict_layer(y_dec[:, -1:]) 375 | 376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): 378 | print( 379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" 380 | ) 381 | break 382 | 383 | y = torch.concat([y, predict], dim=1) 384 | 385 | return y[:, 1:] 386 | 387 | def visualize( 388 | self, 389 | predicts: Tuple[torch.Tensor], 390 | batch: Dict[str, Union[List, torch.Tensor]], 391 | output_dir: str, 392 | limit: int = 4, 393 | ) -> None: 394 | visualize(predicts, batch, output_dir, limit=limit) 395 | -------------------------------------------------------------------------------- /valle/models/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from typing import Dict, List, Tuple, Union 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def visualize( 27 | predicts: Tuple[torch.Tensor], 28 | batch: Dict[str, Union[List, torch.Tensor]], 29 | output_dir: str, 30 | limit: int = 4, 31 | ) -> None: 32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy() 33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() 34 | audio_features = batch["audio_features"].to("cpu").detach().numpy() 35 | audio_features_lens = ( 36 | batch["audio_features_lens"].to("cpu").detach().numpy() 37 | ) 38 | assert text_tokens.ndim == 2 39 | 40 | utt_ids, texts = batch["utt_id"], batch["text"] 41 | 42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() 43 | decoder_outputs = predicts[1] 44 | if isinstance(decoder_outputs, list): 45 | decoder_outputs = decoder_outputs[-1] 46 | decoder_outputs = ( 47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy() 48 | ) 49 | 50 | vmin, vmax = 0, 1024 # Encodec 51 | if decoder_outputs.dtype == np.float32: 52 | vmin, vmax = -6, 0 # Fbank 53 | 54 | num_figures = 3 55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): 56 | _ = plt.figure(figsize=(14, 8 * num_figures)) 57 | 58 | S = text_tokens_lens[b] 59 | T = audio_features_lens[b] 60 | 61 | # encoder 62 | plt.subplot(num_figures, 1, 1) 63 | plt.title(f"Text: {text}") 64 | plt.imshow( 65 | X=np.transpose(encoder_outputs[b]), 66 | cmap=plt.get_cmap("jet"), 67 | aspect="auto", 68 | interpolation="nearest", 69 | ) 70 | plt.gca().invert_yaxis() 71 | plt.axvline(x=S - 0.4, linewidth=2, color="r") 72 | plt.xlabel("Encoder Output") 73 | plt.colorbar() 74 | 75 | # decoder 76 | plt.subplot(num_figures, 1, 2) 77 | plt.imshow( 78 | X=np.transpose(decoder_outputs[b]), 79 | cmap=plt.get_cmap("jet"), 80 | aspect="auto", 81 | interpolation="nearest", 82 | vmin=vmin, 83 | vmax=vmax, 84 | ) 85 | plt.gca().invert_yaxis() 86 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 87 | plt.xlabel("Decoder Output") 88 | plt.colorbar() 89 | 90 | # target 91 | plt.subplot(num_figures, 1, 3) 92 | plt.imshow( 93 | X=np.transpose(audio_features[b]), 94 | cmap=plt.get_cmap("jet"), 95 | aspect="auto", 96 | interpolation="nearest", 97 | vmin=vmin, 98 | vmax=vmax, 99 | ) 100 | plt.gca().invert_yaxis() 101 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 102 | plt.xlabel("Decoder Target") 103 | plt.colorbar() 104 | 105 | plt.savefig(f"{output_dir}/{utt_id}.png") 106 | plt.close() 107 | -------------------------------------------------------------------------------- /valle/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/modules/__init__.py -------------------------------------------------------------------------------- /valle/modules/activation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Linear, Module 6 | from torch.nn import functional as F 7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 8 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | class MultiheadAttention(Module): 13 | r"""Allows the model to jointly attend to information 14 | from different representation subspaces as described in the paper: 15 | `Attention Is All You Need `_. 16 | 17 | Multi-Head Attention is defined as: 18 | 19 | .. math:: 20 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 21 | 22 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 23 | 24 | ``forward()`` will use a special optimized implementation if all of the following 25 | conditions are met: 26 | 27 | - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This 28 | restriction will be loosened in the future.) 29 | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` 30 | - training is disabled (using ``.eval()``) 31 | - dropout is 0 32 | - ``add_bias_kv`` is ``False`` 33 | - ``add_zero_attn`` is ``False`` 34 | - ``batch_first`` is ``True`` and the input is batched 35 | - ``kdim`` and ``vdim`` are equal to ``embed_dim`` 36 | - at most one of ``key_padding_mask`` or ``attn_mask`` is passed 37 | - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` 38 | nor ``attn_mask`` is passed 39 | 40 | If the optimized implementation is in use, a 41 | `NestedTensor `_ can be passed for 42 | ``query``/``key``/``value`` to represent padding more efficiently than using a 43 | padding mask. In this case, a `NestedTensor `_ 44 | will be returned, and an additional speedup proportional to the fraction of the input 45 | that is padding can be expected. 46 | 47 | Args: 48 | embed_dim: Total dimension of the model. 49 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split 50 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). 51 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). 52 | bias: If specified, adds bias to input / output projection layers. Default: ``True``. 53 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. 54 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. 55 | Default: ``False``. 56 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). 57 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). 58 | batch_first: If ``True``, then the input and output tensors are provided 59 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 60 | 61 | Examples:: 62 | 63 | >>> # xdoctest: +SKIP 64 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 65 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 66 | 67 | """ 68 | __constants__ = ["batch_first"] 69 | bias_k: Optional[torch.Tensor] 70 | bias_v: Optional[torch.Tensor] 71 | 72 | def __init__( 73 | self, 74 | embed_dim, 75 | num_heads, 76 | dropout=0.0, 77 | bias=True, 78 | add_bias_kv=False, 79 | add_zero_attn=False, 80 | kdim=None, 81 | vdim=None, 82 | batch_first=False, 83 | linear1_cls=Linear, 84 | linear2_cls=Linear, 85 | device=None, 86 | dtype=None, 87 | ) -> None: 88 | factory_kwargs = {"device": device, "dtype": dtype} 89 | super(MultiheadAttention, self).__init__() 90 | self.embed_dim = embed_dim 91 | self.kdim = kdim if kdim is not None else embed_dim 92 | self.vdim = vdim if vdim is not None else embed_dim 93 | self._qkv_same_embed_dim = ( 94 | self.kdim == embed_dim and self.vdim == embed_dim 95 | ) 96 | 97 | self.num_heads = num_heads 98 | self.dropout = dropout 99 | self.batch_first = batch_first 100 | self.head_dim = embed_dim // num_heads 101 | assert ( 102 | self.head_dim * num_heads == self.embed_dim 103 | ), "embed_dim must be divisible by num_heads" 104 | 105 | if add_bias_kv: 106 | self.bias_k = Parameter( 107 | torch.empty((1, 1, embed_dim), **factory_kwargs) 108 | ) 109 | self.bias_v = Parameter( 110 | torch.empty((1, 1, embed_dim), **factory_kwargs) 111 | ) 112 | else: 113 | self.bias_k = self.bias_v = None 114 | 115 | if linear1_cls == Linear: 116 | if not self._qkv_same_embed_dim: 117 | self.q_proj_weight = Parameter( 118 | torch.empty((embed_dim, embed_dim), **factory_kwargs) 119 | ) 120 | self.k_proj_weight = Parameter( 121 | torch.empty((embed_dim, self.kdim), **factory_kwargs) 122 | ) 123 | self.v_proj_weight = Parameter( 124 | torch.empty((embed_dim, self.vdim), **factory_kwargs) 125 | ) 126 | self.register_parameter("in_proj_weight", None) 127 | else: 128 | self.in_proj_weight = Parameter( 129 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 130 | ) 131 | self.register_parameter("q_proj_weight", None) 132 | self.register_parameter("k_proj_weight", None) 133 | self.register_parameter("v_proj_weight", None) 134 | 135 | if bias: 136 | self.in_proj_bias = Parameter( 137 | torch.empty(3 * embed_dim, **factory_kwargs) 138 | ) 139 | else: 140 | self.register_parameter("in_proj_bias", None) 141 | self.out_proj = NonDynamicallyQuantizableLinear( 142 | embed_dim, embed_dim, bias=bias, **factory_kwargs 143 | ) 144 | 145 | self._reset_parameters() 146 | else: 147 | if not self._qkv_same_embed_dim: 148 | raise NotImplementedError 149 | else: 150 | self.in_proj_linear = linear1_cls( 151 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs 152 | ) 153 | self.in_proj_weight = self.in_proj_linear.weight 154 | 155 | self.register_parameter("q_proj_weight", None) 156 | self.register_parameter("k_proj_weight", None) 157 | self.register_parameter("v_proj_weight", None) 158 | 159 | if bias: 160 | self.in_proj_bias = self.in_proj_linear.bias 161 | else: 162 | self.register_parameter("in_proj_bias", None) 163 | 164 | self.out_proj = linear2_cls( 165 | embed_dim, embed_dim, bias=bias, **factory_kwargs 166 | ) 167 | 168 | if self.bias_k is not None: 169 | xavier_normal_(self.bias_k) 170 | if self.bias_v is not None: 171 | xavier_normal_(self.bias_v) 172 | 173 | self.add_zero_attn = add_zero_attn 174 | 175 | def _reset_parameters(self): 176 | if self._qkv_same_embed_dim: 177 | xavier_uniform_(self.in_proj_weight) 178 | else: 179 | xavier_uniform_(self.q_proj_weight) 180 | xavier_uniform_(self.k_proj_weight) 181 | xavier_uniform_(self.v_proj_weight) 182 | 183 | if self.in_proj_bias is not None: 184 | constant_(self.in_proj_bias, 0.0) 185 | constant_(self.out_proj.bias, 0.0) 186 | 187 | if self.bias_k is not None: 188 | xavier_normal_(self.bias_k) 189 | if self.bias_v is not None: 190 | xavier_normal_(self.bias_v) 191 | 192 | def __setstate__(self, state): 193 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 194 | if "_qkv_same_embed_dim" not in state: 195 | state["_qkv_same_embed_dim"] = True 196 | 197 | super(MultiheadAttention, self).__setstate__(state) 198 | 199 | def forward( 200 | self, 201 | query: Tensor, 202 | key: Tensor, 203 | value: Tensor, 204 | key_padding_mask: Optional[Tensor] = None, 205 | need_weights: bool = True, 206 | attn_mask: Optional[Tensor] = None, 207 | average_attn_weights: bool = True, 208 | ) -> Tuple[Tensor, Optional[Tensor]]: 209 | r""" 210 | Args: 211 | query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` 212 | or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, 213 | :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. 214 | Queries are compared against key-value pairs to produce the output. 215 | See "Attention Is All You Need" for more details. 216 | key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` 217 | or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, 218 | :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. 219 | See "Attention Is All You Need" for more details. 220 | value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when 221 | ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source 222 | sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. 223 | See "Attention Is All You Need" for more details. 224 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` 225 | to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. 226 | Binary and byte masks are supported. 227 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for 228 | the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. 229 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. 230 | Default: ``True``. 231 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 232 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, 233 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be 234 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. 235 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the 236 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the 237 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to 238 | the attention weight. 239 | average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across 240 | heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an 241 | effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) 242 | 243 | Outputs: 244 | - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, 245 | :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, 246 | where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the 247 | embedding dimension ``embed_dim``. 248 | - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, 249 | returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 250 | :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 251 | :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 252 | head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. 253 | 254 | .. note:: 255 | `batch_first` argument is ignored for unbatched inputs. 256 | """ 257 | is_batched = query.dim() == 3 258 | if key_padding_mask is not None: 259 | _kpm_dtype = key_padding_mask.dtype 260 | if _kpm_dtype != torch.bool and not torch.is_floating_point( 261 | key_padding_mask 262 | ): 263 | raise AssertionError( 264 | "only bool and floating types of key_padding_mask are supported" 265 | ) 266 | why_not_fast_path = "" 267 | if not is_batched: 268 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" 269 | elif query is not key or key is not value: 270 | # When lifting this restriction, don't forget to either 271 | # enforce that the dtypes all match or test cases where 272 | # they don't! 273 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" 274 | elif ( 275 | self.in_proj_bias is not None 276 | and query.dtype != self.in_proj_bias.dtype 277 | ): 278 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" 279 | elif ( 280 | self.in_proj_weight is not None 281 | and query.dtype != self.in_proj_weight.dtype 282 | ): 283 | # this case will fail anyway, but at least they'll get a useful error message. 284 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" 285 | elif self.training: 286 | why_not_fast_path = "training is enabled" 287 | elif not self.batch_first: 288 | why_not_fast_path = "batch_first was not True" 289 | elif self.bias_k is not None: 290 | why_not_fast_path = "self.bias_k was not None" 291 | elif self.bias_v is not None: 292 | why_not_fast_path = "self.bias_v was not None" 293 | elif self.dropout: 294 | why_not_fast_path = f"dropout was {self.dropout}, required zero" 295 | elif self.add_zero_attn: 296 | why_not_fast_path = "add_zero_attn was enabled" 297 | elif not self._qkv_same_embed_dim: 298 | why_not_fast_path = "_qkv_same_embed_dim was not True" 299 | elif attn_mask is not None: 300 | why_not_fast_path = "attn_mask was not None" 301 | elif query.is_nested and key_padding_mask is not None: 302 | why_not_fast_path = ( 303 | "key_padding_mask is not supported with NestedTensor input" 304 | ) 305 | elif self.num_heads % 2 == 1: 306 | why_not_fast_path = "num_heads is odd" 307 | elif torch.is_autocast_enabled(): 308 | why_not_fast_path = "autocast is enabled" 309 | 310 | if not why_not_fast_path: 311 | tensor_args = ( 312 | query, 313 | key, 314 | value, 315 | self.in_proj_weight, 316 | self.in_proj_bias, 317 | self.out_proj.weight, 318 | self.out_proj.bias, 319 | ) 320 | # We have to use list comprehensions below because TorchScript does not support 321 | # generator expressions. 322 | if torch.overrides.has_torch_function(tensor_args): 323 | why_not_fast_path = "some Tensor argument has_torch_function" 324 | elif not all( 325 | [ 326 | (x is None or x.is_cuda or "cpu" in str(x.device)) 327 | for x in tensor_args 328 | ] 329 | ): 330 | why_not_fast_path = ( 331 | "some Tensor argument is neither CUDA nor CPU" 332 | ) 333 | elif torch.is_grad_enabled() and any( 334 | [x is not None and x.requires_grad for x in tensor_args] 335 | ): 336 | why_not_fast_path = ( 337 | "grad is enabled and at least one of query or the " 338 | "input/output projection weights or biases requires_grad" 339 | ) 340 | if not why_not_fast_path: 341 | return torch._native_multi_head_attention( 342 | query, 343 | key, 344 | value, 345 | self.embed_dim, 346 | self.num_heads, 347 | self.in_proj_weight, 348 | self.in_proj_bias, 349 | self.out_proj.weight, 350 | self.out_proj.bias, 351 | key_padding_mask 352 | if key_padding_mask is not None 353 | else attn_mask, 354 | need_weights, 355 | average_attn_weights, 356 | 1 357 | if key_padding_mask is not None 358 | else 0 359 | if attn_mask is not None 360 | else None, 361 | ) 362 | 363 | any_nested = query.is_nested or key.is_nested or value.is_nested 364 | assert not any_nested, ( 365 | "MultiheadAttention does not support NestedTensor outside of its fast path. " 366 | + f"The fast path was not hit because {why_not_fast_path}" 367 | ) 368 | 369 | if self.batch_first and is_batched: 370 | # make sure that the transpose op does not affect the "is" property 371 | if key is value: 372 | if query is key: 373 | query = key = value = query.transpose(1, 0) 374 | else: 375 | query, key = [x.transpose(1, 0) for x in (query, key)] 376 | value = key 377 | else: 378 | query, key, value = [ 379 | x.transpose(1, 0) for x in (query, key, value) 380 | ] 381 | 382 | if not self._qkv_same_embed_dim: 383 | attn_output, attn_output_weights = F.multi_head_attention_forward( 384 | query, 385 | key, 386 | value, 387 | self.embed_dim, 388 | self.num_heads, 389 | self.in_proj_weight, 390 | self.in_proj_bias, 391 | self.bias_k, 392 | self.bias_v, 393 | self.add_zero_attn, 394 | self.dropout, 395 | self.out_proj.weight, 396 | self.out_proj.bias, 397 | training=self.training, 398 | key_padding_mask=key_padding_mask, 399 | need_weights=need_weights, 400 | attn_mask=attn_mask, 401 | use_separate_proj_weight=True, 402 | q_proj_weight=self.q_proj_weight, 403 | k_proj_weight=self.k_proj_weight, 404 | v_proj_weight=self.v_proj_weight, 405 | average_attn_weights=average_attn_weights, 406 | ) 407 | else: 408 | attn_output, attn_output_weights = F.multi_head_attention_forward( 409 | query, 410 | key, 411 | value, 412 | self.embed_dim, 413 | self.num_heads, 414 | self.in_proj_weight, 415 | self.in_proj_bias, 416 | self.bias_k, 417 | self.bias_v, 418 | self.add_zero_attn, 419 | self.dropout, 420 | self.out_proj.weight, 421 | self.out_proj.bias, 422 | training=self.training, 423 | key_padding_mask=key_padding_mask, 424 | need_weights=need_weights, 425 | attn_mask=attn_mask, 426 | average_attn_weights=average_attn_weights, 427 | ) 428 | if self.batch_first and is_batched: 429 | return attn_output.transpose(1, 0), attn_output_weights 430 | else: 431 | return attn_output, attn_output_weights 432 | -------------------------------------------------------------------------------- /valle/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class TokenEmbedding(nn.Module): 22 | def __init__( 23 | self, 24 | dim_model: int, 25 | vocab_size: int, 26 | dropout: float = 0.0, 27 | ): 28 | super().__init__() 29 | 30 | self.vocab_size = vocab_size 31 | self.dim_model = dim_model 32 | 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 35 | 36 | @property 37 | def weight(self) -> torch.Tensor: 38 | return self.word_embeddings.weight 39 | 40 | def embedding(self, index: int) -> torch.Tensor: 41 | return self.word_embeddings.weight[index : index + 1] 42 | 43 | def forward(self, x: torch.Tensor): 44 | X = self.word_embeddings(x) 45 | X = self.dropout(X) 46 | 47 | return X 48 | 49 | 50 | class SinePositionalEmbedding(nn.Module): 51 | def __init__( 52 | self, 53 | dim_model: int, 54 | dropout: float = 0.0, 55 | scale: bool = False, 56 | alpha: bool = False, 57 | ): 58 | super().__init__() 59 | self.dim_model = dim_model 60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 62 | self.dropout = torch.nn.Dropout(p=dropout) 63 | 64 | self.reverse = False 65 | self.pe = None 66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 67 | 68 | def extend_pe(self, x): 69 | """Reset the positional encodings.""" 70 | if self.pe is not None: 71 | if self.pe.size(1) >= x.size(1): 72 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 74 | return 75 | pe = torch.zeros(x.size(1), self.dim_model) 76 | if self.reverse: 77 | position = torch.arange( 78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 79 | ).unsqueeze(1) 80 | else: 81 | position = torch.arange( 82 | 0, x.size(1), dtype=torch.float32 83 | ).unsqueeze(1) 84 | div_term = torch.exp( 85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 86 | * -(math.log(10000.0) / self.dim_model) 87 | ) 88 | pe[:, 0::2] = torch.sin(position * div_term) 89 | pe[:, 1::2] = torch.cos(position * div_term) 90 | pe = pe.unsqueeze(0) 91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | self.extend_pe(x) 95 | output = x.unsqueeze(-1) if x.ndim == 2 else x 96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 97 | return self.dropout(output) 98 | -------------------------------------------------------------------------------- /valle/modules/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import torch 20 | 21 | from valle.modules.optim import Eden 22 | 23 | 24 | def calc_lr(step, dim_embed, warmup_steps): 25 | return dim_embed ** (-0.5) * min( 26 | step ** (-0.5), step * warmup_steps ** (-1.5) 27 | ) 28 | 29 | 30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): 31 | def __init__( 32 | self, 33 | base_lr: float, 34 | optimizer: torch.optim.Optimizer, 35 | dim_embed: int, 36 | warmup_steps: int, 37 | last_epoch: int = -1, 38 | verbose: bool = False, 39 | ) -> None: 40 | 41 | self.dim_embed = dim_embed 42 | self.base_lr = base_lr 43 | self.warmup_steps = warmup_steps 44 | self.num_param_groups = len(optimizer.param_groups) 45 | 46 | super().__init__(optimizer, last_epoch, verbose) 47 | 48 | def get_lr(self) -> float: 49 | lr = self.base_lr * calc_lr( 50 | self._step_count, self.dim_embed, self.warmup_steps 51 | ) 52 | return [lr] * self.num_param_groups 53 | 54 | def set_step(self, step: int): 55 | self._step_count = step 56 | 57 | 58 | def get_scheduler(params, optimizer): 59 | if params.scheduler_name.lower() == "eden": 60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) 61 | elif params.scheduler_name.lower() == "noam": 62 | scheduler = NoamScheduler( 63 | params.base_lr, 64 | optimizer, 65 | params.decoder_dim, 66 | warmup_steps=params.warmup_steps, 67 | ) 68 | # scheduler.set_step(params.start_batch or params.batch_idx_train) 69 | elif params.scheduler_name.lower() == "cosine": 70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 71 | params.warmup_steps, 72 | optimizer, 73 | eta_min=params.base_lr, 74 | ) 75 | else: 76 | raise NotImplementedError(f"{params.scheduler_name}") 77 | 78 | return scheduler 79 | -------------------------------------------------------------------------------- /valle/tests/data/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Zhao Ming) 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import unittest 18 | 19 | from valle.data import TextTokenizer 20 | 21 | 22 | class TestTextTokenizer(unittest.TestCase): 23 | def test_espeak(self): 24 | text_tokenizer = TextTokenizer(backend="espeak") 25 | 26 | for (_input, _target) in [ 27 | ("The two parties, the sheep and the wolves, met each other.", 28 | ['ð', 'ə', '_', 't', 'uː', '_', 'p', 'ɑːɹ', 'ɾ',]), # 'i', 'z', ',', '_', 'ð'] 29 | ("Mother! dear father! do you hear me?", 30 | ['m', 'ʌ', 'ð', 'ɚ', '!', '_', 'd', 'ɪɹ', '_', 'f', 'ɑː', 'ð', 'ɚ', '!']), 31 | ("\"Whoever thou art,\" She exclaimed, suddenly seizing Rodolfo's hand,", 32 | ['"', 'h', 'uː', 'ɛ', 'v', 'ɚ', '_', 'ð', 'aʊ', '_', 'ɑːɹ', 't', ',', '"', '_', 'ʃ', 'iː', 33 | '_', 'ɛ', 'k', 's', 'k', 'l', 'eɪ', 'm', 'd', ',', '_', 's', 'ʌ', 'd', 'ə', 'n', 'l', 'i', 34 | '_', 's', 'iː', 'z', 'ɪ', 'ŋ', '_', 'ɹ', 'ə', 'd', 'ɑː', 'l', 'f', 'oʊ', 'z', '_', 'h', 35 | 'æ', 'n', 'd', ',']) 36 | ]: 37 | phonemized = text_tokenizer(_input) 38 | self.assertEqual(phonemized[0][:len(_target)], _target) 39 | 40 | def test_pypinyin(self): 41 | text_tokenizer = TextTokenizer(backend="pypinyin") 42 | 43 | for (_input, _target) in [ 44 | ("你好这是测试", 45 | ["ni3", '-', "hao3", '-', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4"]), 46 | ("\"你好\", 这是测试.", 47 | ["\"", "ni3", '-', "hao3", "\"", ",", '_', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4", "."]), 48 | ("此项 工作 还能 怎么 改进", 49 | ['ci3', '-', 'xiang4', '_', 'gong1', '-', 'zuo4', '_', 50 | 'hai2', '-', 'neng2', '_', 'zen3', '-', 'me5', '_', 'gai3', '-', 'jin4']), # AISHELL 51 | ]: 52 | phonemized = text_tokenizer(_input) 53 | self.assertEqual(phonemized[0], _target) 54 | 55 | def test_pypinyin_initials_finals(self): 56 | text_tokenizer = TextTokenizer(backend="pypinyin_initials_finals") 57 | 58 | for (_input, _target) in [ 59 | ("你好这是测试", 60 | ["n", "i3", "-", "h", "ao3", "-", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4"], 61 | ), 62 | ("\"你好.这是测试.", 63 | ["\"", "n", "i3", "-", "h", "ao3", ".", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."], 64 | ), 65 | ("\"你好. 这是测试.", 66 | ["\"", "n", "i3", "-", "h", "ao3", ".", "_", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."], 67 | ), 68 | ("此项 工作 还能 怎么 改进", ['c', 'i3', '-', 'x', 'iang4', '_', 'g', 'ong1', '-', 'z', 'uo4', '_', 69 | 'h', 'ai2', '-', 'n', 'eng2', '_', 'z', 'en3', '-', 'm', 'e5', '_', 70 | 'g', 'ai3', '-', 'j', 'in4']), # AISHELL 71 | ]: 72 | phonemized = text_tokenizer(_input) 73 | self.assertListEqual(phonemized[0], _target) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /valle/tests/scaling_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | import torch 8 | from icefall.utils import AttributeDict 9 | 10 | from valle.models import NUM_MEL_BINS, get_model 11 | 12 | 13 | class TestModel(unittest.TestCase): 14 | @classmethod 15 | def setUpClass(cls): 16 | cls.devices = [torch.device("cpu")] 17 | if torch.cuda.is_available(): 18 | cls.devices.append(torch.device("cuda", 0)) 19 | if torch.cuda.device_count() > 1: 20 | torch.cuda.set_device(1) 21 | cls.devices.append(torch.device("cuda", 1)) 22 | 23 | def test_scaling_transformer(self): 24 | params = AttributeDict() 25 | params.decoder_dim = 64 26 | params.nhead = 4 27 | params.num_decoder_layers = 4 28 | 29 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 30 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 31 | x_lens[-1] = 8 32 | 33 | y = torch.from_numpy( 34 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32) 35 | ) 36 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 37 | y_lens[-1] = 16 38 | 39 | params.model_name = "Transformer" 40 | params.norm_first = False 41 | params.add_prenet = False 42 | params.scaling_xformers = True 43 | 44 | for device in self.devices: 45 | # Transformer 46 | model = get_model(params) 47 | num_param = sum([p.numel() for p in model.parameters()]) 48 | 49 | model.to(device) 50 | x = x.to(device) 51 | x_lens = x_lens.to(device) 52 | y = y.to(device) 53 | y_lens = y_lens.to(device) 54 | 55 | # Training 56 | codes, loss, metrics = model(x, x_lens, y, y_lens) 57 | # Inference 58 | model.eval() 59 | codes = model.inference(x[-1:], x_lens[-1:]) 60 | params.add_prenet = False 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /valle/tests/valle_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import unittest 17 | 18 | import numpy as np 19 | import torch 20 | from icefall.utils import AttributeDict 21 | from torchmetrics.classification import MulticlassAccuracy 22 | 23 | from valle.data.input_strategies import PromptedFeatures 24 | from valle.models import NUM_MEL_BINS, get_model 25 | 26 | 27 | class TestModel(unittest.TestCase): 28 | @classmethod 29 | def setUpClass(cls): 30 | cls.devices = [torch.device("cpu")] 31 | if torch.cuda.is_available(): 32 | cls.devices.append(torch.device("cuda", 0)) 33 | if torch.cuda.device_count() > 1: 34 | torch.cuda.set_device(1) 35 | cls.devices.append(torch.device("cuda", 1)) 36 | 37 | def test_vallf(self): 38 | params = AttributeDict() 39 | params.decoder_dim = 64 40 | params.nhead = 16 41 | params.num_decoder_layers = 4 42 | 43 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 44 | x_lens = torch.from_numpy(np.random.randint(6, 8, size=[4])) 45 | x_lens[-1] = 8 46 | enroll_x_lens = torch.from_numpy(np.random.randint(2, 4, size=[4])) 47 | 48 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 49 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 50 | y_lens[-1] = 16 51 | 52 | params.norm_first = True 53 | params.add_prenet = False 54 | params.model_name = "VALL-F" 55 | params.share_embedding = True 56 | params.scale_factor = 1.0 57 | params.prepend_bos = True 58 | params.num_quantizers = 1 59 | 60 | for device in self.devices: 61 | for mode in [0, 1, 2]: 62 | params.prefix_mode = mode 63 | # VALL-E 64 | model = get_model(params) 65 | 66 | # VALL-F 67 | model.to(device) 68 | x = x.to(device) 69 | x_lens = x_lens.to(device) 70 | y = y.to(device) 71 | y_lens = y_lens.to(device) 72 | 73 | # Training 74 | for train_stage in [0, 1, 2]: 75 | codes, loss, metrics = model( 76 | x, x_lens, y, y_lens, train_stage=train_stage 77 | ) 78 | 79 | # Inference 80 | model.eval() 81 | codes = model.inference( 82 | x[-1:], 83 | x_lens[-1:], 84 | y[-1:], 85 | enroll_x_lens=enroll_x_lens[-1:], 86 | ) 87 | 88 | params.prepend_bos = not params.prepend_bos 89 | params.num_quantizers += 1 90 | 91 | def test_valle(self): 92 | params = AttributeDict() 93 | params.decoder_dim = 64 94 | params.nhead = 16 95 | params.num_decoder_layers = 4 96 | 97 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 98 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 99 | x_lens[-1] = 8 100 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4])) 101 | 102 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 103 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 104 | y_lens[-1] = 16 105 | 106 | params.norm_first = False 107 | params.add_prenet = True 108 | params.model_name = "VALL-E" 109 | params.share_embedding = True 110 | params.scale_factor = 1.0 111 | params.prepend_bos = False 112 | params.num_quantizers = 8 113 | 114 | for device in self.devices: 115 | for mode in [0, 1, 2]: 116 | params.prefix_mode = mode 117 | # VALL-E 118 | model = get_model(params) 119 | model.to(device) 120 | x = x.to(device) 121 | x_lens = x_lens.to(device) 122 | y = y.to(device) 123 | y_lens = y_lens.to(device) 124 | 125 | # Training 126 | codes, loss, metrics = model(x, x_lens, y, y_lens) 127 | # Inference 128 | model.eval() 129 | codes = model.inference( 130 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens 131 | ) 132 | params.scale_factor = 0.5 133 | 134 | params.prepend_bos = not params.prepend_bos 135 | params.num_quantizers -= 1 136 | 137 | def test_vallef_prefix4(self): 138 | params = AttributeDict() 139 | params.decoder_dim = 64 140 | params.nhead = 16 141 | params.num_decoder_layers = 4 142 | 143 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 144 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 145 | x_lens[-1] = 8 146 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4])) 147 | 148 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 149 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 150 | y_lens[-1] = 16 151 | 152 | prompts = torch.from_numpy(np.random.randint(0, 1000, size=[4, 12, 8])) 153 | prompts_lens = torch.from_numpy(np.random.randint(12, 13, size=[4])) 154 | 155 | params.norm_first = False 156 | params.add_prenet = True 157 | params.share_embedding = False 158 | params.scale_factor = 1.0 159 | params.prepend_bos = False 160 | params.num_quantizers = 8 161 | 162 | for device in self.devices: 163 | for model_name in ["VALL-E", "VALL-F"]: 164 | for mode in [4]: 165 | params.prefix_mode = mode 166 | params.model_name = model_name 167 | # VALL-E 168 | model = get_model(params) 169 | model.to(device) 170 | x = x.to(device) 171 | x_lens = x_lens.to(device) 172 | y = y.to(device) 173 | 174 | _y = PromptedFeatures(prompts, y).to(device) 175 | _y_lens = PromptedFeatures(prompts_lens, y_lens).to(device) 176 | 177 | # Training 178 | codes, loss, metrics = model(x, x_lens, _y, _y_lens) 179 | # Inference 180 | model.eval() 181 | codes = model.inference( 182 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens 183 | ) 184 | 185 | def test_topmetric(self): 186 | metric_top10 = MulticlassAccuracy(1024, top_k=10, average="micro") 187 | metric_top1 = MulticlassAccuracy(1024, top_k=1, average="micro") 188 | batch_size, seq_len = 4, 16 189 | targets = np.random.randint(0, 1000, size=[batch_size, seq_len]) 190 | logits = np.random.random([batch_size, 1024, seq_len]).astype( 191 | np.float32 192 | ) 193 | 194 | larger_logits = np.clip(logits, -1.0, 1.0) 195 | smaller_logits = np.clip(logits, -1.0, 1.0) 196 | for b in range(batch_size): 197 | for t in range(seq_len): 198 | assert targets[b, t] >= 0 199 | larger_logits[b, targets[b, t], t] = 2.0 200 | smaller_logits[b, targets[b, t], t] = -2.0 201 | 202 | targets = torch.from_numpy(targets) 203 | larger_logits = torch.from_numpy(larger_logits) 204 | smaller_logits = torch.from_numpy(smaller_logits) 205 | 206 | for device in self.devices: 207 | metric_top10.to(device) 208 | metric_top1.to(device) 209 | targets = targets.to(device) 210 | 211 | one = metric_top10(larger_logits.to(device), targets) 212 | assert one.cpu().item() == 1.0, one.cpu().item() 213 | 214 | zero = metric_top1(smaller_logits.to(device), targets) 215 | assert zero.cpu().item() == 0.0, zero.cpu().item() 216 | 217 | half = metric_top1( 218 | torch.concat( 219 | [smaller_logits.to(device), larger_logits.to(device)], dim=2 220 | ), 221 | torch.concat([targets, targets], dim=1), 222 | ) 223 | assert half.cpu().item() == 0.5, half.cpu().item() 224 | 225 | def test_transformer(self): 226 | params = AttributeDict() 227 | params.decoder_dim = 64 228 | params.nhead = 4 229 | params.num_decoder_layers = 4 230 | 231 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 232 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 233 | x_lens[-1] = 8 234 | 235 | y = torch.from_numpy( 236 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32) 237 | ) 238 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 239 | y_lens[-1] = 16 240 | 241 | params.model_name = "Transformer" 242 | params.norm_first = False 243 | params.add_prenet = True 244 | params.scaling_xformers = False 245 | 246 | for device in self.devices: 247 | # Transformer 248 | model = get_model(params) 249 | num_param = sum([p.numel() for p in model.parameters()]) 250 | 251 | model.to(device) 252 | x = x.to(device) 253 | x_lens = x_lens.to(device) 254 | y = y.to(device) 255 | y_lens = y_lens.to(device) 256 | 257 | # Training 258 | codes, loss, metrics = model(x, x_lens, y, y_lens) 259 | # Inference 260 | model.eval() 261 | codes = model.inference(x[-1:], x_lens[-1:]) 262 | params.add_prenet = False 263 | 264 | params.scaling_xformers = not params.scaling_xformers 265 | 266 | 267 | if __name__ == "__main__": 268 | unittest.main() 269 | -------------------------------------------------------------------------------- /valle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from icefall.utils import make_pad_mask 4 | 5 | from .symbol_table import SymbolTable 6 | 7 | make_pad_mask = make_pad_mask 8 | SymbolTable = SymbolTable 9 | 10 | 11 | class Transpose(nn.Identity): 12 | """(N, T, D) -> (N, D, T)""" 13 | 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | return input.transpose(1, 2) 16 | -------------------------------------------------------------------------------- /valle/utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | from dataclasses import field 19 | from typing import Dict 20 | from typing import Generic 21 | from typing import List 22 | from typing import Optional 23 | from typing import TypeVar 24 | from typing import Union 25 | 26 | Symbol = TypeVar('Symbol') 27 | 28 | 29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 30 | @dataclass(repr=False) 31 | class SymbolTable(Generic[Symbol]): 32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 33 | actual objects. These objects can be arbitrary Python objects 34 | that can serve as keys in a dictionary (i.e. they need to be 35 | hashable and immutable). 36 | 37 | The SymbolTable can only be read to/written from disk if the 38 | symbols are strings. 39 | ''' 40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 41 | '''Map an integer to a symbol. 42 | ''' 43 | 44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 45 | '''Map a symbol to an integer. 46 | ''' 47 | 48 | _next_available_id: int = 1 49 | '''A helper internal field that helps adding new symbols 50 | to the table efficiently. 51 | ''' 52 | 53 | eps: Symbol = '' 54 | '''Null symbol, always mapped to index 0. 55 | ''' 56 | 57 | def __post_init__(self): 58 | for idx, sym in self._id2sym.items(): 59 | assert self._sym2id[sym] == idx 60 | assert idx >= 0 61 | 62 | for sym, idx in self._sym2id.items(): 63 | assert idx >= 0 64 | assert self._id2sym[idx] == sym 65 | 66 | if 0 not in self._id2sym: 67 | self._id2sym[0] = self.eps 68 | self._sym2id[self.eps] = 0 69 | else: 70 | assert self._id2sym[0] == self.eps 71 | assert self._sym2id[self.eps] == 0 72 | 73 | self._next_available_id = max(self._id2sym) + 1 74 | 75 | @staticmethod 76 | def from_str(s: str) -> 'SymbolTable': 77 | '''Build a symbol table from a string. 78 | 79 | The string consists of lines. Every line has two fields separated 80 | by space(s), tab(s) or both. The first field is the symbol and the 81 | second the integer id of the symbol. 82 | 83 | Args: 84 | s: 85 | The input string with the format described above. 86 | Returns: 87 | An instance of :class:`SymbolTable`. 88 | ''' 89 | id2sym: Dict[int, str] = dict() 90 | sym2id: Dict[str, int] = dict() 91 | 92 | for line in s.split('\n'): 93 | fields = line.split() 94 | if len(fields) == 0: 95 | continue # skip empty lines 96 | assert len(fields) == 2, \ 97 | f'Expect a line with 2 fields. Given: {len(fields)}' 98 | sym, idx = fields[0], int(fields[1]) 99 | assert sym not in sym2id, f'Duplicated symbol {sym}' 100 | assert idx not in id2sym, f'Duplicated id {idx}' 101 | id2sym[idx] = sym 102 | sym2id[sym] = idx 103 | 104 | eps = id2sym.get(0, '') 105 | 106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 107 | 108 | @staticmethod 109 | def from_file(filename: str) -> 'SymbolTable': 110 | '''Build a symbol table from file. 111 | 112 | Every line in the symbol table file has two fields separated by 113 | space(s), tab(s) or both. The following is an example file: 114 | 115 | .. code-block:: 116 | 117 | 0 118 | a 1 119 | b 2 120 | c 3 121 | 122 | Args: 123 | filename: 124 | Name of the symbol table file. Its format is documented above. 125 | 126 | Returns: 127 | An instance of :class:`SymbolTable`. 128 | 129 | ''' 130 | with open(filename, 'r', encoding='utf-8') as f: 131 | return SymbolTable.from_str(f.read().strip()) 132 | 133 | def to_str(self) -> str: 134 | ''' 135 | Returns: 136 | Return a string representation of this object. You can pass 137 | it to the method ``from_str`` to recreate an identical object. 138 | ''' 139 | s = '' 140 | for idx, symbol in sorted(self._id2sym.items()): 141 | s += f'{symbol} {idx}\n' 142 | return s 143 | 144 | def to_file(self, filename: str): 145 | '''Serialize the SymbolTable to a file. 146 | 147 | Every line in the symbol table file has two fields separated by 148 | space(s), tab(s) or both. The following is an example file: 149 | 150 | .. code-block:: 151 | 152 | 0 153 | a 1 154 | b 2 155 | c 3 156 | 157 | Args: 158 | filename: 159 | Name of the symbol table file. Its format is documented above. 160 | ''' 161 | with open(filename, 'w') as f: 162 | for idx, symbol in sorted(self._id2sym.items()): 163 | print(symbol, idx, file=f) 164 | 165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 166 | '''Add a new symbol to the SymbolTable. 167 | 168 | Args: 169 | symbol: 170 | The symbol to be added. 171 | index: 172 | Optional int id to which the symbol should be assigned. 173 | If it is not available, a ValueError will be raised. 174 | 175 | Returns: 176 | The int id to which the symbol has been assigned. 177 | ''' 178 | # Already in the table? Return its ID. 179 | if symbol in self._sym2id: 180 | return self._sym2id[symbol] 181 | # Specific ID not provided - use next available. 182 | if index is None: 183 | index = self._next_available_id 184 | # Specific ID provided but not available. 185 | if index in self._id2sym: 186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 187 | f"already occupied by {self._id2sym[index]}") 188 | self._sym2id[symbol] = index 189 | self._id2sym[index] = symbol 190 | 191 | # Update next available ID if needed 192 | if self._next_available_id <= index: 193 | self._next_available_id = index + 1 194 | 195 | return index 196 | 197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 198 | '''Get a symbol for an id or get an id for a symbol 199 | 200 | Args: 201 | k: 202 | If it is an id, it tries to find the symbol corresponding 203 | to the id; if it is a symbol, it tries to find the id 204 | corresponding to the symbol. 205 | 206 | Returns: 207 | An id or a symbol depending on the given `k`. 208 | ''' 209 | if isinstance(k, int): 210 | return self._id2sym[k] 211 | else: 212 | return self._sym2id[k] 213 | 214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 215 | '''Create a union of two SymbolTables. 216 | Raises an AssertionError if the same IDs are occupied by 217 | different symbols. 218 | 219 | Args: 220 | other: 221 | A symbol table to merge with ``self``. 222 | 223 | Returns: 224 | A new symbol table. 225 | ''' 226 | self._check_compatible(other) 227 | 228 | id2sym = {**self._id2sym, **other._id2sym} 229 | sym2id = {**self._sym2id, **other._sym2id} 230 | 231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 232 | 233 | def _check_compatible(self, other: 'SymbolTable') -> None: 234 | # Epsilon compatibility 235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 236 | f'{self.eps} != {other.eps}' 237 | # IDs compatibility 238 | common_ids = set(self._id2sym).intersection(other._id2sym) 239 | for idx in common_ids: 240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 241 | f'self[idx] = "{self[idx]}", ' \ 242 | f'other[idx] = "{other[idx]}"' 243 | # Symbols compatibility 244 | common_symbols = set(self._sym2id).intersection(other._sym2id) 245 | for sym in common_symbols: 246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 247 | f'self[sym] = "{self[sym]}", ' \ 248 | f'other[sym] = "{other[sym]}"' 249 | 250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 251 | return self.get(item) 252 | 253 | def __contains__(self, item: Union[int, Symbol]) -> bool: 254 | if isinstance(item, int): 255 | return item in self._id2sym 256 | else: 257 | return item in self._sym2id 258 | 259 | def __len__(self) -> int: 260 | return len(self._id2sym) 261 | 262 | def __eq__(self, other: 'SymbolTable') -> bool: 263 | if len(self) != len(other): 264 | return False 265 | 266 | for s in self.symbols: 267 | if self[s] != other[s]: 268 | return False 269 | 270 | return True 271 | 272 | @property 273 | def ids(self) -> List[int]: 274 | '''Returns a list of integer IDs corresponding to the symbols. 275 | ''' 276 | ans = list(self._id2sym.keys()) 277 | ans.sort() 278 | return ans 279 | 280 | @property 281 | def symbols(self) -> List[Symbol]: 282 | '''Returns a list of symbols (e.g., strings) corresponding to 283 | the integer IDs. 284 | ''' 285 | ans = list(self._sym2id.keys()) 286 | ans.sort() 287 | return ans 288 | --------------------------------------------------------------------------------