├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── constants.py ├── data ├── collation.py ├── data_module.py ├── sampler.py ├── semantic_dataset.py └── single_speaker_dataset.py ├── datasets └── example │ └── train.json ├── demo ├── audios-speech-tokenizer │ ├── acoustic │ │ ├── POD0000004393_S0000029.npy │ │ ├── POD0000007005_S0000568.npy │ │ ├── POD0000009720_S0000244.npy │ │ ├── POD0000014360_S0000082.npy │ │ ├── POD0000015908_S0000037.npy │ │ ├── POD1000000022_S0000028.npy │ │ └── male_voice.npy │ └── semantic │ │ ├── POD0000004393_S0000029.npy │ │ ├── POD0000007005_S0000568.npy │ │ ├── POD0000009720_S0000244.npy │ │ ├── POD0000014360_S0000082.npy │ │ ├── POD0000015908_S0000037.npy │ │ ├── POD1000000022_S0000028.npy │ │ └── male_voice.npy ├── audios │ ├── POD0000004393_S0000029.wav │ ├── POD0000007005_S0000568.wav │ ├── POD0000009720_S0000244.wav │ ├── POD0000014360_S0000082.wav │ ├── POD0000015908_S0000037.wav │ ├── POD1000000022_S0000028.wav │ └── male_voice.wav ├── male_voice.wav └── manifest.json ├── docs ├── _config.yml ├── _layouts │ └── default.html ├── assets │ ├── css │ │ └── style.scss │ └── img │ │ └── polyai-logo.webp ├── index.md └── samples │ ├── empress │ ├── 114.wav │ ├── 148.wav │ ├── 161.wav │ ├── 189.wav │ ├── 217.wav │ ├── 226.wav │ ├── 234.wav │ ├── 242.wav │ ├── 262.wav │ ├── 269.wav │ ├── 29.wav │ └── 46.wav │ ├── gigaspeech │ ├── POD1000000004_S0000246.wav │ ├── POD1000000004_S0000247.wav │ ├── POD1000000018_S0000253.wav │ ├── POD1000000018_S0000254.wav │ ├── POD1000000048_S0000035.wav │ ├── POD1000000048_S0000036.wav │ ├── YOU1000000006_S0000051.wav │ ├── YOU1000000006_S0000052.wav │ ├── YOU1000000044_S0000798.wav │ └── YOU1000000044_S0000799.wav │ ├── pheme-100 │ ├── 019.wav │ ├── 042.wav │ ├── 080.wav │ ├── 188.wav │ └── 209.wav │ ├── pheme-300 │ ├── 019.wav │ ├── 042.wav │ ├── 080.wav │ ├── 188.wav │ └── 209.wav │ ├── pheme-empress-300 │ ├── 001.wav │ ├── 002.wav │ ├── 190.wav │ ├── 227.wav │ ├── 235.wav │ ├── 243.wav │ └── 270.wav │ ├── pheme-no-empress-300 │ ├── 190.wav │ ├── 227.wav │ ├── 235.wav │ ├── 243.wav │ └── 270.wav │ └── pheme-no-spkr-300 │ ├── 019.wav │ ├── 042.wav │ ├── 080.wav │ ├── 188.wav │ └── 209.wav ├── modules ├── __init__.py ├── conformer.py ├── masking_logic.py ├── s2a_model.py ├── speech_tokenizer.py ├── t2s_model.py ├── tokenizer.py └── vocoder.py ├── requirements.txt ├── train_s2a.py ├── train_t2s.py ├── transformer_infer.py └── utils ├── __init__.py ├── data_prep.py ├── get_tokens_speech_tokenizer.py └── symbol_table.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | exclude = .git,__pycache__,build,dist 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | app.py 2 | ckpt/* 3 | */__pycache__/* 4 | __pycache__/* 5 | exp/* 6 | datasets/* 7 | wandb/* 8 | venv/* 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pheme Model 2 | 3 | This repo contains recipes and models used for training Pheme TTS models. It is the official implementation for the 4 | paper: [Pheme: Efficient and Conversational Speech Generation](https://arxiv.org/pdf/2401.02839.pdf). Demo is 5 | available [here](https://huggingface.co/spaces/PolyAI/pheme), while a selection of audio samples can be 6 | found [here](https://polyai-ldn.github.io/pheme/). 7 | 8 | Our Pheme TTS framework validates several hypotheses: 9 | 10 | 1. We can train Transformer-based conversational TTS models with much fewer training data than e.g., VALL-E or 11 | SoundStorm (e.g., 10x fewer data). 12 | 2. Training can be performed with conversational, podcast, and noisy data like GigaSpeech. 13 | 3. Efficiency is paramount, which includes parameter efficiency (compact models), data efficiency (fewer training data) 14 | and inference efficiency (reduced latency). 15 | 4. One fundamental ingredient is the separation of semantic and acoustic tokens and the adequate speech tokenizer. 16 | 5. Inference can be run parallelly through MaskGit-style inference with 15x speed-ups compared to similarly sized 17 | autoregressive models. 18 | 6. The single-speaker quality can be improved through student-teacher training with (synthetic) data generated by 19 | third-party providers. 20 | 21 | # Set Up the Environment 22 | 23 | Set up conda environment: 24 | 25 | ``` 26 | conda create --name pheme3 python=3.10 27 | conda activate pheme3 28 | 29 | pip3 install torch torchvision torchaudio 30 | pip3 install -r requirements.txt --no-deps 31 | ``` 32 | 33 | Download pre-trained SpeechTokenizer and unique token list models: 34 | 35 | ``` bash 36 | st_dir="ckpt/speechtokenizer/" 37 | mkdir -p ${st_dir} 38 | cd ${st_dir} 39 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt" 40 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json" 41 | cd .. 42 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols" 43 | ``` 44 | 45 | You need to create an access token to use the speaker embedding of pyannote. 46 | 47 | ``` 48 | export HUGGING_FACE_HUB_TOKEN=YOUR_PRIVATE_TOKEN 49 | ``` 50 | 51 | Download pre-trained T2S and S2A models (the 100M Pheme variant): 52 | 53 | ``` bash 54 | git clone https://huggingface.co/PolyAI/pheme_small ckpt/pheme 55 | mkdir -p "ckpt/t2s" 56 | mkdir -p "ckpt/s2a" 57 | mv ckpt/pheme/config_t2s.json ckpt/t2s/config.json 58 | mv ckpt/pheme/generation_config.json ckpt/t2s/generation_config.json 59 | mv ckpt/pheme/t2s.bin ckpt/t2s/pytorch_model.bin 60 | mv ckpt/pheme/config_s2a.json ckpt/s2a/config.json 61 | mv ckpt/pheme/s2a.ckpt ckpt/s2a/s2a.ckpt 62 | ``` 63 | 64 | or the larger version (300M) at `https://huggingface.co/PolyAI/pheme` 65 | 66 | # Prompt-based Generation 67 | 68 | The generation can be invoked by: 69 | 70 | ``` 71 | python transformer_infer.py 72 | ``` 73 | 74 | # Training 75 | 76 | ## Data Preparation 77 | 78 | The package requires data of the format: `datasets/example/train.json` with `datasets/audios/` where you store wav 79 | files. 80 | The manifest should follow the format: 81 | 82 | ``` 83 | { 84 | "LJ001-0051.wav": { 85 | "text": "and paying great attention to the press work or actual process of printing,", 86 | "raw-text": "and paying great attention to the press work or actual process of printing,", 87 | "duration": 4.860090702947846, 88 | "phoneme": "æ|n|d|_|p|eɪ|ɪ|ŋ|_|ɡ|ɹ|eɪ|t|_|ɐ|t|ɛ|n|ʃ|ə|n|_|t|ə|_|ð|ə|_|\"|p|ɹ|ɛ|s|_|w|ɜː|k|\"|_|ɔː|ɹ|_|æ|k|tʃ|uː|əl|_|p|ɹ|ɑː|s|ɛ|s|_|ʌ|v|_|p|ɹ|ɪ|n|t|ɪ|ŋ|," 89 | }, 90 | "LJ001-0120.wav": { 91 | ... 92 | }, 93 | ... 94 | } 95 | 96 | ``` 97 | Create train/valid/test manifests 98 | ``` 99 | PYTHONPATH=. python utils/data_prep.py 100 | ``` 101 | Resample audio files to 16kHz 102 | ``` 103 | find LJSpeech-1.1/wavs/ -name "*.wav" | parallel ffmpeg -i {} -ar 16000 -ac 1 audios/{/} 104 | ``` 105 | The following command will create semantic and acoustic tokens based on the `audios` folder. 106 | 107 | ``` 108 | PYTHONPATH=. python utils/get_tokens_speech_tokenizer.py \ 109 | --config_path ckpt/speechtokenizer/config.json \ 110 | --ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \ 111 | --encoding_input datasets/ljspeech-training-data/audios \ 112 | --encoding_output datasets/ljspeech-training-data/audios-speech-tokenizer 113 | ``` 114 | 115 | ## T2S 116 | 117 | ``` 118 | TRAIN_MANIFEST="./datasets/ljspeech-training-data/train.json" 119 | DEV_MANIFEST="./datasets/ljspeech-training-data/dev.json" 120 | OUT_DIR="./experiments/t2s-ljspeech" 121 | 122 | OUT_DIR="/home/taras/experiments/t2s-ljspeech" 123 | python train_t2s.py --metapath "${TRAIN_MANIFEST}" \ 124 | --val_metapath "${DEV_MANIFEST}" \ 125 | --output_dir "${OUT_DIR}" \ 126 | --model_size tiny --batch_size 64 \ 127 | --nworkers 12 --warmup_steps 10000 \ 128 | --save_steps 500 --n_epochs 100 \ 129 | --learning_rate 1e-3 130 | ``` 131 | 132 | ## A2S 133 | 134 | ``` 135 | TRAIN_MANIFEST="./datasets/ljspeech-training-data/train.json" 136 | DEV_MANIFEST="./datasets/ljspeech-training-data/dev.json" 137 | OUT_DIR="./experiments/s2a-ljspeech" 138 | 139 | python train_s2a.py --saving_path "${OUT_DIR}" --sampledir "${OUT_DIR}" --vocoder_type SPEECHTOKENIZER \ 140 | --n_codes 1024 --n_cluster_groups 7 --metapath "${TRAIN_MANIFEST}" \ 141 | --val_metapath "${DEV_MANIFEST}" \ 142 | --warmup_step 10000 --nworkers 12 --first_n_lvls 7 \ 143 | --batch_size 200 --ffd_size 1024 --hidden_size 768 --enc_nlayers 3 --dec_nlayers 6 --nheads 8 \ 144 | --depthwise_conv_kernel_size 5 \ 145 | --val_check_interval 60 --sample_rate 16000 --lr 5e-4 \ 146 | --check_val_every_n_epoch 1 --n_semantic_codes 1024 \ 147 | --distributed 148 | 149 | ``` 150 | 151 | ## Speed test with TensoRT-LLM: 152 | 153 | ### A100 GPU / 100M Pheme Variant 154 | 155 | | Model | Batch Size | Steps | RTF (ms) | 156 | |------------------------|------------|-------|----------| 157 | | T2S-S2A Short sentence | 1 | 16 | 0.133 | 158 | | T2S-S2A Long sentence | 1 | 16 | 0.133 | 159 | 160 | ### A100 GPU / 300M Pheme Variant 161 | 162 | | Model | Batch Size | Steps | RTF (ms) | 163 | |------------------------|------------|-------|----------| 164 | | T2S-S2A Short sentence | 1 | 16 | 0.143 | 165 | | T2S-S2A Long sentence | 1 | 16 | 0.143 | 166 | 167 | ## Acknowledge 168 | 169 | [MQTTS](https://github.com/b04901014/MQTTS)\ 170 | [SpeechTokenizer](https://github.com/ZhangXInFD/soundstorm-speechtokenizer)\ 171 | [maskgit](https://github.com/google-research/maskgit)\ 172 | [SoundStorm](https://github.com/lifeiteng/SoundStorm) 173 | 174 | ## TODO 175 | 176 | 1. Add Tensorrt-LLM image 177 | 178 | ## Citation 179 | 180 | If you use this code or components of the model in your own work, please cite our work as: 181 | 182 | ```Tex 183 | @misc{budzianowski2024pheme, 184 | title={Pheme: Efficient and Conversational Speech Generation}, 185 | author={Paweł Budzianowski and Taras Sereda and Tomasz Cichy and Ivan Vulić}, 186 | year={2024}, 187 | eprint={2401.02839}, 188 | archivePrefix={arXiv}, 189 | primaryClass={eess.AS} 190 | } 191 | ``` 192 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/__init__.py -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | """Constants file. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | SPKR_EMB_SIZE = 512 6 | 7 | PAD = 1024 8 | 9 | SPKR_1 = 1025 10 | SPKR_2 = 1026 11 | 12 | BOS_TOKEN_ID = 0 13 | PAD_TOKEN_ID = 0 14 | EOS_TOKEN_ID = 2 -------------------------------------------------------------------------------- /data/collation.py: -------------------------------------------------------------------------------- 1 | """Collators for T2S and S2A. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | from pathlib import Path 6 | from typing import List, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from utils.symbol_table import SymbolTable 12 | 13 | 14 | class GlobalCollater: 15 | def __init__(self, n_codes, n_semantic_codes): 16 | self.n_codes = n_codes 17 | self.sem_mask_id = n_semantic_codes 18 | 19 | def collate(self, batch): 20 | output = { 21 | 'speaker': [], 22 | 'tts_quantize_input': [], 23 | 'tts_quantize_output': [], 24 | 'quantize_mask': [], 25 | 'f_names': [], 26 | 'semantic_tokens': [], 27 | 'quantization_lengths': [], 28 | } 29 | # Get the max length of everything 30 | max_len_q = 0 31 | for _, q_s, q_e, _, _ in batch: 32 | if len(q_s) > max_len_q: 33 | max_len_q = len(q_s) 34 | 35 | output['quantization_lengths'].append(len(q_s)) 36 | 37 | # Pad each element, create mask 38 | for spkr, qs, qe, itm_name, s_tokens in batch: 39 | # Deal with quantizations 40 | q_mask = np.array( 41 | [False] * len(qs) + [True] * (max_len_q - len(qs))) 42 | qs = np.pad( 43 | qs, 44 | [[0, max_len_q-len(qs)], [0, 0]], 45 | constant_values=self.n_codes 46 | ) 47 | qe = np.pad( 48 | qe, 49 | [[0, max_len_q-len(qe)], [0, 0]], 50 | constant_values=self.n_codes 51 | ) 52 | 53 | # Deal with semantics 54 | s_tokens = s_tokens.flatten() 55 | s_tokens = np.pad( 56 | s_tokens, 57 | (0, max_len_q-len(s_tokens)), 58 | constant_values=self.sem_mask_id 59 | ) 60 | 61 | # Speaker padding 62 | spkr = np.concatenate( 63 | (spkr, np.zeros((max_len_q - len(spkr), 512)))) 64 | 65 | # Aggregate 66 | output['speaker'].append(spkr) 67 | output['tts_quantize_input'].append(qs) 68 | output['tts_quantize_output'].append(qe) 69 | output['quantize_mask'].append(q_mask) 70 | output['f_names'].append(itm_name) 71 | output["semantic_tokens"].append(s_tokens) 72 | 73 | for k in output.keys(): 74 | if k == 'f_names': 75 | continue 76 | output[k] = np.array(output[k]) 77 | if 'mask' in k: 78 | output[k] = torch.BoolTensor(output[k]) 79 | elif k in [ 80 | 'tts_quantize_input', 'tts_quantize_output', 81 | 'semantic_tokens', 'quantization_lengths' 82 | ]: 83 | output[k] = torch.LongTensor(output[k]) 84 | else: 85 | output[k] = torch.FloatTensor(output[k]) 86 | return output 87 | 88 | 89 | class TextTokenCollater: 90 | def __init__( 91 | self, 92 | text_tokens: List[str], 93 | add_eos: bool = True, 94 | add_bos: bool = True, 95 | pad_symbol: str = "", 96 | bos_symbol: str = "", 97 | eos_symbol: str = "", 98 | spkr_1_symbol: str = "spkr_1", 99 | spkr_2_symbol: str = "spkr_2", 100 | ): 101 | self.pad_symbol = pad_symbol 102 | 103 | self.add_eos = add_eos 104 | self.add_bos = add_bos 105 | 106 | self.bos_symbol = bos_symbol 107 | self.eos_symbol = eos_symbol 108 | self.spkr_1_symbol = spkr_1_symbol 109 | self.spkr_2_symbol = spkr_2_symbol 110 | 111 | unique_tokens = ( 112 | [pad_symbol] 113 | + ([bos_symbol] if add_bos else []) 114 | + ([eos_symbol] if add_eos else []) 115 | + ([spkr_1_symbol]) 116 | + ([spkr_2_symbol]) 117 | + sorted(text_tokens) 118 | ) 119 | 120 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} 121 | self.idx2token = [token for token in unique_tokens] 122 | 123 | def __call__( 124 | self, texts: List[str], texts_2: Union[None, List[str]] = None 125 | ) -> Tuple[torch.Tensor, torch.Tensor]: 126 | tokens_seqs = [[p for p in text] for text in texts] 127 | 128 | if texts_2 is None: 129 | seqs = [ 130 | ([self.bos_symbol] if self.add_bos else []) 131 | + [self.spkr_1_symbol] 132 | + list(seq) 133 | + ([self.eos_symbol] if self.add_eos else []) 134 | for seq in tokens_seqs 135 | ] 136 | else: 137 | tokens_seqs_2 = [[p for p in text] for text in texts_2] 138 | seqs = [ 139 | ([self.bos_symbol] if self.add_bos else []) 140 | + [self.spkr_1_symbol] 141 | + list(seq) 142 | + ([self.spkr_2_symbol]) 143 | + list(seq_2) 144 | + ([self.eos_symbol] if self.add_eos else []) 145 | for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2) 146 | ] 147 | 148 | tokens_batch = torch.from_numpy( 149 | np.array( 150 | [[self.token2idx[token] for token in seq] for seq in seqs], 151 | dtype=np.int64, 152 | ) 153 | ) 154 | 155 | return tokens_batch 156 | 157 | 158 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: 159 | text_tokens_path = Path(text_tokens_file) 160 | unique_tokens = SymbolTable.from_file(text_tokens_path) 161 | collater = TextTokenCollater( 162 | unique_tokens.symbols, add_bos=True, add_eos=True 163 | ) 164 | return collater 165 | 166 | 167 | def get_text_semantic_token_collater( 168 | text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater: 169 | text_tokens_path = Path(text_tokens_file) 170 | unique_tokens = SymbolTable.from_file(text_tokens_path) 171 | for semantic_idx in range(n_semantic_tokens): 172 | unique_tokens.add(str(semantic_idx)) 173 | 174 | collater = TextTokenCollater( 175 | unique_tokens.symbols, add_bos=True, add_eos=True 176 | ) 177 | return collater 178 | 179 | 180 | if __name__ == '__main__': 181 | text_tokens_file = 'ckpt/unique_text_tokens.k2symbols' 182 | collater = get_text_semantic_token_collater(text_tokens_file) 183 | -------------------------------------------------------------------------------- /data/data_module.py: -------------------------------------------------------------------------------- 1 | """Data module. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import typing 6 | from pathlib import Path 7 | from typing import List 8 | 9 | import lightning.pytorch as pl 10 | from torch.utils import data 11 | 12 | from data.collation import GlobalCollater 13 | from data.sampler import RandomBucketSampler 14 | from data.single_speaker_dataset import QuantizeDataset 15 | from utils import breakpoint_on_error 16 | 17 | 18 | class ConcatDataset(data.ConcatDataset): 19 | def __init__(self, datasets) -> None: 20 | super().__init__(datasets) 21 | self.lengths = [] 22 | for dataset in datasets: 23 | self.lengths.extend(dataset.lengths) 24 | 25 | 26 | class DataModule(pl.LightningDataModule): 27 | def __init__( 28 | self, hp, metapath: List[str], val_metapath: List[str], 29 | world_size, local_rank 30 | ): 31 | super().__init__() 32 | self.hp = hp 33 | self.metapath = metapath 34 | self.val_metapath = val_metapath 35 | self.world_size = world_size 36 | self.local_rank = local_rank 37 | self.collater = GlobalCollater( 38 | self.hp.n_codes, self.hp.n_semantic_codes) 39 | 40 | def setup(self, stage: str) -> None: 41 | if stage == "fit": 42 | self.train_data = self.concatenate_datasets( 43 | self.metapath, dataset_class=QuantizeDataset 44 | ) 45 | 46 | if stage == "valid": 47 | self.val_data = [] 48 | self.val_data_keys = [] 49 | self.prepare_val_datasets() 50 | assert len(self.val_data) > 0 51 | assert len(self.val_data_keys) > 0 52 | 53 | @breakpoint_on_error 54 | def concatenate_datasets( 55 | self, metapaths, dataset_class: typing.Type[QuantizeDataset]): 56 | data = [] 57 | for _, metapath in enumerate(metapaths): 58 | metapath = Path(metapath) 59 | # assumption that audios and audios-embeddings 60 | # are in the same folder as metapath 61 | datadir = metapath.with_name("audios") 62 | assert datadir.exists() 63 | data.append( 64 | dataset_class( 65 | self.hp, 66 | metapath, 67 | datadir=datadir, 68 | speaker_embedding_dir=None, 69 | ) 70 | ) 71 | return ConcatDataset(data) 72 | 73 | def prepare_val_datasets(self): 74 | for manifest in self.val_metapath: 75 | self.val_data.append( 76 | self.concatenate_datasets( 77 | [manifest], dataset_class=QuantizeDataset) 78 | ) 79 | name = Path(manifest).parent.name 80 | self.val_data_keys.append(name) 81 | 82 | assert len(self.val_data) == len(self.val_data_keys) 83 | 84 | def train_dataloader(self): 85 | length = self.train_data.lengths 86 | sampler = RandomBucketSampler( 87 | self.hp.train_bucket_size, 88 | length, 89 | self.hp.batch_size, 90 | drop_last=True, 91 | distributed=self.hp.distributed, 92 | world_size=self.world_size, 93 | rank=self.local_rank, 94 | ) 95 | dataloader = data.DataLoader( 96 | self.train_data, 97 | num_workers=self.hp.nworkers, 98 | batch_sampler=sampler, 99 | collate_fn=self.collater.collate, 100 | pin_memory=True 101 | ) 102 | 103 | return dataloader 104 | 105 | def val_dataloader(self): 106 | val_loaders = [] 107 | for dataset in self.val_data: 108 | val_loaders.append( 109 | data.DataLoader( 110 | dataset, 111 | num_workers=self.hp.nworkers, 112 | batch_size=int(self.hp.batch_size), 113 | collate_fn=self.collater.collate, 114 | shuffle=False, 115 | pin_memory=True 116 | ) 117 | ) 118 | 119 | return val_loaders 120 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | """Original sampling logic of MQTTS. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import math 6 | import random 7 | 8 | import numpy as np 9 | from torch.utils import data 10 | 11 | 12 | def StandardSampler(dataset, shuffle, distributed=False, 13 | world_size=None, rank=None): 14 | if distributed: 15 | return data.distributed.DistributedSampler( 16 | dataset, shuffle=shuffle, num_replicas=world_size, rank=rank) 17 | if shuffle: 18 | return data.RandomSampler(dataset) 19 | return data.SequentialSampler(dataset) 20 | 21 | 22 | def RandomBucketSampler( 23 | nbuckets, length, batch_size, drop_last, distributed=False, 24 | world_size=None, rank=None): 25 | if distributed: 26 | return DistributedRandomBucketSampler( 27 | nbuckets, length, batch_size, drop_last, world_size, rank) 28 | return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last) 29 | 30 | 31 | class SingleRandomBucketSampler(data.Sampler): 32 | def __init__(self, nbuckets, length, batch_size, drop_last): 33 | self.length = length 34 | self.batch_size = batch_size 35 | self.drop_last = drop_last 36 | indices = np.argsort([-x for x in length]) 37 | split = len(indices) // nbuckets 38 | self.indices = [] 39 | for i in range(nbuckets): 40 | self.indices.append(indices[i*split:(i+1)*split]) 41 | if nbuckets * split < len(length): 42 | self.indices.append(indices[nbuckets*split:]) 43 | 44 | def __iter__(self): 45 | random.shuffle(self.indices) 46 | for x in self.indices: 47 | random.shuffle(x) 48 | idxs = [i for x in self.indices for i in x] 49 | batches, batch, sum_len, max_len = [], [], 0, 0 50 | for idx in idxs: 51 | batch.append(idx) 52 | sum_len += self.length[idx] 53 | max_len = max(self.length[idx], max_len) 54 | if max_len * len(batch) > self.batch_size: 55 | batches.append(batch[:-1]) 56 | batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa 57 | if len(batch) > 0 and not self.drop_last: 58 | batches.append(batch) 59 | random.shuffle(batches) 60 | return iter(batches) 61 | 62 | 63 | class DistributedRandomBucketSampler(data.Sampler): 64 | def __init__(self, nbuckets, length, batch_size, 65 | drop_last, num_replicas, rank, seed=1234): 66 | if rank >= num_replicas or rank < 0: 67 | raise ValueError( 68 | "Invalid rank {}, rank should be in the interval" 69 | " [0, {}]".format(rank, num_replicas - 1)) 70 | indices = np.argsort(length) 71 | split = len(indices) // nbuckets 72 | self.length = length 73 | self.batch_size = batch_size 74 | self.drop_last = drop_last 75 | self.indices = [] 76 | for i in range(nbuckets): 77 | self.indices.append(indices[i*split:(i+1)*split]) 78 | if nbuckets * split < len(length): 79 | self.indices.append(indices[nbuckets*split:]) 80 | self.num_replicas = num_replicas 81 | self.rank = rank 82 | self.epoch = 0 83 | self.seed = seed 84 | 85 | def __iter__(self): 86 | # Deterministic shuffling 87 | random.Random(self.epoch + self.seed).shuffle(self.indices) 88 | for i, x in enumerate(self.indices): 89 | seed = self.epoch + self.seed + i * 5 90 | random.Random(seed).shuffle(x) 91 | indices = [i for x in self.indices for i in x] 92 | 93 | # Batching 94 | batches, batch, sum_len, max_len = [], [], 0, 0 95 | for idx in indices: 96 | batch.append(idx) 97 | sum_len += self.length[idx] 98 | max_len = max(self.length[idx], max_len) 99 | if max_len * len(batch) > self.batch_size: 100 | batches.append(batch[:-1]) 101 | batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa 102 | # Subsample 103 | num_samples = math.ceil( 104 | (len(batches) - self.num_replicas) / self.num_replicas) 105 | total_size = num_samples * self.num_replicas 106 | batches = batches[:total_size] 107 | batches = batches[self.rank*num_samples: (self.rank+1)*num_samples] 108 | assert len(batches) == num_samples 109 | 110 | # Stochastic suffling 111 | random.shuffle(batches) 112 | return iter(batches) 113 | 114 | def set_epoch(self, epoch): 115 | self.epoch = epoch 116 | -------------------------------------------------------------------------------- /data/semantic_dataset.py: -------------------------------------------------------------------------------- 1 | """Semantic tokens loading logic. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import json 6 | import logging 7 | import random 8 | import re 9 | from logging import getLogger 10 | from pathlib import Path 11 | from typing import List, Pattern, Union 12 | 13 | import numpy as np 14 | import torch 15 | from phonemizer.backend import EspeakBackend 16 | from phonemizer.backend.espeak.language_switch import LanguageSwitch 17 | from phonemizer.backend.espeak.words_mismatch import WordMismatch 18 | from phonemizer.punctuation import Punctuation 19 | from phonemizer.separator import Separator 20 | from torch.utils.data import DataLoader, Dataset 21 | from tqdm import tqdm 22 | 23 | from data.collation import get_text_semantic_token_collater 24 | 25 | 26 | class TextTokenizer: 27 | """Phonemize Text.""" 28 | 29 | def __init__( 30 | self, 31 | language="en-us", 32 | backend="espeak", 33 | separator=Separator(word="_", syllable="-", phone="|"), 34 | preserve_punctuation=True, 35 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 36 | with_stress: bool = False, 37 | tie: Union[bool, str] = False, 38 | language_switch: LanguageSwitch = "keep-flags", 39 | words_mismatch: WordMismatch = "ignore", 40 | ) -> None: 41 | logger = getLogger("phonemizer") 42 | logger.setLevel(logging.ERROR) 43 | if backend == "espeak": 44 | phonemizer = EspeakBackend( 45 | language, 46 | punctuation_marks=punctuation_marks, 47 | preserve_punctuation=preserve_punctuation, 48 | with_stress=with_stress, 49 | tie=tie, 50 | language_switch=language_switch, 51 | words_mismatch=words_mismatch, 52 | logger=logger, 53 | ) 54 | else: 55 | raise NotImplementedError(f"{backend}") 56 | 57 | self.backend = phonemizer 58 | self.separator = separator 59 | 60 | def to_list(self, phonemized: str) -> List[str]: 61 | fields = [] 62 | for word in phonemized.split(self.separator.word): 63 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. 64 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) 65 | fields.extend( 66 | [p for p in pp if p != self.separator.phone] + [self.separator.word] 67 | ) 68 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( 69 | self.separator.phone 70 | ) 71 | return fields[:-1] 72 | 73 | def __call__(self, text, strip=True) -> List[List[str]]: 74 | if isinstance(text, str): 75 | text = [text] 76 | 77 | phonemized = self.backend.phonemize( 78 | text, separator=self.separator, strip=strip, njobs=1 79 | ) 80 | return [self.to_list(p) for p in phonemized] 81 | 82 | 83 | class Collator: 84 | def collate(self, batch): 85 | input_ids = [item["input_ids"] for item in batch] 86 | output_sequences = [item["labels"] for item in batch] 87 | 88 | # Pad sequences to the maximum length in the batch 89 | input_ids = torch.nn.utils.rnn.pad_sequence( 90 | input_ids, batch_first=True, padding_value=0 91 | ) 92 | output_sequences = torch.nn.utils.rnn.pad_sequence( 93 | output_sequences, batch_first=True, padding_value=-100 94 | ) 95 | # 1 - token is unmasked, 0 - token is masked. 96 | attention_mask = input_ids != 0 97 | return { 98 | "input_ids": input_ids, 99 | "attention_mask": attention_mask, 100 | "labels": output_sequences, 101 | } 102 | 103 | class ConcatenateSemanticDataset(Dataset): 104 | def __init__( 105 | self, manifest_path: str, symbol_table_path: str, 106 | n_samples: int = 0, max_duration=15): 107 | self.data = [] 108 | self.phonemizer = TextTokenizer() 109 | self.text_collater = get_text_semantic_token_collater( 110 | symbol_table_path) 111 | self.manifest_path = manifest_path 112 | self.n_samples = n_samples 113 | self.max_duration = max_duration 114 | if manifest_path is not None: 115 | self._build() 116 | 117 | def __len__(self): 118 | if self.n_samples: 119 | return min(self.n_samples, len(self.data)) 120 | return len(self.data) 121 | 122 | def remove_unknown_symbols(self, text: List[str]): 123 | res = [] 124 | for sym in text: 125 | if sym not in self.text_collater.token2idx: 126 | # print(f'{sym} is unk') 127 | continue 128 | res.append(sym) 129 | return res 130 | 131 | def __getitem__(self, idx): 132 | item = self.data[idx] 133 | 134 | input_ids = item["phoneme"].split("|") 135 | input_ids = self.remove_unknown_symbols(input_ids) 136 | 137 | input_ids_2 = None 138 | if item.get("phoneme_2"): 139 | input_ids_2 = item["phoneme_2"].split("|") 140 | input_ids_2 = [self.remove_unknown_symbols(input_ids_2)] 141 | 142 | input_ids = self.text_collater( 143 | [input_ids], input_ids_2).to(dtype=torch.long) 144 | input_ids = input_ids.to(dtype=torch.long) 145 | 146 | labels = np.load(item["semantic_path"]) 147 | labels = [str(lbl) for lbl in labels] 148 | 149 | labels_2 = None 150 | if item.get("semantic_path_2"): 151 | labels_2 = np.load(item["semantic_path_2"]) 152 | labels_2 = [[str(lbl) for lbl in labels_2]] 153 | 154 | labels = self.text_collater([labels], labels_2).to(dtype=torch.long) 155 | 156 | return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)} 157 | 158 | # TODO - remove this to not load to the memory 159 | def _build(self): 160 | for manifest_path in self.manifest_path: 161 | dataset_path = Path(manifest_path).parent 162 | 163 | with open(manifest_path, "r") as manifest_file: 164 | manifest_data = json.load(manifest_file) 165 | 166 | for key, value in tqdm(manifest_data.items()): 167 | if float(value["duration"]) > self.max_duration: 168 | continue 169 | text = value["text"] 170 | phoneme = value["phoneme"] 171 | npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa 172 | datapoint = { 173 | "text": text, 174 | "semantic_path": npy_path, 175 | "phoneme": phoneme 176 | } 177 | self.data.append(datapoint) 178 | 179 | print(f"Total length of the dataset {manifest_path}: {len(self.data)}") 180 | 181 | random.shuffle(self.data) 182 | 183 | 184 | if __name__ == "__main__": 185 | # Create an instance of the dataset 186 | manifest_path = "datasets/ljspeech-training-data/dev.json" 187 | text_tokens_file = "ckpt/unique_text_tokens.k2symbols" 188 | seq2seq_dataset = ConcatenateSemanticDataset( 189 | [manifest_path, manifest_path], text_tokens_file) 190 | 191 | # seq2seq_dataset.phonemize_and_rewrite_manifest() 192 | batch_size = 1 # Adjust to your desired batch size 193 | dataloader = DataLoader( 194 | seq2seq_dataset, 195 | batch_size=batch_size, 196 | shuffle=True, 197 | collate_fn=Collator().collate, 198 | ) 199 | 200 | for batch in dataloader: 201 | print(batch["input_ids"]) 202 | print(batch["labels"]) 203 | print(batch["input_ids"][0].unique().max()) 204 | print(batch["input_ids"][0].unique().min()) 205 | print(batch["input_ids"].shape) 206 | print(batch["labels"].shape) 207 | break # Stop after the first batch if needed 208 | -------------------------------------------------------------------------------- /data/single_speaker_dataset.py: -------------------------------------------------------------------------------- 1 | """Main loading function. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import json 6 | import os 7 | import random 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import soundfile as sf 12 | import torch 13 | from librosa.util import normalize 14 | from pyannote.audio import Inference 15 | from torch.utils import data 16 | 17 | import constants as c 18 | 19 | 20 | def random_crop(x, maxseqlen): 21 | if x.shape[0] >= maxseqlen: 22 | offset = random.randrange(x.shape[0] - maxseqlen + 1) 23 | x = x[offset: offset + maxseqlen] 24 | else: 25 | offset = 0 26 | return x, offset 27 | 28 | 29 | def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5): 30 | return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C 31 | 32 | 33 | def dynamic_range_decompression(x, C=0.3, M=6.5): 34 | return np.exp(x / C - M) 35 | 36 | 37 | class QuantizeDataset(data.Dataset): 38 | def __init__(self, hp, metapath, datadir=None, speaker_embedding_dir=None): 39 | self.hp = hp 40 | self.datadir = Path(datadir) 41 | self.speaker_embedding_dir = speaker_embedding_dir 42 | self.sem_mask_id = hp.n_semantic_codes 43 | 44 | print(f"Loading metadata in {metapath}...") 45 | with open(metapath, "r") as f: 46 | self.text = json.load(f) 47 | if 0 < self.hp.max_dataset_samples < len(self.text): 48 | self.new_text = {} 49 | num = 0 50 | for k, v in self.text.items(): 51 | if num >= self.hp.max_dataset_samples: 52 | break 53 | self.new_text[k] = v 54 | num += 1 55 | self.text = self.new_text 56 | 57 | self.datasetbase = [x for x in self.text.keys()] 58 | self.dataset = [ 59 | os.path.join(self.datadir, x) for x in self.datasetbase] 60 | 61 | if self.speaker_embedding_dir is None: 62 | self.spkr_embedding = Inference( 63 | "pyannote/embedding", 64 | window="whole", 65 | use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"], 66 | ) 67 | 68 | # Print statistics: 69 | n = len(self.dataset) 70 | print(f"Total {n} examples") 71 | 72 | self.lengths = [float(v["duration"]) for v in self.text.values()] 73 | total_duration = sum(self.lengths) 74 | avglen = total_duration / len(self.lengths) 75 | maxlen = max(self.lengths) 76 | minlen = min(self.lengths) 77 | print( 78 | f"Average duration of audio: {avglen} sec, " 79 | "Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec" 80 | ) 81 | 82 | def __len__(self): 83 | return len(self.dataset) 84 | 85 | def load_quantization(self, _name): 86 | if self.hp.vocoder_type == 'NATIVE': 87 | metadata = self.text[_name] 88 | quantization = np.array(metadata["quantization"]).T # ..., 4 89 | elif self.hp.vocoder_type == 'DAC': 90 | codes_path = self.datadir.parent / 'audios-dac' / (os.path.splitext(_name)[0] + ".npy") # noqa 91 | quantization = np.load(codes_path).T # ..., 12 92 | elif self.hp.vocoder_type == 'ENCODEC': 93 | codes_path = self.datadir.parent / 'audios-encodec' / (os.path.splitext(_name)[0] + ".npy") # noqa 94 | quantization = np.load(codes_path).squeeze(0).T # ..., 8 95 | elif self.hp.vocoder_type == 'SPEECHTOKENIZER': 96 | codes_path = self.datadir.parent / 'audios-speech-tokenizer/acoustic' / (os.path.splitext(_name)[0] + ".npy") # noqa 97 | quantization = np.load(codes_path).T # ..., 7 98 | else: 99 | raise ValueError(f"Unknown vocoder_type {self.hp.vocoder_type}") 100 | 101 | return quantization 102 | 103 | def __getitem__(self, i): 104 | dataname = self.dataset[i] 105 | _name = self.datasetbase[i] 106 | metadata = self.text[_name] 107 | 108 | # Speaker 1 109 | acoustic_tokens = self.load_quantization(_name) 110 | acoustic_tokens = np.pad( 111 | acoustic_tokens, [[1, 0],[0,0]], constant_values=c.SPKR_1) 112 | 113 | npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa 114 | semantic_tokens = np.load(npy_path)[None] 115 | semantic_tokens = np.pad( 116 | semantic_tokens,[[0,0], [1, 0]], constant_values=c.SPKR_1) 117 | 118 | if "name_2" in metadata: 119 | wav, _ = sf.read(dataname.split(".")[0] + "_1.wav") 120 | else: 121 | wav, _ = sf.read(dataname) 122 | audio = normalize(wav) * 0.95 123 | speaker_embedding = self.spkr_embedding( 124 | {"waveform": torch.FloatTensor(audio).unsqueeze(0), 125 | "sample_rate": self.hp.sample_rate,} 126 | ).reshape(1, -1) 127 | speaker_embedding = np.repeat( 128 | speaker_embedding, semantic_tokens.shape[1], axis=0) 129 | 130 | # Speaker 2 131 | if "text_2" in metadata: 132 | _name = _name.split(".wav")[0] + "_2.wav" 133 | acoustic_tokens_2 = self.load_quantization(_name) 134 | acoustic_tokens_2 = np.pad( 135 | acoustic_tokens_2, [[1, 0],[0,0]], constant_values=c.SPKR_2) 136 | 137 | npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa 138 | semantic_tokens_2 = np.load(npy_path)[None] 139 | semantic_tokens_2 = np.pad( 140 | semantic_tokens_2,[[0,0], [1, 0]], constant_values=c.SPKR_2) 141 | 142 | wav, _ = sf.read(dataname.split(".wav")[0] + "_2.wav") 143 | audio = normalize(wav) * 0.95 144 | speaker_embedding_2 = self.spkr_embedding( 145 | {"waveform": torch.FloatTensor(audio).unsqueeze(0), 146 | "sample_rate": self.hp.sample_rate,} 147 | ).reshape(1, -1) 148 | speaker_embedding_2 = np.repeat( 149 | speaker_embedding_2, semantic_tokens_2.shape[1], axis=0) 150 | 151 | # Merge both speakers 152 | acoustic_tokens = np.concatenate( 153 | (acoustic_tokens, acoustic_tokens_2), axis=0) 154 | semantic_tokens = np.concatenate( 155 | (semantic_tokens, semantic_tokens_2), axis=1) 156 | speaker_embedding = np.concatenate( 157 | (speaker_embedding, speaker_embedding_2), axis=0) 158 | 159 | speaker_embedding = speaker_embedding[:self.hp.max_length, :] 160 | acoustic_tokens = acoustic_tokens[:self.hp.max_length, :] 161 | semantic_tokens = semantic_tokens[:, :self.hp.max_length] 162 | 163 | # # HACK - we have no 8 lvls pfb30 164 | # acoustic_tokens = np.concatenate((semantic_tokens.T, acoustic_tokens), axis=1) 165 | # # END HACK 166 | 167 | return speaker_embedding, acoustic_tokens, acoustic_tokens, dataname, semantic_tokens # noqa 168 | -------------------------------------------------------------------------------- /datasets/example/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "LJ001-0051.wav": { 3 | "text": "and paying great attention to the \"press work\" or actual process of printing,", 4 | "raw-text": "and paying great attention to the \"press work\" or actual process of printing,", 5 | "duration": 4.860090702947846, 6 | "phoneme": "æ|n|d|_|p|eɪ|ɪ|ŋ|_|ɡ|ɹ|eɪ|t|_|ɐ|t|ɛ|n|ʃ|ə|n|_|t|ə|_|ð|ə|_|\"|p|ɹ|ɛ|s|_|w|ɜː|k|\"|_|ɔː|ɹ|_|æ|k|tʃ|uː|əl|_|p|ɹ|ɑː|s|ɛ|s|_|ʌ|v|_|p|ɹ|ɪ|n|t|ɪ|ŋ|," 7 | }, 8 | "LJ001-0120.wav": { 9 | "text": "In the old print each figure has its definite individuality, and one cannot be mistaken for the other;", 10 | "raw-text": "In the old print each figure has its definite individuality, and one cannot be mistaken for the other;", 11 | "duration": 6.973106575963719, 12 | "phoneme": "ɪ|n|ð|ɪ|_|oʊ|l|d|_|p|ɹ|ɪ|n|t|_|iː|tʃ|_|f|ɪ|ɡ|j|ɚ|_|h|ɐ|z|_|ɪ|t|s|_|d|ɛ|f|ɪ|n|ə|t|_|ɪ|n|d|ɪ|v|ɪ|d|uː|æ|l|ɪ|ɾ|i|,|_|æ|n|d|_|w|ʌ|n|_|k|æ|n|ɑː|t|_|b|iː|_|m|ɪ|s|t|eɪ|k|ə|n|_|f|ɚ|ð|ɪ|_|ʌ|ð|ɚ|;" 13 | } 14 | } -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD0000004393_S0000029.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD0000004393_S0000029.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD0000007005_S0000568.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD0000007005_S0000568.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD0000009720_S0000244.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD0000009720_S0000244.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD0000014360_S0000082.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD0000014360_S0000082.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD0000015908_S0000037.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD0000015908_S0000037.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/POD1000000022_S0000028.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/POD1000000022_S0000028.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/acoustic/male_voice.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/acoustic/male_voice.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD0000004393_S0000029.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD0000004393_S0000029.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD0000007005_S0000568.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD0000007005_S0000568.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD0000009720_S0000244.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD0000009720_S0000244.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD0000014360_S0000082.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD0000014360_S0000082.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD0000015908_S0000037.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD0000015908_S0000037.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/POD1000000022_S0000028.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/POD1000000022_S0000028.npy -------------------------------------------------------------------------------- /demo/audios-speech-tokenizer/semantic/male_voice.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios-speech-tokenizer/semantic/male_voice.npy -------------------------------------------------------------------------------- /demo/audios/POD0000004393_S0000029.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD0000004393_S0000029.wav -------------------------------------------------------------------------------- /demo/audios/POD0000007005_S0000568.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD0000007005_S0000568.wav -------------------------------------------------------------------------------- /demo/audios/POD0000009720_S0000244.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD0000009720_S0000244.wav -------------------------------------------------------------------------------- /demo/audios/POD0000014360_S0000082.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD0000014360_S0000082.wav -------------------------------------------------------------------------------- /demo/audios/POD0000015908_S0000037.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD0000015908_S0000037.wav -------------------------------------------------------------------------------- /demo/audios/POD1000000022_S0000028.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/POD1000000022_S0000028.wav -------------------------------------------------------------------------------- /demo/audios/male_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/audios/male_voice.wav -------------------------------------------------------------------------------- /demo/male_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/demo/male_voice.wav -------------------------------------------------------------------------------- /demo/manifest.json: -------------------------------------------------------------------------------- 1 | {"audio_filepath":"male_voice.wav","text":"Welcome to Casino Lakes Charles. I'm very happy to help you today. We have a broad range of goods for you!","speaker":0,"audio_prompt_filepath":"audios/male_voice.wav"} 2 | {"audio_filepath":"POD0000015908_S0000037.wav","text":"another important thing was that there was no long-term follow-up of the patients to see if they had really stayed cancer free.","speaker":0,"audio_prompt_filepath":"audios/POD0000015908_S0000037.wav"} 3 | {"audio_filepath":"POD0000009720_S0000244.wav","text":"and the whole thing is just so cozy that he wants to be part of it. he wants to be in their club.","speaker":0,"audio_prompt_filepath":"audios/POD0000009720_S0000244.wav"} 4 | {"audio_filepath":"POD0000014360_S0000082.wav","text":"and this is where a large amount of the profits come, such as elsevier making eight hundred and forty six million dollars profit last year.","speaker":0,"audio_prompt_filepath":"audios/POD0000014360_S0000082.wav"} 5 | {"audio_filepath":"POD0000004393_S0000029.wav","text": "just like with uber, when there is less demand, a sports franchise can also lower prices to try to drive up ticket sales.","speaker":0,"audio_prompt_filepath":"audios/POD0000004393_S0000029.wav"} 6 | {"audio_filepath":"POD0000007005_S0000568.wav","text":"but let's let's just cover it now. if you could make a plea or a suggestion to people involved with nonprofits out there and say,","speaker":0,"audio_prompt_filepath":"audios/POD0000007005_S0000568.wav"} 7 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {% seo %} 7 | 8 | 9 | 10 | 11 | 12 | 13 | {% include head-custom.html %} 14 | 15 | 16 | Skip to the content. 17 | 18 | 30 | 31 |
32 | {{ content }} 33 | 34 | 40 |
41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /docs/assets/css/style.scss: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | @import "{{ site.theme }}"; 5 | 6 | .page-header { 7 | background-image: linear-gradient(170deg, #2876c4, #d9ee50); 8 | } 9 | .main-content { 10 | max-width: 90%; 11 | font-size: 0.8rem; 12 | } 13 | 14 | audio { 15 | width: 140px; 16 | } 17 | 18 | .footnotes { 19 | list-style: none; 20 | padding-left: 0; 21 | } 22 | 23 | .footnotes li { 24 | font-size: 0.8em; 25 | } 26 | 27 | .footnotes a { 28 | text-decoration: none; 29 | } 30 | -------------------------------------------------------------------------------- /docs/assets/img/polyai-logo.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/assets/img/polyai-logo.webp -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## PHEME: Efficient and Conversational Speech Generation. 2 | 3 | - Abstract. In recent years, speech generation has seen remarkable progress, now achieving one-shot generation capability that is often virtually indistinguishable from real human voice. Integrating such advancements in speech generation with large language models might revolutionize a wide range of applications. However, certain applications, such as assistive conversational systems, require natural and conversational speech generation which also operates efficiently in real time. Current state-of-the-art models like VALL-E and SoundStorm, powered by hierarchical neural audio codecs, require large neural components and extensive training data to work well. In contrast, MQTTS aims to build more compact conversational TTS models while capitalizing on smaller-scale real-life conversational speech data. However, its autoregressive nature yields high inference latency and thus limits its real-time usage. In order to mitigate the current limitations of the state-of-the-art TTS models while capitalizing on their strengths, in this work we propose the *PHEME* model series that **1)** offers compact yet high-performing models, **2)** allows for parallel speech generation of **3)** natural conversational speech, and **4)** it can be trained efficiently on smaller-scale conversational data, cutting data demands by more than 10x but still matching the quality of the autoregressive TTS models. We also show that through simple teacher-student distillation we can meet significant improvements in voice quality for single-speaker setups on top of pretrained *PHEME* checkpoints, relying solely on synthetic speech generated by much larger teacher models. 4 | - [Code](https://github.com/PolyAI-LDN/pheme) 5 | - [Demo](https://huggingface.co/spaces/PolyAI/pheme) 6 | - [Paper](https://arxiv.org/pdf/2401.02839.pdf) 7 | 8 | 9 | ### GigaSpeech One-shot1 TTS Examples 10 |
11 |
    12 |
  1. 13 | One-shot - inference setup for voices unseen at the training time, when prompts and speaker embeddings are provided as additional model inputs. 14 |
  2. 15 |
16 |
17 | 18 | 19 | | Prompt audio | Reference audio | PHEME (100M) | PHEME (300M) no speaker embeddings | PHEME (300M) | Prompt text | Reference text | 20 | | :----------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 21 | | | | | | | let's just say in her own words, once i sat down and watched it i never moved, i w as enthralled by it. | and she told me the next time she went back she would take me with her. and i waited, of course, like i said, thirteen years. | 22 | | | | | | | in early twenty-twenty, blue apron put the word out that it was interested in possibly getting scooped up. maybe by a big grocery chain. or someone else with deep pockets who wanted to own a meal kit delivery business. | at the same time, garcia says, the company acted like it was in turnaround mode. it decid ed to streamline operations, including shutting down its fulfillment center in texas | 23 | | | | | | | aside from influencing basically everyone who matters he was one of the first if not, in fact the first artist to bring an electric guitar player with him on to the grand oleopry stag e. | if you want to call it a honky tonk, and it happened after ernest tubb. it was influenced by ernest tubb. before i get to the story and episode, i'd like to address one other thing. | 24 | | | | | | | so it's ah i think there's a range of risks, but generally speaking ah there's goi ng to be a study increase in the floor of the skill level as these ah a i technologies diffuse. | that is, there will be more and more ah capabilities available to people at the bottom of the scale, that is individuals as well as people with more access to computing power, ah money, and data at the higher end. | 25 | | | | | | | so after they put in their name, phone number, email address onto your landing pag e. where would you like to send them? would you like to send them to your facebook page your website? | book an appointment to a buyer on facebook messenger bot, a seller messenger bot. where w ould you like to send them? so for this example i'm just gonna say book an appointment. | 26 | 27 | 28 | 29 | ### Artificial Voice TTS Examples 30 | 31 | | Prompt audio | Reference audio | PHEME (300M) no training on artificial voice | PHEME (300M) | Prompt text | Reference text | 32 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 33 | | | | | | Our garden terrace is a lovely spot for afternoon tea. | The city’s ghost walk is a spooky and fascinating evening adventure. | 34 | | | | | | If you need a quiet place to work, our library is just perfect. | Our hotel’s evening bonfires are a great place to socialize. | 35 | | | | | | There’s a delightful chocolate factory tour, great for families. | Our rooftop jazz nights feature some of the best local talent. | 36 | | | | | | The rooftop bar hosts a live DJ on Friday nights. | Our in-house sommelier leads an exquisite wine and cheese pairing event. | 37 | | | | | | The comedy club in town is known for its hilarious acts. | The annual food fair showcases the best of local cuisine. | 38 | 39 | ### Inference speed with Triton-LLM (RTFs, lower is better) for short and long sentences 40 | 41 | | Model | *short* | *long* | GPU | 42 | | ------------------ | --------- | --------- |--------- | 43 | | MQTTS (100M) | 1.930 | 1.842 | A100 | 44 | | PHEME-SMALL (100M) | **0.133** | **0.133** | A100 | 45 | | PHEME-LARGE (300M) | 0.143 | 0.143 | A100 | 46 | 47 | -------------------------------------------------------------------------------- /docs/samples/empress/114.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/114.wav -------------------------------------------------------------------------------- /docs/samples/empress/148.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/148.wav -------------------------------------------------------------------------------- /docs/samples/empress/161.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/161.wav -------------------------------------------------------------------------------- /docs/samples/empress/189.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/189.wav -------------------------------------------------------------------------------- /docs/samples/empress/217.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/217.wav -------------------------------------------------------------------------------- /docs/samples/empress/226.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/226.wav -------------------------------------------------------------------------------- /docs/samples/empress/234.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/234.wav -------------------------------------------------------------------------------- /docs/samples/empress/242.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/242.wav -------------------------------------------------------------------------------- /docs/samples/empress/262.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/262.wav -------------------------------------------------------------------------------- /docs/samples/empress/269.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/269.wav -------------------------------------------------------------------------------- /docs/samples/empress/29.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/29.wav -------------------------------------------------------------------------------- /docs/samples/empress/46.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/empress/46.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000004_S0000246.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000004_S0000246.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000004_S0000247.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000004_S0000247.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000018_S0000253.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000018_S0000253.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000018_S0000254.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000018_S0000254.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000048_S0000035.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000048_S0000035.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/POD1000000048_S0000036.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/POD1000000048_S0000036.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/YOU1000000006_S0000051.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/YOU1000000006_S0000051.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/YOU1000000006_S0000052.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/YOU1000000006_S0000052.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/YOU1000000044_S0000798.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/YOU1000000044_S0000798.wav -------------------------------------------------------------------------------- /docs/samples/gigaspeech/YOU1000000044_S0000799.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/gigaspeech/YOU1000000044_S0000799.wav -------------------------------------------------------------------------------- /docs/samples/pheme-100/019.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-100/019.wav -------------------------------------------------------------------------------- /docs/samples/pheme-100/042.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-100/042.wav -------------------------------------------------------------------------------- /docs/samples/pheme-100/080.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-100/080.wav -------------------------------------------------------------------------------- /docs/samples/pheme-100/188.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-100/188.wav -------------------------------------------------------------------------------- /docs/samples/pheme-100/209.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-100/209.wav -------------------------------------------------------------------------------- /docs/samples/pheme-300/019.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-300/019.wav -------------------------------------------------------------------------------- /docs/samples/pheme-300/042.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-300/042.wav -------------------------------------------------------------------------------- /docs/samples/pheme-300/080.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-300/080.wav -------------------------------------------------------------------------------- /docs/samples/pheme-300/188.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-300/188.wav -------------------------------------------------------------------------------- /docs/samples/pheme-300/209.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-300/209.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/001.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/002.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/190.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/190.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/227.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/227.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/235.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/235.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/243.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/243.wav -------------------------------------------------------------------------------- /docs/samples/pheme-empress-300/270.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-empress-300/270.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-empress-300/190.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-empress-300/190.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-empress-300/227.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-empress-300/227.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-empress-300/235.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-empress-300/235.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-empress-300/243.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-empress-300/243.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-empress-300/270.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-empress-300/270.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-spkr-300/019.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-spkr-300/019.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-spkr-300/042.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-spkr-300/042.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-spkr-300/080.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-spkr-300/080.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-spkr-300/188.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-spkr-300/188.wav -------------------------------------------------------------------------------- /docs/samples/pheme-no-spkr-300/209.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/docs/samples/pheme-no-spkr-300/209.wav -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolyAI-LDN/pheme/3a520340b5094db0aa7275f28942dad2fc2cf7a3/modules/__init__.py -------------------------------------------------------------------------------- /modules/conformer.py: -------------------------------------------------------------------------------- 1 | """Conformer definition adjusted given the Lucidrain's repo. 2 | https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa 3 | 4 | Copyright PolyAI Limited. 5 | """ 6 | from collections import namedtuple 7 | from functools import wraps 8 | from typing import Dict, Union 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from einops import rearrange, reduce 13 | from einops.layers.torch import EinMix, Rearrange 14 | from torch import einsum, nn 15 | 16 | 17 | # rotary embedding 18 | class RotaryEmbedding(nn.Module): 19 | def __init__(self, dim, theta = 10000): 20 | super().__init__() 21 | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 22 | self.register_buffer("inv_freq", inv_freq, persistent = False) 23 | 24 | @property 25 | def device(self): 26 | return next(self.buffers()).device 27 | 28 | def forward(self, seq_len): 29 | t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) 30 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 31 | freqs = torch.cat((freqs, freqs), dim = -1) 32 | return freqs 33 | 34 | def rotate_half(x): 35 | x1, x2 = x.chunk(2, dim=-1) 36 | return torch.cat((-x2, x1), dim=-1) 37 | 38 | def apply_rotary_pos_emb(pos, t): 39 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 40 | 41 | 42 | # constants 43 | EfficientAttentionConfig = namedtuple( 44 | 'EfficientAttentionConfig', 45 | ['enable_flash', 'enable_math', 'enable_mem_efficient'] 46 | ) 47 | 48 | # helpers 49 | def exists(val): 50 | return val is not None 51 | 52 | def default(val, d): 53 | return val if exists(val) else d 54 | 55 | def divisible_by(numer, denom): 56 | return (numer % denom) == 0 57 | 58 | def calc_same_padding(kernel_size): 59 | pad = kernel_size // 2 60 | return (pad, pad - (kernel_size + 1) % 2) 61 | 62 | def eval_decorator(fn): 63 | @wraps(fn) 64 | def inner(model, *args, **kwargs): 65 | was_training = model.training 66 | model.eval() 67 | out = fn(model, *args, **kwargs) 68 | model.train(was_training) 69 | return out 70 | return inner 71 | 72 | 73 | def once(fn): 74 | called = False 75 | @wraps(fn) 76 | def inner(x): 77 | nonlocal called 78 | if called: 79 | return 80 | called = True 81 | return fn(x) 82 | return inner 83 | 84 | print_once = once(print) 85 | 86 | 87 | # t5 relative positional bias 88 | class T5RelativePositionBias(nn.Module): 89 | def __init__( 90 | self, 91 | scale = 1., 92 | num_buckets = 32, 93 | max_distance = 128, 94 | heads = 8 95 | ): 96 | super().__init__() 97 | self.scale = scale 98 | self.num_buckets = num_buckets 99 | self.max_distance = max_distance 100 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 101 | 102 | @staticmethod 103 | def _relative_position_bucket( 104 | relative_position, 105 | num_buckets = 32, 106 | max_distance = 128 107 | ): 108 | ret = 0 109 | n = -relative_position 110 | 111 | num_buckets //= 2 112 | ret += (n < 0).long() * num_buckets 113 | n = torch.abs(n) 114 | 115 | max_exact = num_buckets // 2 116 | is_small = n < max_exact 117 | 118 | val_if_large = max_exact + ( 119 | torch.log(n.float() / max_exact) / math.log( 120 | max_distance / max_exact) * (num_buckets - max_exact) 121 | ).long() 122 | 123 | val_if_large = torch.min( 124 | val_if_large, 125 | torch.full_like(val_if_large, num_buckets - 1) 126 | ) 127 | 128 | ret += torch.where(is_small, n, val_if_large) 129 | return ret 130 | 131 | @property 132 | def device(self): 133 | return next(self.parameters()).device 134 | 135 | def forward(self, n): 136 | pos = torch.arange(n, device = self.device).long() 137 | rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') 138 | 139 | rp_bucket = self._relative_position_bucket( 140 | rel_pos, num_buckets = self.num_buckets, 141 | max_distance = self.max_distance) 142 | values = self.relative_attention_bias(rp_bucket) 143 | 144 | bias = rearrange(values, 'i j h -> h i j') 145 | return bias * self.scale 146 | 147 | 148 | # main class 149 | class Attend(nn.Module): 150 | def __init__( 151 | self, 152 | causal = False, 153 | dropout = 0., 154 | flash = False 155 | ): 156 | super().__init__() 157 | self.dropout = dropout 158 | self.attn_dropout = nn.Dropout(dropout) 159 | 160 | self.causal = causal 161 | self.flash = flash 162 | 163 | # determine efficient attention configs for cuda and cpu 164 | self.cpu_config = EfficientAttentionConfig(True, True, True) 165 | self.cuda_config = None 166 | 167 | if not torch.cuda.is_available() or not flash: 168 | return 169 | 170 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 171 | 172 | if device_properties.major == 8 and device_properties.minor == 0: 173 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa 174 | self.cuda_config = EfficientAttentionConfig(True, True, True) 175 | else: 176 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa 177 | self.cuda_config = EfficientAttentionConfig(False, True, True) 178 | 179 | def get_mask(self, i, j, device): 180 | return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa 181 | 182 | def flash_attn(self, q, k, v, mask = None, attn_bias = None): 183 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa 184 | 185 | # single headed key / values 186 | 187 | if k.ndim == 3: 188 | k = rearrange(k, 'b n d -> b 1 n d') 189 | 190 | if v.ndim == 3: 191 | v = rearrange(v, 'b n d -> b 1 n d') 192 | 193 | # Check if mask exists and expand to compatible shape 194 | # The mask is B L, so it would have to be expanded to B H N L 195 | if exists(mask) and mask.ndim != 4: 196 | mask = rearrange(mask, 'b j -> b 1 1 j') 197 | mask = mask.expand(-1, heads, q_len, -1) 198 | 199 | # Check if there is a compatible device for flash attention 200 | config = self.cuda_config if is_cuda else self.cpu_config 201 | causal = self.causal 202 | 203 | # handle attention bias 204 | if exists(attn_bias): 205 | mask_value = -torch.finfo(q.dtype).max // 2 206 | causal_mask = self.get_mask(q_len, k_len, device) 207 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value) 208 | 209 | if exists(mask): 210 | attn_bias = attn_bias.masked_fill(~mask, mask_value) 211 | 212 | mask = attn_bias 213 | causal = False 214 | 215 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 216 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 217 | out = F.scaled_dot_product_attention( 218 | q, k, v, 219 | attn_mask = mask, 220 | dropout_p = self.dropout if self.training else 0., 221 | is_causal = causal 222 | ) 223 | 224 | return out 225 | 226 | def forward(self, q, k, v, mask = None, attn_bias = None): 227 | """ 228 | einstein notation 229 | b - batch 230 | h - heads 231 | n, i, j - sequence length (base sequence length, source, target) 232 | d - feature dimension 233 | """ 234 | 235 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 236 | 237 | scale = q.shape[-1] ** -0.5 238 | 239 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' 240 | 241 | if self.flash: 242 | assert not exists(attn_bias) 243 | return self.flash_attn(q, k, v, mask = mask) 244 | 245 | # similarity 246 | 247 | sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale 248 | 249 | # attention bias 250 | 251 | if exists(attn_bias): 252 | sim = sim + attn_bias 253 | 254 | # causal mask 255 | if self.causal: 256 | causal_mask = self.get_mask(q_len, k_len, device) 257 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 258 | 259 | # key padding mask 260 | if exists(mask): 261 | if mask.ndim != 4: 262 | mask = rearrange(mask, 'b j -> b 1 1 j') 263 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 264 | 265 | # attention 266 | attn = sim.softmax(dim=-1) 267 | attn = self.attn_dropout(attn) 268 | 269 | # aggregate values 270 | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) 271 | 272 | return out 273 | 274 | 275 | class Swish(nn.Module): 276 | def forward(self, x): 277 | return x * x.sigmoid() 278 | 279 | 280 | class GLU(nn.Module): 281 | def __init__(self, dim): 282 | super().__init__() 283 | self.dim = dim 284 | 285 | def forward(self, x): 286 | out, gate = x.chunk(2, dim=self.dim) 287 | return out * gate.sigmoid() 288 | 289 | 290 | class DepthWiseConv1d(nn.Module): 291 | def __init__(self, chan_in, chan_out, kernel_size, padding): 292 | super().__init__() 293 | self.padding = padding 294 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) 295 | 296 | def forward(self, x): 297 | x = F.pad(x, self.padding) 298 | return self.conv(x) 299 | 300 | 301 | class Scale(nn.Module): 302 | def __init__(self, scale, fn): 303 | super().__init__() 304 | self.fn = fn 305 | self.scale = scale 306 | 307 | def forward(self, x, **kwargs): 308 | return self.fn(x, **kwargs) * self.scale 309 | 310 | 311 | class ChanLayerNorm(nn.Module): 312 | def __init__(self, dim): 313 | super().__init__() 314 | self.gamma = nn.Parameter(torch.ones(1, dim, 1)) 315 | 316 | def forward(self, x): 317 | eps = 1e-6 if x.dtype == torch.float32 else 1e-4 318 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 319 | mean = torch.mean(x, dim = 1, keepdim = True) 320 | return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma 321 | 322 | 323 | class PreNorm(nn.Module): 324 | def __init__(self, dim, fn): 325 | super().__init__() 326 | self.fn = fn 327 | self.norm = nn.LayerNorm(dim) 328 | 329 | def forward(self, x, **kwargs): 330 | x = self.norm(x) 331 | return self.fn(x, **kwargs) 332 | 333 | 334 | class Attention(nn.Module): 335 | def __init__( 336 | self, 337 | dim, 338 | heads = 8, 339 | dim_head = 64, 340 | dropout = 0., 341 | flash = True 342 | ): 343 | super().__init__() 344 | inner_dim = dim_head * heads 345 | self.heads= heads 346 | self.scale = dim_head ** -0.5 347 | 348 | self.attend = Attend( 349 | flash = flash, 350 | dropout = dropout 351 | ) 352 | 353 | self.dropout = nn.Dropout(dropout) 354 | 355 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 356 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 357 | self.to_out = nn.Linear(inner_dim, dim) 358 | 359 | def forward( 360 | self, 361 | x, 362 | context = None, 363 | mask = None, 364 | rotary_emb = None, 365 | attn_bias = None 366 | ): 367 | n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) 368 | context = default(context, x) 369 | 370 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 371 | q, k, v = map( 372 | lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 373 | 374 | if exists(rotary_emb): 375 | q = apply_rotary_pos_emb(rotary_emb, q) 376 | k = apply_rotary_pos_emb(rotary_emb, k) 377 | 378 | out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) 379 | 380 | out = rearrange(out, 'b h n d -> b n (h d)') 381 | return self.to_out(out) 382 | 383 | 384 | class FeedForward(nn.Module): 385 | def __init__( 386 | self, 387 | dim, 388 | mult = 4, 389 | dropout = 0. 390 | ): 391 | super().__init__() 392 | self.net = nn.Sequential( 393 | nn.Linear(dim, dim * mult), 394 | Swish(), 395 | nn.Dropout(dropout), 396 | nn.Linear(dim * mult, dim), 397 | nn.Dropout(dropout) 398 | ) 399 | 400 | def forward(self, x): 401 | return self.net(x) 402 | 403 | 404 | class ConformerConvModule(nn.Module): 405 | def __init__( 406 | self, 407 | dim, 408 | causal = False, 409 | expansion_factor = 2, 410 | kernel_size = 31, 411 | dropout = 0. 412 | ): 413 | super().__init__() 414 | 415 | inner_dim = dim * expansion_factor 416 | padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) 417 | 418 | self.net = nn.Sequential( 419 | nn.LayerNorm(dim), 420 | Rearrange('b n c -> b c n'), 421 | nn.Conv1d(dim, inner_dim * 2, 1), 422 | GLU(dim=1), 423 | DepthWiseConv1d( 424 | inner_dim, inner_dim, kernel_size = kernel_size, 425 | padding = padding 426 | ), 427 | Swish(), 428 | ChanLayerNorm(inner_dim), 429 | nn.Conv1d(inner_dim, dim, 1), 430 | Rearrange('b c n -> b n c'), 431 | nn.Dropout(dropout) 432 | ) 433 | 434 | def forward(self, x): 435 | return self.net(x) 436 | 437 | 438 | # Conformer Block 439 | class ConformerBlock(nn.Module): 440 | def __init__( 441 | self, 442 | *, 443 | dim, 444 | dim_head = 64, 445 | heads = 8, 446 | ff_mult = 4, 447 | conv_expansion_factor = 2, 448 | conv_kernel_size = 31, 449 | attn_dropout = 0., 450 | attn_flash = True, 451 | ff_dropout = 0., 452 | conv_dropout = 0., 453 | conv_causal = False 454 | ): 455 | super().__init__() 456 | self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) 457 | self.attn = Attention( 458 | dim = dim, dim_head = dim_head, heads = heads, 459 | dropout = attn_dropout, flash = attn_flash 460 | ) 461 | self.conv = ConformerConvModule( 462 | dim = dim, causal = conv_causal, 463 | expansion_factor = conv_expansion_factor, 464 | kernel_size = conv_kernel_size, dropout = conv_dropout 465 | ) 466 | self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) 467 | 468 | self.attn = PreNorm(dim, self.attn) 469 | self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) 470 | self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) 471 | 472 | self.post_norm = nn.LayerNorm(dim) 473 | 474 | def forward( 475 | self, 476 | x, 477 | mask = None, 478 | rotary_emb = None, 479 | attn_bias = None 480 | ): 481 | x = self.ff1(x) + x 482 | x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa 483 | x = self.conv(x) + x 484 | x = self.ff2(x) + x 485 | x = self.post_norm(x) 486 | return x 487 | 488 | 489 | # Conformer 490 | class Conformer(nn.Module): 491 | def __init__( 492 | self, 493 | dim, 494 | *, 495 | num_layers, 496 | dim_head = 64, 497 | heads = 8, 498 | ff_mult = 4, 499 | conv_expansion_factor = 2, 500 | conv_kernel_size = 31, 501 | attn_dropout = 0., 502 | ff_dropout = 0., 503 | conv_dropout = 0., 504 | conv_causal = False, 505 | attn_flash = True, 506 | t5_rel_pos_bias = False 507 | ): 508 | super().__init__() 509 | 510 | assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa 511 | 512 | self.dim = dim 513 | self.layers = nn.ModuleList([]) 514 | 515 | self.rotary_emb = RotaryEmbedding( 516 | dim_head) if not t5_rel_pos_bias else None 517 | self.rel_pos_bias = T5RelativePositionBias( 518 | dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None 519 | 520 | for _ in range(num_layers): 521 | self.layers.append(ConformerBlock( 522 | dim = dim, 523 | dim_head = dim_head, 524 | heads = heads, 525 | ff_mult = ff_mult, 526 | conv_expansion_factor = conv_expansion_factor, 527 | conv_kernel_size = conv_kernel_size, 528 | attn_dropout = attn_dropout, 529 | ff_dropout = ff_dropout, 530 | conv_dropout = conv_dropout, 531 | conv_causal = conv_causal, 532 | attn_flash = attn_flash 533 | )) 534 | 535 | def forward(self, x, mask = None): 536 | seq_len = x.shape[-2] 537 | 538 | rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa 539 | attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa 540 | 541 | for block in self.layers: 542 | x = block( 543 | x, 544 | mask = mask, 545 | rotary_emb = rotary_emb, 546 | attn_bias = attn_bias 547 | ) 548 | return x 549 | 550 | 551 | # conformer with sum reduction across quantized tokens at the beginning, 552 | # along with heads 553 | class ConformerWrapper(nn.Module): 554 | def __init__( 555 | self, 556 | *, 557 | codebook_size, 558 | num_quantizers, 559 | conformer: Union[Conformer, Dict[str, any]], 560 | grouped_quantizers = 1 561 | ): 562 | super().__init__() 563 | self.conformer = conformer 564 | 565 | if isinstance(conformer, dict): 566 | self.conformer = Conformer(**self.conformer) 567 | 568 | dim = self.conformer.dim 569 | 570 | self.embedding_proj = nn.Sequential( 571 | nn.Linear(dim * grouped_quantizers, dim), 572 | nn.LayerNorm(dim) 573 | ) if grouped_quantizers > 1 else nn.Identity() 574 | 575 | num_codes_with_mask = codebook_size + 1 576 | num_effective_quantizers = num_quantizers * grouped_quantizers 577 | 578 | self.code_embeds = nn.Embedding( 579 | num_codes_with_mask * num_effective_quantizers, dim) 580 | 581 | self.register_buffer( 582 | 'quantizer_offsets', 583 | torch.arange(num_effective_quantizers) * num_codes_with_mask, 584 | persistent = False 585 | ) 586 | self.register_buffer( 587 | 'mask_tokens', self.quantizer_offsets + num_codes_with_mask, 588 | persistent = False 589 | ) 590 | 591 | self.dim = dim 592 | self.codebook_size = codebook_size 593 | 594 | self.num_codes_with_mask = num_codes_with_mask 595 | self.num_quantizers = num_quantizers 596 | self.grouped_quantizers = grouped_quantizers 597 | 598 | self.heads = nn.Sequential( 599 | nn.Linear(dim, dim * num_effective_quantizers), 600 | Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) 601 | ) 602 | 603 | # each quantizer codebook would require its own logits weight 604 | # and bias matrices 605 | # the amazing einops makes this easy with 'EinMix' 606 | self.to_logits = nn.Sequential( 607 | nn.LayerNorm(dim), 608 | Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers), 609 | EinMix( 610 | 'b n gq d -> b n gq l', 611 | weight_shape = 'gq d l', 612 | bias_shape = 'gq l', 613 | gq = num_effective_quantizers, 614 | l = codebook_size, 615 | d = dim 616 | ), 617 | Rearrange('b ... d -> b (...) d') 618 | ) 619 | 620 | def forward( 621 | self, 622 | x, 623 | *, 624 | mask = None, 625 | cond = None, 626 | sum_embeds = None, 627 | return_embeddings = False, 628 | return_logits_and_embeddings = False 629 | ): 630 | """ 631 | einops notation: 632 | b - batch 633 | n - sequence 634 | g - groups 635 | q - quantizers 636 | d - feature dimension 637 | """ 638 | 639 | n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers 640 | assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa 641 | 642 | x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q) 643 | x = x + self.quantizer_offsets 644 | 645 | x = self.code_embeds(x) 646 | 647 | x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g) 648 | 649 | x = self.embedding_proj(x) 650 | 651 | if exists(sum_embeds): 652 | x = x + sum_embeds 653 | 654 | if exists(cond): 655 | if cond.ndim == 2: 656 | cond = rearrange(cond, 'b d -> b 1 d') 657 | 658 | x = x + cond 659 | 660 | x = self.conformer(x, mask = mask) 661 | embeds = self.heads(x) 662 | 663 | if return_embeddings or not exists(self.to_logits): 664 | return embeds 665 | 666 | logits = self.to_logits(embeds) 667 | 668 | if return_logits_and_embeddings: 669 | return logits, embeds 670 | 671 | return logits 672 | -------------------------------------------------------------------------------- /modules/masking_logic.py: -------------------------------------------------------------------------------- 1 | """Masking and sampling logic adapted from MaskGIT original paper: 2 | https://github.com/google-research/maskgit 3 | 4 | Copyright PolyAI Limited. 5 | """ 6 | from dataclasses import dataclass 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | @dataclass 14 | class State: 15 | """Holds decoding state data.""" 16 | # The position of the decoding loop in the length dimension. 17 | cur_index: None 18 | # The active sequence log probabilities and finished sequence scores. 19 | cur_seqs: None 20 | final_seqs: None 21 | 22 | 23 | def state_init(init_indices, num_iter, start_iter=0): 24 | """Initializes the decoding state data structure.""" 25 | cur_index_0 = start_iter 26 | cur_seqs_0 = init_indices 27 | final_seqs_0 = torch.unsqueeze(init_indices, 1) 28 | final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1)) 29 | return State( 30 | cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0) 31 | 32 | 33 | def schedule(ratio, method="cosine"): 34 | if method == "uniform": 35 | mask_ratio = 1. - ratio 36 | elif "pow" in method: 37 | exponent = float(method.replace("pow", "")) 38 | mask_ratio = 1. - ratio**exponent 39 | elif method == "cosine": 40 | mask_ratio = np.cos(ratio * (np.pi/2)) 41 | 42 | mask_ratio = np.clip(mask_ratio, 1e-6, 1.) 43 | return mask_ratio 44 | 45 | 46 | def mask_by_random_topk(mask_len, probs, temperature=1.0): 47 | noise = gumbel_noise_like(probs) 48 | confidence = torch.log(probs) + temperature * noise 49 | sorted_confidence, _ = torch.sort(confidence, dim=-1) 50 | # Obtains cut off threshold given the mask lengths. 51 | cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1) 52 | # Masks tokens with lower confidence. 53 | masking = (confidence < cut_off) 54 | return masking 55 | 56 | 57 | def gumbel_noise_like(t): 58 | noise = torch.zeros_like(t).uniform_(1e-20, 1) 59 | return -torch.log(-torch.log(noise)) 60 | 61 | 62 | def sample_from_logits( 63 | logits, 64 | sample: bool = True, 65 | temperature: float = 1.0, 66 | top_k: int = None, 67 | top_p: float = None, 68 | return_probs: bool = False 69 | ): 70 | shp = logits.shape[:-1] 71 | 72 | # Apply top_k sampling 73 | if top_k is not None: 74 | v, _ = logits.topk(top_k) 75 | logits[logits < v[..., [-1]]] = -float("inf") 76 | 77 | # Apply top_p (nucleus) sampling 78 | if top_p is not None and top_p < 1.0: 79 | v, sorted_indices = logits.sort(descending=True) 80 | cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1) 81 | 82 | sorted_indices_to_remove = cumulative_probs > top_p 83 | # Right shift indices_to_remove to keep 1st token over threshold 84 | sorted_indices_to_remove = F.pad( 85 | sorted_indices_to_remove, (1, 0), value=False)[..., :-1] 86 | 87 | # Compute indices_to_remove in unsorted array 88 | indices_to_remove = sorted_indices_to_remove.scatter( 89 | -1, sorted_indices, sorted_indices_to_remove 90 | ) 91 | 92 | logits[indices_to_remove] = -float("inf") 93 | 94 | # Perform multinomial sampling after normalizing logits 95 | probs = ( 96 | F.softmax(logits / temperature, dim=-1) 97 | if temperature > 0 98 | else logits.softmax(dim=-1) 99 | ) 100 | token = ( 101 | probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp) 102 | if sample 103 | else logits.argmax(-1) 104 | ) 105 | 106 | if return_probs: 107 | token_probs = probs.take_along_dim( 108 | token.unsqueeze(-1), dim=-1).squeeze(-1) 109 | return token, token_probs 110 | else: 111 | return token 112 | -------------------------------------------------------------------------------- /modules/s2a_model.py: -------------------------------------------------------------------------------- 1 | """A2S model definition. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | from typing import Union 6 | 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from einops import rearrange 13 | 14 | import constants as c 15 | from modules import masking_logic 16 | from modules.conformer import Conformer 17 | from modules.masking_logic import (State, mask_by_random_topk, 18 | sample_from_logits, state_init) 19 | from utils import load_checkpoint 20 | 21 | 22 | class Pheme(pl.LightningModule): 23 | def __init__(self, hp): 24 | super().__init__() 25 | self.hp = hp 26 | self.model = TTSConformer(hp) 27 | self.cross_entropy = nn.CrossEntropyLoss( 28 | label_smoothing=self.hp.label_smoothing, 29 | ignore_index=self.hp.n_codes 30 | ) 31 | if self.hp.pretrained_path: 32 | self.load() 33 | else: 34 | self.apply(self.init_weights) 35 | 36 | if self.hp.only_inference: 37 | self.model.eval() 38 | 39 | self.save_hyperparameters() 40 | 41 | def load(self): 42 | state_dict = load_checkpoint(self.hp.pretrained_path) 43 | print(f"Parameters loaded from {self.hp.pretrained_path}") 44 | self.load_state_dict(state_dict, strict=True) 45 | 46 | def init_weights(self, module): 47 | if isinstance(module, nn.Linear): 48 | module.weight.data.normal_(mean=0.0, std=0.02) 49 | if module.bias is not None: 50 | module.bias.data.zero_() 51 | if isinstance(module, nn.Embedding): 52 | module.weight.data.normal_(mean=0.0, std=0.02) 53 | module._fill_padding_idx_with_zero() 54 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): 55 | module.bias.data.zero_() 56 | module.weight.data.fill_(1.0) 57 | elif isinstance(module, nn.Conv1d): 58 | module.weight.data.normal_(mean=0.0, std=0.02) 59 | if module.bias is not None: 60 | module.bias.data.zero_() 61 | 62 | def configure_optimizers(self): 63 | optimizer_adam = optim.AdamW( 64 | self.parameters(), lr=self.hp.lr, 65 | betas=(self.hp.adam_beta1, self.hp.adam_beta2)) 66 | 67 | # Learning rate scheduler 68 | num_training_steps = self.hp.training_step 69 | num_warmup_steps = self.hp.warmup_step 70 | num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps) 71 | 72 | def lambda_lr(current_step: int): 73 | if current_step < num_warmup_steps: 74 | return float(current_step) / float(max(1, num_warmup_steps)) 75 | elif current_step < (num_warmup_steps + num_flat_steps): 76 | return 1.0 77 | return max( 78 | 0.0, 79 | float(num_training_steps - current_step) 80 | / float( 81 | max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa 82 | ), 83 | ) 84 | 85 | scheduler_adam = { 86 | "scheduler": optim.lr_scheduler.LambdaLR( 87 | optimizer_adam, lambda_lr), 88 | "interval": "step", 89 | } 90 | return [optimizer_adam], [scheduler_adam] 91 | 92 | def top_k_accuracy(self, y_true, y_pred_probabilities, k): 93 | _, sorted_indices = torch.sort(y_pred_probabilities, descending=True) 94 | 95 | # Get the top-k predictions 96 | top_k_indices = sorted_indices[:, :k] 97 | expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices) 98 | 99 | # Check if true labels exist in top-k predictions 100 | hits = torch.sum(torch.eq(top_k_indices, expanded_y_true)) 101 | accuracy = hits.item() / (len(y_true) + 1e-7) 102 | 103 | return accuracy 104 | 105 | def training_step(self, batch, batch_idx): 106 | # Sample training level 107 | rvq_level = torch.randint( 108 | 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item() 109 | 110 | target, chosen_tokens, _, _ = self.model( 111 | batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"], 112 | batch["quantization_lengths"], 113 | speaker_emb=batch["speaker"], 114 | min_seq_length=batch["quantization_lengths"].min().item()) 115 | 116 | # Mask targets and labels 117 | mask = chosen_tokens 118 | target = target[mask] 119 | 120 | labels = batch["tts_quantize_input"][:, :, rvq_level] 121 | labels = labels[mask] 122 | 123 | loss = self.cross_entropy(target, labels) 124 | acc = (target.argmax(-1) == labels).float().mean() 125 | self.log("train/loss", loss, on_step=True, prog_bar=True) 126 | self.log("train/acc", acc, on_step=True, prog_bar=True) 127 | self.log( 128 | f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False) 129 | 130 | return loss 131 | 132 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 133 | speaker_emb = batch["speaker"] 134 | acoustic_tokens = batch["tts_quantize_input"] 135 | semantic_tokens = batch["semantic_tokens"] 136 | 137 | if self.hp.only_inference: 138 | self.inference( 139 | acoustic_tokens, semantic_tokens, self.hp.first_n_lvls) 140 | else: 141 | rvq_level = torch.randint( 142 | 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,) 143 | ).item() 144 | 145 | # FIXME: edge case 146 | if len(semantic_tokens.shape) == 3: 147 | semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T") 148 | 149 | target, chosen_tokens, _, _ = self.model( 150 | acoustic_tokens, rvq_level, semantic_tokens, 151 | torch.tensor([acoustic_tokens.shape[1]]).to(self.device), 152 | speaker_emb=speaker_emb, 153 | min_seq_length=acoustic_tokens.shape[1] 154 | ) 155 | 156 | target = target[chosen_tokens] 157 | labels = acoustic_tokens[:, :, rvq_level][chosen_tokens] 158 | loss = self.cross_entropy(target, labels) 159 | 160 | acc = (target.argmax(-1) == labels).float().mean() 161 | acc_5 = self.top_k_accuracy(labels, target, 5) 162 | 163 | self.log( 164 | f"val/dataset_{dataloader_idx}/loss", 165 | loss, 166 | on_epoch=True, 167 | logger=True, 168 | add_dataloader_idx=False, 169 | ) 170 | self.log( 171 | f"val/dataset_{dataloader_idx}/acc_lvl", 172 | acc, 173 | on_epoch=True, 174 | logger=True, 175 | add_dataloader_idx=False, 176 | ) 177 | self.log( 178 | f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}", 179 | acc, 180 | on_epoch=True, 181 | logger=True, 182 | add_dataloader_idx=False, 183 | ) 184 | self.log( 185 | f"val/dataset_{dataloader_idx}/acc_top_5", 186 | acc_5, 187 | on_epoch=True, 188 | logger=True, 189 | add_dataloader_idx=False, 190 | ) 191 | self.log( 192 | f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}", 193 | acc_5, 194 | on_epoch=True, 195 | logger=True, 196 | add_dataloader_idx=False, 197 | ) 198 | 199 | def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0): 200 | acc = (logits.argmax(-1) == labels).float().mean() 201 | acc_5 = self.top_k_accuracy(labels, logits, 5) 202 | acc_10 = self.top_k_accuracy(labels, logits, 10) 203 | 204 | idx = torch.randperm(logits.shape[0]) 205 | logits_shuffled = logits[idx] 206 | random = self.top_k_accuracy(labels, logits_shuffled, 10) 207 | print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc}," 208 | f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}") 209 | 210 | 211 | class TTSConformer(pl.LightningModule): 212 | def __init__(self, hp): 213 | super().__init__() 214 | self.hp = hp 215 | self.padding_id = self.hp.n_codes 216 | 217 | additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2] 218 | 219 | self.embedding = nn.ModuleList( 220 | [ 221 | nn.Embedding( 222 | self.hp.n_codes + len(additional_codes), 223 | self.hp.hidden_size, 224 | padding_idx=self.padding_id) 225 | for _ in range(self.hp.n_cluster_groups) 226 | ] 227 | ) 228 | 229 | # Additional modules 230 | self.semantic_embedding = nn.Embedding( 231 | self.hp.n_semantic_codes + len(additional_codes), 232 | self.hp.hidden_size, 233 | padding_idx=self.padding_id) 234 | 235 | if self.hp.use_spkr_emb: 236 | self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size) 237 | 238 | self.conformer = Conformer( 239 | dim=self.hp.hidden_size, 240 | num_layers=self.hp.enc_nlayers, 241 | heads=self.hp.nheads, 242 | dim_head=64, 243 | ff_mult=4, # 512*4=2048 244 | conv_expansion_factor=2, 245 | conv_kernel_size=self.hp.depthwise_conv_kernel_size, 246 | attn_dropout=self.hp.dropout, 247 | ff_dropout=self.hp.dropout, 248 | conv_dropout=self.hp.dropout, 249 | attn_flash=True, 250 | t5_rel_pos_bias=False 251 | ) 252 | 253 | self.heads = nn.ModuleList( 254 | [ 255 | nn.Linear( 256 | self.hp.hidden_size, 257 | self.hp.n_codes + len(additional_codes) 258 | ) 259 | for _ in range(self.hp.n_cluster_groups) 260 | ] 261 | ) 262 | 263 | def build_mask_from_lengths(self, length, max_len=None): 264 | max_len = max_len or length.max().item() 265 | mask = torch.arange( 266 | max_len, device=length.device)[None, :] >= length[:, None] 267 | return mask.bool() 268 | 269 | @torch.no_grad() 270 | def create_mask( 271 | self, B, T, lengths, mask_ratio=None, start_t=None, 272 | min_seq_length=None 273 | ): 274 | # 1. Define the random length of condition tokens given the shortest 275 | # audio in the batch 276 | if start_t is None: 277 | start_t = torch.randint(1, min_seq_length - 1, (1,)).item() 278 | 279 | # 2. Mask other tokens - sample different masking levels per 280 | if mask_ratio is None: 281 | ratio = torch.rand(1).item() 282 | mask_ratio = masking_logic.schedule(ratio) 283 | 284 | # Create a random tensor with values between 0 and 1 285 | random_tensor = torch.rand( 286 | (B, T - start_t), dtype=torch.float).to(self.device) 287 | # Create a mask where values less than p are set to True 288 | initial_mask = random_tensor < mask_ratio 289 | length_mask = self.build_mask_from_lengths( 290 | lengths - start_t, T - start_t) 291 | # we can't pick up tokens past token lengths 292 | initial_mask = torch.logical_and(initial_mask, ~length_mask) 293 | 294 | # Constrain ratio to always include some samples 295 | # If all are False let's pick up at least one: 296 | if torch.sum(initial_mask) == 0: 297 | choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,)) 298 | initial_mask[torch.arange(B), choose_steps] = torch.tensor( 299 | True, device=self.device) 300 | 301 | # 3. Add condition tokens containing information 302 | acoustic_token_mask = torch.cat( 303 | (torch.full((B, start_t), False, device=self.device), initial_mask), # noqa 304 | 1 305 | ) 306 | 307 | return acoustic_token_mask, start_t, mask_ratio 308 | 309 | def process_input( 310 | self, data, lengths, rvq_level, min_seq_length=None, 311 | mask_ratio=None, start_t=None, acoustic_token_mask=None 312 | ): 313 | """ 314 | data: (B, T, code_level, D) 315 | rvq_level: int 316 | """ 317 | B = data.size(0) 318 | T = data.size(1) 319 | level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D] 320 | 321 | # Choose acoustic tokens to mask 322 | if acoustic_token_mask is None: 323 | acoustic_token_mask, start_t, mask_ratio = self.create_mask( 324 | B, T, lengths, mask_ratio=mask_ratio, start_t=start_t, 325 | min_seq_length=min_seq_length) 326 | # Remove code information from chosen tokens 327 | level_data[acoustic_token_mask, :] = 0 328 | 329 | # Embed only lower rvq_level 330 | lower_code_data = data[:, :, :rvq_level, :].sum(dim=2) 331 | 332 | # Combine with chosen tokens at rvq_level. 333 | # Note: all tokens at rvq_level+1: will be discarded. 334 | summed_data = torch.add(lower_code_data, level_data) 335 | 336 | return summed_data, acoustic_token_mask, mask_ratio, start_t 337 | 338 | def forward( 339 | self, x, code_level, semantic_tokens, lengths, 340 | speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None, 341 | acoustic_token_mask=None 342 | ): 343 | # FIXME: parallelize this 344 | batch = [] 345 | for lvl, embed in enumerate(self.embedding[:(code_level + 1)]): 346 | batch.append(embed(x[:, :, lvl])) # [B T D] 347 | 348 | x = torch.stack(batch, dim=2) # [B T C D] 349 | x, acoustic_token_mask, mask_ratio, start_t = self.process_input( 350 | x, lengths, code_level, min_seq_length=min_seq_length, 351 | mask_ratio=mask_ratio, start_t=start_t, 352 | acoustic_token_mask=acoustic_token_mask 353 | ) 354 | 355 | # Add phoneme embeddings 356 | # Cross attention for all tokens? 357 | 358 | # Add semantic tokens 359 | # HACK ME 360 | semantic_emb = self.semantic_embedding(semantic_tokens) 361 | x = torch.add(x, semantic_emb) 362 | # FIXME pfb30 363 | 364 | # Merge different modalities 365 | if self.hp.use_spkr_emb: 366 | spkr_emb = F.normalize(speaker_emb, dim=-1) 367 | spkr_emb = self.spkr_linear( 368 | F.dropout(spkr_emb, self.hp.speaker_embed_dropout) 369 | ) 370 | x = torch.add(x, spkr_emb) 371 | 372 | output_frames = self.conformer(x, None) 373 | 374 | x = self.heads[code_level](output_frames) 375 | 376 | return x, acoustic_token_mask, mask_ratio, start_t 377 | 378 | @torch.no_grad() 379 | def inference( 380 | self, codes, semantic_tokens, 381 | length: torch.LongTensor, rvq_levels=7, 382 | mask_ratio=0.99, maskgit_inference=True, 383 | start_t: Union[torch.LongTensor, None] = None, 384 | speaker_emb=None, steps=16 385 | ): 386 | # Use half of the recording for the conditioning 387 | if start_t is None: 388 | start_t = torch.tensor(int((codes.shape[1]) / 2)).long() 389 | 390 | start_t = start_t.item() 391 | 392 | for rvq_level in range(rvq_levels): 393 | original_codes = torch.clone(codes) 394 | if rvq_level == 0 and maskgit_inference: 395 | codes = self.multi_step_inference( 396 | original_codes, semantic_tokens, length, 397 | start_t=start_t, vamp_filtering=False, 398 | speaker_emb=speaker_emb, steps=16 399 | ) 400 | else: 401 | codes = self.one_step_inference( 402 | original_codes, semantic_tokens, length, 403 | code_level=rvq_level, 404 | mask_ratio=mask_ratio, start_t=start_t, 405 | speaker_emb=speaker_emb 406 | ) 407 | 408 | codes = rearrange(codes, 'T C -> 1 T C') 409 | 410 | # Remove any padding left 411 | codes = rearrange(codes, '1 T C -> 1 C T') 412 | codes = torch.where(codes >= self.hp.n_codes, 0, codes) 413 | acoustic_tokens = codes 414 | semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c') 415 | semantic_tokens = torch.where( 416 | semantic_tokens >= self.hp.n_codes, 0, semantic_tokens) 417 | codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1) 418 | 419 | return codes 420 | 421 | @torch.no_grad() 422 | def one_step_inference( 423 | self, original_codes, semantic_tokens, lengths, code_level=0, 424 | mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None 425 | ): 426 | codes = torch.clone(original_codes) 427 | logits, _, _, _ = self.forward( 428 | codes, code_level, semantic_tokens, lengths, 429 | mask_ratio=mask_ratio, start_t=start_t, 430 | speaker_emb=speaker_emb, acoustic_token_mask=False) 431 | 432 | if inference_setup == "argmax": 433 | probs = torch.nn.functional.softmax(logits, dim=-1) 434 | top_indeces = torch.argmax(probs, dim=-1) 435 | 436 | if inference_setup == "sampling": 437 | top_indeces = torch.distributions.Categorical( 438 | logits=logits).sample() 439 | 440 | codes = rearrange(codes, '1 T C -> T C') 441 | codes[start_t:, code_level] = top_indeces[0, start_t:] 442 | 443 | return codes 444 | 445 | @torch.no_grad() 446 | def multi_step_inference( 447 | self, original_codes, semantic_tokens, lengths, 448 | start_t: torch.LongTensor=None, 449 | choice_temperature=1.0, start_iter=0, 450 | steps=16, vamp_filtering=False, speaker_emb=None 451 | ): 452 | codes = torch.clone(original_codes) 453 | code_level = 0 454 | _, seq_len, _ = original_codes.shape 455 | mask_token_id = self.padding_id 456 | 457 | # Get true codes for the prompt 458 | prompt_mask = codes[:, :start_t, code_level] 459 | 460 | # Fill up rest with masks 461 | mask = torch.full( 462 | (1, seq_len - start_t), mask_token_id, device=self.device) 463 | inputs = torch.cat((prompt_mask, mask), 1) 464 | 465 | num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1) 466 | 467 | # Initializes state 468 | state = state_init(inputs, steps, start_iter=start_iter) 469 | 470 | def loop_cond_fn(state): 471 | """Beam search loop termination condition.""" 472 | not_at_end = (state.cur_index < steps) 473 | return not_at_end 474 | 475 | while loop_cond_fn(state): 476 | """Beam search loop state update function.""" 477 | step = state.cur_index 478 | # Current input ids: [batch_size, seq_length]. 479 | cur_ids = state.cur_seqs 480 | 481 | # Calls model on current seqs to get next-iteration seqs. 482 | with torch.no_grad(): 483 | logits, _, _, _ = self.forward( 484 | rearrange(inputs, 'B T -> B T 1'), 485 | code_level, 486 | semantic_tokens, lengths, 487 | acoustic_token_mask=False, 488 | speaker_emb=speaker_emb) 489 | 490 | # Samples the ids using categorical sampling: 491 | if vamp_filtering: 492 | typical_mass = 0.2 493 | typical_min_tokens = 1 494 | top_p = None 495 | sample_cutoff = 0.5 496 | typical_filtering = False 497 | sampled_ids, selected_probs = sample_from_logits( 498 | logits, sample=((step / steps) <= sample_cutoff), 499 | temperature=choice_temperature, 500 | typical_filtering=typical_filtering, 501 | typical_mass=typical_mass, 502 | typical_min_tokens=typical_min_tokens, 503 | top_k=None, top_p=top_p, return_probs=True, 504 | ) 505 | else: 506 | sampled_ids = torch.distributions.Categorical( 507 | logits=logits).sample() 508 | 509 | # Just updates the masked tokens. 510 | unknown_map = (cur_ids == mask_token_id) 511 | sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) 512 | # Defines the mask ratio for the next round. The number to mask out 513 | # is determined by mask_ratio * unknown_number_in_the_beginning. 514 | ratio = 1. * (step + 1) / steps 515 | mask_ratio = masking_logic.schedule(ratio) 516 | 517 | # Updates final seqs with the current sampled_ids. 518 | final_seqs = torch.clone(state.final_seqs) 519 | final_seqs[:, step, :] = sampled_ids 520 | # Computes the probabilities of each selected tokens. 521 | probs = torch.nn.functional.softmax(logits, dim=-1) 522 | # Extract the probabilities of sampled ids 523 | selected_probs = torch.squeeze( 524 | torch.take_along_dim( 525 | probs, torch.unsqueeze(sampled_ids, -1) , -1), 526 | -1 527 | ) 528 | 529 | # Ignores the tokens given in the input 530 | # by overwriting their confidence. 531 | selected_probs = torch.where( 532 | unknown_map, selected_probs, torch.inf) 533 | # Gets mask lens for each sample in the 534 | # batch according to the mask ratio. 535 | num_to_mask = torch.unsqueeze( 536 | torch.floor(num_mask_tokens_at_start * mask_ratio), 1) 537 | 538 | # Keeps at least one of prediction in this 539 | # round and also masks out at least 540 | # one and for the next iteration 541 | num_to_mask = torch.maximum( 542 | torch.tensor(1), 543 | torch.minimum( 544 | torch.sum(unknown_map, dim=-1, keepdim=True) - 1, 545 | num_to_mask) 546 | ) 547 | # Adds noise for randomness 548 | masking = mask_by_random_topk( 549 | num_to_mask, selected_probs, choice_temperature * (1. - ratio)) 550 | # Masks tokens with lower confidence. 551 | sampled_ids = torch.where(masking, mask_token_id, sampled_ids) 552 | 553 | state = State( 554 | cur_index=state.cur_index + 1, 555 | cur_seqs=sampled_ids, 556 | final_seqs=final_seqs 557 | ) 558 | 559 | codes = torch.clone(original_codes) 560 | codes = rearrange(codes, '1 T C -> T C') 561 | codes[:, 0] = state.final_seqs[0][-1] 562 | 563 | return codes 564 | -------------------------------------------------------------------------------- /modules/speech_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Speech tokenizer class. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import torchaudio 11 | from speechtokenizer import SpeechTokenizer as ST 12 | 13 | from modules.tokenizer import BaseTokenizer 14 | 15 | 16 | class SpeechTokenizer(BaseTokenizer): 17 | def __init__(self, config_path: str, ckpt_path: str, device: torch.device): 18 | self.device = device 19 | self.model = ST.load_from_checkpoint( 20 | config_path, ckpt_path).to(self.device) 21 | self.model.eval() 22 | 23 | def encode_file( 24 | self, folder_path: str, destination_folder: str, filename: str): 25 | dest_path = os.path.join( 26 | destination_folder, "semantic", 27 | os.path.splitext(filename)[0] + ".npy" 28 | ) 29 | dest_path2 = os.path.join( 30 | destination_folder, "acoustic", 31 | os.path.splitext(filename)[0] + ".npy" 32 | ) 33 | if os.path.exists(dest_path) and os.path.exists(dest_path2): 34 | pass 35 | else: 36 | self._create_subfolders(destination_folder=destination_folder) 37 | 38 | file_path = os.path.join(folder_path, filename) 39 | wav_info = torchaudio.info(file_path) 40 | wav_dur_sec = wav_info.num_frames / wav_info.sample_rate 41 | if wav_dur_sec > 60: 42 | logging.info( 43 | f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec," 44 | "can cause CUDA OOM" 45 | ) 46 | return 47 | wav, sr = torchaudio.load(file_path) 48 | if sr != self.model.sample_rate: 49 | logging.warning( 50 | "Wav sample rate %(wav_sr)s does not match the model" 51 | "sampling rate %(model_sr)s. Resampling audio", 52 | {"wav_sr": sr, "model_sr": self.model.sample_rate}, 53 | ) 54 | wav = torchaudio.functional.resample( 55 | wav, sr, self.model.sample_rate) 56 | wav = wav.unsqueeze(0) 57 | wav = wav.to(self.device) 58 | 59 | # Extract discrete codes from SpeechTokenizer 60 | with torch.no_grad(): 61 | codes = self.model.encode(wav) # codes: (n_q, B, T) 62 | 63 | semantic_tokens = codes[0, 0, :] 64 | acoustic_tokens = codes[1:, 0, :] 65 | 66 | # Save the encoding as .npy 67 | dest_path = os.path.join( 68 | destination_folder, "acoustic", 69 | os.path.splitext(filename)[0] + ".npy" 70 | ) 71 | np.save(dest_path, acoustic_tokens.cpu().numpy()) 72 | 73 | dest_path = os.path.join( 74 | destination_folder, "semantic", 75 | os.path.splitext(filename)[0] + ".npy" 76 | ) 77 | np.save(dest_path, semantic_tokens.cpu().numpy()) 78 | 79 | @staticmethod 80 | def _create_subfolders(destination_folder: str): 81 | if not os.path.exists(destination_folder + "/acoustic"): 82 | os.makedirs(destination_folder + "/acoustic") 83 | 84 | if not os.path.exists(destination_folder + "/semantic"): 85 | os.makedirs(destination_folder + "/semantic") 86 | -------------------------------------------------------------------------------- /modules/t2s_model.py: -------------------------------------------------------------------------------- 1 | """T2S model definition. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import os 6 | 7 | import numpy as np 8 | from torch import nn 9 | from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration 10 | 11 | from data.collation import get_text_semantic_token_collater 12 | 13 | 14 | def compute_custom_metrics(eval_prediction: EvalPrediction): 15 | # eval_prediction: tuple 16 | # eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa 17 | # eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa 18 | logits = eval_prediction.predictions[0] 19 | labels = eval_prediction.label_ids 20 | n_vocab = logits.shape[-1] 21 | mask = labels == -100 22 | top_1 = np.argmax(logits, axis=-1) == labels 23 | top_1[mask] = False 24 | top_5 = np.argsort(logits, axis=-1)[:, :, -5:] 25 | top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1) 26 | top_5[mask] = False 27 | 28 | top_10 = np.argsort(logits, axis=-1)[:, :, -10:] 29 | top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1) 30 | top_10[mask] = False 31 | 32 | top_1_accuracy = np.sum(top_1) / np.sum(~mask) 33 | top_5_accuracy = np.sum(top_5) / np.sum(~mask) 34 | top_10_accuracy = np.sum(top_10) / np.sum(~mask) 35 | 36 | return { 37 | "top_1_accuracy": top_1_accuracy, 38 | "top_5_accuracy": top_5_accuracy, 39 | "top_10_accuracy": top_10_accuracy, 40 | } 41 | 42 | 43 | class T2S(nn.Module): 44 | def __init__(self, hp): 45 | super().__init__() 46 | self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols" 47 | self.collater = get_text_semantic_token_collater(self.text_tokens_file) 48 | self.model_size = hp.model_size 49 | self.vocab_size = len(self.collater.idx2token) 50 | self.config = self._define_model_config(self.model_size) 51 | 52 | print(f"{self.config = }") 53 | self.t2s = T5ForConditionalGeneration(self.config) 54 | 55 | def _define_model_config(self, model_size): 56 | if model_size == "test": 57 | # n_params = 16M 58 | d_ff = 16 59 | d_model = 8 60 | d_kv = 32 61 | num_heads = 1 62 | num_decoder_layers = 1 63 | num_layers = 1 64 | elif model_size == "tiny": 65 | # n_params = 16M 66 | d_ff = 1024 67 | d_model = 256 68 | d_kv = 32 69 | num_heads = 4 70 | num_decoder_layers = 4 71 | num_layers = 4 72 | elif model_size == "t5small": 73 | # n_params = 60M 74 | d_ff = 2048 75 | d_model = 512 76 | d_kv = 64 77 | num_heads = 8 78 | num_decoder_layers = 6 79 | num_layers = 6 80 | elif model_size == "large": 81 | # n_params = 100M 82 | d_ff = 2048 83 | d_model = 512 84 | d_kv = 64 85 | num_heads = 8 86 | num_decoder_layers = 14 87 | num_layers = 14 88 | elif model_size == "Large": 89 | # n_params = 114M 90 | d_ff = 4096 91 | d_model = 512 92 | d_kv = 64 93 | num_heads = 8 94 | num_decoder_layers = 6 95 | num_layers = 10 96 | else: 97 | raise ValueError(f"unknown {model_size}") 98 | 99 | config = T5Config( 100 | d_ff=d_ff, 101 | d_model=d_model, 102 | d_kv=d_kv, 103 | num_heads=num_heads, 104 | num_decoder_layers=num_decoder_layers, 105 | num_layers=num_layers, 106 | decoder_start_token_id=0, 107 | eos_token_id=2, 108 | vocab_size=self.vocab_size, 109 | ) 110 | 111 | return config 112 | -------------------------------------------------------------------------------- /modules/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Base tokenizer class. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import os 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | 8 | from tqdm import tqdm 9 | 10 | from utils import measure_duration 11 | 12 | 13 | class BaseTokenizer: 14 | @measure_duration 15 | def encode_files_with_model_seq( 16 | self, folder_path: str, destination_folder: str): 17 | # Ensure destination folder exists 18 | if not os.path.exists(destination_folder): 19 | os.makedirs(destination_folder) 20 | 21 | # Go through each file in the folder 22 | filenames = os.listdir(folder_path) 23 | # encoding files has no side effects 24 | for filename in tqdm(filenames): 25 | self.encode_file( 26 | folder_path=folder_path, 27 | destination_folder=destination_folder, 28 | filename=filename, 29 | ) 30 | 31 | def get_chunk(self, folder_path, start_percent=0, end_percent=100): 32 | filenames = os.listdir(folder_path) 33 | total_files = len(filenames) 34 | 35 | start_idx = int(total_files * (start_percent / 100)) 36 | end_idx = int(total_files * (end_percent / 100)) 37 | 38 | return filenames[start_idx:end_idx] 39 | 40 | @measure_duration 41 | def encode_files_with_model_concurrent( 42 | self, filenames: list, folder_path: str, destination_folder: str, 43 | n_threads: int = os.cpu_count() 44 | ): 45 | # Ensure destination folder exists 46 | if not os.path.exists(destination_folder): 47 | os.makedirs(destination_folder) 48 | 49 | # encoding files has no side effects 50 | with ThreadPoolExecutor(max_workers=n_threads) as executor: 51 | futures = [ 52 | executor.submit( 53 | self.encode_file, 54 | folder_path=folder_path, 55 | destination_folder=destination_folder, 56 | filename=filename, 57 | ) 58 | for filename in filenames 59 | ] 60 | # Wait for all tasks to complete 61 | with tqdm(total=len(futures)) as pbar: 62 | for future in as_completed(futures): 63 | res = future.result() 64 | pbar.update(n=1) 65 | 66 | # Explicitly shut down the thread pool 67 | executor.shutdown() 68 | 69 | def encode_file( 70 | self, folder_path: str, destination_folder: str, filename: str): 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /modules/vocoder.py: -------------------------------------------------------------------------------- 1 | """Vocoder wrapper. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import enum 6 | 7 | import numpy as np 8 | import soundfile as sf 9 | import torch 10 | import torch.nn as nn 11 | from speechtokenizer import SpeechTokenizer 12 | 13 | 14 | class VocoderType(enum.Enum): 15 | SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320) 16 | 17 | def __init__(self, name, compression_ratio): 18 | self._name_ = name 19 | self.compression_ratio = compression_ratio 20 | 21 | def get_vocoder(self, ckpt_path, config_path, **kwargs): 22 | if self.name == "SPEECHTOKENIZER": 23 | if ckpt_path: 24 | vocoder = STWrapper(ckpt_path, config_path) 25 | else: 26 | vocoder = STWrapper() 27 | else: 28 | raise ValueError(f"Unknown vocoder type {self.name}") 29 | return vocoder 30 | 31 | 32 | class STWrapper(nn.Module): 33 | def __init__( 34 | self, 35 | ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt', 36 | config_path = './ckpt/speechtokenizer/config.json', 37 | ): 38 | super().__init__() 39 | self.model = SpeechTokenizer.load_from_checkpoint( 40 | config_path, ckpt_path) 41 | 42 | def eval(self): 43 | self.model.eval() 44 | 45 | @torch.no_grad() 46 | def decode(self, codes: torch.Tensor, verbose: bool = False): 47 | original_device = codes.device 48 | 49 | codes = codes.to(self.device) 50 | audio_array = self.model.decode(codes) 51 | 52 | return audio_array.to(original_device) 53 | 54 | def decode_to_file(self, codes_path, out_path) -> None: 55 | codes = np.load(codes_path) 56 | codes = torch.from_numpy(codes) 57 | wav = self.decode(codes).cpu().numpy() 58 | sf.write(out_path, wav, samplerate=self.model.sample_rate) 59 | 60 | @torch.no_grad() 61 | def encode(self, wav, verbose=False, n_quantizers: int = None): 62 | original_device = wav.device 63 | wav = wav.to(self.device) 64 | codes = self.model.encode(wav) # codes: (n_q, B, T) 65 | return codes.to(original_device) 66 | 67 | def encode_to_file(self, wav_path, out_path) -> None: 68 | wav, _ = sf.read(wav_path, dtype='float32') 69 | wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0) 70 | codes = self.encode(wav).cpu().numpy() 71 | np.save(out_path, codes) 72 | 73 | def remove_weight_norm(self): 74 | pass 75 | 76 | @property 77 | def device(self): 78 | return next(self.model.parameters()).device 79 | 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==23.2.1 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | alembic==1.13.0 5 | altair==5.2.0 6 | annotated-types==0.6.0 7 | antlr4-python3-runtime==4.9.3 8 | anyio==3.7.1 9 | asteroid-filterbanks==0.4.0 10 | async-timeout==4.0.3 11 | attrs==23.1.0 12 | audioread==3.0.1 13 | Babel==2.13.1 14 | certifi==2023.11.17 15 | cffi==1.16.0 16 | charset-normalizer==3.3.2 17 | click==8.1.7 18 | clldutils==3.20.0 19 | colorama==0.4.6 20 | colorlog==6.8.0 21 | contourpy==1.2.0 22 | csvw==3.2.1 23 | cycler==0.12.1 24 | decorator==5.1.1 25 | dlinfo==1.2.1 26 | docopt==0.6.2 27 | einops==0.7.0 28 | exceptiongroup==1.2.0 29 | fastapi==0.104.1 30 | ffmpy==0.3.1 31 | filelock==3.13.1 32 | fonttools==4.46.0 33 | frozenlist==1.4.0 34 | fsspec==2023.10.0 35 | gradio==3.48.0 36 | gradio_client==0.6.1 37 | greenlet==3.0.1 38 | h11==0.14.0 39 | httpcore==1.0.2 40 | httpx==0.25.2 41 | huggingface-hub==0.19.4 42 | HyperPyYAML==1.2.2 43 | idna==3.6 44 | importlib-resources==6.1.1 45 | isodate==0.6.1 46 | Jinja2==3.1.2 47 | joblib==1.3.2 48 | jsonschema==4.20.0 49 | jsonschema-specifications==2023.11.2 50 | julius==0.2.7 51 | kiwisolver==1.4.5 52 | language-tags==1.2.0 53 | lazy_loader==0.3 54 | librosa==0.10.1 55 | lightning==2.1.2 56 | lightning-utilities==0.10.0 57 | llvmlite==0.41.1 58 | lxml==4.9.3 59 | Mako==1.3.0 60 | Markdown==3.5.1 61 | markdown-it-py==3.0.0 62 | MarkupSafe==2.1.3 63 | matplotlib==3.8.2 64 | mdurl==0.1.2 65 | mpmath==1.3.0 66 | msgpack==1.0.7 67 | multidict==6.0.4 68 | networkx==3.2.1 69 | numba==0.58.1 70 | numpy==1.26.2 71 | nvidia-cublas-cu12==12.1.3.1 72 | nvidia-cuda-cupti-cu12==12.1.105 73 | nvidia-cuda-nvrtc-cu12==12.1.105 74 | nvidia-cuda-runtime-cu12==12.1.105 75 | nvidia-cudnn-cu12==8.9.2.26 76 | nvidia-cufft-cu12==11.0.2.54 77 | nvidia-curand-cu12==10.3.2.106 78 | nvidia-cusolver-cu12==11.4.5.107 79 | nvidia-cusparse-cu12==12.1.0.106 80 | nvidia-nccl-cu12==2.18.1 81 | nvidia-nvjitlink-cu12==12.3.101 82 | nvidia-nvtx-cu12==12.1.105 83 | omegaconf==2.3.0 84 | optuna==3.4.0 85 | orjson==3.9.10 86 | packaging==23.2 87 | pandas==2.1.3 88 | phonemizer==3.2.1 89 | Pillow==10.1.0 90 | platformdirs==4.0.0 91 | pooch==1.8.0 92 | primePy==1.3 93 | protobuf==4.25.1 94 | pyannote.audio @ https://github.com/pyannote/pyannote-audio/archive/develop.zip 95 | pyannote.core==5.0.0 96 | pyannote.database==5.0.1 97 | pyannote.metrics==3.2.1 98 | pyannote.pipeline==3.0.1 99 | pycparser==2.21 100 | pydantic==2.5.2 101 | pydantic_core==2.14.5 102 | pydub==0.25.1 103 | Pygments==2.17.2 104 | pylatexenc==2.10 105 | pyparsing==3.1.1 106 | python-dateutil==2.8.2 107 | python-multipart==0.0.6 108 | pytorch-lightning==2.1.2 109 | pytorch-metric-learning==2.3.0 110 | pytz==2023.3.post1 111 | PyYAML==6.0.1 112 | rdflib==7.0.0 113 | referencing==0.31.1 114 | regex==2023.10.3 115 | requests==2.31.0 116 | rfc3986==1.5.0 117 | rich==13.7.0 118 | rpds-py==0.13.2 119 | ruamel.yaml==0.18.5 120 | ruamel.yaml.clib==0.2.8 121 | safetensors==0.4.1 122 | scikit-learn==1.3.2 123 | scipy==1.11.4 124 | segments==2.2.1 125 | semantic-version==2.10.0 126 | semver==3.0.2 127 | sentencepiece==0.1.99 128 | shellingham==1.5.4 129 | six==1.16.0 130 | sniffio==1.3.0 131 | sortedcontainers==2.4.0 132 | soundfile==0.12.1 133 | soxr==0.3.7 134 | speechbrain==0.5.16 135 | speechtokenizer==0.1.2 136 | SQLAlchemy==2.0.23 137 | starlette==0.27.0 138 | sympy==1.12 139 | tabulate==0.9.0 140 | tensorboardX==2.6.2.2 141 | threadpoolctl==3.2.0 142 | tokenizers==0.15.0 143 | tomlkit==0.12.0 144 | toolz==0.12.0 145 | torch==2.1.1 146 | torch-audiomentations==0.11.0 147 | torch-pitch-shift==1.2.4 148 | torchaudio==2.1.1 149 | torchmetrics==1.2.1 150 | torchvision==0.16.1 151 | tqdm==4.66.1 152 | transformers==4.35.2 153 | triton==2.1.0 154 | typer==0.9.0 155 | typing_extensions==4.8.0 156 | tzdata==2023.3 157 | uritemplate==4.1.1 158 | urllib3==2.1.0 159 | uvicorn==0.24.0.post1 160 | websockets==11.0.3 161 | yarl==1.9.3 162 | accelerate==0.26.1 163 | wandb==0.16.2 164 | -------------------------------------------------------------------------------- /train_s2a.py: -------------------------------------------------------------------------------- 1 | """S2A training logic. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import argparse 6 | import json 7 | import os 8 | from pathlib import Path 9 | from typing import List 10 | 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 14 | 15 | from data.data_module import DataModule 16 | from modules.s2a_model import Pheme 17 | from modules.vocoder import VocoderType 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | # Paths 23 | parser.add_argument("--saving_path", type=str, default="./ckpt") 24 | parser.add_argument("--resume_checkpoint", type=str, default=None) 25 | parser.add_argument( 26 | "--vocoder_type", 27 | type=str, 28 | choices=[voc_type.name for voc_type in VocoderType], 29 | default=VocoderType.SPEECHTOKENIZER.name, 30 | ) 31 | parser.add_argument("--vocoder_config_path", type=str, default=None) 32 | parser.add_argument("--vocoder_ckpt_path", type=str, default=None) 33 | parser.add_argument( 34 | "--metapath", type=str, nargs="+", help="paths to train metadata", 35 | required=True 36 | ) 37 | parser.add_argument( 38 | "--val_metapath", type=str, nargs="+", default=[], 39 | help="paths to validation metadata", 40 | ) 41 | parser.add_argument("--pretrained_path", type=str, default=None) 42 | parser.add_argument("--speaker_embedding_dir", type=str, default=None) 43 | parser.add_argument("--sampledir", type=str, default="./logs") 44 | 45 | # Optimizer 46 | parser.add_argument("--lr", type=float, default=1e-4) 47 | parser.add_argument("--batch_size", type=float, default=200) 48 | parser.add_argument("--max_length", type=int, default=1600) 49 | parser.add_argument("--train_bucket_size", type=int, default=8192) 50 | parser.add_argument("--training_step", type=int, default=800000) 51 | parser.add_argument("--optim_flat_percent", type=float, default=0.0) 52 | parser.add_argument("--warmup_step", type=int, default=50) 53 | parser.add_argument("--adam_beta1", type=float, default=0.9) 54 | parser.add_argument("--adam_beta2", type=float, default=0.98) 55 | 56 | # Architecture 57 | parser.add_argument("--ffd_size", type=int, default=3072) 58 | parser.add_argument("--hidden_size", type=int, default=768) 59 | parser.add_argument("--enc_nlayers", type=int, default=6) 60 | parser.add_argument("--dec_nlayers", type=int, default=6) 61 | parser.add_argument("--nheads", type=int, default=12) 62 | parser.add_argument("--dropout", type=float, default=0.1) 63 | parser.add_argument("--depthwise_conv_kernel_size", type=int, default=5) 64 | parser.add_argument("--aligner_softmax_temp", type=float, default=1.0) 65 | parser.add_argument("--layer_norm_eps", type=float, default=1e-5) 66 | parser.add_argument("--use_sem_tokens", type=bool, default=True) 67 | parser.add_argument("--use_spkr_emb", action="store_true") 68 | parser.add_argument("--use_text_emb", action="store_true") 69 | parser.add_argument("--only_inference", action="store_true") 70 | 71 | # Dropout 72 | parser.add_argument("--speaker_embed_dropout", type=float, default=0.05) 73 | parser.add_argument("--label_smoothing", type=float, default=0.0) 74 | 75 | # Trainer 76 | parser.add_argument("--val_check_interval", type=int, default=1) 77 | parser.add_argument("--max_dataset_samples", type=int, default=-1) 78 | parser.add_argument("--check_val_every_n_epoch", type=int, default=1) 79 | parser.add_argument( 80 | "--precision", type=str, choices=["16", "32", "bf16"], default="bf16" 81 | ) 82 | parser.add_argument("--nworkers", type=int, default=16) 83 | parser.add_argument("--distributed", action="store_true") 84 | parser.add_argument( 85 | "--accelerator", 86 | type=str, 87 | choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"], 88 | default="gpu", 89 | ) 90 | parser.add_argument("--version", type=int, default=None) 91 | parser.add_argument("--accumulate_grad_batches", type=int, default=1) 92 | 93 | # Data 94 | parser.add_argument("--sample_rate", type=int, default=16000) 95 | parser.add_argument("--n_codes", type=int, default=1024) 96 | parser.add_argument("--n_cluster_groups", type=int, default=7) 97 | parser.add_argument("--first_n_lvls", type=int, default=7) 98 | parser.add_argument("--use_pretrained_ckpt_cfg", action="store_true") 99 | parser.add_argument("--n_semantic_codes", type=int, default=1024) 100 | 101 | # Distribution 102 | parser.add_argument("--sagemaker", action="store_true") 103 | 104 | args = parser.parse_args() 105 | 106 | return args 107 | 108 | 109 | def split_metapath(in_paths: List[str]): 110 | podidx_paths, other_paths = [], [] 111 | 112 | for itm_path in in_paths: 113 | if itm_path.endswith("jsonl"): 114 | podidx_paths.append(itm_path) 115 | else: 116 | other_paths.append(itm_path) 117 | 118 | return podidx_paths, other_paths 119 | 120 | 121 | if __name__ == "__main__": 122 | args = parse_args() 123 | os.makedirs(args.saving_path, exist_ok=True) 124 | 125 | with open(os.path.join(args.saving_path, "config.json"), "w") as f: 126 | json.dump(args.__dict__, f, indent=2) 127 | 128 | if args.pretrained_path: 129 | if ( 130 | Path(args.pretrained_path).with_name("config.json").exists() 131 | and args.use_pretrained_ckpt_cfg 132 | ): 133 | with open( 134 | Path(args.pretrained_path).with_name("config.json"), "r") as f: 135 | prev_cfg = json.load(f) 136 | for k, v in prev_cfg.items(): 137 | if isinstance(v, (int,)): 138 | if args.__dict__[k] != v: 139 | print(f"updating {k}!", args.__dict__[k], v) 140 | args.__dict__[k] = v 141 | 142 | fname_prefix = f"" 143 | checkpoint_callback = ModelCheckpoint( 144 | dirpath=args.saving_path, 145 | filename=(fname_prefix + "{epoch}-{step}"), 146 | every_n_train_steps=( 147 | None if args.val_check_interval == 1.0 else args.val_check_interval # noqa 148 | ), 149 | every_n_epochs=( 150 | None if args.check_val_every_n_epoch == 1 else args.check_val_every_n_epoch # noqa 151 | ), 152 | verbose=True, 153 | save_last=True, 154 | save_top_k=3, 155 | monitor="val/dataset_0/acc_top_5", 156 | mode='max' 157 | ) 158 | lr_monitor = LearningRateMonitor(logging_interval="step") 159 | 160 | logger_tb = TensorBoardLogger( 161 | args.saving_path, name="VQ-TTS", version=args.version) 162 | logger_wandb = WandbLogger(project="mqtts", log_model=True, config=args) 163 | 164 | distribution_kwargs = { 165 | "accelerator": "gpu", 166 | "strategy": "ddp_find_unused_parameters_true" if args.distributed else "auto", # noqa 167 | } 168 | 169 | wrapper = Trainer( 170 | precision=args.precision, 171 | callbacks=[checkpoint_callback, lr_monitor], 172 | num_sanity_val_steps=20, 173 | max_steps=args.training_step, 174 | accumulate_grad_batches=args.accumulate_grad_batches, 175 | logger=[logger_tb, logger_wandb], 176 | check_val_every_n_epoch=args.check_val_every_n_epoch, 177 | profiler="simple", 178 | use_distributed_sampler=False, 179 | # distribution 180 | **distribution_kwargs, 181 | ) 182 | model = Pheme(args) 183 | logger_wandb.watch(model=model) 184 | _, other_metapath = split_metapath(args.metapath) 185 | _, other_val_metapath = split_metapath(args.val_metapath) 186 | 187 | print( 188 | f"Received datasets: \n{other_metapath = } " 189 | f"\n \n{other_val_metapath = }" 190 | ) 191 | 192 | other_meta = {} 193 | if len(other_metapath) > 0: 194 | other_meta["fit"] = other_metapath 195 | if len(other_val_metapath) > 0: 196 | other_meta["valid"] = other_val_metapath 197 | 198 | data_module = DataModule( 199 | args, other_metapath, other_val_metapath, 200 | wrapper.world_size, wrapper.local_rank 201 | ) 202 | data_module.setup(stage="fit") 203 | train_data_module = data_module 204 | 205 | valid_dataloaders = [] 206 | data_module.setup(stage="valid") 207 | valid_dataloaders.extend(data_module.val_dataloader()) 208 | 209 | wrapper.fit( 210 | model, 211 | train_dataloaders=train_data_module.train_dataloader(), 212 | val_dataloaders=valid_dataloaders, 213 | ckpt_path=args.resume_checkpoint, 214 | ) 215 | -------------------------------------------------------------------------------- /train_t2s.py: -------------------------------------------------------------------------------- 1 | """Train T2S to generate semantic tokens. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import argparse 6 | import logging 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import torch 11 | from transformers import Trainer, TrainingArguments 12 | 13 | from data.semantic_dataset import Collator, ConcatenateSemanticDataset 14 | from modules.t2s_model import T2S, compute_custom_metrics 15 | from utils import split_metapath 16 | 17 | 18 | # Synchronize the GPU 19 | torch.cuda.synchronize() 20 | 21 | # Check for CUDA errors 22 | if torch.cuda.is_available(): 23 | device = torch.cuda.current_device() 24 | print(torch.cuda.get_device_properties(device)) 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--metapath", type=str, nargs="+", help="paths to train metadata", 31 | required=True 32 | ) 33 | parser.add_argument( 34 | "--val_metapath", type=str, nargs="+", default=[], 35 | help="paths to validation metadata", 36 | ) 37 | parser.add_argument( 38 | "--train_path", type=str, 39 | default="datasets/giga-training-data/train.json" 40 | ) 41 | parser.add_argument( 42 | "--eval_path", type=str, 43 | default="datasets/giga-training-data/dev.json" 44 | ) 45 | parser.add_argument("--output_dir", type=str, required=True) 46 | parser.add_argument( 47 | "--model_size", choices=["test", "tiny", "t5small", "large", "Large"], 48 | default="tiny" 49 | ) 50 | parser.add_argument("--eval_accumulation_steps", type=int, default=10) 51 | parser.add_argument("--warmup_steps", type=int, default=5000) 52 | parser.add_argument("--save_steps", type=int, default=500) 53 | parser.add_argument("--batch_size", type=int, default=16) 54 | parser.add_argument("--n_epochs", type=int, default=20) 55 | parser.add_argument("--nworkers", type=int, default=8) 56 | parser.add_argument("--max_duration", type=int, default=15) 57 | parser.add_argument("--eval_n_samples", type=int, default=400) 58 | parser.add_argument("--learning_rate", type=float, default=5E-4) 59 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 60 | parser.add_argument("--resume_from_checkpoint", type=str, default=None) 61 | args = parser.parse_args() 62 | 63 | return args 64 | 65 | 66 | if __name__ == "__main__": 67 | logging.basicConfig(level=logging.INFO) 68 | args = parse_args() 69 | 70 | model = T2S(args) 71 | n_params = sum([param.numel() for param in model.parameters()]) 72 | print(f"Model has {n_params = }") 73 | 74 | train_path = split_metapath(args.metapath) 75 | eval_paths = split_metapath(args.val_metapath) 76 | 77 | dataset_train = ConcatenateSemanticDataset( 78 | manifest_path=train_path, 79 | symbol_table_path=model.text_tokens_file, 80 | max_duration=args.max_duration 81 | ) 82 | 83 | dataset_eval = ConcatenateSemanticDataset( 84 | manifest_path=eval_paths, 85 | symbol_table_path=model.text_tokens_file, 86 | n_samples=args.eval_n_samples, 87 | max_duration=args.max_duration 88 | ) 89 | 90 | current_timestamp = datetime.now() 91 | current_timestamp = current_timestamp.strftime("%Y-%m-%d-%H:%M:%S") 92 | if args.resume_from_checkpoint is not None: 93 | output_dir = Path(args.resume_from_checkpoint).parent 94 | else: 95 | output_dir = Path(args.output_dir) 96 | 97 | training_args = TrainingArguments( 98 | output_dir=output_dir, 99 | learning_rate=args.learning_rate, 100 | per_device_train_batch_size=args.batch_size, 101 | per_device_eval_batch_size=args.batch_size, 102 | gradient_accumulation_steps=args.gradient_accumulation_steps, 103 | num_train_epochs=args.n_epochs, 104 | save_steps=args.save_steps, 105 | eval_steps=args.save_steps, 106 | save_total_limit=3, 107 | dataloader_num_workers=args.nworkers, 108 | evaluation_strategy="steps", 109 | save_strategy="steps", 110 | load_best_model_at_end=True, 111 | report_to=["all"], 112 | bf16=False, 113 | warmup_steps=args.warmup_steps, 114 | ddp_find_unused_parameters=False, 115 | eval_accumulation_steps=args.eval_accumulation_steps 116 | ) 117 | 118 | trainer = Trainer( 119 | model=model.t2s, 120 | args=training_args, 121 | data_collator=Collator().collate, 122 | train_dataset=dataset_train, 123 | eval_dataset=dataset_eval, 124 | compute_metrics=compute_custom_metrics, 125 | ) 126 | 127 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 128 | -------------------------------------------------------------------------------- /transformer_infer.py: -------------------------------------------------------------------------------- 1 | """Inference logic. 2 | 3 | Copyright PolyAI Limited. 4 | """ 5 | import argparse 6 | import json 7 | import logging 8 | import os 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import soundfile as sf 14 | import torch 15 | from einops import rearrange 16 | from librosa.util import normalize 17 | from pyannote.audio import Inference 18 | from transformers import GenerationConfig, T5ForConditionalGeneration 19 | 20 | import constants as c 21 | from data.collation import get_text_semantic_token_collater 22 | from data.semantic_dataset import TextTokenizer 23 | from modules.s2a_model import Pheme 24 | from modules.vocoder import VocoderType 25 | 26 | # How many times one token can be generated 27 | MAX_TOKEN_COUNT = 100 28 | 29 | logging.basicConfig(level=logging.DEBUG) 30 | device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 31 | 32 | 33 | def parse_arguments(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--text", type=str, 37 | default="I gotta say, I would never expect that to happen!" 38 | ) 39 | parser.add_argument( 40 | "--manifest_path", type=str, default="demo/manifest.json") 41 | parser.add_argument("--outputdir", type=str, default="demo/") 42 | parser.add_argument("--featuredir", type=str, default="demo/") 43 | parser.add_argument( 44 | "--text_tokens_file", type=str, 45 | default="ckpt/unique_text_tokens.k2symbols" 46 | ) 47 | parser.add_argument("--t2s_path", type=str, default="ckpt/t2s/") 48 | parser.add_argument( 49 | "--s2a_path", type=str, default="ckpt/s2a/s2a.ckpt") 50 | 51 | parser.add_argument("--target_sample_rate", type=int, default=16_000) 52 | 53 | parser.add_argument("--temperature", type=float, default=0.7) 54 | parser.add_argument("--top_k", type=int, default=210) 55 | parser.add_argument("--voice", type=str, default="male_voice") 56 | 57 | return parser.parse_args() 58 | 59 | 60 | class PhemeClient(): 61 | def __init__(self, args): 62 | self.args = args 63 | self.outputdir = args.outputdir 64 | self.target_sample_rate = args.target_sample_rate 65 | self.featuredir = Path(args.featuredir).expanduser() 66 | self.collater = get_text_semantic_token_collater(args.text_tokens_file) 67 | self.phonemizer = TextTokenizer() 68 | 69 | self.load_manifest(args.manifest_path) 70 | 71 | # T2S model 72 | self.t2s = T5ForConditionalGeneration.from_pretrained(args.t2s_path) 73 | self.t2s.to(device) 74 | self.t2s.eval() 75 | 76 | # S2A model 77 | self.s2a = Pheme.load_from_checkpoint(args.s2a_path) 78 | self.s2a.to(device=device) 79 | self.s2a.eval() 80 | 81 | # Vocoder 82 | vocoder = VocoderType["SPEECHTOKENIZER"].get_vocoder(None, None) 83 | self.vocoder = vocoder.to(device) 84 | self.vocoder.eval() 85 | 86 | self.spkr_embedding = Inference( 87 | "pyannote/embedding", 88 | window="whole", 89 | use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"], 90 | ) 91 | 92 | def load_manifest(self, input_path): 93 | input_file = {} 94 | with open(input_path, "rb") as f: 95 | for line in f: 96 | temp = json.loads(line) 97 | input_file[temp["audio_filepath"].split(".wav")[0]] = temp 98 | self.input_file = input_file 99 | 100 | def lazy_decode(self, decoder_output, symbol_table): 101 | semantic_tokens = map(lambda x: symbol_table[x], decoder_output) 102 | semantic_tokens = [int(x) for x in semantic_tokens if x.isdigit()] 103 | 104 | return np.array(semantic_tokens) 105 | 106 | def infer_text(self, text, voice, sampling_config): 107 | semantic_prompt = np.load(self.args.featuredir + "/audios-speech-tokenizer/semantic/" + f"{voice}.npy") # noqa 108 | phones_seq = self.phonemizer(text)[0] 109 | input_ids = self.collater([phones_seq]) 110 | input_ids = input_ids.type(torch.IntTensor).to(device) 111 | 112 | labels = [str(lbl) for lbl in semantic_prompt] 113 | labels = self.collater([labels])[:, :-1] 114 | decoder_input_ids = labels.to(device).long() 115 | logging.debug(f"decoder_input_ids: {decoder_input_ids}") 116 | 117 | counts = 1E10 118 | while (counts > MAX_TOKEN_COUNT): 119 | output_ids = self.t2s.generate( 120 | input_ids, decoder_input_ids=decoder_input_ids, 121 | generation_config=sampling_config).sequences 122 | 123 | # check repetitiveness 124 | _, counts = torch.unique_consecutive(output_ids, return_counts=True) 125 | counts = max(counts).item() 126 | 127 | output_semantic = self.lazy_decode( 128 | output_ids[0], self.collater.idx2token) 129 | 130 | # remove the prompt 131 | return output_semantic[len(semantic_prompt):].reshape(1, -1) 132 | 133 | def _load_speaker_emb(self, element_id_prompt): 134 | wav, _ = sf.read(self.featuredir / element_id_prompt) 135 | audio = normalize(wav) * 0.95 136 | speaker_emb = self.spkr_embedding( 137 | { 138 | "waveform": torch.FloatTensor(audio).unsqueeze(0), 139 | "sample_rate": self.target_sample_rate 140 | } 141 | ).reshape(1, -1) 142 | 143 | return speaker_emb 144 | 145 | def _load_prompt(self, prompt_file_path): 146 | element_id_prompt = Path(prompt_file_path).stem 147 | acoustic_path_prompt = self.featuredir / "audios-speech-tokenizer/acoustic" / f"{element_id_prompt}.npy" # noqa 148 | semantic_path_prompt = self.featuredir / "audios-speech-tokenizer/semantic" / f"{element_id_prompt}.npy" # noqa 149 | 150 | acoustic_prompt = np.load(acoustic_path_prompt).squeeze().T 151 | semantic_prompt = np.load(semantic_path_prompt)[None] 152 | 153 | return acoustic_prompt, semantic_prompt 154 | 155 | def infer_acoustic(self, output_semantic, prompt_file_path): 156 | semantic_tokens = output_semantic.reshape(1, -1) 157 | acoustic_tokens = np.full( 158 | [semantic_tokens.shape[1], 7], fill_value=c.PAD) 159 | 160 | acoustic_prompt, semantic_prompt = self._load_prompt(prompt_file_path) # noqa 161 | 162 | # Prepend prompt 163 | acoustic_tokens = np.concatenate( 164 | [acoustic_prompt, acoustic_tokens], axis=0) 165 | semantic_tokens = np.concatenate([ 166 | semantic_prompt, semantic_tokens], axis=1) 167 | 168 | # Add speaker 169 | acoustic_tokens = np.pad( 170 | acoustic_tokens, [[1, 0], [0, 0]], constant_values=c.SPKR_1) 171 | semantic_tokens = np.pad( 172 | semantic_tokens, [[0,0], [1, 0]], constant_values=c.SPKR_1) 173 | 174 | speaker_emb = None 175 | if self.s2a.hp.use_spkr_emb: 176 | speaker_emb = self._load_speaker_emb(prompt_file_path) 177 | speaker_emb = np.repeat( 178 | speaker_emb, semantic_tokens.shape[1], axis=0) 179 | speaker_emb = torch.from_numpy(speaker_emb).to(device) 180 | else: 181 | speaker_emb = None 182 | 183 | acoustic_tokens = torch.from_numpy( 184 | acoustic_tokens).unsqueeze(0).to(device).long() 185 | semantic_tokens = torch.from_numpy(semantic_tokens).to(device).long() 186 | start_t = torch.tensor( 187 | [acoustic_prompt.shape[0]], dtype=torch.long, device=device) 188 | length = torch.tensor([ 189 | semantic_tokens.shape[1]], dtype=torch.long, device=device) 190 | 191 | codes = self.s2a.model.inference( 192 | acoustic_tokens, 193 | semantic_tokens, 194 | start_t=start_t, 195 | length=length, 196 | maskgit_inference=True, 197 | speaker_emb=speaker_emb 198 | ) 199 | 200 | # Remove the prompt 201 | synth_codes = codes[:, :, start_t:] 202 | synth_codes = rearrange(synth_codes, "b c t -> c b t") 203 | 204 | return synth_codes 205 | 206 | def generate_audio(self, text, voice, sampling_config, prompt_file_path): 207 | start_time = time.time() 208 | output_semantic = self.infer_text( 209 | text, voice, sampling_config 210 | ) 211 | logging.debug(f"semantic_tokens: {time.time() - start_time}") 212 | 213 | start_time = time.time() 214 | codes = self.infer_acoustic(output_semantic, prompt_file_path) 215 | logging.debug(f"acoustic_tokens: {time.time() - start_time}") 216 | 217 | start_time = time.time() 218 | audio_array = self.vocoder.decode(codes) 219 | audio_array = rearrange(audio_array, "1 1 T -> T").cpu().numpy() 220 | logging.debug(f"vocoder time: {time.time() - start_time}") 221 | 222 | return audio_array 223 | 224 | @torch.no_grad() 225 | def infer( 226 | self, text, voice="male_voice", temperature=0.7, 227 | top_k=210, max_new_tokens=750, 228 | ): 229 | sampling_config = GenerationConfig.from_pretrained( 230 | self.args.t2s_path, 231 | top_k=top_k, 232 | num_beams=1, 233 | do_sample=True, 234 | temperature=temperature, 235 | num_return_sequences=1, 236 | max_new_tokens=max_new_tokens, 237 | return_dict_in_generate=True, 238 | output_scores=True 239 | ) 240 | 241 | voice_data = self.input_file[voice] 242 | prompt_file_path = voice_data["audio_prompt_filepath"] 243 | text = voice_data["text"] + " " + text 244 | 245 | audio_array = self.generate_audio( 246 | text, voice, sampling_config, prompt_file_path) 247 | 248 | return audio_array 249 | 250 | 251 | if __name__ == "__main__": 252 | args = parse_arguments() 253 | args.outputdir = Path(args.outputdir).expanduser() 254 | args.outputdir.mkdir(parents=True, exist_ok=True) 255 | args.manifest_path = Path(args.manifest_path).expanduser() 256 | 257 | client = PhemeClient(args) 258 | audio_array = client.infer(args.text, voice=args.voice) 259 | sf.write(os.path.join( 260 | args.outputdir, f"{args.voice}.wav"), audio_array, 261 | args.target_sample_rate 262 | ) 263 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Copyright PolyAI Limited.""" 2 | import logging 3 | import pdb 4 | import sys 5 | import traceback 6 | from functools import wraps 7 | from time import time 8 | from typing import List 9 | 10 | import torch 11 | 12 | from .symbol_table import SymbolTable 13 | 14 | 15 | def load_checkpoint(ckpt_path: str) -> dict: 16 | """ 17 | Loads checkpoint, while matching phone embedding size. 18 | """ 19 | state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 20 | new_state_dict = dict() 21 | for p_name in state_dict.keys(): 22 | if p_name.startswith("vocoder"): 23 | continue 24 | 25 | new_state_dict[p_name] = state_dict[p_name] 26 | 27 | return new_state_dict 28 | 29 | 30 | def breakpoint_on_error(fn): 31 | """Creates a breakpoint on error 32 | 33 | Use as a wrapper 34 | 35 | Args: 36 | fn: the function 37 | 38 | Returns: 39 | inner function 40 | """ 41 | 42 | def inner(*args, **kwargs): 43 | try: 44 | return fn(*args, **kwargs) 45 | except Exception: 46 | """Standard python way of creating a breakpoint on error""" 47 | extype, value, tb = sys.exc_info() 48 | print(f"extype={extype},\nvalue={value}") 49 | traceback.print_exc() 50 | pdb.post_mortem(tb) 51 | 52 | return inner 53 | 54 | 55 | def measure_duration(f): 56 | @wraps(f) 57 | def wrap(*args, **kw): 58 | ts = time() 59 | result = f(*args, **kw) 60 | te = time() 61 | logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts)) 62 | return result 63 | 64 | return wrap 65 | 66 | 67 | def split_metapath(in_paths: List[str]): 68 | other_paths = [] 69 | 70 | for itm_path in in_paths: 71 | other_paths.append(itm_path) 72 | 73 | return other_paths 74 | -------------------------------------------------------------------------------- /utils/data_prep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import orjson 6 | import soundfile as sf 7 | from torchaudio.datasets import LJSPEECH 8 | from tqdm import tqdm 9 | 10 | from data.semantic_dataset import TextTokenizer 11 | 12 | def split_and_write_manifests(args): 13 | data_root = args.data_root 14 | dataset = LJSPEECH(data_root, download=True) 15 | np.random.seed(42) 16 | dataset_idxs = np.arange(start=0, stop=len(dataset)) 17 | np.random.shuffle(dataset_idxs) 18 | test_idxs, val_idxs, train_idxs = ( 19 | dataset_idxs[:300], 20 | dataset_idxs[300:600], 21 | dataset_idxs[600:], 22 | ) 23 | 24 | print(f"{len(test_idxs)=}") 25 | print(f"{len(val_idxs)=}") 26 | print(f"{len(train_idxs)=}") 27 | dataset_items = dataset._flist 28 | test_data, val_data, train_data = dict(), dict(), dict() 29 | phonemizer = TextTokenizer() 30 | for idx, itm in tqdm(enumerate(dataset_items)): 31 | file_id, raw_text, text = itm 32 | file_id = file_id + ".wav" 33 | wav_path = dataset._path / file_id 34 | wav_obj = sf.SoundFile(wav_path) 35 | duration = wav_obj.frames / wav_obj.samplerate 36 | 37 | phones = phonemizer(text)[0] 38 | phones = "|".join(phones) 39 | 40 | datapoint = { 41 | file_id: { 42 | "text": text, 43 | "raw-text": raw_text, 44 | "duration": duration, 45 | "phoneme": phones, 46 | } 47 | } 48 | if idx in test_idxs: 49 | test_data.update(datapoint) 50 | elif idx in val_idxs: 51 | val_data.update(datapoint) 52 | elif idx in train_idxs: 53 | train_data.update(datapoint) 54 | 55 | test_manifest_path = data_root / "test.json" 56 | val_manifest_path = data_root / "dev.json" 57 | train_manifest_path = data_root / "train.json" 58 | 59 | with open(test_manifest_path, "wb") as f: 60 | f.write(orjson.dumps(test_data, option=orjson.OPT_INDENT_2)) 61 | 62 | with open(val_manifest_path, "wb") as f: 63 | f.write(orjson.dumps(val_data, option=orjson.OPT_INDENT_2)) 64 | 65 | with open(train_manifest_path, "wb") as f: 66 | f.write(orjson.dumps(train_data, option=orjson.OPT_INDENT_2)) 67 | 68 | 69 | def main(): 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--data_root", type=Path, default="./datasets/ljspeech-training-data") 72 | args = parser.parse_args() 73 | args.data_root.mkdir(exist_ok=True) 74 | 75 | split_and_write_manifests(args) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /utils/get_tokens_speech_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Get tokens using the SpeechTokenizer. 2 | 3 | Apply SpeechTokenizer to extract acoustic and semantic tokens. 4 | The tokens will be extracted to 5 | encoding_output/acoustic and encoding_output/semantic. 6 | 7 | python utils/get_tokens_speech_tokenizer.py \ 8 | --config_path ckpt/speechtokenizer/config.json \ 9 | --ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \ 10 | --encoding_input datasets/example/audios \ 11 | --encoding_output datasets/example/audios-speech-tokenizer 12 | 13 | Copyright PolyAI Limited. 14 | """ 15 | 16 | import argparse 17 | import logging 18 | import multiprocessing 19 | import os 20 | import pathlib 21 | from concurrent.futures import ProcessPoolExecutor, as_completed 22 | 23 | import torch 24 | from modules.speech_tokenizer import SpeechTokenizer 25 | from tqdm import tqdm 26 | 27 | from utils import measure_duration 28 | 29 | PROJECT_ROOT = str(pathlib.Path(__file__).parent.parent.resolve()) 30 | logging.basicConfig(level=logging.DEBUG) 31 | 32 | 33 | @measure_duration 34 | def main(args): 35 | n_gpus = torch.cuda.device_count() 36 | n_workers = n_gpus * 4 37 | filenames = os.listdir(args.encoding_input) 38 | chunk_size = (len(filenames) + n_workers - 1) // n_workers 39 | futures = [] 40 | with ProcessPoolExecutor() as executor: 41 | for idx in range(n_workers): 42 | device = torch.device(f"cuda:{idx%n_gpus}") 43 | _filenames = filenames[idx * chunk_size : (idx + 1) * chunk_size] 44 | futures.append(executor.submit(tokenize, _filenames, device, args)) 45 | 46 | for f in as_completed(futures): 47 | f.result() 48 | 49 | 50 | def tokenize(filenames, device, args): 51 | 52 | tokenizer = SpeechTokenizer( 53 | config_path=args.config_path, ckpt_path=args.ckpt_path, device=device 54 | ) 55 | for filename in tqdm(filenames): 56 | tokenizer.encode_file( 57 | folder_path=args.encoding_input, 58 | destination_folder=args.encoding_output, 59 | filename=filename, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | multiprocessing.set_start_method("spawn") 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument( 67 | "--config_path", 68 | type=str, 69 | help="Path to the SpeechTokenizer config", 70 | default=PROJECT_ROOT + "/ckpt/speechtokenizer/config.json", 71 | ) 72 | parser.add_argument( 73 | "--ckpt_path", 74 | type=str, 75 | help="Path to the SpeechTokenizer checkpoint", 76 | default=PROJECT_ROOT + "/ckpt/speechtokenizer/SpeechTokenizer.pt", 77 | ) 78 | parser.add_argument( 79 | "--encoding_input", 80 | type=str, 81 | help="Path to the input folder for encoding", 82 | default=PROJECT_ROOT + "/datasets/example/audios", 83 | ) 84 | parser.add_argument( 85 | "--encoding_output", 86 | type=str, 87 | help="Path where to save the encoded tokens", 88 | default=PROJECT_ROOT + "/datasets/example/audios-speech-tokenizer", 89 | ) 90 | 91 | args = parser.parse_args() 92 | print("Parsed args") 93 | print(args) 94 | 95 | main(args) 96 | -------------------------------------------------------------------------------- /utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | """Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | Copyright PolyAI Limited. 16 | """ 17 | from dataclasses import dataclass, field 18 | from typing import Dict, Generic, List, Optional, TypeVar, Union 19 | 20 | Symbol = TypeVar('Symbol') 21 | 22 | 23 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 24 | @dataclass(repr=False) 25 | class SymbolTable(Generic[Symbol]): 26 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 27 | actual objects. These objects can be arbitrary Python objects 28 | that can serve as keys in a dictionary (i.e. they need to be 29 | hashable and immutable). 30 | 31 | The SymbolTable can only be read to/written from disk if the 32 | symbols are strings. 33 | ''' 34 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 35 | '''Map an integer to a symbol. 36 | ''' 37 | 38 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 39 | '''Map a symbol to an integer. 40 | ''' 41 | 42 | _next_available_id: int = 1 43 | '''A helper internal field that helps adding new symbols 44 | to the table efficiently. 45 | ''' 46 | 47 | eps: Symbol = '' 48 | '''Null symbol, always mapped to index 0. 49 | ''' 50 | 51 | def __post_init__(self): 52 | for idx, sym in self._id2sym.items(): 53 | assert self._sym2id[sym] == idx 54 | assert idx >= 0 55 | 56 | for sym, idx in self._sym2id.items(): 57 | assert idx >= 0 58 | assert self._id2sym[idx] == sym 59 | 60 | if 0 not in self._id2sym: 61 | self._id2sym[0] = self.eps 62 | self._sym2id[self.eps] = 0 63 | else: 64 | assert self._id2sym[0] == self.eps 65 | assert self._sym2id[self.eps] == 0 66 | 67 | self._next_available_id = max(self._id2sym) + 1 68 | 69 | @staticmethod 70 | def from_str(s: str) -> 'SymbolTable': 71 | '''Build a symbol table from a string. 72 | 73 | The string consists of lines. Every line has two fields separated 74 | by space(s), tab(s) or both. The first field is the symbol and the 75 | second the integer id of the symbol. 76 | 77 | Args: 78 | s: 79 | The input string with the format described above. 80 | Returns: 81 | An instance of :class:`SymbolTable`. 82 | ''' 83 | id2sym: Dict[int, str] = dict() 84 | sym2id: Dict[str, int] = dict() 85 | 86 | for line in s.split('\n'): 87 | fields = line.split() 88 | if len(fields) == 0: 89 | continue # skip empty lines 90 | assert len(fields) == 2, \ 91 | f'Expect a line with 2 fields. Given: {len(fields)}' 92 | sym, idx = fields[0], int(fields[1]) 93 | assert sym not in sym2id, f'Duplicated symbol {sym}' 94 | assert idx not in id2sym, f'Duplicated id {idx}' 95 | id2sym[idx] = sym 96 | sym2id[sym] = idx 97 | 98 | eps = id2sym.get(0, '') 99 | 100 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 101 | 102 | @staticmethod 103 | def from_file(filename: str) -> 'SymbolTable': 104 | '''Build a symbol table from file. 105 | 106 | Every line in the symbol table file has two fields separated by 107 | space(s), tab(s) or both. The following is an example file: 108 | 109 | .. code-block:: 110 | 111 | 0 112 | a 1 113 | b 2 114 | c 3 115 | 116 | Args: 117 | filename: 118 | Name of the symbol table file. Its format is documented above. 119 | 120 | Returns: 121 | An instance of :class:`SymbolTable`. 122 | 123 | ''' 124 | with open(filename, 'r', encoding='utf-8') as f: 125 | return SymbolTable.from_str(f.read().strip()) 126 | 127 | def to_str(self) -> str: 128 | ''' 129 | Returns: 130 | Return a string representation of this object. You can pass 131 | it to the method ``from_str`` to recreate an identical object. 132 | ''' 133 | s = '' 134 | for idx, symbol in sorted(self._id2sym.items()): 135 | s += f'{symbol} {idx}\n' 136 | return s 137 | 138 | def to_file(self, filename: str): 139 | '''Serialize the SymbolTable to a file. 140 | 141 | Every line in the symbol table file has two fields separated by 142 | space(s), tab(s) or both. The following is an example file: 143 | 144 | .. code-block:: 145 | 146 | 0 147 | a 1 148 | b 2 149 | c 3 150 | 151 | Args: 152 | filename: 153 | Name of the symbol table file. Its format is documented above. 154 | ''' 155 | with open(filename, 'w') as f: 156 | for idx, symbol in sorted(self._id2sym.items()): 157 | print(symbol, idx, file=f) 158 | 159 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 160 | '''Add a new symbol to the SymbolTable. 161 | 162 | Args: 163 | symbol: 164 | The symbol to be added. 165 | index: 166 | Optional int id to which the symbol should be assigned. 167 | If it is not available, a ValueError will be raised. 168 | 169 | Returns: 170 | The int id to which the symbol has been assigned. 171 | ''' 172 | # Already in the table? Return its ID. 173 | if symbol in self._sym2id: 174 | return self._sym2id[symbol] 175 | # Specific ID not provided - use next available. 176 | if index is None: 177 | index = self._next_available_id 178 | # Specific ID provided but not available. 179 | if index in self._id2sym: 180 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 181 | f"already occupied by {self._id2sym[index]}") 182 | self._sym2id[symbol] = index 183 | self._id2sym[index] = symbol 184 | 185 | # Update next available ID if needed 186 | if self._next_available_id <= index: 187 | self._next_available_id = index + 1 188 | 189 | return index 190 | 191 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 192 | '''Get a symbol for an id or get an id for a symbol 193 | 194 | Args: 195 | k: 196 | If it is an id, it tries to find the symbol corresponding 197 | to the id; if it is a symbol, it tries to find the id 198 | corresponding to the symbol. 199 | 200 | Returns: 201 | An id or a symbol depending on the given `k`. 202 | ''' 203 | if isinstance(k, int): 204 | return self._id2sym[k] 205 | else: 206 | return self._sym2id[k] 207 | 208 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 209 | '''Create a union of two SymbolTables. 210 | Raises an AssertionError if the same IDs are occupied by 211 | different symbols. 212 | 213 | Args: 214 | other: 215 | A symbol table to merge with ``self``. 216 | 217 | Returns: 218 | A new symbol table. 219 | ''' 220 | self._check_compatible(other) 221 | 222 | id2sym = {**self._id2sym, **other._id2sym} 223 | sym2id = {**self._sym2id, **other._sym2id} 224 | 225 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 226 | 227 | def _check_compatible(self, other: 'SymbolTable') -> None: 228 | # Epsilon compatibility 229 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 230 | f'{self.eps} != {other.eps}' 231 | # IDs compatibility 232 | common_ids = set(self._id2sym).intersection(other._id2sym) 233 | for idx in common_ids: 234 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 235 | f'self[idx] = "{self[idx]}", ' \ 236 | f'other[idx] = "{other[idx]}"' 237 | # Symbols compatibility 238 | common_symbols = set(self._sym2id).intersection(other._sym2id) 239 | for sym in common_symbols: 240 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 241 | f'self[sym] = "{self[sym]}", ' \ 242 | f'other[sym] = "{other[sym]}"' 243 | 244 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 245 | return self.get(item) 246 | 247 | def __contains__(self, item: Union[int, Symbol]) -> bool: 248 | if isinstance(item, int): 249 | return item in self._id2sym 250 | else: 251 | return item in self._sym2id 252 | 253 | def __len__(self) -> int: 254 | return len(self._id2sym) 255 | 256 | def __eq__(self, other: 'SymbolTable') -> bool: 257 | if len(self) != len(other): 258 | return False 259 | 260 | for s in self.symbols: 261 | if self[s] != other[s]: 262 | return False 263 | 264 | return True 265 | 266 | @property 267 | def ids(self) -> List[int]: 268 | '''Returns a list of integer IDs corresponding to the symbols. 269 | ''' 270 | ans = list(self._id2sym.keys()) 271 | ans.sort() 272 | return ans 273 | 274 | @property 275 | def symbols(self) -> List[Symbol]: 276 | '''Returns a list of symbols (e.g., strings) corresponding to 277 | the integer IDs. 278 | ''' 279 | ans = list(self._sym2id.keys()) 280 | ans.sort() 281 | return ans 282 | --------------------------------------------------------------------------------