├── .flake8
├── .github
├── FUNDING.yml
├── guide.jpeg
└── sponsor.jpg
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── README.zh-CN.md
├── docker
└── Dockerfile
├── docs
├── images
│ ├── Overview.jpg
│ ├── train.png
│ └── vallf.png
└── requirements.txt
├── egs
├── aishell1
│ ├── README.md
│ ├── bin
│ ├── demos
│ │ └── 0_demo.wav
│ ├── prepare.sh
│ ├── prompts
│ │ ├── ch_24k.txt
│ │ ├── ch_24k.wav
│ │ └── ch_24k_loudness_normalized20.wav
│ └── shared
├── libritts
│ ├── README.md
│ ├── bin
│ ├── prepare.sh
│ ├── prompts
│ │ ├── 8455_210777_000067_000000.txt
│ │ ├── 8455_210777_000067_000000.wav
│ │ ├── 8463_294825_000043_000000.txt
│ │ └── 8463_294825_000043_000000.wav
│ └── shared
│ │ └── parse_options.sh
└── ljspeech
│ ├── README.md
│ ├── bin
│ ├── demos
│ ├── 0.wav
│ └── 1.wav
│ ├── prepare.sh
│ ├── prompts
│ ├── LJ049-0108_24K.txt
│ ├── LJ049-0108_24K.wav
│ ├── LJ049-0110_24K.txt
│ ├── LJ049-0110_24K.wav
│ ├── LJ049-0124.txt
│ ├── LJ049-0124_24K.wav
│ ├── LJ049-0185.txt
│ └── LJ049-0185_24K.wav
│ └── shared
├── examples
├── setup.py
├── test.sh
└── valle
├── __init__.py
├── bin
├── __init__.py
├── display_manifest_statistics.py
├── infer.py
├── tokenizer.py
└── trainer.py
├── data
├── __init__.py
├── collation.py
├── datamodule.py
├── dataset.py
├── fbank.py
├── input_strategies.py
└── tokenizer.py
├── models
├── __init__.py
├── macros.py
├── transformer.py
├── valle.py
└── visualizer.py
├── modules
├── __init__.py
├── activation.py
├── embedding.py
├── optim.py
├── scaling.py
├── scheduler.py
└── transformer.py
├── tests
├── data
│ └── tokenizer_test.py
├── scaling_test.py
└── valle_test.py
└── utils
├── __init__.py
└── symbol_table.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | show-source=true
3 | statistics=true
4 | max-line-length = 80
5 | per-file-ignores =
6 | # line too long
7 | valle/bin/trainer.py: E501, E402,
8 | valle/bin/infer.py: E501, E402,
9 | valle/bin/tokenizer.py: E501, E402,
10 | valle/models/*.py: E501, E203,
11 | valle/models/__init__.py: F401, E501
12 | valle/utils/__init__.py: F401,
13 | valle/__init__.py: F401,
14 | valle/modules/__init__.py: F401,
15 | valle/modules/activation.py: E501,
16 | valle/modules/embedding.py: E203,
17 | valle/modules/scheduler.py: F401,
18 | valle/modules/transformer.py: E501,
19 | valle/modules/scaling.py: E501, F841, F401
20 | valle/data/fbank.py: E501,
21 | valle/data/input_strategies.py: E501,
22 | valle/data/datamodule.py: E501,
23 | valle/tests/*.py: F841, F401
24 | valle/tests/data/*.py: F841, F401
25 |
26 | exclude =
27 | .git,
28 | egs/**/data/**,
29 | setup.py,
30 | valle/utils/symbol_table.py,
31 | valle/modules/optim.py,
32 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: lifeiteng
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: ['https://github.com/lifeiteng/SoundStorm/blob/master/.github/sponsor.jpg']
14 |
--------------------------------------------------------------------------------
/.github/guide.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/.github/guide.jpeg
--------------------------------------------------------------------------------
/.github/sponsor.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/.github/sponsor.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | valle.egg-info/
2 | __pycache__
3 | path.sh
4 | exp
5 | exp*/
6 | *.pt
7 | log
8 | .DS_Store
9 | .vscode
10 | egs/*/data
11 | egs/libritts/prompts/*.png
12 | valle/version.py
13 | infer/
14 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 21.6b0
4 | hooks:
5 | - id: black
6 | args: [--line-length=80]
7 | additional_dependencies: ['click==8.0.1']
8 | exclude: valle\/__init__\.py
9 | exclude: valle\/utils\/symbol_table\.py
10 | exclude: valle\/modules\/scaling\.py
11 | exclude: valle\/tests\/data\/tokenizer\_test\.py
12 |
13 | - repo: https://github.com/PyCQA/flake8
14 | rev: 3.9.2
15 | hooks:
16 | - id: flake8
17 | args: [--max-line-length=80]
18 | exclude: valle\/modules\/scaling\.py
19 | exclude: valle\/tests\/data\/tokenizer\_test\.py
20 |
21 | - repo: https://github.com/pycqa/isort
22 | rev: 5.12.0
23 | hooks:
24 | - id: isort
25 | args: [--profile=black, --line-length=80]
26 | exclude: valle\/utils\/symbol_table\.py
27 | exclude: valle\/modules\/scaling\.py
28 |
29 | - repo: https://github.com/pre-commit/pre-commit-hooks
30 | rev: v4.0.1
31 | hooks:
32 | - id: check-executables-have-shebangs
33 | - id: end-of-file-fixer
34 | - id: mixed-line-ending
35 | - id: trailing-whitespace
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Language : 🇺🇸 | [🇨🇳](./README.zh-CN.md)
2 |
3 | An unofficial PyTorch implementation of VALL-E([Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers](https://arxiv.org/abs/2301.02111)).
4 |
5 | We can train the VALL-E model on one GPU.
6 |
7 | 
8 |
9 | ## Demo
10 |
11 | * [official demo](https://valle-demo.github.io/)
12 | * [reproduced demo](https://lifeiteng.github.io/valle/index.html)
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Broader impacts
20 |
21 | > Since VALL-E could synthesize speech that maintains speaker identity, it may carry potential risks in misuse of the model, such as spoofing voice identification or impersonating a specific speaker.
22 |
23 | To avoid abuse, Well-trained models and services will not be provided.
24 |
25 | ## Install Deps
26 |
27 | To get up and running quickly just follow the steps below:
28 |
29 | ```
30 | # PyTorch
31 | pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
32 | pip install torchmetrics==0.11.1
33 | # fbank
34 | pip install librosa==0.8.1
35 |
36 | # phonemizer pypinyin
37 | apt-get install espeak-ng
38 | ## OSX: brew install espeak
39 | pip install phonemizer==3.2.1 pypinyin==0.48.0
40 |
41 | # lhotse update to newest version
42 | # https://github.com/lhotse-speech/lhotse/pull/956
43 | # https://github.com/lhotse-speech/lhotse/pull/960
44 | pip uninstall lhotse
45 | pip uninstall lhotse
46 | pip install git+https://github.com/lhotse-speech/lhotse
47 |
48 | # k2
49 | # find the right version in https://huggingface.co/csukuangfj/k2
50 | pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl
51 |
52 | # icefall
53 | git clone https://github.com/k2-fsa/icefall
54 | cd icefall
55 | pip install -r requirements.txt
56 | export PYTHONPATH=`pwd`/../icefall:$PYTHONPATH
57 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.zshrc
58 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.bashrc
59 | cd -
60 | source ~/.zshrc
61 |
62 | # valle
63 | git clone https://github.com/lifeiteng/valle.git
64 | cd valle
65 | pip install -e .
66 | ```
67 |
68 |
69 | ## Training&Inference
70 | * #### English example [examples/libritts/README.md](egs/libritts/README.md)
71 | * #### Chinese example [examples/aishell1/README.md](egs/aishell1/README.md)
72 | * ### Prefix Mode 0 1 2 4 for NAR Decoder
73 | **Paper Chapter 5.1** "The average length of the waveform in LibriLight is 60 seconds. During
74 | training, we randomly crop the waveform to a random length between 10 seconds and 20 seconds. For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds from the same utterance."
75 | * **0**: no acoustic prompt tokens
76 | * **1**: random prefix of current batched utterances **(This is recommended)**
77 | * **2**: random segment of current batched utterances
78 | * **4**: same as the paper (As they randomly crop the long waveform to multiple utterances, so the same utterance means pre or post utterance in the same long waveform.)
79 | ```
80 | # If train NAR Decoders with prefix_mode 4
81 | python3 bin/trainer.py --prefix_mode 4 --dataset libritts --input-strategy PromptedPrecomputedFeatures ...
82 | ```
83 |
84 | #### [LibriTTS demo](https://lifeiteng.github.io/valle/index.html) Trained on one GPU with 24G memory
85 |
86 | ```
87 | cd examples/libritts
88 |
89 | # step1 prepare dataset
90 | bash prepare.sh --stage -1 --stop-stage 3
91 |
92 | # step2 train the model on one GPU with 24GB memory
93 | exp_dir=exp/valle
94 |
95 | ## Train AR model
96 | python3 bin/trainer.py --max-duration 80 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
97 | --num-buckets 6 --dtype "bfloat16" --save-every-n 10000 --valid-interval 20000 \
98 | --model-name valle --share-embedding true --norm-first true --add-prenet false \
99 | --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
100 | --base-lr 0.05 --warmup-steps 200 --average-period 0 \
101 | --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 4 \
102 | --exp-dir ${exp_dir}
103 |
104 | ## Train NAR model
105 | cp ${exp_dir}/best-valid-loss.pt ${exp_dir}/epoch-2.pt # --start-epoch 3=2+1
106 | python3 bin/trainer.py --max-duration 40 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
107 | --num-buckets 6 --dtype "float32" --save-every-n 10000 --valid-interval 20000 \
108 | --model-name valle --share-embedding true --norm-first true --add-prenet false \
109 | --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
110 | --base-lr 0.05 --warmup-steps 200 --average-period 0 \
111 | --num-epochs 40 --start-epoch 3 --start-batch 0 --accumulate-grad-steps 4 \
112 | --exp-dir ${exp_dir}
113 |
114 | # step3 inference
115 | python3 bin/infer.py --output-dir infer/demos \
116 | --checkpoint=${exp_dir}/best-valid-loss.pt \
117 | --text-prompts "KNOT one point one five miles per hour." \
118 | --audio-prompts ./prompts/8463_294825_000043_000000.wav \
119 | --text "To get up and running quickly just follow the steps below." \
120 |
121 | # Demo Inference
122 | https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/run.sh#L68
123 | ```
124 | 
125 |
126 | #### Troubleshooting
127 |
128 | * **SummaryWriter segmentation fault (core dumped)**
129 | * LINE `tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")`
130 | * FIX [https://github.com/tensorflow/tensorboard/pull/6135/files](https://github.com/tensorflow/tensorboard/pull/6135/files)
131 | ```
132 | file=`python -c 'import site; print(f"{site.getsitepackages()[0]}/tensorboard/summary/writer/event_file_writer.py")'`
133 | sed -i 's/import tf/import tensorflow_stub as tf/g' $file
134 | ```
135 |
136 | #### Training on a custom dataset?
137 | * prepare the dataset to `lhotse manifests`
138 | * There are plenty of references here [lhotse/recipes](https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes)
139 | * `python3 bin/tokenizer.py ...`
140 | * `python3 bin/trainer.py ...`
141 |
142 | ## Contributing
143 |
144 | * Parallelize bin/tokenizer.py on multi-GPUs
145 | *
146 |
147 | ## Citing
148 |
149 | To cite this repository:
150 |
151 | ```bibtex
152 | @misc{valle,
153 | author={Feiteng Li},
154 | title={VALL-E: A neural codec language model},
155 | year={2023},
156 | url={http://github.com/lifeiteng/vall-e}
157 | }
158 | ```
159 |
160 | ```bibtex
161 | @article{VALL-E,
162 | title = {Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers},
163 | author = {Chengyi Wang, Sanyuan Chen, Yu Wu,
164 | Ziqiang Zhang, Long Zhou, Shujie Liu,
165 | Zhuo Chen, Yanqing Liu, Huaming Wang,
166 | Jinyu Li, Lei He, Sheng Zhao, Furu Wei},
167 | year = {2023},
168 | eprint = {2301.02111},
169 | archivePrefix = {arXiv},
170 | volume = {abs/2301.02111},
171 | url = {http://arxiv.org/abs/2301.02111},
172 | }
173 | ```
174 |
175 | ## Star History
176 |
177 | [](https://star-history.com/#lifeiteng/vall-e&Date)
178 |
--------------------------------------------------------------------------------
/README.zh-CN.md:
--------------------------------------------------------------------------------
1 | 非官方 VALL-E([Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers](https://arxiv.org/abs/2301.02111))开源 PyTorch 实现。
2 |
3 |
4 |
5 | 未同步更新,移步英文版[🇺🇸](./README.md)
6 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
2 |
3 | RUN apt-get update && \
4 | apt-get upgrade -y
5 | RUN apt-get install -y vim wget git libsndfile1 espeak-ng
6 |
7 | RUN pip install torchmetrics==0.11.1 librosa==0.8.1 phonemizer==3.2.1 pypinyin==0.48.0 lhotse matplotlib h5py
8 |
9 | RUN pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl
10 |
11 | RUN pip install torchaudio==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
12 |
13 | WORKDIR /workspace
14 | RUN git clone https://github.com/k2-fsa/icefall && \
15 | cd icefall && \
16 | pip install -r requirements.txt
17 | ENV PYTHONPATH=/workspace/icefall:$PYTHONPATH
18 |
19 | RUN cp /opt/conda/lib/libpython3.10.so.1.0 /usr/lib/x86_64-linux-gnu/
20 |
21 | WORKDIR /workspace
22 | RUN git clone https://github.com/lifeiteng/vall-e.git && \
23 | cd vall-e && \
24 | pip install -e .
25 |
26 | WORKDIR /workspace/vall-e
27 |
--------------------------------------------------------------------------------
/docs/images/Overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/Overview.jpg
--------------------------------------------------------------------------------
/docs/images/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/train.png
--------------------------------------------------------------------------------
/docs/images/vallf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/images/vallf.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/docs/requirements.txt
--------------------------------------------------------------------------------
/egs/aishell1/README.md:
--------------------------------------------------------------------------------
1 | # aishell1
2 | `150` 小时的数据不足以训练出稳定的结果,尝试加到 `500`/`1000` 小时的数据去训练
3 | * 当前的实现,韵母和声调放在一起组成建模单元,也增加了对数据量的需求。
4 |
5 | ## Prepare Dataset
6 | ```
7 | cd egs/aishell1
8 |
9 | # Those stages are very time-consuming
10 | bash prepare.sh --stage -1 --stop-stage 3
11 |
12 | ## train
13 | Cut statistics:
14 | ╒═══════════════════════════╤═══════════╕
15 | │ Cuts count: │ 120098 │
16 | ├───────────────────────────┼───────────┤
17 | │ Total duration (hh:mm:ss) │ 150:51:08 │
18 | ├───────────────────────────┼───────────┤
19 | │ mean │ 4.5 │
20 | ├───────────────────────────┼───────────┤
21 | │ std │ 1.4 │
22 | ├───────────────────────────┼───────────┤
23 | │ min │ 1.2 │
24 | ├───────────────────────────┼───────────┤
25 | │ 25% │ 3.5 │
26 | ├───────────────────────────┼───────────┤
27 | │ 50% │ 4.3 │
28 | ├───────────────────────────┼───────────┤
29 | │ 75% │ 5.3 │
30 | ├───────────────────────────┼───────────┤
31 | │ 99% │ 8.5 │
32 | ├───────────────────────────┼───────────┤
33 | │ 99.5% │ 9.1 │
34 | ├───────────────────────────┼───────────┤
35 | │ 99.9% │ 10.5 │
36 | ├───────────────────────────┼───────────┤
37 | │ max │ 14.5 │
38 | ├───────────────────────────┼───────────┤
39 | │ Recordings available: │ 120098 │
40 | ├───────────────────────────┼───────────┤
41 | │ Features available: │ 120098 │
42 | ├───────────────────────────┼───────────┤
43 | │ Supervisions available: │ 120098 │
44 | ╘═══════════════════════════╧═══════════╛
45 | SUPERVISION custom fields:
46 | Speech duration statistics:
47 | ╒══════════════════════════════╤═══════════╤══════════════════════╕
48 | │ Total speech duration │ 150:51:08 │ 100.00% of recording │
49 | ├──────────────────────────────┼───────────┼──────────────────────┤
50 | │ Total speaking time duration │ 150:51:08 │ 100.00% of recording │
51 | ├──────────────────────────────┼───────────┼──────────────────────┤
52 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │
53 | ╘══════════════════════════════╧═══════════╧══════════════════════╛
54 |
55 |
56 | ## dev
57 | Cut statistics:
58 | ╒═══════════════════════════╤══════════╕
59 | │ Cuts count: │ 400 │
60 | ├───────────────────────────┼──────────┤
61 | │ Total duration (hh:mm:ss) │ 00:28:37 │
62 | ├───────────────────────────┼──────────┤
63 | │ mean │ 4.3 │
64 | ├───────────────────────────┼──────────┤
65 | │ std │ 1.1 │
66 | ├───────────────────────────┼──────────┤
67 | │ min │ 2.3 │
68 | ├───────────────────────────┼──────────┤
69 | │ 25% │ 3.5 │
70 | ├───────────────────────────┼──────────┤
71 | │ 50% │ 4.0 │
72 | ├───────────────────────────┼──────────┤
73 | │ 75% │ 5.0 │
74 | ├───────────────────────────┼──────────┤
75 | │ 99% │ 7.4 │
76 | ├───────────────────────────┼──────────┤
77 | │ 99.5% │ 7.5 │
78 | ├───────────────────────────┼──────────┤
79 | │ 99.9% │ 8.0 │
80 | ├───────────────────────────┼──────────┤
81 | │ max │ 8.0 │
82 | ├───────────────────────────┼──────────┤
83 | │ Recordings available: │ 400 │
84 | ├───────────────────────────┼──────────┤
85 | │ Features available: │ 400 │
86 | ├───────────────────────────┼──────────┤
87 | │ Supervisions available: │ 400 │
88 | ╘═══════════════════════════╧══════════╛
89 | SUPERVISION custom fields:
90 | Speech duration statistics:
91 | ╒══════════════════════════════╤══════════╤══════════════════════╕
92 | │ Total speech duration │ 00:28:37 │ 100.00% of recording │
93 | ├──────────────────────────────┼──────────┼──────────────────────┤
94 | │ Total speaking time duration │ 00:28:37 │ 100.00% of recording │
95 | ├──────────────────────────────┼──────────┼──────────────────────┤
96 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │
97 | ╘══════════════════════════════╧══════════╧══════════════════════╛
98 |
99 |
100 | ## test
101 | Cut statistics:
102 | ╒═══════════════════════════╤══════════╕
103 | │ Cuts count: │ 7176 │
104 | ├───────────────────────────┼──────────┤
105 | │ Total duration (hh:mm:ss) │ 10:01:49 │
106 | ├───────────────────────────┼──────────┤
107 | │ mean │ 5.0 │
108 | ├───────────────────────────┼──────────┤
109 | │ std │ 1.6 │
110 | ├───────────────────────────┼──────────┤
111 | │ min │ 1.9 │
112 | ├───────────────────────────┼──────────┤
113 | │ 25% │ 3.8 │
114 | ├───────────────────────────┼──────────┤
115 | │ 50% │ 4.7 │
116 | ├───────────────────────────┼──────────┤
117 | │ 75% │ 5.9 │
118 | ├───────────────────────────┼──────────┤
119 | │ 99% │ 9.9 │
120 | ├───────────────────────────┼──────────┤
121 | │ 99.5% │ 10.7 │
122 | ├───────────────────────────┼──────────┤
123 | │ 99.9% │ 11.9 │
124 | ├───────────────────────────┼──────────┤
125 | │ max │ 14.7 │
126 | ├───────────────────────────┼──────────┤
127 | │ Recordings available: │ 7176 │
128 | ├───────────────────────────┼──────────┤
129 | │ Features available: │ 7176 │
130 | ├───────────────────────────┼──────────┤
131 | │ Supervisions available: │ 7176 │
132 | ╘═══════════════════════════╧══════════╛
133 | SUPERVISION custom fields:
134 | Speech duration statistics:
135 | ╒══════════════════════════════╤══════════╤══════════════════════╕
136 | │ Total speech duration │ 10:01:49 │ 100.00% of recording │
137 | ├──────────────────────────────┼──────────┼──────────────────────┤
138 | │ Total speaking time duration │ 10:01:49 │ 100.00% of recording │
139 | ├──────────────────────────────┼──────────┼──────────────────────┤
140 | │ Total silence duration │ 00:00:00 │ 0.00% of recording │
141 | ╘══════════════════════════════╧══════════╧══════════════════════╛
142 | ```
143 |
144 | ## Training & Inference
145 | refer to [Training](../../README.md##Training&Inference)
146 |
--------------------------------------------------------------------------------
/egs/aishell1/bin:
--------------------------------------------------------------------------------
1 | ../../valle/bin
--------------------------------------------------------------------------------
/egs/aishell1/demos/0_demo.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/demos/0_demo.wav
--------------------------------------------------------------------------------
/egs/aishell1/prepare.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eou pipefail
4 |
5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
7 |
8 | nj=16
9 | stage=-1
10 | stop_stage=4
11 |
12 | # We assume dl_dir (download dir) contains the following
13 | # directories and files. If not, they will be downloaded
14 | # by this script automatically.
15 | #
16 | # - $dl_dir/aishell
17 | # You can download aishell from https://www.openslr.org/33/
18 | #
19 |
20 | dl_dir=$PWD/download
21 |
22 | dataset_parts="-p train -p dev -p test" # debug
23 |
24 | text_extractor="pypinyin_initials_finals"
25 | audio_extractor="Encodec" # or Fbank
26 | audio_feats_dir=data/tokenized
27 |
28 | . shared/parse_options.sh || exit 1
29 |
30 |
31 | # All files generated by this script are saved in "data".
32 | # You can safely remove "data" and rerun this script to regenerate it.
33 | mkdir -p data
34 |
35 | log() {
36 | # This function is from espnet
37 | local fname=${BASH_SOURCE[1]##*/}
38 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
39 | }
40 |
41 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
42 | log "dl_dir: $dl_dir"
43 | log "Stage 0: Download data"
44 |
45 | # If you have pre-downloaded it to /path/to/aishell,
46 | # you can create a symlink
47 | #
48 | # ln -sfv /path/to/aishell $dl_dir/aishell
49 | #
50 | if [ ! -d $dl_dir/aishell/dev ]; then
51 | lhotse download aishell $dl_dir
52 | fi
53 | fi
54 |
55 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
56 | log "Stage 1: Prepare aishell manifest"
57 | # We assume that you have downloaded the aishell corpus
58 | # to $dl_dir/aishell
59 | mkdir -p data/manifests
60 | if [ ! -e data/manifests/.aishell.done ]; then
61 | lhotse prepare aishell $dl_dir/aishell data/manifests
62 | touch data/manifests/.aishell.done
63 | fi
64 | fi
65 |
66 |
67 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
68 | log "Stage 2: Tokenize/Fbank aishell"
69 | mkdir -p ${audio_feats_dir}
70 | if [ ! -e ${audio_feats_dir}/.aishell.tokenize.done ]; then
71 | python3 bin/tokenizer.py --dataset-parts "${dataset_parts}" \
72 | --text-extractor ${text_extractor} \
73 | --audio-extractor ${audio_extractor} \
74 | --batch-duration 400 \
75 | --prefix "aishell" \
76 | --src-dir "data/manifests" \
77 | --output-dir "${audio_feats_dir}"
78 | fi
79 | touch ${audio_feats_dir}/.aishell.tokenize.done
80 | fi
81 |
82 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
83 | log "Stage 3: Prepare aishell train/dev/test"
84 | if [ ! -e ${audio_feats_dir}/.aishell.train.done ]; then
85 | # dev 14326
86 | lhotse subset --first 400 \
87 | ${audio_feats_dir}/aishell_cuts_dev.jsonl.gz \
88 | ${audio_feats_dir}/cuts_dev.jsonl.gz
89 |
90 | lhotse subset --last 13926 \
91 | ${audio_feats_dir}/aishell_cuts_dev.jsonl.gz \
92 | ${audio_feats_dir}/cuts_dev_others.jsonl.gz
93 |
94 | # train
95 | lhotse combine \
96 | ${audio_feats_dir}/cuts_dev_others.jsonl.gz \
97 | ${audio_feats_dir}/aishell_cuts_train.jsonl.gz \
98 | ${audio_feats_dir}/cuts_train.jsonl.gz
99 |
100 | # test
101 | lhotse copy \
102 | ${audio_feats_dir}/aishell_cuts_test.jsonl.gz \
103 | ${audio_feats_dir}/cuts_test.jsonl.gz
104 |
105 | touch ${audio_feats_dir}/.aishell.train.done
106 | fi
107 | fi
108 |
109 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
110 |
--------------------------------------------------------------------------------
/egs/aishell1/prompts/ch_24k.txt:
--------------------------------------------------------------------------------
1 | 甚至 出现 交易 几乎 停滞 的 情况
2 |
--------------------------------------------------------------------------------
/egs/aishell1/prompts/ch_24k.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/prompts/ch_24k.wav
--------------------------------------------------------------------------------
/egs/aishell1/prompts/ch_24k_loudness_normalized20.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/aishell1/prompts/ch_24k_loudness_normalized20.wav
--------------------------------------------------------------------------------
/egs/aishell1/shared:
--------------------------------------------------------------------------------
1 | ../libritts/shared/
--------------------------------------------------------------------------------
/egs/libritts/README.md:
--------------------------------------------------------------------------------
1 | # LibriTTS
2 |
3 | ## Install deps
4 | ```
5 | pip install librosa==0.8.1
6 | ```
7 |
8 | ## Prepare Dataset
9 | ```
10 | cd egs/libritts
11 |
12 | # Those stages are very time-consuming
13 | bash prepare.sh --stage -1 --stop-stage 3
14 | ```
15 | #### data
16 |
17 | ```
18 | ## train
19 | Cut statistics:
20 | ╒═══════════════════════════╤═══════════╕
21 | │ Cuts count: │ 354780 │
22 | ├───────────────────────────┼───────────┤
23 | │ Total duration (hh:mm:ss) │ 555:09:48 │
24 | ├───────────────────────────┼───────────┤
25 | │ mean │ 5.6 │
26 | ├───────────────────────────┼───────────┤
27 | │ std │ 4.5 │
28 | ├───────────────────────────┼───────────┤
29 | │ min │ 0.1 │
30 | ├───────────────────────────┼───────────┤
31 | │ 25% │ 2.3 │
32 | ├───────────────────────────┼───────────┤
33 | │ 50% │ 4.3 │
34 | ├───────────────────────────┼───────────┤
35 | │ 75% │ 7.6 │
36 | ├───────────────────────────┼───────────┤
37 | │ 80% │ 8.7 │
38 | ├───────────────────────────┼───────────┤
39 | │ 85% │ 10.0 │
40 | ├───────────────────────────┼───────────┤
41 | │ 90% │ 11.8 │
42 | ├───────────────────────────┼───────────┤
43 | │ 95% │ 14.8 │
44 | ├───────────────────────────┼───────────┤
45 | │ 99% │ 20.9 │
46 | ├───────────────────────────┼───────────┤
47 | │ 99.5% │ 23.1 │
48 | ├───────────────────────────┼───────────┤
49 | │ 99.9% │ 27.4 │
50 | ├───────────────────────────┼───────────┤
51 | │ max │ 43.9 │
52 | ├───────────────────────────┼───────────┤
53 | │ Recordings available: │ 354780 │
54 | ├───────────────────────────┼───────────┤
55 | │ Features available: │ 354780 │
56 | ├───────────────────────────┼───────────┤
57 | │ Supervisions available: │ 354780 │
58 | ╘═══════════════════════════╧═══════════╛
59 | SUPERVISION custom fields:
60 | Speech duration statistics:
61 | ╒══════════════════════════════╤═══════════╤══════════════════════╕
62 | │ Total speech duration │ 555:09:48 │ 100.00% of recording │
63 | ├──────────────────────────────┼───────────┼──────────────────────┤
64 | │ Total speaking time duration │ 555:09:48 │ 100.00% of recording │
65 | ├──────────────────────────────┼───────────┼──────────────────────┤
66 | │ Total silence duration │ 00:00:01 │ 0.00% of recording │
67 | ╘══════════════════════════════╧═══════════╧══════════════════════╛
68 | ```
69 |
70 | ## Training & Inference
71 | refer to [Training](../../README.md##Training&Inference)
72 |
--------------------------------------------------------------------------------
/egs/libritts/bin:
--------------------------------------------------------------------------------
1 | ../../valle/bin
--------------------------------------------------------------------------------
/egs/libritts/prepare.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eou pipefail
4 |
5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
7 |
8 | nj=16
9 | stage=-1
10 | stop_stage=3
11 |
12 | # We assume dl_dir (download dir) contains the following
13 | # directories and files. If not, they will be downloaded
14 | # by this script automatically.
15 | #
16 | # - $dl_dir/LibriTTS
17 | # You can download LibriTTS from https://www.openslr.org/60/
18 | # After downloading tar.gz files, you should extract them into dl_dir/LibriTTS.
19 | # Ignoring *.tar.gz files, which you can download into anywhere, the structure of $dl_dir should look like below
20 | #
21 | # dl_dir
22 | # ├── dev-clean.tar.gz
23 | # ├── dev-other.tar.gz
24 | # ├── LibriTTS
25 | # │ ├── BOOKS.txt
26 | # │ ├── CHAPTERS.txt
27 | # │ ├── dev-clean
28 | # │ ├── dev-other
29 | # │ ├── eval_sentences10.tsv
30 | # │ ├── LICENSE.txt
31 | # │ ├── NOTE.txt
32 | # │ ├── reader_book.tsv
33 | # │ ├── README_librispeech.txt
34 | # │ ├── README_libritts.txt
35 | # │ ├── speakers.tsv
36 | # │ ├── SPEAKERS.txt
37 | # │ ├── test-clean
38 | # │ ├── test-other
39 | # │ ├── train-clean-100
40 | # │ ├── train-clean-360
41 | # │ └── train-other-500
42 | # ├── test-clean.tar.gz
43 | # ├── test-other.tar.gz
44 | # ├── train-clean-100.tar.gz
45 | # ├── train-clean-360.tar.gz
46 | # └── train-other-500.tar.gz
47 |
48 | echo "We will download the LibriTTS dataset by default. If the downloading fails or you want to download the dataset yourself, see the comments in this script for steps."
49 |
50 | dl_dir=$PWD/download
51 |
52 | # dataset_parts="-p dev-clean -p test-clean" # debug
53 | dataset_parts="--dataset-parts all" # all
54 |
55 | audio_extractor="Encodec" # or Fbank
56 | audio_feats_dir=data/tokenized
57 |
58 | . shared/parse_options.sh || exit 1
59 |
60 |
61 | # All files generated by this script are saved in "data".
62 | # You can safely remove "data" and rerun this script to regenerate it.
63 | mkdir -p data
64 |
65 | log() {
66 | # This function is from espnet
67 | local fname=${BASH_SOURCE[1]##*/}
68 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
69 | }
70 |
71 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
72 | log "dl_dir: $dl_dir"
73 | log "Stage 0: Download data"
74 |
75 | # If you have pre-downloaded it to /path/to/LibriTTS,
76 | # you can create a symlink
77 | #
78 | # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
79 | #
80 | if [ ! -d $dl_dir/LibriTTS/dev-other ]; then
81 | # lhotse download libritts $dl_dir
82 | lhotse download libritts ${dataset_parts} $dl_dir
83 | fi
84 | fi
85 |
86 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
87 | log "Stage 1: Prepare LibriTTS manifest"
88 | # We assume that you have downloaded the LibriTTS corpus
89 | # to $dl_dir/LibriTTS
90 | mkdir -p data/manifests
91 | if [ ! -e data/manifests/.libritts.done ]; then
92 | lhotse prepare libritts ${dataset_parts} -j $nj $dl_dir/LibriTTS data/manifests
93 | touch data/manifests/.libritts.done
94 | fi
95 | fi
96 |
97 |
98 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
99 | log "Stage 2: Tokenize/Fbank LibriTTS"
100 | mkdir -p ${audio_feats_dir}
101 | if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then
102 | python3 bin/tokenizer.py --dataset-parts "${dataset_parts}" \
103 | --audio-extractor ${audio_extractor} \
104 | --batch-duration 400 \
105 | --src-dir "data/manifests" \
106 | --output-dir "${audio_feats_dir}"
107 | fi
108 | touch ${audio_feats_dir}/.libritts.tokenize.done
109 | fi
110 |
111 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
112 | log "Stage 3: Prepare LibriTTS train/dev/test"
113 | if [ ! -e ${audio_feats_dir}/.libritts.train.done ]; then
114 | if [ "${dataset_parts}" == "--dataset-parts all" ];then
115 | # train
116 | lhotse combine \
117 | ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \
118 | ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \
119 | ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \
120 | ${audio_feats_dir}/cuts_train.jsonl.gz
121 |
122 | # dev
123 | lhotse copy \
124 | ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
125 | ${audio_feats_dir}/cuts_dev.jsonl.gz
126 | else # debug
127 | # train
128 | lhotse copy \
129 | ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
130 | ${audio_feats_dir}/cuts_train.jsonl.gz
131 | # dev
132 | lhotse subset --first 400 \
133 | ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
134 | ${audio_feats_dir}/cuts_dev.jsonl.gz
135 | fi
136 |
137 | # test
138 | lhotse copy \
139 | ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
140 | ${audio_feats_dir}/cuts_test.jsonl.gz
141 |
142 | touch ${audio_feats_dir}/.libritts.train.done
143 | fi
144 | fi
145 |
146 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
147 |
--------------------------------------------------------------------------------
/egs/libritts/prompts/8455_210777_000067_000000.txt:
--------------------------------------------------------------------------------
1 | This I read with great attention, while they sat silent.
2 |
--------------------------------------------------------------------------------
/egs/libritts/prompts/8455_210777_000067_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/libritts/prompts/8455_210777_000067_000000.wav
--------------------------------------------------------------------------------
/egs/libritts/prompts/8463_294825_000043_000000.txt:
--------------------------------------------------------------------------------
1 | KNOT one point one five miles per hour
2 |
--------------------------------------------------------------------------------
/egs/libritts/prompts/8463_294825_000043_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/libritts/prompts/8463_294825_000043_000000.wav
--------------------------------------------------------------------------------
/egs/libritts/shared/parse_options.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4 | # Arnab Ghoshal, Karel Vesely
5 |
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15 | # MERCHANTABLITY OR NON-INFRINGEMENT.
16 | # See the Apache 2 License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # Parse command-line options.
21 | # To be sourced by another script (as in ". parse_options.sh").
22 | # Option format is: --option-name arg
23 | # and shell variable "option_name" gets set to value "arg."
24 | # The exception is --help, which takes no arguments, but prints the
25 | # $help_message variable (if defined).
26 |
27 |
28 | ###
29 | ### The --config file options have lower priority to command line
30 | ### options, so we need to import them first...
31 | ###
32 |
33 | # Now import all the configs specified by command-line, in left-to-right order
34 | for ((argpos=1; argpos<$#; argpos++)); do
35 | if [ "${!argpos}" == "--config" ]; then
36 | argpos_plus1=$((argpos+1))
37 | config=${!argpos_plus1}
38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39 | . $config # source the config file.
40 | fi
41 | done
42 |
43 |
44 | ###
45 | ### Now we process the command line options
46 | ###
47 | while true; do
48 | [ -z "${1:-}" ] && break; # break if there are no arguments
49 | case "$1" in
50 | # If the enclosing script is called with --help option, print the help
51 | # message and exit. Scripts should put help messages in $help_message
52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53 | else printf "$help_message\n" 1>&2 ; fi;
54 | exit 0 ;;
55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56 | exit 1 ;;
57 | # If the first command-line argument begins with "--" (e.g. --foo-bar),
58 | # then work out the variable name as $name, which will equal "foo_bar".
59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60 | # Next we test whether the variable in question is undefned-- if so it's
61 | # an invalid option and we die. Note: $0 evaluates to the name of the
62 | # enclosing script.
63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64 | # is undefined. We then have to wrap this test inside "eval" because
65 | # foo_bar is itself inside a variable ($name).
66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67 |
68 | oldval="`eval echo \\$$name`";
69 | # Work out whether we seem to be expecting a Boolean argument.
70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71 | was_bool=true;
72 | else
73 | was_bool=false;
74 | fi
75 |
76 | # Set the variable to the right value-- the escaped quotes make it work if
77 | # the option had spaces, like --cmd "queue.pl -sync y"
78 | eval $name=\"$2\";
79 |
80 | # Check that Boolean-valued arguments are really Boolean.
81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83 | exit 1;
84 | fi
85 | shift 2;
86 | ;;
87 | *) break;
88 | esac
89 | done
90 |
91 |
92 | # Check for an empty argument to the --cmd option, which can easily occur as a
93 | # result of scripting errors.
94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95 |
96 |
97 | true; # so this script returns exit code 0.
98 |
--------------------------------------------------------------------------------
/egs/ljspeech/README.md:
--------------------------------------------------------------------------------
1 | # LJSpeech
2 |
3 | ## Install deps
4 | ```
5 | pip install librosa==0.8.1
6 |
7 | # lhotse update LJSpeech
8 | # https://github.com/lhotse-speech/lhotse/pull/988
9 | ```
10 |
11 | ## Prepare Dataset
12 | ```
13 | cd egs/ljspeech
14 |
15 | bash prepare.sh --stage -1 --stop-stage 3 \
16 | --audio_extractor "Encodec" \
17 | --audio_feats_dir data/tokenized
18 | ```
19 |
20 | ## Training & Inference
21 | **LJSpeech is used to debug, Please try LibriTTS**
22 |
23 | refer to [LibriTTS Training](../../README.md##Training&Inference)
24 |
--------------------------------------------------------------------------------
/egs/ljspeech/bin:
--------------------------------------------------------------------------------
1 | ../../valle/bin
--------------------------------------------------------------------------------
/egs/ljspeech/demos/0.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/demos/0.wav
--------------------------------------------------------------------------------
/egs/ljspeech/demos/1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/demos/1.wav
--------------------------------------------------------------------------------
/egs/ljspeech/prepare.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eou pipefail
4 |
5 | # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
6 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
7 |
8 | nj=16
9 | stage=-1
10 | stop_stage=3
11 |
12 | # We assume dl_dir (download dir) contains the following
13 | # directories and files. If not, they will be downloaded
14 | # by this script automatically.
15 | #
16 | # - $dl_dir/LJSpeech-1.1
17 |
18 | dl_dir=$PWD/download
19 |
20 | audio_extractor="Encodec" # or Fbank
21 | audio_feats_dir=data/tokenized
22 |
23 |
24 | . shared/parse_options.sh || exit 1
25 |
26 |
27 | # All files generated by this script are saved in "data".
28 | # You can safely remove "data" and rerun this script to regenerate it.
29 | mkdir -p data
30 |
31 | log() {
32 | # This function is from espnet
33 | local fname=${BASH_SOURCE[1]##*/}
34 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
35 | }
36 |
37 | log "dl_dir: $dl_dir"
38 |
39 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
40 | log "Stage 0: Download data"
41 |
42 | # If you have pre-downloaded it to /path/to/LJSpeech,
43 | # you can create a symlink
44 | #
45 | # ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech
46 | #
47 | if [ ! -d $dl_dir/LJSpeech-1.1 ];then
48 | lhotse download ljspeech $dl_dir
49 | fi
50 | fi
51 |
52 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
53 | log "Stage 1: Prepare LJSpeech manifest"
54 | # We assume that you have downloaded the LJSpeech corpus
55 | # to $dl_dir/LJSpeech
56 | mkdir -p data/manifests
57 | if [ ! -e data/manifests/.ljspeech.done ]; then
58 | lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
59 | touch data/manifests/.ljspeech.done
60 | fi
61 | fi
62 |
63 |
64 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
65 | log "Stage 2: Split LJSpeech"
66 |
67 | # 13100 = dev/test/train = 200/200/12500
68 | if [ ! -e data/manifests/ljspeech_recordings_test.jsonl.gz ]; then
69 | for manifest in "recordings" "supervisions";do
70 | lhotse subset --last 600 data/manifests/ljspeech_${manifest}_all.jsonl.gz \
71 | data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz || exit 1
72 | lhotse subset --last 400 data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz \
73 | data/manifests/ljspeech_${manifest}_test.jsonl.gz || exit 1
74 | lhotse subset --first 200 data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz \
75 | data/manifests/ljspeech_${manifest}_dev.jsonl.gz || exit 1
76 |
77 | lhotse subset --first 12500 data/manifests/ljspeech_${manifest}_all.jsonl.gz \
78 | data/manifests/ljspeech_${manifest}_train.jsonl.gz || exit 1
79 |
80 | rm -f data/manifests/ljspeech_${manifest}_dev_test.jsonl.gz
81 | done
82 | fi
83 | fi
84 |
85 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
86 | log "Stage 3: ${audio_extractor} LJSpeech"
87 |
88 | mkdir -p ${audio_feats_dir}
89 | if [ ! -e ${audio_feats_dir}/.ljspeech.done ]; then
90 | python3 bin/tokenizer.py --dataset-parts "train test dev" --prefix "ljspeech" \
91 | --audio-extractor ${audio_extractor} \
92 | --batch-duration 400 \
93 | --src-dir "data/manifests" \
94 | --output-dir "${audio_feats_dir}"
95 | fi
96 | touch ${audio_feats_dir}/.ljspeech.done
97 |
98 | cd ${audio_feats_dir}
99 | ln -sf ljspeech_cuts_train.jsonl.gz cuts_train.jsonl.gz
100 | ln -sf ljspeech_cuts_dev.jsonl.gz cuts_dev.jsonl.gz
101 | ln -sf ljspeech_cuts_test.jsonl.gz cuts_test.jsonl.gz
102 | cd -
103 | fi
104 |
105 | python3 ./bin/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
106 |
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0108_24K.txt:
--------------------------------------------------------------------------------
1 | In addition, the restriction would probably eliminate a need for the requirement which has been urged as necessary for the exercise of Federal power,
2 |
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0108_24K.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0108_24K.wav
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0110_24K.txt:
--------------------------------------------------------------------------------
1 | The governmental consequences of assassination of one of the specified officials give the United States ample power to act for its own protection.
2 |
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0110_24K.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0110_24K.wav
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0124.txt:
--------------------------------------------------------------------------------
1 | In addition, the proposed legislation will insure.
2 |
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0124_24K.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0124_24K.wav
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0185.txt:
--------------------------------------------------------------------------------
1 | During the period the Commission was giving thought to this situation,
2 |
--------------------------------------------------------------------------------
/egs/ljspeech/prompts/LJ049-0185_24K.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/egs/ljspeech/prompts/LJ049-0185_24K.wav
--------------------------------------------------------------------------------
/egs/ljspeech/shared:
--------------------------------------------------------------------------------
1 | ../libritts/shared
--------------------------------------------------------------------------------
/examples:
--------------------------------------------------------------------------------
1 | egs/
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import sys
4 | from pathlib import Path
5 | from subprocess import DEVNULL, PIPE, run
6 |
7 | from setuptools import find_packages, setup
8 |
9 | project_root = Path(__file__).parent
10 |
11 | # modified from https://github.com/lhotse-speech/lhotse/blob/master/setup.py
12 |
13 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #
14 | # NOTE: REMEMBER TO UPDATE THE FALLBACK VERSION IN valle/__init__.py WHEN RELEASING #
15 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #
16 | MAJOR_VERSION = 1
17 | MINOR_VERSION = 0
18 | PATCH_VERSION = 0
19 | IS_DEV_VERSION = True # False = public release, True = otherwise
20 |
21 |
22 | if sys.version_info < (3,):
23 | # fmt: off
24 | print(
25 | "Python 2 has reached end-of-life and is no longer supported by valle."
26 | )
27 | # fmt: on
28 | sys.exit(-1)
29 |
30 | if sys.version_info < (3, 7):
31 | print(
32 | "Python 3.6 has reached end-of-life on December 31st, 2021 "
33 | "and is no longer supported by valle."
34 | )
35 | sys.exit(-1)
36 |
37 |
38 | def discover_valle_version() -> str:
39 | """
40 | Scans Valle source code to determine the current version.
41 | When development version is detected, it queries git for the commit hash
42 | to append it as a local version identifier.
43 |
44 | Ideally this function would have been imported from valle.version and
45 | re-used when valle is imported to set the version, but it introduces
46 | a circular dependency. To avoid this, we write the determined version
47 | into project_root / 'valle' / 'version.py' during setup and read it
48 | from there later. If it's not detected, the version will be 0.0.0.dev.
49 | """
50 |
51 | version = f"{MAJOR_VERSION}.{MINOR_VERSION}.{PATCH_VERSION}"
52 | if not IS_DEV_VERSION:
53 | # This is a PyPI public release -- return a clean version string.
54 | return version
55 |
56 | version = version + ".dev"
57 |
58 | # This is not a PyPI release -- try to read the git commit
59 | try:
60 | git_commit = (
61 | run(
62 | ["git", "rev-parse", "--short", "HEAD"],
63 | check=True,
64 | stdout=PIPE,
65 | stderr=DEVNULL,
66 | )
67 | .stdout.decode()
68 | .rstrip("\n")
69 | .strip()
70 | )
71 | dirty_commit = (
72 | len(
73 | run(
74 | ["git", "diff", "--shortstat"],
75 | check=True,
76 | stdout=PIPE,
77 | stderr=DEVNULL,
78 | )
79 | .stdout.decode()
80 | .rstrip("\n")
81 | .strip()
82 | )
83 | > 0
84 | )
85 | git_commit = (
86 | git_commit + ".dirty" if dirty_commit else git_commit + ".clean"
87 | )
88 | source_version = f"+git.{git_commit}"
89 | except Exception:
90 | source_version = ".unknownsource"
91 | # See the format:
92 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#local-version-identifiers
93 | version = version + source_version
94 |
95 | return version
96 |
97 |
98 | def mark_valle_version(version: str) -> None:
99 | (project_root / "valle" / "version.py").write_text(
100 | f'__version__ = "{version}"'
101 | )
102 |
103 |
104 | VALLE_VERSION = discover_valle_version()
105 | mark_valle_version(VALLE_VERSION)
106 |
107 |
108 | install_requires = [
109 | "encodec",
110 | "phonemizer",
111 | ]
112 |
113 | try:
114 | # If the user already installed PyTorch, make sure he has torchaudio too.
115 | # Otherwise, we'll just install the latest versions from PyPI for the user.
116 | import torch
117 |
118 | try:
119 | import torchaudio
120 | except ImportError:
121 | raise ValueError(
122 | "We detected that you have already installed PyTorch, but haven't installed torchaudio. "
123 | "Unfortunately we can't detect the compatible torchaudio version for you; "
124 | "you will have to install it manually. "
125 | "For instructions, please refer either to https://pytorch.org/get-started/locally/ "
126 | "or https://github.com/pytorch/audio#dependencies"
127 | )
128 | except ImportError:
129 | install_requires.extend(["torch", "torchaudio"])
130 |
131 | docs_require = (
132 | (project_root / "docs" / "requirements.txt").read_text().splitlines()
133 | )
134 | tests_require = [
135 | # "pytest==7.1.3",
136 | # "pytest-forked==1.4.0",
137 | # "pytest-xdist==2.5.0",
138 | # "pytest-cov==4.0.0",
139 | ]
140 | workflow_requires = [""]
141 | dev_requires = sorted(
142 | docs_require
143 | + tests_require
144 | + workflow_requires
145 | + ["jupyterlab", "matplotlib"]
146 | )
147 | all_requires = sorted(dev_requires)
148 |
149 | if os.environ.get("READTHEDOCS", False):
150 | # When building documentation, omit torchaudio installation and mock it instead.
151 | # This works around the inability to install libsoundfile1 in read-the-docs env,
152 | # which caused the documentation builds to silently crash.
153 | install_requires = [
154 | req
155 | for req in install_requires
156 | if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"])
157 | ]
158 |
159 | setup(
160 | name="valle",
161 | version=VALLE_VERSION,
162 | python_requires=">=3.7.0",
163 | description="Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers",
164 | author="The Valle Development Team",
165 | author_email="lifeiteng0422@163.com",
166 | long_description=(project_root / "README.md").read_text(encoding="utf-8"),
167 | long_description_content_type="text/markdown",
168 | license="Apache-2.0 License",
169 | packages=find_packages(exclude=["test", "test.*"]),
170 | include_package_data=True,
171 | entry_points={},
172 | install_requires=install_requires,
173 | extras_require={
174 | "docs": docs_require,
175 | "tests": tests_require,
176 | "dev": dev_requires,
177 | "all": all_requires,
178 | },
179 | classifiers=[
180 | "Development Status :: 1 - Beta",
181 | "Programming Language :: Python :: 3.7",
182 | "Programming Language :: Python :: 3.8",
183 | "Programming Language :: Python :: 3.9",
184 | "Programming Language :: Python :: 3.10",
185 | "Intended Audience :: Science/Research",
186 | "Operating System :: POSIX :: Linux",
187 | "Operating System :: MacOS :: MacOS X",
188 | "License :: OSI Approved :: Apache Software License",
189 | "Topic :: Multimedia :: Sound/Audio :: Speech",
190 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
191 | "Topic :: Software Development :: Libraries :: Python Modules",
192 | "Typing :: Typed",
193 | ],
194 | )
195 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
4 |
5 | python3 valle/tests/valle_test.py
6 | python3 valle/tests/scaling_test.py
7 | python3 valle/tests/data/tokenizer_test.py
8 |
--------------------------------------------------------------------------------
/valle/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data, models, modules, utils
2 |
--------------------------------------------------------------------------------
/valle/bin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/bin/__init__.py
--------------------------------------------------------------------------------
/valle/bin/display_manifest_statistics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
3 | # Copyright 2023 (authors: Feiteng Li)
4 | #
5 | # See ../../../../LICENSE for clarification regarding multiple authors
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """
20 | This file displays duration statistics of utterances in the manifests.
21 | You can use the displayed value to choose minimum/maximum duration
22 | to remove short and long utterances during the training.
23 | """
24 |
25 | import argparse
26 | from pathlib import Path
27 |
28 | from lhotse import load_manifest_lazy
29 |
30 |
31 | def get_args():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument(
34 | "--manifest-dir",
35 | type=Path,
36 | default=Path("data/tokenized"),
37 | help="Path to the tokenized manifests.",
38 | )
39 | return parser.parse_args()
40 |
41 |
42 | def main():
43 | args = get_args()
44 | manifest_dir = args.manifest_dir or Path("data/tokenized")
45 | for part in ["train", "dev", "test"]:
46 | print(f"## {part}")
47 | cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz")
48 | cuts.describe()
49 | print("\n")
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/valle/bin/infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Phonemize Text and EnCodec Audio.
17 |
18 | Usage example:
19 | python3 bin/infer.py \
20 | --decoder-dim 128 --nhead 4 --num-decoder-layers 4 --model-name valle \
21 | --text-prompts "Go to her." \
22 | --audio-prompts ./prompts/61_70970_000007_000001.wav \
23 | --output-dir infer/demo_valle_epoch20 \
24 | --checkpoint exp/valle_nano_v2/epoch-20.pt
25 |
26 | """
27 | import argparse
28 | import logging
29 | import os
30 | from pathlib import Path
31 |
32 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
33 |
34 | import torch
35 | import torchaudio
36 | from icefall.utils import AttributeDict, str2bool
37 |
38 | from valle.data import (
39 | AudioTokenizer,
40 | TextTokenizer,
41 | tokenize_audio,
42 | tokenize_text,
43 | )
44 | from valle.data.collation import get_text_token_collater
45 | from valle.models import get_model
46 |
47 |
48 | def get_args():
49 | parser = argparse.ArgumentParser()
50 |
51 | parser.add_argument(
52 | "--text-prompts",
53 | type=str,
54 | default="",
55 | help="Text prompts which are separated by |.",
56 | )
57 |
58 | parser.add_argument(
59 | "--audio-prompts",
60 | type=str,
61 | default="",
62 | help="Audio prompts which are separated by | and should be aligned with --text-prompts.",
63 | )
64 |
65 | parser.add_argument(
66 | "--text",
67 | type=str,
68 | default="To get up and running quickly just follow the steps below.",
69 | help="Text to be synthesized.",
70 | )
71 |
72 | # model
73 | # add_model_arguments(parser)
74 | # parser.add_argument(
75 | # "--text-tokens",
76 | # type=str,
77 | # default="data/tokenized/unique_text_tokens.k2symbols",
78 | # help="Path to the unique text tokens file.",
79 | # )
80 |
81 | parser.add_argument(
82 | "--text-extractor",
83 | type=str,
84 | default="espeak",
85 | help="espeak or pypinyin or pypinyin_initials_finals",
86 | )
87 |
88 | parser.add_argument(
89 | "--checkpoint",
90 | type=str,
91 | default="exp/vallf_nano_full/checkpoint-100000.pt",
92 | help="Path to the saved checkpoint.",
93 | )
94 |
95 | parser.add_argument(
96 | "--output-dir",
97 | type=Path,
98 | default=Path("infer/demo"),
99 | help="Path to the tokenized files.",
100 | )
101 |
102 | parser.add_argument(
103 | "--top-k",
104 | type=int,
105 | default=-100,
106 | help="Whether AR Decoder do top_k(if > 0) sampling.",
107 | )
108 |
109 | parser.add_argument(
110 | "--temperature",
111 | type=float,
112 | default=1.0,
113 | help="The temperature of AR Decoder top_k sampling.",
114 | )
115 |
116 | parser.add_argument(
117 | "--continual",
118 | type=str2bool,
119 | default=False,
120 | help="Do continual task.",
121 | )
122 |
123 | return parser.parse_args()
124 |
125 |
126 | def load_model(checkpoint, device):
127 | if not checkpoint:
128 | return None
129 |
130 | checkpoint = torch.load(checkpoint, map_location=device)
131 |
132 | args = AttributeDict(checkpoint)
133 | model = get_model(args)
134 |
135 | missing_keys, unexpected_keys = model.load_state_dict(
136 | checkpoint["model"], strict=True
137 | )
138 | assert not missing_keys
139 | model.to(device)
140 | model.eval()
141 |
142 | text_tokens = args.text_tokens
143 |
144 | return model, text_tokens
145 |
146 |
147 | @torch.no_grad()
148 | def main():
149 | args = get_args()
150 | text_tokenizer = TextTokenizer(backend=args.text_extractor)
151 |
152 | device = torch.device("cpu")
153 | if torch.cuda.is_available():
154 | device = torch.device("cuda", 0)
155 | model, text_tokens = load_model(args.checkpoint, device)
156 | text_collater = get_text_token_collater(text_tokens)
157 |
158 | audio_tokenizer = AudioTokenizer()
159 |
160 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
161 |
162 | text_prompts = " ".join(args.text_prompts.split("|"))
163 |
164 | audio_prompts = []
165 | if args.audio_prompts:
166 | for n, audio_file in enumerate(args.audio_prompts.split("|")):
167 | encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
168 | if False:
169 | samples = audio_tokenizer.decode(encoded_frames)
170 | torchaudio.save(
171 | f"{args.output_dir}/p{n}.wav", samples[0], 24000
172 | )
173 |
174 | audio_prompts.append(encoded_frames[0][0])
175 |
176 | assert len(args.text_prompts.split("|")) == len(audio_prompts)
177 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1)
178 | audio_prompts = audio_prompts.to(device)
179 |
180 | if os.path.isfile(args.text): # for demos
181 | # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
182 | with open(args.text) as f:
183 | for line in f:
184 | fields = line.strip().split("\t")
185 | assert len(fields) == 4
186 | prompt_text, prompt_audio, text, audio_path = fields
187 | logging.info(f"synthesize text: {text}")
188 | text_tokens, text_tokens_lens = text_collater(
189 | [
190 | tokenize_text(
191 | text_tokenizer, text=f"{prompt_text} {text}".strip()
192 | )
193 | ]
194 | )
195 | _, enroll_x_lens = text_collater(
196 | [
197 | tokenize_text(
198 | text_tokenizer, text=f"{prompt_text}".strip()
199 | )
200 | ]
201 | )
202 |
203 | audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
204 | audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device)
205 |
206 | # synthesis
207 | encoded_frames = model.inference(
208 | text_tokens.to(device),
209 | text_tokens_lens.to(device),
210 | audio_prompts,
211 | enroll_x_lens=enroll_x_lens,
212 | top_k=args.top_k,
213 | temperature=args.temperature,
214 | )
215 |
216 | samples = audio_tokenizer.decode(
217 | [(encoded_frames.transpose(2, 1), None)]
218 | )
219 | # store
220 | torchaudio.save(audio_path, samples[0].cpu(), 24000)
221 | return
222 |
223 | for n, text in enumerate(args.text.split("|")):
224 | logging.info(f"synthesize text: {text}")
225 | text_tokens, text_tokens_lens = text_collater(
226 | [
227 | tokenize_text(
228 | text_tokenizer, text=f"{text_prompts} {text}".strip()
229 | )
230 | ]
231 | )
232 |
233 | # synthesis
234 | if args.continual:
235 | assert text == ""
236 | encoded_frames = model.continual(
237 | text_tokens.to(device),
238 | text_tokens_lens.to(device),
239 | audio_prompts,
240 | )
241 | else:
242 | enroll_x_lens = None
243 | if text_prompts:
244 | _, enroll_x_lens = text_collater(
245 | [
246 | tokenize_text(
247 | text_tokenizer, text=f"{text_prompts}".strip()
248 | )
249 | ]
250 | )
251 | encoded_frames = model.inference(
252 | text_tokens.to(device),
253 | text_tokens_lens.to(device),
254 | audio_prompts,
255 | enroll_x_lens=enroll_x_lens,
256 | top_k=args.top_k,
257 | temperature=args.temperature,
258 | )
259 |
260 | if audio_prompts != []:
261 | samples = audio_tokenizer.decode(
262 | [(encoded_frames.transpose(2, 1), None)]
263 | )
264 | # store
265 | torchaudio.save(
266 | f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
267 | )
268 | else: # Transformer
269 | pass
270 |
271 |
272 | torch.set_num_threads(1)
273 | torch.set_num_interop_threads(1)
274 | torch._C._jit_set_profiling_executor(False)
275 | torch._C._jit_set_profiling_mode(False)
276 | torch._C._set_graph_executor_optimize(False)
277 | if __name__ == "__main__":
278 | formatter = (
279 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
280 | )
281 | logging.basicConfig(format=formatter, level=logging.INFO)
282 | main()
283 |
--------------------------------------------------------------------------------
/valle/bin/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Phonemize Text and EnCodec Audio.
17 |
18 | Usage example:
19 | python3 bin/tokenizer.py \
20 | --src_dir ./data/manifests --output_dir ./data/tokenized
21 |
22 | """
23 | import argparse
24 | import logging
25 | import os
26 | from pathlib import Path
27 |
28 | import torch
29 | import torch.multiprocessing
30 | from icefall.utils import get_executor
31 | from lhotse import CutSet, NumpyHdf5Writer
32 | from lhotse.recipes.utils import read_manifests_if_cached
33 | from tqdm.auto import tqdm
34 |
35 | from valle.data import (
36 | AudioTokenConfig,
37 | AudioTokenExtractor,
38 | TextTokenizer,
39 | tokenize_text,
40 | )
41 | from valle.data.fbank import get_fbank_extractor
42 | from valle.utils import SymbolTable
43 |
44 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
45 |
46 |
47 | # Torch's multithreaded behavior needs to be disabled or
48 | # it wastes a lot of CPU and slow things down.
49 | # Do this outside of main() in case it needs to take effect
50 | # even when we are not invoking the main (e.g. when spawning subprocesses).
51 | torch.set_num_threads(1)
52 | torch.set_num_interop_threads(1)
53 | torch.multiprocessing.set_sharing_strategy("file_system")
54 |
55 |
56 | def get_args():
57 | parser = argparse.ArgumentParser()
58 |
59 | parser.add_argument(
60 | "--src-dir",
61 | type=Path,
62 | default=Path("data/manifests"),
63 | help="Path to the manifest files",
64 | )
65 | parser.add_argument(
66 | "--output-dir",
67 | type=Path,
68 | default=Path("data/tokenized"),
69 | help="Path to the tokenized files",
70 | )
71 | parser.add_argument(
72 | "--text-extractor",
73 | type=str,
74 | default="espeak",
75 | help="espeak or pypinyin or pypinyin_initials_finals",
76 | )
77 | parser.add_argument(
78 | "--audio-extractor",
79 | type=str,
80 | default="Encodec",
81 | help="Encodec or Fbank",
82 | )
83 | parser.add_argument(
84 | "--dataset-parts",
85 | type=str,
86 | default="dev-clean test-clean",
87 | help="Space separated dataset parts",
88 | )
89 | parser.add_argument(
90 | "--prefix",
91 | type=str,
92 | default="libritts",
93 | help="prefix of the manifest file",
94 | )
95 | parser.add_argument(
96 | "--suffix",
97 | type=str,
98 | default="jsonl.gz",
99 | help="suffix of the manifest file",
100 | )
101 | parser.add_argument(
102 | "--batch-duration",
103 | type=float,
104 | default=400.0,
105 | help="The maximum number of audio seconds in a batch."
106 | "Determines batch size dynamically.",
107 | )
108 |
109 | return parser.parse_args()
110 |
111 |
112 | def main():
113 | args = get_args()
114 |
115 | dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
116 | if dataset_parts == "all": # LibriTTS
117 | dataset_parts = [
118 | "dev-clean",
119 | "dev-other",
120 | "test-clean",
121 | "test-other",
122 | "train-clean-100",
123 | "train-clean-360",
124 | "train-other-500",
125 | ]
126 | else:
127 | dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
128 |
129 | assert len(dataset_parts) >= 1
130 |
131 | manifests = read_manifests_if_cached(
132 | dataset_parts=dataset_parts,
133 | output_dir=args.src_dir,
134 | prefix=args.prefix,
135 | suffix=args.suffix,
136 | types=["recordings", "supervisions", "cuts"],
137 | )
138 |
139 | text_tokenizer = None
140 | if args.text_extractor:
141 | text_tokenizer = TextTokenizer(backend=args.text_extractor)
142 |
143 | audio_extractor = None
144 | if args.audio_extractor:
145 | if args.audio_extractor == "Encodec":
146 | audio_extractor = AudioTokenExtractor(AudioTokenConfig())
147 | else:
148 | assert args.audio_extractor == "Fbank"
149 | audio_extractor = get_fbank_extractor()
150 |
151 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
152 | unique_symbols = set()
153 | num_jobs = min(32, os.cpu_count())
154 | logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
155 |
156 | prefix = args.prefix
157 | if prefix and not prefix.endswith("_"):
158 | prefix = f"{prefix}_"
159 | with get_executor() as ex:
160 | for partition, m in manifests.items():
161 | logging.info(
162 | f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
163 | )
164 | try:
165 | cut_set = CutSet.from_manifests(
166 | recordings=m["recordings"],
167 | supervisions=m["supervisions"],
168 | )
169 | except Exception:
170 | cut_set = m["cuts"]
171 |
172 | # AudioTokenizer
173 | if args.audio_extractor:
174 | if args.audio_extractor == "Encodec":
175 | storage_path = (
176 | f"{args.output_dir}/{args.prefix}_encodec_{partition}"
177 | )
178 | else:
179 | storage_path = (
180 | f"{args.output_dir}/{args.prefix}_fbank_{partition}"
181 | )
182 |
183 | if args.prefix.lower() in ["ljspeech", "aishell", "baker"]:
184 | cut_set = cut_set.resample(24000)
185 | # https://github.com/lifeiteng/vall-e/issues/90
186 | # if args.prefix == "aishell":
187 | # # NOTE: the loudness of aishell audio files is around -33
188 | # # The best way is datamodule --on-the-fly-feats --enable-audio-aug
189 | # cut_set = cut_set.normalize_loudness(
190 | # target=-20.0, affix_id=True
191 | # )
192 |
193 | with torch.no_grad():
194 | if (
195 | torch.cuda.is_available()
196 | and args.audio_extractor == "Encodec"
197 | ):
198 | cut_set = cut_set.compute_and_store_features_batch(
199 | extractor=audio_extractor,
200 | storage_path=storage_path,
201 | num_workers=num_jobs,
202 | batch_duration=args.batch_duration,
203 | collate=False,
204 | overwrite=True,
205 | storage_type=NumpyHdf5Writer,
206 | )
207 | else:
208 | cut_set = cut_set.compute_and_store_features(
209 | extractor=audio_extractor,
210 | storage_path=storage_path,
211 | num_jobs=num_jobs if ex is None else 64,
212 | executor=ex,
213 | storage_type=NumpyHdf5Writer,
214 | )
215 |
216 | # TextTokenizer
217 | if args.text_extractor:
218 | if (
219 | args.prefix == "baker"
220 | and args.text_extractor == "labeled_pinyin"
221 | ):
222 | for c in tqdm(cut_set):
223 | phonemes = c.supervisions[0].custom["tokens"]["text"]
224 | unique_symbols.update(phonemes)
225 | else:
226 | for c in tqdm(cut_set):
227 | if args.prefix == "ljspeech":
228 | text = c.supervisions[0].custom["normalized_text"]
229 | text = text.replace("”", '"').replace("“", '"')
230 | phonemes = tokenize_text(text_tokenizer, text=text)
231 | elif args.prefix == "aishell":
232 | phonemes = tokenize_text(
233 | text_tokenizer, text=c.supervisions[0].text
234 | )
235 | c.supervisions[0].custom = {}
236 | else:
237 | assert args.prefix == "libritts"
238 | phonemes = tokenize_text(
239 | text_tokenizer, text=c.supervisions[0].text
240 | )
241 | c.supervisions[0].custom["tokens"] = {"text": phonemes}
242 | unique_symbols.update(phonemes)
243 |
244 | cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}"
245 | cut_set.to_file(f"{args.output_dir}/{cuts_filename}")
246 |
247 | if args.text_extractor:
248 | unique_phonemes = SymbolTable()
249 | for s in sorted(list(unique_symbols)):
250 | unique_phonemes.add(s)
251 | logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
252 |
253 | unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
254 | unique_phonemes.to_file(unique_phonemes_file)
255 |
256 |
257 | if __name__ == "__main__":
258 | formatter = (
259 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
260 | )
261 | logging.basicConfig(format=formatter, level=logging.INFO)
262 | main()
263 |
--------------------------------------------------------------------------------
/valle/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datamodule import *
2 | from .tokenizer import *
3 | from .collation import *
4 |
--------------------------------------------------------------------------------
/valle/data/collation.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from valle.utils import SymbolTable
8 |
9 |
10 | class TextTokenCollater:
11 | """Collate list of text tokens
12 |
13 | Map sentences to integers. Sentences are padded to equal length.
14 | Beginning and end-of-sequence symbols can be added.
15 |
16 | Example:
17 | >>> token_collater = TextTokenCollater(text_tokens)
18 | >>> tokens_batch, tokens_lens = token_collater(text)
19 |
20 | Returns:
21 | tokens_batch: IntTensor of shape (B, L)
22 | B: batch dimension, number of input sentences
23 | L: length of the longest sentence
24 | tokens_lens: IntTensor of shape (B,)
25 | Length of each sentence after adding and
26 | but before padding.
27 | """
28 |
29 | def __init__(
30 | self,
31 | text_tokens: List[str],
32 | add_eos: bool = True,
33 | add_bos: bool = True,
34 | pad_symbol: str = "",
35 | bos_symbol: str = "",
36 | eos_symbol: str = "",
37 | ):
38 | self.pad_symbol = pad_symbol
39 |
40 | self.add_eos = add_eos
41 | self.add_bos = add_bos
42 |
43 | self.bos_symbol = bos_symbol
44 | self.eos_symbol = eos_symbol
45 |
46 | unique_tokens = (
47 | [pad_symbol]
48 | + ([bos_symbol] if add_bos else [])
49 | + ([eos_symbol] if add_eos else [])
50 | + sorted(text_tokens)
51 | )
52 |
53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54 | self.idx2token = [token for token in unique_tokens]
55 |
56 | def index(
57 | self, tokens_list: List[str]
58 | ) -> Tuple[torch.Tensor, torch.Tensor]:
59 | seqs, seq_lens = [], []
60 | for tokens in tokens_list:
61 | assert (
62 | all([True if s in self.token2idx else False for s in tokens])
63 | is True
64 | )
65 | seq = (
66 | ([self.bos_symbol] if self.add_bos else [])
67 | + list(tokens)
68 | + ([self.eos_symbol] if self.add_eos else [])
69 | )
70 | seqs.append(seq)
71 | seq_lens.append(len(seq))
72 |
73 | max_len = max(seq_lens)
74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75 | seq.extend([self.pad_symbol] * (max_len - seq_len))
76 |
77 | tokens = torch.from_numpy(
78 | np.array(
79 | [[self.token2idx[token] for token in seq] for seq in seqs],
80 | dtype=np.int64,
81 | )
82 | )
83 | tokens_lens = torch.IntTensor(seq_lens)
84 |
85 | return tokens, tokens_lens
86 |
87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88 | tokens_seqs = [[p for p in text] for text in texts]
89 | max_len = len(max(tokens_seqs, key=len))
90 |
91 | seqs = [
92 | ([self.bos_symbol] if self.add_bos else [])
93 | + list(seq)
94 | + ([self.eos_symbol] if self.add_eos else [])
95 | + [self.pad_symbol] * (max_len - len(seq))
96 | for seq in tokens_seqs
97 | ]
98 |
99 | tokens_batch = torch.from_numpy(
100 | np.array(
101 | [[self.token2idx[token] for token in seq] for seq in seqs],
102 | dtype=np.int64,
103 | )
104 | )
105 |
106 | tokens_lens = torch.IntTensor(
107 | [
108 | len(seq) + int(self.add_eos) + int(self.add_bos)
109 | for seq in tokens_seqs
110 | ]
111 | )
112 |
113 | return tokens_batch, tokens_lens
114 |
115 |
116 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
117 | text_tokens_path = Path(text_tokens_file)
118 | unique_tokens = SymbolTable.from_file(text_tokens_path)
119 | collater = TextTokenCollater(
120 | unique_tokens.symbols, add_bos=True, add_eos=True
121 | )
122 | return collater
123 |
--------------------------------------------------------------------------------
/valle/data/datamodule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import argparse
19 | import inspect
20 | import logging
21 | from functools import lru_cache
22 | from pathlib import Path
23 | from typing import Any, Dict, Optional
24 |
25 | import torch
26 | from icefall.utils import str2bool
27 | from lhotse import CutSet, load_manifest_lazy
28 | from lhotse.dataset import (
29 | CutConcatenate,
30 | DynamicBucketingSampler,
31 | PrecomputedFeatures,
32 | SimpleCutSampler,
33 | SpecAugment,
34 | )
35 | from lhotse.dataset.input_strategies import OnTheFlyFeatures
36 | from lhotse.utils import fix_random_seed
37 | from torch.utils.data import DataLoader
38 |
39 | from valle.data.collation import get_text_token_collater
40 | from valle.data.dataset import SpeechSynthesisDataset
41 | from valle.data.fbank import get_fbank_extractor
42 | from valle.data.input_strategies import PromptedPrecomputedFeatures
43 |
44 | PrecomputedFeatures = PrecomputedFeatures
45 |
46 |
47 | class _SeedWorkers:
48 | def __init__(self, seed: int):
49 | self.seed = seed
50 |
51 | def __call__(self, worker_id: int):
52 | fix_random_seed(self.seed + worker_id)
53 |
54 |
55 | def _get_input_strategy(input_strategy, dataset, cuts):
56 | if input_strategy == "PromptedPrecomputedFeatures":
57 | return PromptedPrecomputedFeatures(dataset, cuts)
58 |
59 | return eval(input_strategy)()
60 |
61 |
62 | class TtsDataModule:
63 | """
64 | DataModule for VALL-E TTS experiments.
65 | It assumes there is always one train and valid dataloader.
66 |
67 | It contains all the common data pipeline modules used in TTS
68 | experiments, e.g.:
69 | - dynamic batch size,
70 | - bucketing samplers,
71 | - cut concatenation[not used & tested yet],
72 | - augmentation[not used & tested yet],
73 | - on-the-fly feature extraction[not used & tested yet]
74 |
75 | This class should be derived for specific corpora used in TTS tasks.
76 | """
77 |
78 | def __init__(self, args: argparse.Namespace):
79 | self.args = args
80 |
81 | @classmethod
82 | def add_arguments(cls, parser: argparse.ArgumentParser):
83 | group = parser.add_argument_group(
84 | title="TTS data related options",
85 | description="These options are used for the preparation of "
86 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87 | "effective batch sizes, sampling strategies, applied data "
88 | "augmentations, etc.",
89 | )
90 | group.add_argument(
91 | "--manifest-dir",
92 | type=Path,
93 | default=Path("data/tokenized"),
94 | help="Path to directory with train/valid/test cuts.",
95 | )
96 | group.add_argument(
97 | "--max-duration",
98 | type=int,
99 | default=40.0,
100 | help="Maximum pooled recordings duration (seconds) in a "
101 | "single batch. You can reduce it if it causes CUDA OOM.",
102 | )
103 | group.add_argument(
104 | "--bucketing-sampler",
105 | type=str2bool,
106 | default=True,
107 | help="When enabled, the batches will come from buckets of "
108 | "similar duration (saves padding frames).",
109 | )
110 | group.add_argument(
111 | "--num-buckets",
112 | type=int,
113 | default=10,
114 | help="The number of buckets for the DynamicBucketingSampler"
115 | "(you might want to increase it for larger datasets).",
116 | )
117 | group.add_argument(
118 | "--concatenate-cuts",
119 | type=str2bool,
120 | default=False,
121 | help="When enabled, utterances (cuts) will be concatenated "
122 | "to minimize the amount of padding.",
123 | )
124 | group.add_argument(
125 | "--duration-factor",
126 | type=float,
127 | default=1.0,
128 | help="Determines the maximum duration of a concatenated cut "
129 | "relative to the duration of the longest cut in a batch.",
130 | )
131 | group.add_argument(
132 | "--gap",
133 | type=float,
134 | default=0.1,
135 | help="The amount of padding (in seconds) inserted between "
136 | "concatenated cuts. This padding is filled with noise when "
137 | "noise augmentation is used.",
138 | )
139 | group.add_argument(
140 | "--on-the-fly-feats",
141 | type=str2bool,
142 | default=False,
143 | help="When enabled, use on-the-fly cut mixing and feature "
144 | "extraction. Will drop existing precomputed feature manifests "
145 | "if available.",
146 | )
147 | group.add_argument(
148 | "--shuffle",
149 | type=str2bool,
150 | default=True,
151 | help="When enabled (=default), the examples will be "
152 | "shuffled for each epoch.",
153 | )
154 | group.add_argument(
155 | "--buffer-size",
156 | type=int,
157 | default=40000,
158 | help="How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets."
159 | "Increasing ``max_duration`` (batch_size) or ``num_buckets`` might require increasing this number."
160 | "It will result in larger memory usage.",
161 | )
162 | group.add_argument(
163 | "--shuffle-buffer-size",
164 | type=int,
165 | default=100000,
166 | help="How many cuts (or cut pairs, triplets) are being held in memory"
167 | "a buffer used for streaming shuffling. Larger number means better randomness at the cost"
168 | "of higher memory usage.",
169 | )
170 | group.add_argument(
171 | "--drop-last",
172 | type=str2bool,
173 | default=False,
174 | help="Whether to drop last batch. Used by sampler.",
175 | )
176 | group.add_argument(
177 | "--return-cuts",
178 | type=str2bool,
179 | default=True,
180 | help="When enabled, each batch will have the "
181 | "field: batch['supervisions']['cut'] with the cuts that "
182 | "were used to construct it.",
183 | )
184 |
185 | group.add_argument(
186 | "--num-workers",
187 | type=int,
188 | default=8,
189 | help="The number of training dataloader workers that "
190 | "collect the batches.",
191 | )
192 |
193 | group.add_argument(
194 | "--enable-spec-aug",
195 | type=str2bool,
196 | default=False,
197 | help="When enabled, use SpecAugment for training dataset.",
198 | )
199 |
200 | group.add_argument(
201 | "--spec-aug-time-warp-factor",
202 | type=int,
203 | default=80,
204 | help="Used only when --enable-spec-aug is True. "
205 | "It specifies the factor for time warping in SpecAugment. "
206 | "Larger values mean more warping. "
207 | "A value less than 1 means to disable time warp.",
208 | )
209 |
210 | group.add_argument(
211 | "--input-strategy",
212 | type=str,
213 | default="PrecomputedFeatures",
214 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
215 | )
216 |
217 | group.add_argument(
218 | "--dataset",
219 | type=str,
220 | default="libritts",
221 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
222 | )
223 |
224 | parser.add_argument(
225 | "--text-tokens",
226 | type=str,
227 | default="data/tokenized/unique_text_tokens.k2symbols",
228 | help="Path to the unique text tokens file",
229 | )
230 |
231 | parser.add_argument(
232 | "--sampling-rate",
233 | type=int,
234 | default=24000,
235 | help="""Audio sampling rate.""",
236 | )
237 |
238 | def train_dataloaders(
239 | self,
240 | cuts_train: CutSet,
241 | sampler_state_dict: Optional[Dict[str, Any]] = None,
242 | ) -> DataLoader:
243 | """
244 | Args:
245 | cuts_train:
246 | CutSet for training.
247 | sampler_state_dict:
248 | The state dict for the training sampler.
249 | """
250 | transforms = []
251 |
252 | if self.args.concatenate_cuts:
253 | logging.info(
254 | f"Using cut concatenation with duration factor "
255 | f"{self.args.duration_factor} and gap {self.args.gap}."
256 | )
257 | # Cut concatenation should be the first transform in the list,
258 | # so that if we e.g. mix noise in, it will fill the gaps between
259 | # different utterances.
260 | transforms = [
261 | CutConcatenate(
262 | duration_factor=self.args.duration_factor, gap=self.args.gap
263 | )
264 | ] + transforms
265 |
266 | input_transforms = []
267 | if self.args.enable_spec_aug:
268 | logging.info("Enable SpecAugment")
269 | logging.info(
270 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
271 | )
272 | # Set the value of num_frame_masks according to Lhotse's version.
273 | # In different Lhotse's versions, the default of num_frame_masks is
274 | # different.
275 | num_frame_masks = 10
276 | num_frame_masks_parameter = inspect.signature(
277 | SpecAugment.__init__
278 | ).parameters["num_frame_masks"]
279 | if num_frame_masks_parameter.default == 1:
280 | num_frame_masks = 2
281 | logging.info(f"Num frame mask: {num_frame_masks}")
282 | input_transforms.append(
283 | SpecAugment(
284 | time_warp_factor=self.args.spec_aug_time_warp_factor,
285 | num_frame_masks=num_frame_masks,
286 | features_mask_size=27,
287 | num_feature_masks=2,
288 | frames_mask_size=100,
289 | )
290 | )
291 | else:
292 | logging.info("Disable SpecAugment")
293 |
294 | logging.info("About to create train dataset")
295 | if self.args.on_the_fly_feats:
296 | # NOTE: the PerturbSpeed transform should be added only if we
297 | # remove it from data prep stage.
298 | # Add on-the-fly speed perturbation; since originally it would
299 | # have increased epoch size by 3, we will apply prob 2/3 and use
300 | # 3x more epochs.
301 | # Speed perturbation probably should come first before
302 | # concatenation, but in principle the transforms order doesn't have
303 | # to be strict (e.g. could be randomized)
304 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
305 | # Drop feats to be on the safe side.
306 | train = SpeechSynthesisDataset(
307 | get_text_token_collater(self.args.text_tokens),
308 | cut_transforms=transforms,
309 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
310 | feature_transforms=input_transforms,
311 | )
312 | else:
313 | train = SpeechSynthesisDataset(
314 | get_text_token_collater(self.args.text_tokens),
315 | feature_input_strategy=_get_input_strategy(
316 | self.args.input_strategy, self.args.dataset, cuts_train
317 | ),
318 | cut_transforms=transforms,
319 | feature_transforms=input_transforms,
320 | )
321 |
322 | if self.args.bucketing_sampler:
323 | logging.info("Using DynamicBucketingSampler")
324 | train_sampler = DynamicBucketingSampler(
325 | cuts_train,
326 | max_duration=self.args.max_duration,
327 | shuffle=self.args.shuffle,
328 | buffer_size=self.args.buffer_size,
329 | shuffle_buffer_size=self.args.shuffle_buffer_size,
330 | quadratic_duration=10,
331 | num_cuts_for_bins_estimate=10000,
332 | drop_last=True,
333 | )
334 | else:
335 | logging.info(
336 | "Using SimpleCutSampler and sort by duraton(ascending=True)."
337 | )
338 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
339 | train_sampler = SimpleCutSampler(
340 | cuts_train,
341 | max_duration=self.args.max_duration,
342 | shuffle=self.args.shuffle,
343 | )
344 | logging.info("About to create train dataloader")
345 |
346 | if sampler_state_dict is not None:
347 | logging.info("Loading sampler state dict")
348 | train_sampler.load_state_dict(sampler_state_dict)
349 |
350 | # 'seed' is derived from the current random state, which will have
351 | # previously been set in the main process.
352 | seed = torch.randint(0, 100000, ()).item()
353 | worker_init_fn = _SeedWorkers(seed)
354 |
355 | train_dl = DataLoader(
356 | train,
357 | sampler=train_sampler,
358 | batch_size=None,
359 | num_workers=self.args.num_workers,
360 | persistent_workers=False,
361 | worker_init_fn=worker_init_fn,
362 | )
363 |
364 | return train_dl
365 |
366 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
367 | logging.info("About to create dev dataset")
368 | if self.args.on_the_fly_feats:
369 | validate = SpeechSynthesisDataset(
370 | get_text_token_collater(self.args.text_tokens),
371 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
372 | cut_transforms=[],
373 | )
374 | else:
375 | validate = SpeechSynthesisDataset(
376 | get_text_token_collater(self.args.text_tokens),
377 | feature_input_strategy=_get_input_strategy(
378 | self.args.input_strategy, self.args.dataset, cuts_valid
379 | ),
380 | cut_transforms=[],
381 | )
382 | valid_sampler = DynamicBucketingSampler(
383 | cuts_valid,
384 | max_duration=self.args.max_duration,
385 | shuffle=False,
386 | drop_last=True,
387 | )
388 | logging.info("About to create dev dataloader")
389 | valid_dl = DataLoader(
390 | validate,
391 | sampler=valid_sampler,
392 | batch_size=None,
393 | num_workers=4,
394 | persistent_workers=False,
395 | )
396 |
397 | return valid_dl
398 |
399 | def test_dataloaders(self, cuts: CutSet) -> DataLoader:
400 | logging.debug("About to create test dataset")
401 | test = SpeechSynthesisDataset(
402 | get_text_token_collater(self.args.text_tokens),
403 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
404 | if self.args.on_the_fly_feats
405 | else _get_input_strategy(
406 | self.args.input_strategy, self.args.dataset, cuts
407 | ),
408 | cut_transforms=[],
409 | )
410 | sampler = DynamicBucketingSampler(
411 | cuts,
412 | max_duration=self.args.max_duration,
413 | shuffle=False,
414 | drop_last=True,
415 | )
416 | logging.debug("About to create test dataloader")
417 | test_dl = DataLoader(
418 | test,
419 | batch_size=None,
420 | sampler=sampler,
421 | num_workers=self.args.num_workers,
422 | )
423 | return test_dl
424 |
425 | @lru_cache()
426 | def train_cuts(self) -> CutSet:
427 | logging.info("About to get train cuts")
428 | return load_manifest_lazy(
429 | self.args.manifest_dir / "cuts_train.jsonl.gz"
430 | )
431 |
432 | @lru_cache()
433 | def dev_cuts(self) -> CutSet:
434 | logging.info("About to get dev cuts")
435 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
436 |
437 | @lru_cache()
438 | def test_cuts(self) -> CutSet:
439 | logging.info("About to get test cuts")
440 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
441 |
--------------------------------------------------------------------------------
/valle/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """
18 | modified from lhoste.dataset.speech_synthesis.py
19 | """
20 |
21 | from typing import Callable, Dict, List, Sequence, Union
22 |
23 | import torch
24 | from lhotse import validate
25 | from lhotse.cut import CutSet
26 | from lhotse.dataset.collation import collate_audio
27 | from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
28 | from lhotse.utils import ifnone
29 |
30 | from valle.data.collation import TextTokenCollater
31 |
32 |
33 | class SpeechSynthesisDataset(torch.utils.data.Dataset):
34 | """
35 | The PyTorch Dataset for the speech synthesis(e.g. TTS) task.
36 | Each item in this dataset is a dict of:
37 |
38 | .. code-block::
39 |
40 | {
41 | 'audio': (B x NumSamples) float tensor
42 | 'audio_lens': (B, ) int tensor
43 | 'text': str
44 | 'audio_features': (B x NumFrames x NumFeatures) float tensor
45 | 'audio_features_lens': (B, ) int tensor
46 | 'text_tokens': (B x NumTextTokens) long tensor
47 | 'text_tokens_lens': (B, ) int tensor
48 | }
49 | """
50 |
51 | def __init__(
52 | self,
53 | text_token_collater: TextTokenCollater,
54 | cut_transforms: List[Callable[[CutSet], CutSet]] = None,
55 | feature_input_strategy: BatchIO = PrecomputedFeatures(),
56 | feature_transforms: Union[Sequence[Callable], Callable] = None,
57 | ) -> None:
58 | super().__init__()
59 |
60 | self.text_token_collater = text_token_collater
61 | self.cut_transforms = ifnone(cut_transforms, [])
62 | self.feature_input_strategy = feature_input_strategy
63 |
64 | if feature_transforms is None:
65 | feature_transforms = []
66 | elif not isinstance(feature_transforms, Sequence):
67 | feature_transforms = [feature_transforms]
68 |
69 | assert all(
70 | isinstance(transform, Callable) for transform in feature_transforms
71 | ), "Feature transforms must be Callable"
72 | self.feature_transforms = feature_transforms
73 |
74 | def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
75 | validate_for_tts(cuts)
76 |
77 | for transform in self.cut_transforms:
78 | cuts = transform(cuts)
79 |
80 | if False: # not used
81 | audio, audio_lens = collate_audio(cuts)
82 | else: # for sharing tokenized features in different machines
83 | audio, audio_lens = None, None
84 |
85 | audio_features, audio_features_lens = self.feature_input_strategy(cuts)
86 |
87 | for transform in self.feature_transforms:
88 | audio_features = transform(audio_features)
89 |
90 | text_tokens, text_tokens_lens = self.text_token_collater(
91 | [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
92 | )
93 |
94 | return {
95 | "utt_id": [cut.id for cut in cuts],
96 | "text": [cut.supervisions[0].text for cut in cuts],
97 | "audio": audio,
98 | "audio_lens": audio_lens,
99 | "audio_features": audio_features,
100 | "audio_features_lens": audio_features_lens,
101 | "text_tokens": text_tokens,
102 | "text_tokens_lens": text_tokens_lens,
103 | }
104 |
105 |
106 | def validate_for_tts(cuts: CutSet) -> None:
107 | validate(cuts)
108 | for cut in cuts:
109 | assert (
110 | len(cut.supervisions) == 1
111 | ), "Only the Cuts with single supervision are supported."
112 |
--------------------------------------------------------------------------------
/valle/data/fbank.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from dataclasses import asdict, dataclass
19 | from typing import Any, Dict, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from lhotse.features.base import FeatureExtractor
24 | from lhotse.utils import EPSILON, Seconds, compute_num_frames
25 | from librosa.filters import mel as librosa_mel_fn
26 |
27 |
28 | @dataclass
29 | class BigVGANFbankConfig:
30 | # Spectogram-related part
31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32 | frame_length: Seconds = 1024 / 24000.0
33 | frame_shift: Seconds = 256 / 24000.0
34 | remove_dc_offset: bool = True
35 | round_to_power_of_two: bool = True
36 |
37 | # Fbank-related part
38 | low_freq: float = 0.0
39 | high_freq: float = 12000.0
40 | num_mel_bins: int = 100
41 | use_energy: bool = False
42 |
43 | def to_dict(self) -> Dict[str, Any]:
44 | return asdict(self)
45 |
46 | @staticmethod
47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48 | return BigVGANFbankConfig(**data)
49 |
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 |
55 | def spectral_normalize_torch(magnitudes):
56 | output = dynamic_range_compression_torch(magnitudes)
57 | return output
58 |
59 |
60 | # https://github.com/NVIDIA/BigVGAN
61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62 | class BigVGANFbank(FeatureExtractor):
63 | name = "fbank"
64 | config_type = BigVGANFbankConfig
65 |
66 | def __init__(self, config: Optional[Any] = None):
67 | super(BigVGANFbank, self).__init__(config)
68 | sampling_rate = 24000
69 | self.mel_basis = torch.from_numpy(
70 | librosa_mel_fn(
71 | sampling_rate,
72 | 1024,
73 | self.config.num_mel_bins,
74 | self.config.low_freq,
75 | self.config.high_freq,
76 | ).astype(np.float32)
77 | )
78 | self.hann_window = torch.hann_window(1024)
79 |
80 | def _feature_fn(self, samples, **kwargs):
81 | win_length, n_fft = 1024, 1024
82 | hop_size = 256
83 | if True:
84 | sampling_rate = 24000
85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86 | expected_num_frames = compute_num_frames(
87 | duration=duration,
88 | frame_shift=self.frame_shift,
89 | sampling_rate=sampling_rate,
90 | )
91 | pad_size = (
92 | (expected_num_frames - 1) * hop_size
93 | + win_length
94 | - samples.shape[-1]
95 | )
96 | assert pad_size >= 0
97 |
98 | y = torch.nn.functional.pad(
99 | samples,
100 | (0, pad_size),
101 | mode="constant",
102 | )
103 | else:
104 | y = torch.nn.functional.pad(
105 | samples,
106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107 | mode="reflect",
108 | )
109 |
110 | y = y.squeeze(1)
111 |
112 | # complex tensor as default, then use view_as_real for future pytorch compatibility
113 | spec = torch.stft(
114 | y,
115 | n_fft,
116 | hop_length=hop_size,
117 | win_length=win_length,
118 | window=self.hann_window,
119 | center=False,
120 | pad_mode="reflect",
121 | normalized=False,
122 | onesided=True,
123 | return_complex=True,
124 | )
125 | spec = torch.view_as_real(spec)
126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127 |
128 | spec = torch.matmul(self.mel_basis, spec)
129 | spec = spectral_normalize_torch(spec)
130 |
131 | return spec.transpose(2, 1).squeeze(0)
132 |
133 | def extract(
134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135 | ) -> np.ndarray:
136 | assert sampling_rate == 24000
137 | params = asdict(self.config)
138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139 | params["frame_shift"] *= 1000.0
140 | params["frame_length"] *= 1000.0
141 | if not isinstance(samples, torch.Tensor):
142 | samples = torch.from_numpy(samples)
143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144 | if len(samples.shape) == 1:
145 | samples = samples.unsqueeze(0)
146 | features = self._feature_fn(samples, **params).to(torch.float32)
147 | return features.numpy()
148 |
149 | @property
150 | def frame_shift(self) -> Seconds:
151 | return self.config.frame_shift
152 |
153 | def feature_dim(self, sampling_rate: int) -> int:
154 | return self.config.num_mel_bins
155 |
156 | @staticmethod
157 | def mix(
158 | features_a: np.ndarray,
159 | features_b: np.ndarray,
160 | energy_scaling_factor_b: float,
161 | ) -> np.ndarray:
162 | return np.log(
163 | np.maximum(
164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165 | EPSILON,
166 | np.exp(features_a)
167 | + energy_scaling_factor_b * np.exp(features_b),
168 | )
169 | )
170 |
171 | @staticmethod
172 | def compute_energy(features: np.ndarray) -> float:
173 | return float(np.sum(np.exp(features)))
174 |
175 |
176 | def get_fbank_extractor() -> BigVGANFbank:
177 | return BigVGANFbank(BigVGANFbankConfig())
178 |
179 |
180 | if __name__ == "__main__":
181 | extractor = BigVGANFbank(BigVGANFbankConfig())
182 |
183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184 | samples = torch.clip(samples, -1.0, 1.0)
185 | fbank = extractor.extract(samples, 24000.0)
186 | print(f"fbank {fbank.shape}")
187 |
188 | from scipy.io.wavfile import read
189 |
190 | MAX_WAV_VALUE = 32768.0
191 |
192 | sampling_rate, samples = read(
193 | "egs/libritts/prompts/5639_40744_000000_000002.wav"
194 | )
195 | print(f"samples: [{samples.min()}, {samples.max()}]")
196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197 | print(f"fbank {fbank.shape}")
198 |
199 | import matplotlib.pyplot as plt
200 |
201 | _ = plt.figure(figsize=(18, 10))
202 | plt.imshow(
203 | X=fbank.transpose(1, 0),
204 | cmap=plt.get_cmap("jet"),
205 | aspect="auto",
206 | interpolation="nearest",
207 | )
208 | plt.gca().invert_yaxis()
209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210 | plt.close()
211 |
212 | print("fbank test PASS!")
213 |
--------------------------------------------------------------------------------
/valle/data/input_strategies.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from concurrent.futures import ThreadPoolExecutor
4 | from typing import Tuple, Type
5 |
6 | from lhotse import CutSet
7 | from lhotse.dataset.collation import collate_features
8 | from lhotse.dataset.input_strategies import (
9 | ExecutorType,
10 | PrecomputedFeatures,
11 | _get_executor,
12 | )
13 | from lhotse.utils import fastcopy
14 |
15 |
16 | class PromptedFeatures:
17 | def __init__(self, prompts, features):
18 | self.prompts = prompts
19 | self.features = features
20 |
21 | def to(self, device):
22 | return PromptedFeatures(
23 | self.prompts.to(device), self.features.to(device)
24 | )
25 |
26 | def sum(self):
27 | return self.features.sum()
28 |
29 | @property
30 | def ndim(self):
31 | return self.features.ndim
32 |
33 | @property
34 | def data(self):
35 | return (self.prompts, self.features)
36 |
37 |
38 | class PromptedPrecomputedFeatures(PrecomputedFeatures):
39 | """
40 | :class:`InputStrategy` that reads pre-computed features, whose manifests
41 | are attached to cuts, from disk.
42 |
43 | It automatically pads the feature matrices with pre or post feature.
44 |
45 | .. automethod:: __call__
46 | """
47 |
48 | def __init__(
49 | self,
50 | dataset: str,
51 | cuts: CutSet,
52 | num_workers: int = 0,
53 | executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54 | ) -> None:
55 | super(PromptedPrecomputedFeatures, self).__init__(
56 | num_workers, executor_type
57 | )
58 |
59 | self.utt2neighbors = defaultdict(lambda: [])
60 |
61 | if dataset.lower() == "libritts":
62 | # 909_131041_000013_000002
63 | # 909_131041_000013_000003
64 | speaker2utts = defaultdict(lambda: [])
65 |
66 | utt2cut = {}
67 | for cut in cuts:
68 | speaker = cut.supervisions[0].speaker
69 | speaker2utts[speaker].append(cut.id)
70 | utt2cut[cut.id] = cut
71 |
72 | for spk in speaker2utts:
73 | uttids = sorted(speaker2utts[spk])
74 | # Using the property of sorted keys to find previous utterance
75 | # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76 | if len(uttids) == 1:
77 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78 | continue
79 |
80 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81 | utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82 |
83 | for utt in utt2prevutt:
84 | self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85 |
86 | for utt in utt2postutt:
87 | self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88 | elif dataset.lower() == "ljspeech":
89 | utt2cut = {}
90 | uttids = []
91 | for cut in cuts:
92 | uttids.append(cut.id)
93 | utt2cut[cut.id] = cut
94 |
95 | if len(uttids) == 1:
96 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97 | else:
98 | # Using the property of sorted keys to find previous utterance
99 | # The keys has structure: LJ001-0010
100 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101 | utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102 |
103 | for utt in utt2postutt:
104 | postutt = utt2postutt[utt]
105 | if utt[:5] == postutt[:5]:
106 | self.utt2neighbors[utt].append(utt2cut[postutt])
107 |
108 | for utt in utt2prevutt:
109 | prevutt = utt2prevutt[utt]
110 | if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111 | self.utt2neighbors[utt].append(utt2cut[prevutt])
112 | else:
113 | raise ValueError
114 |
115 | def __call__(
116 | self, cuts: CutSet
117 | ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118 | """
119 | Reads the pre-computed features from disk/other storage.
120 | The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121 |
122 | :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123 | """
124 | features, features_lens = collate_features(
125 | cuts,
126 | executor=_get_executor(
127 | self.num_workers, executor_type=self._executor_type
128 | ),
129 | )
130 |
131 | prompts_cuts = []
132 | for k, cut in enumerate(cuts):
133 | prompts_cut = random.choice(self.utt2neighbors[cut.id])
134 | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135 |
136 | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137 | # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138 | # max_duration=mini_duration,
139 | # offset_type="random",
140 | # preserve_id=True,
141 | # )
142 | prompts_cuts = CutSet(
143 | cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144 | ).truncate(
145 | max_duration=mini_duration,
146 | offset_type="random",
147 | preserve_id=False,
148 | )
149 |
150 | prompts, prompts_lens = collate_features(
151 | prompts_cuts,
152 | executor=_get_executor(
153 | self.num_workers, executor_type=self._executor_type
154 | ),
155 | )
156 |
157 | return PromptedFeatures(prompts, features), PromptedFeatures(
158 | prompts_lens, features_lens
159 | )
160 |
--------------------------------------------------------------------------------
/valle/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import re
17 | from dataclasses import asdict, dataclass
18 | from typing import Any, Dict, List, Optional, Pattern, Union
19 |
20 | import numpy as np
21 | import torch
22 | import torchaudio
23 | from encodec import EncodecModel
24 | from encodec.utils import convert_audio
25 | from lhotse.features import FeatureExtractor
26 | from lhotse.utils import Seconds, compute_num_frames
27 | from phonemizer.backend import EspeakBackend
28 | from phonemizer.backend.espeak.language_switch import LanguageSwitch
29 | from phonemizer.backend.espeak.words_mismatch import WordMismatch
30 | from phonemizer.punctuation import Punctuation
31 | from phonemizer.separator import Separator
32 |
33 | try:
34 | from pypinyin import Style, pinyin
35 | from pypinyin.style._utils import get_finals, get_initials
36 | except Exception:
37 | pass
38 |
39 |
40 | class PypinyinBackend:
41 | """PypinyinBackend for Chinese. Most codes is referenced from espnet.
42 | There are two types pinyin or initials_finals, one is
43 | just like "ni1 hao3", the other is like "n i1 h ao3".
44 | """
45 |
46 | def __init__(
47 | self,
48 | backend="initials_finals",
49 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
50 | ) -> None:
51 | self.backend = backend
52 | self.punctuation_marks = punctuation_marks
53 |
54 | def phonemize(
55 | self, text: List[str], separator: Separator, strip=True, njobs=1
56 | ) -> List[str]:
57 | assert isinstance(text, List)
58 | phonemized = []
59 | for _text in text:
60 | _text = re.sub(" +", " ", _text.strip())
61 | _text = _text.replace(" ", separator.word)
62 | phones = []
63 | if self.backend == "pypinyin":
64 | for n, py in enumerate(
65 | pinyin(
66 | _text, style=Style.TONE3, neutral_tone_with_five=True
67 | )
68 | ):
69 | if all([c in self.punctuation_marks for c in py[0]]):
70 | if len(phones):
71 | assert phones[-1] == separator.syllable
72 | phones.pop(-1)
73 |
74 | phones.extend(list(py[0]))
75 | else:
76 | phones.extend([py[0], separator.syllable])
77 | elif self.backend == "pypinyin_initials_finals":
78 | for n, py in enumerate(
79 | pinyin(
80 | _text, style=Style.TONE3, neutral_tone_with_five=True
81 | )
82 | ):
83 | if all([c in self.punctuation_marks for c in py[0]]):
84 | if len(phones):
85 | assert phones[-1] == separator.syllable
86 | phones.pop(-1)
87 | phones.extend(list(py[0]))
88 | else:
89 | if py[0][-1].isalnum():
90 | initial = get_initials(py[0], strict=False)
91 | if py[0][-1].isdigit():
92 | final = (
93 | get_finals(py[0][:-1], strict=False)
94 | + py[0][-1]
95 | )
96 | else:
97 | final = get_finals(py[0], strict=False)
98 | phones.extend(
99 | [
100 | initial,
101 | separator.phone,
102 | final,
103 | separator.syllable,
104 | ]
105 | )
106 | else:
107 | assert ValueError
108 | else:
109 | raise NotImplementedError
110 | phonemized.append(
111 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
112 | )
113 | return phonemized
114 |
115 |
116 | class TextTokenizer:
117 | """Phonemize Text."""
118 |
119 | def __init__(
120 | self,
121 | language="en-us",
122 | backend="espeak",
123 | separator=Separator(word="_", syllable="-", phone="|"),
124 | preserve_punctuation=True,
125 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
126 | with_stress: bool = False,
127 | tie: Union[bool, str] = False,
128 | language_switch: LanguageSwitch = "keep-flags",
129 | words_mismatch: WordMismatch = "ignore",
130 | ) -> None:
131 | if backend == "espeak":
132 | phonemizer = EspeakBackend(
133 | language,
134 | punctuation_marks=punctuation_marks,
135 | preserve_punctuation=preserve_punctuation,
136 | with_stress=with_stress,
137 | tie=tie,
138 | language_switch=language_switch,
139 | words_mismatch=words_mismatch,
140 | )
141 | elif backend in ["pypinyin", "pypinyin_initials_finals"]:
142 | phonemizer = PypinyinBackend(
143 | backend=backend,
144 | punctuation_marks=punctuation_marks + separator.word,
145 | )
146 | else:
147 | raise NotImplementedError(f"{backend}")
148 |
149 | self.backend = phonemizer
150 | self.separator = separator
151 |
152 | def to_list(self, phonemized: str) -> List[str]:
153 | fields = []
154 | for word in phonemized.split(self.separator.word):
155 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
156 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
157 | fields.extend(
158 | [p for p in pp if p != self.separator.phone]
159 | + [self.separator.word]
160 | )
161 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
162 | self.separator.phone
163 | )
164 | return fields[:-1]
165 |
166 | def __call__(self, text, strip=True) -> List[List[str]]:
167 | if isinstance(text, str):
168 | text = [text]
169 |
170 | phonemized = self.backend.phonemize(
171 | text, separator=self.separator, strip=strip, njobs=1
172 | )
173 | return [self.to_list(p) for p in phonemized]
174 |
175 |
176 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
177 | phonemes = tokenizer([text.strip()])
178 | return phonemes[0] # k2symbols
179 |
180 |
181 | def remove_encodec_weight_norm(model):
182 | from encodec.modules import SConv1d
183 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
184 | from torch.nn.utils import remove_weight_norm
185 |
186 | encoder = model.encoder.model
187 | for key in encoder._modules:
188 | if isinstance(encoder._modules[key], SEANetResnetBlock):
189 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
190 | block_modules = encoder._modules[key].block._modules
191 | for skey in block_modules:
192 | if isinstance(block_modules[skey], SConv1d):
193 | remove_weight_norm(block_modules[skey].conv.conv)
194 | elif isinstance(encoder._modules[key], SConv1d):
195 | remove_weight_norm(encoder._modules[key].conv.conv)
196 |
197 | decoder = model.decoder.model
198 | for key in decoder._modules:
199 | if isinstance(decoder._modules[key], SEANetResnetBlock):
200 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
201 | block_modules = decoder._modules[key].block._modules
202 | for skey in block_modules:
203 | if isinstance(block_modules[skey], SConv1d):
204 | remove_weight_norm(block_modules[skey].conv.conv)
205 | elif isinstance(decoder._modules[key], SConvTranspose1d):
206 | remove_weight_norm(decoder._modules[key].convtr.convtr)
207 | elif isinstance(decoder._modules[key], SConv1d):
208 | remove_weight_norm(decoder._modules[key].conv.conv)
209 |
210 |
211 | class AudioTokenizer:
212 | """EnCodec audio."""
213 |
214 | def __init__(
215 | self,
216 | device: Any = None,
217 | ) -> None:
218 | # Instantiate a pretrained EnCodec model
219 | model = EncodecModel.encodec_model_24khz()
220 | model.set_target_bandwidth(6.0)
221 | remove_encodec_weight_norm(model)
222 |
223 | if not device:
224 | device = torch.device("cpu")
225 | if torch.cuda.is_available():
226 | device = torch.device("cuda:0")
227 |
228 | self._device = device
229 |
230 | self.codec = model.to(device)
231 | self.sample_rate = model.sample_rate
232 | self.channels = model.channels
233 |
234 | @property
235 | def device(self):
236 | return self._device
237 |
238 | def encode(self, wav: torch.Tensor) -> torch.Tensor:
239 | return self.codec.encode(wav.to(self.device))
240 |
241 | def decode(self, frames: torch.Tensor) -> torch.Tensor:
242 | return self.codec.decode(frames)
243 |
244 |
245 | def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
246 | # Load and pre-process the audio waveform
247 | wav, sr = torchaudio.load(audio_path)
248 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
249 | wav = wav.unsqueeze(0)
250 |
251 | # Extract discrete codes from EnCodec
252 | with torch.no_grad():
253 | encoded_frames = tokenizer.encode(wav)
254 | return encoded_frames
255 |
256 |
257 | @dataclass
258 | class AudioTokenConfig:
259 | frame_shift: Seconds = 320.0 / 24000
260 | num_quantizers: int = 8
261 |
262 | def to_dict(self) -> Dict[str, Any]:
263 | return asdict(self)
264 |
265 | @staticmethod
266 | def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
267 | return AudioTokenConfig(**data)
268 |
269 |
270 | class AudioTokenExtractor(FeatureExtractor):
271 | name = "encodec"
272 | config_type = AudioTokenConfig
273 |
274 | def __init__(self, config: Optional[Any] = None):
275 | super(AudioTokenExtractor, self).__init__(config)
276 | self.tokenizer = AudioTokenizer()
277 |
278 | def extract(
279 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
280 | ) -> np.ndarray:
281 | if not isinstance(samples, torch.Tensor):
282 | samples = torch.from_numpy(samples)
283 | if sampling_rate != self.tokenizer.sample_rate:
284 | samples = convert_audio(
285 | samples,
286 | sampling_rate,
287 | self.tokenizer.sample_rate,
288 | self.tokenizer.channels,
289 | )
290 | if len(samples.shape) == 2:
291 | samples = samples.unsqueeze(0)
292 | else:
293 | raise ValueError()
294 |
295 | device = self.tokenizer.device
296 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
297 | codes = encoded_frames[0][0] # [B, n_q, T]
298 | if True:
299 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
300 | expected_num_frames = compute_num_frames(
301 | duration=duration,
302 | frame_shift=self.frame_shift,
303 | sampling_rate=sampling_rate,
304 | )
305 | assert abs(codes.shape[-1] - expected_num_frames) <= 1
306 | codes = codes[..., :expected_num_frames]
307 | return codes.cpu().squeeze(0).permute(1, 0).numpy()
308 |
309 | @property
310 | def frame_shift(self) -> Seconds:
311 | return self.config.frame_shift
312 |
313 | def feature_dim(self, sampling_rate: int) -> int:
314 | return self.config.num_quantizers
315 |
316 | def pad_tensor_list(self, tensor_list, device, padding_value=0):
317 | # 计算每个张量的长度
318 | lengths = [tensor.shape[0] for tensor in tensor_list]
319 | # 使用pad_sequence函数进行填充
320 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
321 | padded_tensor = torch.nn.utils.rnn.pad_sequence(
322 | tensor_list, batch_first=True, padding_value=padding_value
323 | )
324 | return padded_tensor, lengths
325 |
326 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
327 | samples = [wav.squeeze() for wav in samples]
328 | device = self.tokenizer.device
329 | samples, lengths = self.pad_tensor_list(samples, device)
330 | samples = samples.unsqueeze(1)
331 |
332 | if not isinstance(samples, torch.Tensor):
333 | samples = torch.from_numpy(samples)
334 | if len(samples.shape) != 3:
335 | raise ValueError()
336 | if sampling_rate != self.tokenizer.sample_rate:
337 | samples = [
338 | convert_audio(
339 | wav,
340 | sampling_rate,
341 | self.tokenizer.sample_rate,
342 | self.tokenizer.channels,
343 | )
344 | for wav in samples
345 | ]
346 | samples = torch.stack(samples, 0) # convert samples from list to tensor
347 | # Extract discrete codes from EnCodec
348 | with torch.no_grad():
349 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
350 | encoded_frames = encoded_frames[0][0] # [B, n_q, T]
351 | batch_codes = []
352 | for b, length in enumerate(lengths):
353 | codes = encoded_frames[b]
354 | duration = round(length / sampling_rate, ndigits=12)
355 | expected_num_frames = compute_num_frames(
356 | duration=duration,
357 | frame_shift=self.frame_shift,
358 | sampling_rate=sampling_rate,
359 | )
360 | batch_codes.append(codes[..., :expected_num_frames])
361 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
362 |
363 |
364 | if __name__ == "__main__":
365 | model = EncodecModel.encodec_model_24khz()
366 | model.set_target_bandwidth(6.0)
367 |
368 | samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
369 | torch.float32
370 | )
371 | codes_raw = model.encode(samples)
372 |
373 | remove_encodec_weight_norm(model)
374 | codes_norm = model.encode(samples)
375 |
376 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
377 |
--------------------------------------------------------------------------------
/valle/models/__init__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch.nn as nn
4 | from icefall.utils import AttributeDict, str2bool
5 |
6 | from .macros import (
7 | NUM_AUDIO_TOKENS,
8 | NUM_MEL_BINS,
9 | NUM_SPEAKER_CLASSES,
10 | NUM_TEXT_TOKENS,
11 | SPEAKER_EMBEDDING_DIM,
12 | )
13 | from .transformer import Transformer
14 | from .valle import VALLE, VALLF
15 | from .visualizer import visualize
16 |
17 |
18 | def add_model_arguments(parser: argparse.ArgumentParser):
19 | parser.add_argument(
20 | "--model-name",
21 | type=str,
22 | default="VALL-E",
23 | help="VALL-E, VALL-F, Transformer.",
24 | )
25 | parser.add_argument(
26 | "--decoder-dim",
27 | type=int,
28 | default=1024,
29 | help="Embedding dimension in the decoder model.",
30 | )
31 | parser.add_argument(
32 | "--nhead",
33 | type=int,
34 | default=16,
35 | help="Number of attention heads in the Decoder layers.",
36 | )
37 | parser.add_argument(
38 | "--num-decoder-layers",
39 | type=int,
40 | default=12,
41 | help="Number of Decoder layers.",
42 | )
43 | parser.add_argument(
44 | "--scale-factor",
45 | type=float,
46 | default=1.0,
47 | help="Model scale factor which will be assigned different meanings in different models.",
48 | )
49 | parser.add_argument(
50 | "--norm-first",
51 | type=str2bool,
52 | default=True,
53 | help="Pre or Post Normalization.",
54 | )
55 | parser.add_argument(
56 | "--add-prenet",
57 | type=str2bool,
58 | default=False,
59 | help="Whether add PreNet after Inputs.",
60 | )
61 |
62 | # VALL-E & F
63 | parser.add_argument(
64 | "--prefix-mode",
65 | type=int,
66 | default=0,
67 | help="The mode for how to prefix VALL-E NAR Decoder, "
68 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
69 | )
70 | parser.add_argument(
71 | "--share-embedding",
72 | type=str2bool,
73 | default=True,
74 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
75 | )
76 | parser.add_argument(
77 | "--prepend-bos",
78 | type=str2bool,
79 | default=False,
80 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.",
81 | )
82 | parser.add_argument(
83 | "--num-quantizers",
84 | type=int,
85 | default=8,
86 | help="Number of Audio/Semantic quantization layers.",
87 | )
88 |
89 | # Transformer
90 | parser.add_argument(
91 | "--scaling-xformers",
92 | type=str2bool,
93 | default=False,
94 | help="Apply Reworked Conformer scaling on Transformers.",
95 | )
96 |
97 |
98 | def get_model(params: AttributeDict) -> nn.Module:
99 | if params.model_name.lower() in ["vall-f", "vallf"]:
100 | model = VALLF(
101 | params.decoder_dim,
102 | params.nhead,
103 | params.num_decoder_layers,
104 | norm_first=params.norm_first,
105 | add_prenet=params.add_prenet,
106 | prefix_mode=params.prefix_mode,
107 | share_embedding=params.share_embedding,
108 | nar_scale_factor=params.scale_factor,
109 | prepend_bos=params.prepend_bos,
110 | num_quantizers=params.num_quantizers,
111 | )
112 | elif params.model_name.lower() in ["vall-e", "valle"]:
113 | model = VALLE(
114 | params.decoder_dim,
115 | params.nhead,
116 | params.num_decoder_layers,
117 | norm_first=params.norm_first,
118 | add_prenet=params.add_prenet,
119 | prefix_mode=params.prefix_mode,
120 | share_embedding=params.share_embedding,
121 | nar_scale_factor=params.scale_factor,
122 | prepend_bos=params.prepend_bos,
123 | num_quantizers=params.num_quantizers,
124 | )
125 | else:
126 | assert params.model_name in ["Transformer"]
127 | model = Transformer(
128 | params.decoder_dim,
129 | params.nhead,
130 | params.num_decoder_layers,
131 | norm_first=params.norm_first,
132 | add_prenet=params.add_prenet,
133 | scaling_xformers=params.scaling_xformers,
134 | )
135 |
136 | return model
137 |
--------------------------------------------------------------------------------
/valle/models/macros.py:
--------------------------------------------------------------------------------
1 | # Text
2 | NUM_TEXT_TOKENS = 512
3 |
4 | # Audio
5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7 |
8 |
9 | # Speaker
10 | NUM_SPEAKER_CLASSES = 4096
11 | SPEAKER_EMBEDDING_DIM = 64
12 |
--------------------------------------------------------------------------------
/valle/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from functools import partial
16 | from typing import Any, Dict, List, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from icefall.utils import make_pad_mask
22 | from torchmetrics.classification import BinaryAccuracy
23 |
24 | from valle.models.valle import Transpose
25 | from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
26 | from valle.modules.scaling import BalancedDoubleSwish, ScaledLinear
27 | from valle.modules.transformer import (
28 | BalancedBasicNorm,
29 | IdentityNorm,
30 | TransformerDecoderLayer,
31 | TransformerEncoder,
32 | TransformerEncoderLayer,
33 | )
34 |
35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36 | from .visualizer import visualize
37 |
38 | IdentityNorm = IdentityNorm
39 |
40 |
41 | class Transformer(nn.Module):
42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43 | Neural Speech Synthesis with Transformer Network
44 | https://arxiv.org/abs/1809.08895
45 | """
46 |
47 | def __init__(
48 | self,
49 | d_model: int,
50 | nhead: int,
51 | num_layers: int,
52 | norm_first: bool = True,
53 | add_prenet: bool = False,
54 | scaling_xformers: bool = False,
55 | ):
56 | """
57 | Args:
58 | d_model:
59 | The number of expected features in the input (required).
60 | nhead:
61 | The number of heads in the multiheadattention models (required).
62 | num_layers:
63 | The number of sub-decoder-layers in the decoder (required).
64 | """
65 | super().__init__()
66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67 |
68 | if add_prenet:
69 | self.encoder_prenet = nn.Sequential(
70 | Transpose(),
71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72 | nn.BatchNorm1d(d_model),
73 | nn.ReLU(),
74 | nn.Dropout(0.5),
75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76 | nn.BatchNorm1d(d_model),
77 | nn.ReLU(),
78 | nn.Dropout(0.5),
79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80 | nn.BatchNorm1d(d_model),
81 | nn.ReLU(),
82 | nn.Dropout(0.5),
83 | Transpose(),
84 | nn.Linear(d_model, d_model),
85 | )
86 |
87 | self.decoder_prenet = nn.Sequential(
88 | nn.Linear(NUM_MEL_BINS, 256),
89 | nn.ReLU(),
90 | nn.Dropout(0.5),
91 | nn.Linear(256, 256),
92 | nn.ReLU(),
93 | nn.Dropout(0.5),
94 | nn.Linear(256, d_model),
95 | )
96 |
97 | assert scaling_xformers is False # TODO: update this block
98 | else:
99 | self.encoder_prenet = nn.Identity()
100 | if scaling_xformers:
101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102 | else:
103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104 |
105 | self.encoder_position = SinePositionalEmbedding(
106 | d_model,
107 | dropout=0.1,
108 | scale=False,
109 | )
110 | self.decoder_position = SinePositionalEmbedding(
111 | d_model, dropout=0.1, scale=False
112 | )
113 |
114 | if scaling_xformers:
115 | self.encoder = TransformerEncoder(
116 | TransformerEncoderLayer(
117 | d_model,
118 | nhead,
119 | dim_feedforward=d_model * 4,
120 | dropout=0.1,
121 | batch_first=True,
122 | norm_first=norm_first,
123 | linear1_self_attention_cls=ScaledLinear,
124 | linear2_self_attention_cls=partial(
125 | ScaledLinear, initial_scale=0.01
126 | ),
127 | linear1_feedforward_cls=ScaledLinear,
128 | linear2_feedforward_cls=partial(
129 | ScaledLinear, initial_scale=0.01
130 | ),
131 | activation=partial(
132 | BalancedDoubleSwish,
133 | channel_dim=-1,
134 | max_abs=10.0,
135 | min_prob=0.25,
136 | ),
137 | layer_norm_cls=IdentityNorm,
138 | ),
139 | num_layers=num_layers,
140 | norm=BalancedBasicNorm(d_model) if norm_first else None,
141 | )
142 |
143 | self.decoder = nn.TransformerDecoder(
144 | TransformerDecoderLayer(
145 | d_model,
146 | nhead,
147 | dim_feedforward=d_model * 4,
148 | dropout=0.1,
149 | batch_first=True,
150 | norm_first=norm_first,
151 | linear1_self_attention_cls=ScaledLinear,
152 | linear2_self_attention_cls=partial(
153 | ScaledLinear, initial_scale=0.01
154 | ),
155 | linear1_feedforward_cls=ScaledLinear,
156 | linear2_feedforward_cls=partial(
157 | ScaledLinear, initial_scale=0.01
158 | ),
159 | activation=partial(
160 | BalancedDoubleSwish,
161 | channel_dim=-1,
162 | max_abs=10.0,
163 | min_prob=0.25,
164 | ),
165 | layer_norm_cls=IdentityNorm,
166 | ),
167 | num_layers=num_layers,
168 | norm=BalancedBasicNorm(d_model) if norm_first else None,
169 | )
170 |
171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172 | self.stop_layer = nn.Linear(d_model, 1)
173 | else:
174 | self.encoder = nn.TransformerEncoder(
175 | nn.TransformerEncoderLayer(
176 | d_model,
177 | nhead,
178 | dim_feedforward=d_model * 4,
179 | activation=F.relu,
180 | dropout=0.1,
181 | batch_first=True,
182 | norm_first=norm_first,
183 | ),
184 | num_layers=num_layers,
185 | norm=nn.LayerNorm(d_model) if norm_first else None,
186 | )
187 |
188 | self.decoder = nn.TransformerDecoder(
189 | nn.TransformerDecoderLayer(
190 | d_model,
191 | nhead,
192 | dim_feedforward=d_model * 4,
193 | activation=F.relu,
194 | dropout=0.1,
195 | batch_first=True,
196 | norm_first=norm_first,
197 | ),
198 | num_layers=num_layers,
199 | norm=nn.LayerNorm(d_model) if norm_first else None,
200 | )
201 |
202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203 | self.stop_layer = nn.Linear(d_model, 1)
204 |
205 | self.stop_accuracy_metric = BinaryAccuracy(
206 | threshold=0.5, multidim_average="global"
207 | )
208 |
209 | # self.apply(self._init_weights)
210 |
211 | # def _init_weights(self, module):
212 | # if isinstance(module, (nn.Linear)):
213 | # module.weight.data.normal_(mean=0.0, std=0.02)
214 | # if isinstance(module, nn.Linear) and module.bias is not None:
215 | # module.bias.data.zero_()
216 | # elif isinstance(module, nn.LayerNorm):
217 | # module.bias.data.zero_()
218 | # module.weight.data.fill_(1.0)
219 | # elif isinstance(module, nn.Embedding):
220 | # module.weight.data.normal_(mean=0.0, std=0.02)
221 |
222 | def forward(
223 | self,
224 | x: torch.Tensor,
225 | x_lens: torch.Tensor,
226 | y: torch.Tensor,
227 | y_lens: torch.Tensor,
228 | reduction: str = "sum",
229 | train_stage: int = 0,
230 | **kwargs,
231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232 | """
233 | Args:
234 | x:
235 | A 2-D tensor of shape (N, S).
236 | x_lens:
237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238 | before padding.
239 | y:
240 | A 3-D tensor of shape (N, T, 8).
241 | y_lens:
242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243 | before padding.
244 | train_stage:
245 | Not used in this model.
246 | Returns:
247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248 | """
249 | del train_stage
250 |
251 | assert x.ndim == 2, x.shape
252 | assert x_lens.ndim == 1, x_lens.shape
253 | assert y.ndim == 3, y.shape
254 | assert y_lens.ndim == 1, y_lens.shape
255 |
256 | assert torch.all(x_lens > 0)
257 |
258 | # NOTE: x has been padded in TextTokenCollater
259 | x_mask = make_pad_mask(x_lens).to(x.device)
260 |
261 | x = self.text_embedding(x)
262 | x = self.encoder_prenet(x)
263 | x = self.encoder_position(x)
264 | x = self.encoder(x, src_key_padding_mask=x_mask)
265 |
266 | total_loss, metrics = 0.0, {}
267 |
268 | y_mask = make_pad_mask(y_lens).to(y.device)
269 | y_mask_float = y_mask.type(torch.float32)
270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271 |
272 | # Training
273 | # AR Decoder
274 | def pad_y(y):
275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276 | # inputs, targets
277 | return y[:, :-1], y[:, 1:]
278 |
279 | y, targets = pad_y(y * data_mask) # mask padding as zeros
280 |
281 | y_emb = self.decoder_prenet(y)
282 | y_pos = self.decoder_position(y_emb)
283 |
284 | y_len = y_lens.max()
285 | tgt_mask = torch.triu(
286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287 | diagonal=1,
288 | )
289 | y_dec = self.decoder(
290 | y_pos,
291 | x,
292 | tgt_mask=tgt_mask,
293 | memory_key_padding_mask=x_mask,
294 | )
295 |
296 | predict = self.predict_layer(y_dec)
297 | # loss
298 | total_loss = F.mse_loss(predict, targets, reduction=reduction)
299 |
300 | logits = self.stop_layer(y_dec).squeeze(-1)
301 | stop_loss = F.binary_cross_entropy_with_logits(
302 | logits,
303 | y_mask_float.detach(),
304 | weight=1.0 + y_mask_float.detach() * 4.0,
305 | reduction=reduction,
306 | )
307 | metrics["stop_loss"] = stop_loss.detach()
308 |
309 | stop_accuracy = self.stop_accuracy_metric(
310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311 | y_mask.type(torch.int64),
312 | )
313 | # icefall MetricsTracker.norm_items()
314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315 | torch.float32
316 | )
317 |
318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319 |
320 | def inference(
321 | self,
322 | x: torch.Tensor,
323 | x_lens: torch.Tensor,
324 | y: Any = None,
325 | **kwargs,
326 | ) -> torch.Tensor:
327 | """
328 | Args:
329 | x:
330 | A 2-D tensor of shape (1, S).
331 | x_lens:
332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333 | before padding.
334 | Returns:
335 | Return the predicted audio code matrix and cross-entropy loss.
336 | """
337 | assert x.ndim == 2, x.shape
338 | assert x_lens.ndim == 1, x_lens.shape
339 |
340 | assert torch.all(x_lens > 0)
341 |
342 | x_mask = make_pad_mask(x_lens).to(x.device)
343 |
344 | x = self.text_embedding(x)
345 | x = self.encoder_prenet(x)
346 | x = self.encoder_position(x)
347 | x = self.encoder(x, src_key_padding_mask=x_mask)
348 |
349 | x_mask = make_pad_mask(x_lens).to(x.device)
350 |
351 | # AR Decoder
352 | # TODO: Managing decoder steps avoid repetitive computation
353 | y = torch.zeros(
354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355 | )
356 | while True:
357 | y_emb = self.decoder_prenet(y)
358 | y_pos = self.decoder_position(y_emb)
359 |
360 | tgt_mask = torch.triu(
361 | torch.ones(
362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363 | ),
364 | diagonal=1,
365 | )
366 |
367 | y_dec = self.decoder(
368 | y_pos,
369 | x,
370 | tgt_mask=tgt_mask,
371 | memory_mask=None,
372 | memory_key_padding_mask=x_mask,
373 | )
374 | predict = self.predict_layer(y_dec[:, -1:])
375 |
376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378 | print(
379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380 | )
381 | break
382 |
383 | y = torch.concat([y, predict], dim=1)
384 |
385 | return y[:, 1:]
386 |
387 | def visualize(
388 | self,
389 | predicts: Tuple[torch.Tensor],
390 | batch: Dict[str, Union[List, torch.Tensor]],
391 | output_dir: str,
392 | limit: int = 4,
393 | ) -> None:
394 | visualize(predicts, batch, output_dir, limit=limit)
395 |
--------------------------------------------------------------------------------
/valle/models/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | from typing import Dict, List, Tuple, Union
20 |
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def visualize(
27 | predicts: Tuple[torch.Tensor],
28 | batch: Dict[str, Union[List, torch.Tensor]],
29 | output_dir: str,
30 | limit: int = 4,
31 | ) -> None:
32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34 | audio_features = batch["audio_features"].to("cpu").detach().numpy()
35 | audio_features_lens = (
36 | batch["audio_features_lens"].to("cpu").detach().numpy()
37 | )
38 | assert text_tokens.ndim == 2
39 |
40 | utt_ids, texts = batch["utt_id"], batch["text"]
41 |
42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43 | decoder_outputs = predicts[1]
44 | if isinstance(decoder_outputs, list):
45 | decoder_outputs = decoder_outputs[-1]
46 | decoder_outputs = (
47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48 | )
49 |
50 | vmin, vmax = 0, 1024 # Encodec
51 | if decoder_outputs.dtype == np.float32:
52 | vmin, vmax = -6, 0 # Fbank
53 |
54 | num_figures = 3
55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56 | _ = plt.figure(figsize=(14, 8 * num_figures))
57 |
58 | S = text_tokens_lens[b]
59 | T = audio_features_lens[b]
60 |
61 | # encoder
62 | plt.subplot(num_figures, 1, 1)
63 | plt.title(f"Text: {text}")
64 | plt.imshow(
65 | X=np.transpose(encoder_outputs[b]),
66 | cmap=plt.get_cmap("jet"),
67 | aspect="auto",
68 | interpolation="nearest",
69 | )
70 | plt.gca().invert_yaxis()
71 | plt.axvline(x=S - 0.4, linewidth=2, color="r")
72 | plt.xlabel("Encoder Output")
73 | plt.colorbar()
74 |
75 | # decoder
76 | plt.subplot(num_figures, 1, 2)
77 | plt.imshow(
78 | X=np.transpose(decoder_outputs[b]),
79 | cmap=plt.get_cmap("jet"),
80 | aspect="auto",
81 | interpolation="nearest",
82 | vmin=vmin,
83 | vmax=vmax,
84 | )
85 | plt.gca().invert_yaxis()
86 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
87 | plt.xlabel("Decoder Output")
88 | plt.colorbar()
89 |
90 | # target
91 | plt.subplot(num_figures, 1, 3)
92 | plt.imshow(
93 | X=np.transpose(audio_features[b]),
94 | cmap=plt.get_cmap("jet"),
95 | aspect="auto",
96 | interpolation="nearest",
97 | vmin=vmin,
98 | vmax=vmax,
99 | )
100 | plt.gca().invert_yaxis()
101 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
102 | plt.xlabel("Decoder Target")
103 | plt.colorbar()
104 |
105 | plt.savefig(f"{output_dir}/{utt_id}.png")
106 | plt.close()
107 |
--------------------------------------------------------------------------------
/valle/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lifeiteng/vall-e/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/modules/__init__.py
--------------------------------------------------------------------------------
/valle/modules/activation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Linear, Module
6 | from torch.nn import functional as F
7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
9 | from torch.nn.parameter import Parameter
10 |
11 |
12 | class MultiheadAttention(Module):
13 | r"""Allows the model to jointly attend to information
14 | from different representation subspaces as described in the paper:
15 | `Attention Is All You Need `_.
16 |
17 | Multi-Head Attention is defined as:
18 |
19 | .. math::
20 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
21 |
22 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
23 |
24 | ``forward()`` will use a special optimized implementation if all of the following
25 | conditions are met:
26 |
27 | - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
28 | restriction will be loosened in the future.)
29 | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
30 | - training is disabled (using ``.eval()``)
31 | - dropout is 0
32 | - ``add_bias_kv`` is ``False``
33 | - ``add_zero_attn`` is ``False``
34 | - ``batch_first`` is ``True`` and the input is batched
35 | - ``kdim`` and ``vdim`` are equal to ``embed_dim``
36 | - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
37 | - if a `NestedTensor `_ is passed, neither ``key_padding_mask``
38 | nor ``attn_mask`` is passed
39 |
40 | If the optimized implementation is in use, a
41 | `NestedTensor `_ can be passed for
42 | ``query``/``key``/``value`` to represent padding more efficiently than using a
43 | padding mask. In this case, a `NestedTensor `_
44 | will be returned, and an additional speedup proportional to the fraction of the input
45 | that is padding can be expected.
46 |
47 | Args:
48 | embed_dim: Total dimension of the model.
49 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
50 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
51 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
52 | bias: If specified, adds bias to input / output projection layers. Default: ``True``.
53 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
54 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
55 | Default: ``False``.
56 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
57 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
58 | batch_first: If ``True``, then the input and output tensors are provided
59 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
60 |
61 | Examples::
62 |
63 | >>> # xdoctest: +SKIP
64 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
65 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
66 |
67 | """
68 | __constants__ = ["batch_first"]
69 | bias_k: Optional[torch.Tensor]
70 | bias_v: Optional[torch.Tensor]
71 |
72 | def __init__(
73 | self,
74 | embed_dim,
75 | num_heads,
76 | dropout=0.0,
77 | bias=True,
78 | add_bias_kv=False,
79 | add_zero_attn=False,
80 | kdim=None,
81 | vdim=None,
82 | batch_first=False,
83 | linear1_cls=Linear,
84 | linear2_cls=Linear,
85 | device=None,
86 | dtype=None,
87 | ) -> None:
88 | factory_kwargs = {"device": device, "dtype": dtype}
89 | super(MultiheadAttention, self).__init__()
90 | self.embed_dim = embed_dim
91 | self.kdim = kdim if kdim is not None else embed_dim
92 | self.vdim = vdim if vdim is not None else embed_dim
93 | self._qkv_same_embed_dim = (
94 | self.kdim == embed_dim and self.vdim == embed_dim
95 | )
96 |
97 | self.num_heads = num_heads
98 | self.dropout = dropout
99 | self.batch_first = batch_first
100 | self.head_dim = embed_dim // num_heads
101 | assert (
102 | self.head_dim * num_heads == self.embed_dim
103 | ), "embed_dim must be divisible by num_heads"
104 |
105 | if add_bias_kv:
106 | self.bias_k = Parameter(
107 | torch.empty((1, 1, embed_dim), **factory_kwargs)
108 | )
109 | self.bias_v = Parameter(
110 | torch.empty((1, 1, embed_dim), **factory_kwargs)
111 | )
112 | else:
113 | self.bias_k = self.bias_v = None
114 |
115 | if linear1_cls == Linear:
116 | if not self._qkv_same_embed_dim:
117 | self.q_proj_weight = Parameter(
118 | torch.empty((embed_dim, embed_dim), **factory_kwargs)
119 | )
120 | self.k_proj_weight = Parameter(
121 | torch.empty((embed_dim, self.kdim), **factory_kwargs)
122 | )
123 | self.v_proj_weight = Parameter(
124 | torch.empty((embed_dim, self.vdim), **factory_kwargs)
125 | )
126 | self.register_parameter("in_proj_weight", None)
127 | else:
128 | self.in_proj_weight = Parameter(
129 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
130 | )
131 | self.register_parameter("q_proj_weight", None)
132 | self.register_parameter("k_proj_weight", None)
133 | self.register_parameter("v_proj_weight", None)
134 |
135 | if bias:
136 | self.in_proj_bias = Parameter(
137 | torch.empty(3 * embed_dim, **factory_kwargs)
138 | )
139 | else:
140 | self.register_parameter("in_proj_bias", None)
141 | self.out_proj = NonDynamicallyQuantizableLinear(
142 | embed_dim, embed_dim, bias=bias, **factory_kwargs
143 | )
144 |
145 | self._reset_parameters()
146 | else:
147 | if not self._qkv_same_embed_dim:
148 | raise NotImplementedError
149 | else:
150 | self.in_proj_linear = linear1_cls(
151 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
152 | )
153 | self.in_proj_weight = self.in_proj_linear.weight
154 |
155 | self.register_parameter("q_proj_weight", None)
156 | self.register_parameter("k_proj_weight", None)
157 | self.register_parameter("v_proj_weight", None)
158 |
159 | if bias:
160 | self.in_proj_bias = self.in_proj_linear.bias
161 | else:
162 | self.register_parameter("in_proj_bias", None)
163 |
164 | self.out_proj = linear2_cls(
165 | embed_dim, embed_dim, bias=bias, **factory_kwargs
166 | )
167 |
168 | if self.bias_k is not None:
169 | xavier_normal_(self.bias_k)
170 | if self.bias_v is not None:
171 | xavier_normal_(self.bias_v)
172 |
173 | self.add_zero_attn = add_zero_attn
174 |
175 | def _reset_parameters(self):
176 | if self._qkv_same_embed_dim:
177 | xavier_uniform_(self.in_proj_weight)
178 | else:
179 | xavier_uniform_(self.q_proj_weight)
180 | xavier_uniform_(self.k_proj_weight)
181 | xavier_uniform_(self.v_proj_weight)
182 |
183 | if self.in_proj_bias is not None:
184 | constant_(self.in_proj_bias, 0.0)
185 | constant_(self.out_proj.bias, 0.0)
186 |
187 | if self.bias_k is not None:
188 | xavier_normal_(self.bias_k)
189 | if self.bias_v is not None:
190 | xavier_normal_(self.bias_v)
191 |
192 | def __setstate__(self, state):
193 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0
194 | if "_qkv_same_embed_dim" not in state:
195 | state["_qkv_same_embed_dim"] = True
196 |
197 | super(MultiheadAttention, self).__setstate__(state)
198 |
199 | def forward(
200 | self,
201 | query: Tensor,
202 | key: Tensor,
203 | value: Tensor,
204 | key_padding_mask: Optional[Tensor] = None,
205 | need_weights: bool = True,
206 | attn_mask: Optional[Tensor] = None,
207 | average_attn_weights: bool = True,
208 | ) -> Tuple[Tensor, Optional[Tensor]]:
209 | r"""
210 | Args:
211 | query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
212 | or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
213 | :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
214 | Queries are compared against key-value pairs to produce the output.
215 | See "Attention Is All You Need" for more details.
216 | key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
217 | or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
218 | :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
219 | See "Attention Is All You Need" for more details.
220 | value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
221 | ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
222 | sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
223 | See "Attention Is All You Need" for more details.
224 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
225 | to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
226 | Binary and byte masks are supported.
227 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
228 | the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
229 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
230 | Default: ``True``.
231 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
232 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
233 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
234 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
235 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
236 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
237 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to
238 | the attention weight.
239 | average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
240 | heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
241 | effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
242 |
243 | Outputs:
244 | - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
245 | :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
246 | where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
247 | embedding dimension ``embed_dim``.
248 | - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
249 | returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
250 | :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
251 | :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
252 | head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
253 |
254 | .. note::
255 | `batch_first` argument is ignored for unbatched inputs.
256 | """
257 | is_batched = query.dim() == 3
258 | if key_padding_mask is not None:
259 | _kpm_dtype = key_padding_mask.dtype
260 | if _kpm_dtype != torch.bool and not torch.is_floating_point(
261 | key_padding_mask
262 | ):
263 | raise AssertionError(
264 | "only bool and floating types of key_padding_mask are supported"
265 | )
266 | why_not_fast_path = ""
267 | if not is_batched:
268 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
269 | elif query is not key or key is not value:
270 | # When lifting this restriction, don't forget to either
271 | # enforce that the dtypes all match or test cases where
272 | # they don't!
273 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
274 | elif (
275 | self.in_proj_bias is not None
276 | and query.dtype != self.in_proj_bias.dtype
277 | ):
278 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
279 | elif (
280 | self.in_proj_weight is not None
281 | and query.dtype != self.in_proj_weight.dtype
282 | ):
283 | # this case will fail anyway, but at least they'll get a useful error message.
284 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
285 | elif self.training:
286 | why_not_fast_path = "training is enabled"
287 | elif not self.batch_first:
288 | why_not_fast_path = "batch_first was not True"
289 | elif self.bias_k is not None:
290 | why_not_fast_path = "self.bias_k was not None"
291 | elif self.bias_v is not None:
292 | why_not_fast_path = "self.bias_v was not None"
293 | elif self.dropout:
294 | why_not_fast_path = f"dropout was {self.dropout}, required zero"
295 | elif self.add_zero_attn:
296 | why_not_fast_path = "add_zero_attn was enabled"
297 | elif not self._qkv_same_embed_dim:
298 | why_not_fast_path = "_qkv_same_embed_dim was not True"
299 | elif attn_mask is not None:
300 | why_not_fast_path = "attn_mask was not None"
301 | elif query.is_nested and key_padding_mask is not None:
302 | why_not_fast_path = (
303 | "key_padding_mask is not supported with NestedTensor input"
304 | )
305 | elif self.num_heads % 2 == 1:
306 | why_not_fast_path = "num_heads is odd"
307 | elif torch.is_autocast_enabled():
308 | why_not_fast_path = "autocast is enabled"
309 |
310 | if not why_not_fast_path:
311 | tensor_args = (
312 | query,
313 | key,
314 | value,
315 | self.in_proj_weight,
316 | self.in_proj_bias,
317 | self.out_proj.weight,
318 | self.out_proj.bias,
319 | )
320 | # We have to use list comprehensions below because TorchScript does not support
321 | # generator expressions.
322 | if torch.overrides.has_torch_function(tensor_args):
323 | why_not_fast_path = "some Tensor argument has_torch_function"
324 | elif not all(
325 | [
326 | (x is None or x.is_cuda or "cpu" in str(x.device))
327 | for x in tensor_args
328 | ]
329 | ):
330 | why_not_fast_path = (
331 | "some Tensor argument is neither CUDA nor CPU"
332 | )
333 | elif torch.is_grad_enabled() and any(
334 | [x is not None and x.requires_grad for x in tensor_args]
335 | ):
336 | why_not_fast_path = (
337 | "grad is enabled and at least one of query or the "
338 | "input/output projection weights or biases requires_grad"
339 | )
340 | if not why_not_fast_path:
341 | return torch._native_multi_head_attention(
342 | query,
343 | key,
344 | value,
345 | self.embed_dim,
346 | self.num_heads,
347 | self.in_proj_weight,
348 | self.in_proj_bias,
349 | self.out_proj.weight,
350 | self.out_proj.bias,
351 | key_padding_mask
352 | if key_padding_mask is not None
353 | else attn_mask,
354 | need_weights,
355 | average_attn_weights,
356 | 1
357 | if key_padding_mask is not None
358 | else 0
359 | if attn_mask is not None
360 | else None,
361 | )
362 |
363 | any_nested = query.is_nested or key.is_nested or value.is_nested
364 | assert not any_nested, (
365 | "MultiheadAttention does not support NestedTensor outside of its fast path. "
366 | + f"The fast path was not hit because {why_not_fast_path}"
367 | )
368 |
369 | if self.batch_first and is_batched:
370 | # make sure that the transpose op does not affect the "is" property
371 | if key is value:
372 | if query is key:
373 | query = key = value = query.transpose(1, 0)
374 | else:
375 | query, key = [x.transpose(1, 0) for x in (query, key)]
376 | value = key
377 | else:
378 | query, key, value = [
379 | x.transpose(1, 0) for x in (query, key, value)
380 | ]
381 |
382 | if not self._qkv_same_embed_dim:
383 | attn_output, attn_output_weights = F.multi_head_attention_forward(
384 | query,
385 | key,
386 | value,
387 | self.embed_dim,
388 | self.num_heads,
389 | self.in_proj_weight,
390 | self.in_proj_bias,
391 | self.bias_k,
392 | self.bias_v,
393 | self.add_zero_attn,
394 | self.dropout,
395 | self.out_proj.weight,
396 | self.out_proj.bias,
397 | training=self.training,
398 | key_padding_mask=key_padding_mask,
399 | need_weights=need_weights,
400 | attn_mask=attn_mask,
401 | use_separate_proj_weight=True,
402 | q_proj_weight=self.q_proj_weight,
403 | k_proj_weight=self.k_proj_weight,
404 | v_proj_weight=self.v_proj_weight,
405 | average_attn_weights=average_attn_weights,
406 | )
407 | else:
408 | attn_output, attn_output_weights = F.multi_head_attention_forward(
409 | query,
410 | key,
411 | value,
412 | self.embed_dim,
413 | self.num_heads,
414 | self.in_proj_weight,
415 | self.in_proj_bias,
416 | self.bias_k,
417 | self.bias_v,
418 | self.add_zero_attn,
419 | self.dropout,
420 | self.out_proj.weight,
421 | self.out_proj.bias,
422 | training=self.training,
423 | key_padding_mask=key_padding_mask,
424 | need_weights=need_weights,
425 | attn_mask=attn_mask,
426 | average_attn_weights=average_attn_weights,
427 | )
428 | if self.batch_first and is_batched:
429 | return attn_output.transpose(1, 0), attn_output_weights
430 | else:
431 | return attn_output, attn_output_weights
432 |
--------------------------------------------------------------------------------
/valle/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class TokenEmbedding(nn.Module):
22 | def __init__(
23 | self,
24 | dim_model: int,
25 | vocab_size: int,
26 | dropout: float = 0.0,
27 | ):
28 | super().__init__()
29 |
30 | self.vocab_size = vocab_size
31 | self.dim_model = dim_model
32 |
33 | self.dropout = torch.nn.Dropout(p=dropout)
34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35 |
36 | @property
37 | def weight(self) -> torch.Tensor:
38 | return self.word_embeddings.weight
39 |
40 | def embedding(self, index: int) -> torch.Tensor:
41 | return self.word_embeddings.weight[index : index + 1]
42 |
43 | def forward(self, x: torch.Tensor):
44 | X = self.word_embeddings(x)
45 | X = self.dropout(X)
46 |
47 | return X
48 |
49 |
50 | class SinePositionalEmbedding(nn.Module):
51 | def __init__(
52 | self,
53 | dim_model: int,
54 | dropout: float = 0.0,
55 | scale: bool = False,
56 | alpha: bool = False,
57 | ):
58 | super().__init__()
59 | self.dim_model = dim_model
60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0
61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62 | self.dropout = torch.nn.Dropout(p=dropout)
63 |
64 | self.reverse = False
65 | self.pe = None
66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67 |
68 | def extend_pe(self, x):
69 | """Reset the positional encodings."""
70 | if self.pe is not None:
71 | if self.pe.size(1) >= x.size(1):
72 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74 | return
75 | pe = torch.zeros(x.size(1), self.dim_model)
76 | if self.reverse:
77 | position = torch.arange(
78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
79 | ).unsqueeze(1)
80 | else:
81 | position = torch.arange(
82 | 0, x.size(1), dtype=torch.float32
83 | ).unsqueeze(1)
84 | div_term = torch.exp(
85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86 | * -(math.log(10000.0) / self.dim_model)
87 | )
88 | pe[:, 0::2] = torch.sin(position * div_term)
89 | pe[:, 1::2] = torch.cos(position * div_term)
90 | pe = pe.unsqueeze(0)
91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | self.extend_pe(x)
95 | output = x.unsqueeze(-1) if x.ndim == 2 else x
96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97 | return self.dropout(output)
98 |
--------------------------------------------------------------------------------
/valle/modules/scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import torch
20 |
21 | from valle.modules.optim import Eden
22 |
23 |
24 | def calc_lr(step, dim_embed, warmup_steps):
25 | return dim_embed ** (-0.5) * min(
26 | step ** (-0.5), step * warmup_steps ** (-1.5)
27 | )
28 |
29 |
30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31 | def __init__(
32 | self,
33 | base_lr: float,
34 | optimizer: torch.optim.Optimizer,
35 | dim_embed: int,
36 | warmup_steps: int,
37 | last_epoch: int = -1,
38 | verbose: bool = False,
39 | ) -> None:
40 |
41 | self.dim_embed = dim_embed
42 | self.base_lr = base_lr
43 | self.warmup_steps = warmup_steps
44 | self.num_param_groups = len(optimizer.param_groups)
45 |
46 | super().__init__(optimizer, last_epoch, verbose)
47 |
48 | def get_lr(self) -> float:
49 | lr = self.base_lr * calc_lr(
50 | self._step_count, self.dim_embed, self.warmup_steps
51 | )
52 | return [lr] * self.num_param_groups
53 |
54 | def set_step(self, step: int):
55 | self._step_count = step
56 |
57 |
58 | def get_scheduler(params, optimizer):
59 | if params.scheduler_name.lower() == "eden":
60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61 | elif params.scheduler_name.lower() == "noam":
62 | scheduler = NoamScheduler(
63 | params.base_lr,
64 | optimizer,
65 | params.decoder_dim,
66 | warmup_steps=params.warmup_steps,
67 | )
68 | # scheduler.set_step(params.start_batch or params.batch_idx_train)
69 | elif params.scheduler_name.lower() == "cosine":
70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71 | params.warmup_steps,
72 | optimizer,
73 | eta_min=params.base_lr,
74 | )
75 | else:
76 | raise NotImplementedError(f"{params.scheduler_name}")
77 |
78 | return scheduler
79 |
--------------------------------------------------------------------------------
/valle/tests/data/tokenizer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Zhao Ming)
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import unittest
18 |
19 | from valle.data import TextTokenizer
20 |
21 |
22 | class TestTextTokenizer(unittest.TestCase):
23 | def test_espeak(self):
24 | text_tokenizer = TextTokenizer(backend="espeak")
25 |
26 | for (_input, _target) in [
27 | ("The two parties, the sheep and the wolves, met each other.",
28 | ['ð', 'ə', '_', 't', 'uː', '_', 'p', 'ɑːɹ', 'ɾ',]), # 'i', 'z', ',', '_', 'ð']
29 | ("Mother! dear father! do you hear me?",
30 | ['m', 'ʌ', 'ð', 'ɚ', '!', '_', 'd', 'ɪɹ', '_', 'f', 'ɑː', 'ð', 'ɚ', '!']),
31 | ("\"Whoever thou art,\" She exclaimed, suddenly seizing Rodolfo's hand,",
32 | ['"', 'h', 'uː', 'ɛ', 'v', 'ɚ', '_', 'ð', 'aʊ', '_', 'ɑːɹ', 't', ',', '"', '_', 'ʃ', 'iː',
33 | '_', 'ɛ', 'k', 's', 'k', 'l', 'eɪ', 'm', 'd', ',', '_', 's', 'ʌ', 'd', 'ə', 'n', 'l', 'i',
34 | '_', 's', 'iː', 'z', 'ɪ', 'ŋ', '_', 'ɹ', 'ə', 'd', 'ɑː', 'l', 'f', 'oʊ', 'z', '_', 'h',
35 | 'æ', 'n', 'd', ','])
36 | ]:
37 | phonemized = text_tokenizer(_input)
38 | self.assertEqual(phonemized[0][:len(_target)], _target)
39 |
40 | def test_pypinyin(self):
41 | text_tokenizer = TextTokenizer(backend="pypinyin")
42 |
43 | for (_input, _target) in [
44 | ("你好这是测试",
45 | ["ni3", '-', "hao3", '-', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4"]),
46 | ("\"你好\", 这是测试.",
47 | ["\"", "ni3", '-', "hao3", "\"", ",", '_', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4", "."]),
48 | ("此项 工作 还能 怎么 改进",
49 | ['ci3', '-', 'xiang4', '_', 'gong1', '-', 'zuo4', '_',
50 | 'hai2', '-', 'neng2', '_', 'zen3', '-', 'me5', '_', 'gai3', '-', 'jin4']), # AISHELL
51 | ]:
52 | phonemized = text_tokenizer(_input)
53 | self.assertEqual(phonemized[0], _target)
54 |
55 | def test_pypinyin_initials_finals(self):
56 | text_tokenizer = TextTokenizer(backend="pypinyin_initials_finals")
57 |
58 | for (_input, _target) in [
59 | ("你好这是测试",
60 | ["n", "i3", "-", "h", "ao3", "-", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4"],
61 | ),
62 | ("\"你好.这是测试.",
63 | ["\"", "n", "i3", "-", "h", "ao3", ".", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."],
64 | ),
65 | ("\"你好. 这是测试.",
66 | ["\"", "n", "i3", "-", "h", "ao3", ".", "_", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."],
67 | ),
68 | ("此项 工作 还能 怎么 改进", ['c', 'i3', '-', 'x', 'iang4', '_', 'g', 'ong1', '-', 'z', 'uo4', '_',
69 | 'h', 'ai2', '-', 'n', 'eng2', '_', 'z', 'en3', '-', 'm', 'e5', '_',
70 | 'g', 'ai3', '-', 'j', 'in4']), # AISHELL
71 | ]:
72 | phonemized = text_tokenizer(_input)
73 | self.assertListEqual(phonemized[0], _target)
74 |
75 |
76 | if __name__ == "__main__":
77 | unittest.main()
78 |
--------------------------------------------------------------------------------
/valle/tests/scaling_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 |
4 | import unittest
5 |
6 | import numpy as np
7 | import torch
8 | from icefall.utils import AttributeDict
9 |
10 | from valle.models import NUM_MEL_BINS, get_model
11 |
12 |
13 | class TestModel(unittest.TestCase):
14 | @classmethod
15 | def setUpClass(cls):
16 | cls.devices = [torch.device("cpu")]
17 | if torch.cuda.is_available():
18 | cls.devices.append(torch.device("cuda", 0))
19 | if torch.cuda.device_count() > 1:
20 | torch.cuda.set_device(1)
21 | cls.devices.append(torch.device("cuda", 1))
22 |
23 | def test_scaling_transformer(self):
24 | params = AttributeDict()
25 | params.decoder_dim = 64
26 | params.nhead = 4
27 | params.num_decoder_layers = 4
28 |
29 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
30 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4]))
31 | x_lens[-1] = 8
32 |
33 | y = torch.from_numpy(
34 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32)
35 | )
36 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
37 | y_lens[-1] = 16
38 |
39 | params.model_name = "Transformer"
40 | params.norm_first = False
41 | params.add_prenet = False
42 | params.scaling_xformers = True
43 |
44 | for device in self.devices:
45 | # Transformer
46 | model = get_model(params)
47 | num_param = sum([p.numel() for p in model.parameters()])
48 |
49 | model.to(device)
50 | x = x.to(device)
51 | x_lens = x_lens.to(device)
52 | y = y.to(device)
53 | y_lens = y_lens.to(device)
54 |
55 | # Training
56 | codes, loss, metrics = model(x, x_lens, y, y_lens)
57 | # Inference
58 | model.eval()
59 | codes = model.inference(x[-1:], x_lens[-1:])
60 | params.add_prenet = False
61 |
62 |
63 | if __name__ == "__main__":
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/valle/tests/valle_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import unittest
17 |
18 | import numpy as np
19 | import torch
20 | from icefall.utils import AttributeDict
21 | from torchmetrics.classification import MulticlassAccuracy
22 |
23 | from valle.data.input_strategies import PromptedFeatures
24 | from valle.models import NUM_MEL_BINS, get_model
25 |
26 |
27 | class TestModel(unittest.TestCase):
28 | @classmethod
29 | def setUpClass(cls):
30 | cls.devices = [torch.device("cpu")]
31 | if torch.cuda.is_available():
32 | cls.devices.append(torch.device("cuda", 0))
33 | if torch.cuda.device_count() > 1:
34 | torch.cuda.set_device(1)
35 | cls.devices.append(torch.device("cuda", 1))
36 |
37 | def test_vallf(self):
38 | params = AttributeDict()
39 | params.decoder_dim = 64
40 | params.nhead = 16
41 | params.num_decoder_layers = 4
42 |
43 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
44 | x_lens = torch.from_numpy(np.random.randint(6, 8, size=[4]))
45 | x_lens[-1] = 8
46 | enroll_x_lens = torch.from_numpy(np.random.randint(2, 4, size=[4]))
47 |
48 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8]))
49 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
50 | y_lens[-1] = 16
51 |
52 | params.norm_first = True
53 | params.add_prenet = False
54 | params.model_name = "VALL-F"
55 | params.share_embedding = True
56 | params.scale_factor = 1.0
57 | params.prepend_bos = True
58 | params.num_quantizers = 1
59 |
60 | for device in self.devices:
61 | for mode in [0, 1, 2]:
62 | params.prefix_mode = mode
63 | # VALL-E
64 | model = get_model(params)
65 |
66 | # VALL-F
67 | model.to(device)
68 | x = x.to(device)
69 | x_lens = x_lens.to(device)
70 | y = y.to(device)
71 | y_lens = y_lens.to(device)
72 |
73 | # Training
74 | for train_stage in [0, 1, 2]:
75 | codes, loss, metrics = model(
76 | x, x_lens, y, y_lens, train_stage=train_stage
77 | )
78 |
79 | # Inference
80 | model.eval()
81 | codes = model.inference(
82 | x[-1:],
83 | x_lens[-1:],
84 | y[-1:],
85 | enroll_x_lens=enroll_x_lens[-1:],
86 | )
87 |
88 | params.prepend_bos = not params.prepend_bos
89 | params.num_quantizers += 1
90 |
91 | def test_valle(self):
92 | params = AttributeDict()
93 | params.decoder_dim = 64
94 | params.nhead = 16
95 | params.num_decoder_layers = 4
96 |
97 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
98 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4]))
99 | x_lens[-1] = 8
100 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4]))
101 |
102 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8]))
103 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
104 | y_lens[-1] = 16
105 |
106 | params.norm_first = False
107 | params.add_prenet = True
108 | params.model_name = "VALL-E"
109 | params.share_embedding = True
110 | params.scale_factor = 1.0
111 | params.prepend_bos = False
112 | params.num_quantizers = 8
113 |
114 | for device in self.devices:
115 | for mode in [0, 1, 2]:
116 | params.prefix_mode = mode
117 | # VALL-E
118 | model = get_model(params)
119 | model.to(device)
120 | x = x.to(device)
121 | x_lens = x_lens.to(device)
122 | y = y.to(device)
123 | y_lens = y_lens.to(device)
124 |
125 | # Training
126 | codes, loss, metrics = model(x, x_lens, y, y_lens)
127 | # Inference
128 | model.eval()
129 | codes = model.inference(
130 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens
131 | )
132 | params.scale_factor = 0.5
133 |
134 | params.prepend_bos = not params.prepend_bos
135 | params.num_quantizers -= 1
136 |
137 | def test_vallef_prefix4(self):
138 | params = AttributeDict()
139 | params.decoder_dim = 64
140 | params.nhead = 16
141 | params.num_decoder_layers = 4
142 |
143 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
144 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4]))
145 | x_lens[-1] = 8
146 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4]))
147 |
148 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8]))
149 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
150 | y_lens[-1] = 16
151 |
152 | prompts = torch.from_numpy(np.random.randint(0, 1000, size=[4, 12, 8]))
153 | prompts_lens = torch.from_numpy(np.random.randint(12, 13, size=[4]))
154 |
155 | params.norm_first = False
156 | params.add_prenet = True
157 | params.share_embedding = False
158 | params.scale_factor = 1.0
159 | params.prepend_bos = False
160 | params.num_quantizers = 8
161 |
162 | for device in self.devices:
163 | for model_name in ["VALL-E", "VALL-F"]:
164 | for mode in [4]:
165 | params.prefix_mode = mode
166 | params.model_name = model_name
167 | # VALL-E
168 | model = get_model(params)
169 | model.to(device)
170 | x = x.to(device)
171 | x_lens = x_lens.to(device)
172 | y = y.to(device)
173 |
174 | _y = PromptedFeatures(prompts, y).to(device)
175 | _y_lens = PromptedFeatures(prompts_lens, y_lens).to(device)
176 |
177 | # Training
178 | codes, loss, metrics = model(x, x_lens, _y, _y_lens)
179 | # Inference
180 | model.eval()
181 | codes = model.inference(
182 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens
183 | )
184 |
185 | def test_topmetric(self):
186 | metric_top10 = MulticlassAccuracy(1024, top_k=10, average="micro")
187 | metric_top1 = MulticlassAccuracy(1024, top_k=1, average="micro")
188 | batch_size, seq_len = 4, 16
189 | targets = np.random.randint(0, 1000, size=[batch_size, seq_len])
190 | logits = np.random.random([batch_size, 1024, seq_len]).astype(
191 | np.float32
192 | )
193 |
194 | larger_logits = np.clip(logits, -1.0, 1.0)
195 | smaller_logits = np.clip(logits, -1.0, 1.0)
196 | for b in range(batch_size):
197 | for t in range(seq_len):
198 | assert targets[b, t] >= 0
199 | larger_logits[b, targets[b, t], t] = 2.0
200 | smaller_logits[b, targets[b, t], t] = -2.0
201 |
202 | targets = torch.from_numpy(targets)
203 | larger_logits = torch.from_numpy(larger_logits)
204 | smaller_logits = torch.from_numpy(smaller_logits)
205 |
206 | for device in self.devices:
207 | metric_top10.to(device)
208 | metric_top1.to(device)
209 | targets = targets.to(device)
210 |
211 | one = metric_top10(larger_logits.to(device), targets)
212 | assert one.cpu().item() == 1.0, one.cpu().item()
213 |
214 | zero = metric_top1(smaller_logits.to(device), targets)
215 | assert zero.cpu().item() == 0.0, zero.cpu().item()
216 |
217 | half = metric_top1(
218 | torch.concat(
219 | [smaller_logits.to(device), larger_logits.to(device)], dim=2
220 | ),
221 | torch.concat([targets, targets], dim=1),
222 | )
223 | assert half.cpu().item() == 0.5, half.cpu().item()
224 |
225 | def test_transformer(self):
226 | params = AttributeDict()
227 | params.decoder_dim = 64
228 | params.nhead = 4
229 | params.num_decoder_layers = 4
230 |
231 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
232 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4]))
233 | x_lens[-1] = 8
234 |
235 | y = torch.from_numpy(
236 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32)
237 | )
238 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
239 | y_lens[-1] = 16
240 |
241 | params.model_name = "Transformer"
242 | params.norm_first = False
243 | params.add_prenet = True
244 | params.scaling_xformers = False
245 |
246 | for device in self.devices:
247 | # Transformer
248 | model = get_model(params)
249 | num_param = sum([p.numel() for p in model.parameters()])
250 |
251 | model.to(device)
252 | x = x.to(device)
253 | x_lens = x_lens.to(device)
254 | y = y.to(device)
255 | y_lens = y_lens.to(device)
256 |
257 | # Training
258 | codes, loss, metrics = model(x, x_lens, y, y_lens)
259 | # Inference
260 | model.eval()
261 | codes = model.inference(x[-1:], x_lens[-1:])
262 | params.add_prenet = False
263 |
264 | params.scaling_xformers = not params.scaling_xformers
265 |
266 |
267 | if __name__ == "__main__":
268 | unittest.main()
269 |
--------------------------------------------------------------------------------
/valle/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from icefall.utils import make_pad_mask
4 |
5 | from .symbol_table import SymbolTable
6 |
7 | make_pad_mask = make_pad_mask
8 | SymbolTable = SymbolTable
9 |
10 |
11 | class Transpose(nn.Identity):
12 | """(N, T, D) -> (N, D, T)"""
13 |
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | return input.transpose(1, 2)
16 |
--------------------------------------------------------------------------------
/valle/utils/symbol_table.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2 | #
3 | # See ../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass
18 | from dataclasses import field
19 | from typing import Dict
20 | from typing import Generic
21 | from typing import List
22 | from typing import Optional
23 | from typing import TypeVar
24 | from typing import Union
25 |
26 | Symbol = TypeVar('Symbol')
27 |
28 |
29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter.
30 | @dataclass(repr=False)
31 | class SymbolTable(Generic[Symbol]):
32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to
33 | actual objects. These objects can be arbitrary Python objects
34 | that can serve as keys in a dictionary (i.e. they need to be
35 | hashable and immutable).
36 |
37 | The SymbolTable can only be read to/written from disk if the
38 | symbols are strings.
39 | '''
40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict)
41 | '''Map an integer to a symbol.
42 | '''
43 |
44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict)
45 | '''Map a symbol to an integer.
46 | '''
47 |
48 | _next_available_id: int = 1
49 | '''A helper internal field that helps adding new symbols
50 | to the table efficiently.
51 | '''
52 |
53 | eps: Symbol = ''
54 | '''Null symbol, always mapped to index 0.
55 | '''
56 |
57 | def __post_init__(self):
58 | for idx, sym in self._id2sym.items():
59 | assert self._sym2id[sym] == idx
60 | assert idx >= 0
61 |
62 | for sym, idx in self._sym2id.items():
63 | assert idx >= 0
64 | assert self._id2sym[idx] == sym
65 |
66 | if 0 not in self._id2sym:
67 | self._id2sym[0] = self.eps
68 | self._sym2id[self.eps] = 0
69 | else:
70 | assert self._id2sym[0] == self.eps
71 | assert self._sym2id[self.eps] == 0
72 |
73 | self._next_available_id = max(self._id2sym) + 1
74 |
75 | @staticmethod
76 | def from_str(s: str) -> 'SymbolTable':
77 | '''Build a symbol table from a string.
78 |
79 | The string consists of lines. Every line has two fields separated
80 | by space(s), tab(s) or both. The first field is the symbol and the
81 | second the integer id of the symbol.
82 |
83 | Args:
84 | s:
85 | The input string with the format described above.
86 | Returns:
87 | An instance of :class:`SymbolTable`.
88 | '''
89 | id2sym: Dict[int, str] = dict()
90 | sym2id: Dict[str, int] = dict()
91 |
92 | for line in s.split('\n'):
93 | fields = line.split()
94 | if len(fields) == 0:
95 | continue # skip empty lines
96 | assert len(fields) == 2, \
97 | f'Expect a line with 2 fields. Given: {len(fields)}'
98 | sym, idx = fields[0], int(fields[1])
99 | assert sym not in sym2id, f'Duplicated symbol {sym}'
100 | assert idx not in id2sym, f'Duplicated id {idx}'
101 | id2sym[idx] = sym
102 | sym2id[sym] = idx
103 |
104 | eps = id2sym.get(0, '')
105 |
106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
107 |
108 | @staticmethod
109 | def from_file(filename: str) -> 'SymbolTable':
110 | '''Build a symbol table from file.
111 |
112 | Every line in the symbol table file has two fields separated by
113 | space(s), tab(s) or both. The following is an example file:
114 |
115 | .. code-block::
116 |
117 | 0
118 | a 1
119 | b 2
120 | c 3
121 |
122 | Args:
123 | filename:
124 | Name of the symbol table file. Its format is documented above.
125 |
126 | Returns:
127 | An instance of :class:`SymbolTable`.
128 |
129 | '''
130 | with open(filename, 'r', encoding='utf-8') as f:
131 | return SymbolTable.from_str(f.read().strip())
132 |
133 | def to_str(self) -> str:
134 | '''
135 | Returns:
136 | Return a string representation of this object. You can pass
137 | it to the method ``from_str`` to recreate an identical object.
138 | '''
139 | s = ''
140 | for idx, symbol in sorted(self._id2sym.items()):
141 | s += f'{symbol} {idx}\n'
142 | return s
143 |
144 | def to_file(self, filename: str):
145 | '''Serialize the SymbolTable to a file.
146 |
147 | Every line in the symbol table file has two fields separated by
148 | space(s), tab(s) or both. The following is an example file:
149 |
150 | .. code-block::
151 |
152 | 0
153 | a 1
154 | b 2
155 | c 3
156 |
157 | Args:
158 | filename:
159 | Name of the symbol table file. Its format is documented above.
160 | '''
161 | with open(filename, 'w') as f:
162 | for idx, symbol in sorted(self._id2sym.items()):
163 | print(symbol, idx, file=f)
164 |
165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
166 | '''Add a new symbol to the SymbolTable.
167 |
168 | Args:
169 | symbol:
170 | The symbol to be added.
171 | index:
172 | Optional int id to which the symbol should be assigned.
173 | If it is not available, a ValueError will be raised.
174 |
175 | Returns:
176 | The int id to which the symbol has been assigned.
177 | '''
178 | # Already in the table? Return its ID.
179 | if symbol in self._sym2id:
180 | return self._sym2id[symbol]
181 | # Specific ID not provided - use next available.
182 | if index is None:
183 | index = self._next_available_id
184 | # Specific ID provided but not available.
185 | if index in self._id2sym:
186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
187 | f"already occupied by {self._id2sym[index]}")
188 | self._sym2id[symbol] = index
189 | self._id2sym[index] = symbol
190 |
191 | # Update next available ID if needed
192 | if self._next_available_id <= index:
193 | self._next_available_id = index + 1
194 |
195 | return index
196 |
197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
198 | '''Get a symbol for an id or get an id for a symbol
199 |
200 | Args:
201 | k:
202 | If it is an id, it tries to find the symbol corresponding
203 | to the id; if it is a symbol, it tries to find the id
204 | corresponding to the symbol.
205 |
206 | Returns:
207 | An id or a symbol depending on the given `k`.
208 | '''
209 | if isinstance(k, int):
210 | return self._id2sym[k]
211 | else:
212 | return self._sym2id[k]
213 |
214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable':
215 | '''Create a union of two SymbolTables.
216 | Raises an AssertionError if the same IDs are occupied by
217 | different symbols.
218 |
219 | Args:
220 | other:
221 | A symbol table to merge with ``self``.
222 |
223 | Returns:
224 | A new symbol table.
225 | '''
226 | self._check_compatible(other)
227 |
228 | id2sym = {**self._id2sym, **other._id2sym}
229 | sym2id = {**self._sym2id, **other._sym2id}
230 |
231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
232 |
233 | def _check_compatible(self, other: 'SymbolTable') -> None:
234 | # Epsilon compatibility
235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
236 | f'{self.eps} != {other.eps}'
237 | # IDs compatibility
238 | common_ids = set(self._id2sym).intersection(other._id2sym)
239 | for idx in common_ids:
240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
241 | f'self[idx] = "{self[idx]}", ' \
242 | f'other[idx] = "{other[idx]}"'
243 | # Symbols compatibility
244 | common_symbols = set(self._sym2id).intersection(other._sym2id)
245 | for sym in common_symbols:
246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
247 | f'self[sym] = "{self[sym]}", ' \
248 | f'other[sym] = "{other[sym]}"'
249 |
250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
251 | return self.get(item)
252 |
253 | def __contains__(self, item: Union[int, Symbol]) -> bool:
254 | if isinstance(item, int):
255 | return item in self._id2sym
256 | else:
257 | return item in self._sym2id
258 |
259 | def __len__(self) -> int:
260 | return len(self._id2sym)
261 |
262 | def __eq__(self, other: 'SymbolTable') -> bool:
263 | if len(self) != len(other):
264 | return False
265 |
266 | for s in self.symbols:
267 | if self[s] != other[s]:
268 | return False
269 |
270 | return True
271 |
272 | @property
273 | def ids(self) -> List[int]:
274 | '''Returns a list of integer IDs corresponding to the symbols.
275 | '''
276 | ans = list(self._id2sym.keys())
277 | ans.sort()
278 | return ans
279 |
280 | @property
281 | def symbols(self) -> List[Symbol]:
282 | '''Returns a list of symbols (e.g., strings) corresponding to
283 | the integer IDs.
284 | '''
285 | ans = list(self._sym2id.keys())
286 | ans.sort()
287 | return ans
288 |
--------------------------------------------------------------------------------