├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE-CODE ├── LICENSE-MODEL ├── README.md ├── RealEdit.txt ├── cog.yaml ├── config.py ├── data ├── __init__.py ├── gigaspeech.py ├── phonemize_encodec_encode_hf.py └── tokenizer.py ├── demo ├── 5895_34622_000026_000002.wav ├── 84_121550_000074_000000.wav ├── pam.wav └── temp │ ├── 84_121550_000074_000000.txt │ └── mfa_alignments │ ├── 5895_34622_000026_000002.csv │ └── 84_121550_000074_000000.csv ├── edit_utils.py ├── environment.yml ├── gradio_app.ipynb ├── gradio_app.py ├── gradio_requirements.txt ├── inference_speech_editing.ipynb ├── inference_speech_editing_scale.py ├── inference_tts.ipynb ├── inference_tts_scale.py ├── main.py ├── models ├── codebooks_patterns.py ├── modules │ ├── __init__.py │ ├── activation.py │ ├── embedding.py │ ├── sampling.py │ ├── scaling.py │ ├── transformer.py │ └── utils.py └── voicecraft.py ├── predict.py ├── pretrained_models └── .gitkeep ├── start-jupyter.bat ├── start-jupyter.sh ├── steps ├── __init__.py ├── optim.py ├── trainer.py └── trainer_utils.py ├── tts_demo.py ├── voicecraft-gradio-colab.ipynb └── z_scripts ├── e830M.sh └── e830M_ft.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | *.log 12 | *.pdf 13 | *.mkv 14 | *.mp4 15 | *.png 16 | *.wav 17 | *.mp3 18 | *.pth 19 | *.th 20 | *.json 21 | 22 | *durip* 23 | *rtx* 24 | *l40* 25 | *a40* 26 | 27 | src/audiocraft 28 | 29 | !/demo/ 30 | !/demo/* 31 | /demo/temp/*.txt 32 | !/demo/temp/84_121550_000074_000000.txt 33 | .cog/tmp/* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM jupyter/base-notebook:python-3.9.13 2 | 3 | USER root 4 | 5 | # Install OS dependencies 6 | RUN apt-get update && apt-get install -y git-core ffmpeg espeak-ng && \ 7 | apt-get clean && \ 8 | rm -rf /var/lib/apt/lists/* 9 | 10 | # Update Conda, create the voicecraft environment, and install dependencies 11 | RUN conda update -y -n base -c conda-forge conda && \ 12 | conda create -y -n voicecraft python=3.9.16 && \ 13 | conda run -n voicecraft conda install -y -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 && \ 14 | conda run -n voicecraft mfa model download dictionary english_us_arpa && \ 15 | conda run -n voicecraft mfa model download acoustic english_us_arpa && \ 16 | conda run -n voicecraft pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft && \ 17 | conda run -n voicecraft pip install xformers==0.0.22 && \ 18 | conda run -n voicecraft pip install torch==2.0.1 && \ 19 | conda run -n voicecraft pip install torchaudio==2.0.2 && \ 20 | conda run -n voicecraft pip install tensorboard==2.16.2 && \ 21 | conda run -n voicecraft pip install phonemizer==3.2.1 && \ 22 | conda run -n voicecraft pip install datasets==2.16.0 && \ 23 | conda run -n voicecraft pip install torchmetrics==0.11.1 && \ 24 | conda run -n voicecraft pip install huggingface_hub==0.22.2 25 | 26 | 27 | # Install the Jupyter kernel 28 | RUN conda install -n voicecraft ipykernel --update-deps --force-reinstall -y && \ 29 | conda run -n voicecraft python -m ipykernel install --name=voicecraft -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /LICENSE-MODEL: -------------------------------------------------------------------------------- 1 | Coqui Public Model License 1.0.0 2 | https://coqui.ai/cpml.txt 3 | 4 | This license allows only non-commercial use of a machine learning model and its outputs. 5 | 6 | Acceptance 7 | In order to get any license under these terms, you must agree to them as both strict obligations and conditions to all your licenses. 8 | 9 | Licenses 10 | The licensor grants you a copyright license to do everything you might do with the model that would otherwise infringe the licensor's copyright in it, for any non-commercial purpose. The licensor grants you a patent license that covers patent claims the licensor can license, or becomes able to license, that you would infringe by using the model in the form provided by the licensor, for any non-commercial purpose. 11 | 12 | Non-commercial Purpose 13 | Non-commercial purposes include any of the following uses of the model or its output, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output. 14 | 15 | Personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, amateur pursuits, or religious observance. 16 | Use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development. Use of the model to train other models for commercial use is not a non-commercial purpose. 17 | Use by any charitable organization for charitable purposes, or for testing or evaluation. Use for revenue-generating activity, including projects directly funded by government grants, is not a non-commercial purpose. 18 | Notices 19 | You must ensure that anyone who gets a copy of any part of the model, or any modification of the model, or their output, from you also gets a copy of these terms or the URL for them above. 20 | 21 | No Other Rights 22 | These terms do not allow you to sublicense or transfer any of your licenses to anyone else, or prevent the licensor from granting licenses to anyone else. These terms do not imply any other licenses. 23 | 24 | Patent Defense 25 | If you make any written claim that the model infringes or contributes to infringement of any patent, your licenses for the model granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company. 26 | 27 | Violations 28 | The first time you are notified in writing that you have violated any of these terms, or done anything with the model or its output that is not covered by your licenses, your licenses can nonetheless continue if you come into full compliance with these terms, and take practical steps to correct past violations, within 30 days of receiving notice. Otherwise, all your licenses end immediately. 29 | 30 | No Liability 31 | AS FAR AS THE LAW ALLOWS, THE MODEL AND ITS OUTPUT COME AS IS, WITHOUT ANY WARRANTY OR CONDITION, AND THE LICENSOR WILL NOT BE LIABLE TO YOU FOR ANY DAMAGES ARISING OUT OF THESE TERMS OR THE USE OR NATURE OF THE MODEL OR ITS OUTPUT, UNDER ANY KIND OF LEGAL CLAIM. IF THIS PROVISION IS NOT ENFORCEABLE IN YOUR JURISDICTION, YOUR LICENSES ARE VOID. 32 | 33 | Definitions 34 | The licensor is the individual or entity offering these terms, and the model is the model the licensor makes available under these terms, including any documentation or similar information about the model. 35 | 36 | You refers to the individual or entity agreeing to these terms. 37 | 38 | Your company is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. Control means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect. 39 | 40 | Your licenses are all the licenses granted to you under these terms. 41 | 42 | Use means anything you do with the model or its output requiring one of your licenses. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild 2 | [![Paper](https://img.shields.io/badge/arXiv-2403.16973-brightgreen.svg?style=flat-square)](https://arxiv.org/pdf/2403.16973.pdf) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) [![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft) [![YouTube demo](https://img.shields.io/youtube/views/eikybOi8iwU)](https://youtu.be/eikybOi8iwU) [![Demo page](https://img.shields.io/badge/Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) 3 | 4 | 5 | ### TL;DR 6 | VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts. 7 | 8 | To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference. 9 | 10 | ## How to run inference 11 | There are three ways (besides running Gradio in Colab): 12 | 13 | 1. More flexible inference beyond Gradio UI in Google Colab. see [quickstart colab](#quickstart-colab) 14 | 2. with docker. see [quickstart docker](#quickstart-docker) 15 | 3. without docker. see [environment setup](#environment-setup). You can also run gradio locally if you choose this option 16 | 4. As a standalone script that you can easily integrate into other projects. 17 | see [quickstart command line](#quickstart-command-line). 18 | 19 | When you are inside the docker image or you have installed all dependencies, Checkout [`inference_tts.ipynb`](./inference_tts.ipynb). 20 | 21 | If you want to do model development such as training/finetuning, I recommend following [envrionment setup](#environment-setup) and [training](#training). 22 | 23 | ## News 24 | :star: 03/15/2025: change inference sampling from topp=1 to topk=40 massively improve editing and TTS performance 25 | 26 | :star: 04/22/2024: 330M/830M TTS Enhanced Models are up [here](https://huggingface.co/pyp1), load them through [`gradio_app.py`](./gradio_app.py) or [`inference_tts.ipynb`](./inference_tts.ipynb)! Replicate demo is up, major thanks to [@chenxwh](https://github.com/chenxwh)! 27 | 28 | :star: 04/11/2024: VoiceCraft Gradio is now available on HuggingFace Spaces [here](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio)! Major thanks to [@zuev-stepan](https://github.com/zuev-stepan), [@Sewlell](https://github.com/Sewlell), [@pgsoar](https://github.com/pgosar) [@Ph0rk0z](https://github.com/Ph0rk0z). 29 | 30 | :star: 04/05/2024: I finetuned giga330M with the TTS objective on gigaspeech and 1/5 of librilight. Weights are [here](https://huggingface.co/pyp1/VoiceCraft/tree/main). Make sure maximal prompt + generation length <= 16 seconds (due to our limited compute, we had to drop utterances longer than 16s in training data). Even stronger models forthcomming, stay tuned! 31 | 32 | :star: 03/28/2024: Model weights for giga330M and giga830M are up on HuggingFace🤗 [here](https://huggingface.co/pyp1/VoiceCraft/tree/main)! 33 | 34 | ## TODO 35 | - [x] Codebase upload 36 | - [x] Environment setup 37 | - [x] Inference demo for speech editing and TTS 38 | - [x] Training guidance 39 | - [x] RealEdit dataset and training manifest 40 | - [x] Model weights 41 | - [x] Better guidance on training/finetuning 42 | - [x] Colab notebooks 43 | - [x] HuggingFace Spaces demo 44 | - [x] Command line 45 | - [ ] Improve efficiency 46 | 47 | ## QuickStart Colab 48 | 49 | :star: To try out speech editing or TTS Inference with VoiceCraft, the simplest way is using Google Colab. 50 | Instructions to run are on the Colab itself. 51 | 52 | 1. To try [Speech Editing](https://colab.research.google.com/drive/1FV7EC36dl8UioePY1xXijXTMl7X47kR_?usp=sharing) 53 | 2. To try [TTS Inference](https://colab.research.google.com/drive/1lch_6it5-JpXgAQlUTRRI2z2_rk5K67Z?usp=sharing) 54 | 55 | ## QuickStart Command Line 56 | 57 | :star: To use it as a standalone script, check out tts_demo.py and speech_editing_demo.py. 58 | Be sure to first [setup your environment](#environment-setup). 59 | Without arguments, they will run the standard demo arguments used as an example elsewhere 60 | in this repository. You can use the command line arguments to specify unique input audios, 61 | target transcripts, and inference hyperparameters. Run the help command for more information: 62 | `python3 tts_demo.py -h` 63 | 64 | ## QuickStart Docker 65 | :star: To try out TTS inference with VoiceCraft, you can also use docker. Thank [@ubergarm](https://github.com/ubergarm) and [@jayc88](https://github.com/jay-c88) for making this happen. 66 | 67 | Tested on Linux and Windows and should work with any host with docker installed. 68 | ```bash 69 | # 1. clone the repo on in a directory on a drive with plenty of free space 70 | git clone git@github.com:jasonppy/VoiceCraft.git 71 | cd VoiceCraft 72 | 73 | # 2. assumes you have docker installed with nvidia container container-toolkit (windows has this built into the driver) 74 | # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.13.5/install-guide.html 75 | # sudo apt-get install -y nvidia-container-toolkit-base || yay -Syu nvidia-container-toolkit || echo etc... 76 | 77 | # 3. First build the docker image 78 | docker build --tag "voicecraft" . 79 | 80 | # 4. Try to start an existing container otherwise create a new one passing in all GPUs 81 | ./start-jupyter.sh # linux 82 | start-jupyter.bat # windows 83 | 84 | # 5. now open a webpage on the host box to the URL shown at the bottom of: 85 | docker logs jupyter 86 | 87 | # 6. optionally look inside from another terminal 88 | docker exec -it jupyter /bin/bash 89 | export USER=(your_linux_username_used_above) 90 | export HOME=/home/$USER 91 | sudo apt-get update 92 | 93 | # 7. confirm video card(s) are visible inside container 94 | nvidia-smi 95 | 96 | # 8. Now in browser, open inference_tts.ipynb and work through one cell at a time 97 | echo GOOD LUCK 98 | ``` 99 | 100 | ## Environment setup 101 | ```bash 102 | conda create -n voicecraft python=3.9.16 103 | conda activate voicecraft 104 | 105 | pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft 106 | pip install xformers==0.0.22 107 | pip install torchaudio==2.0.2 torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201 108 | apt-get install ffmpeg # if you don't already have ffmpeg installed 109 | apt-get install espeak-ng # backend for the phonemizer installed below 110 | pip install tensorboard==2.16.2 111 | pip install phonemizer==3.2.1 112 | pip install datasets==2.16.0 113 | pip install torchmetrics==0.11.1 114 | pip install huggingface_hub==0.22.2 115 | # install MFA for getting forced-alignment, this could take a few minutes 116 | conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 117 | # install MFA english dictionary and model 118 | mfa model download dictionary english_us_arpa 119 | mfa model download acoustic english_us_arpa 120 | # pip install huggingface_hub 121 | # conda install pocl # above gives an warning for installing pocl, not sure if really need this 122 | 123 | # to run ipynb 124 | conda install -n voicecraft ipykernel --no-deps --force-reinstall 125 | ``` 126 | 127 | If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching. 128 | 129 | ## Inference Examples 130 | Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb) 131 | 132 | ## Gradio 133 | ### Run in colab 134 | 135 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) 136 | 137 | ### Run locally 138 | After environment setup install additional dependencies: 139 | ```bash 140 | apt-get install -y espeak espeak-data libespeak1 libespeak-dev 141 | apt-get install -y festival* 142 | apt-get install -y build-essential 143 | apt-get install -y flac libasound2-dev libsndfile1-dev vorbis-tools 144 | apt-get install -y libxml2-dev libxslt-dev zlib1g-dev 145 | pip install -r gradio_requirements.txt 146 | ``` 147 | 148 | Run gradio server from terminal or [`gradio_app.ipynb`](./gradio_app.ipynb): 149 | ```bash 150 | python gradio_app.py 151 | ``` 152 | It is ready to use on [default url](http://127.0.0.1:7860). 153 | 154 | ### How to use it 155 | 1. (optionally) Select models 156 | 2. Load models 157 | 3. Transcribe 158 | 4. (optionally) Tweak some parameters 159 | 5. Run 160 | 6. (optionally) Rerun part-by-part in Long TTS mode 161 | 162 | ### Some features 163 | Smart transcript: write only what you want to generate 164 | 165 | TTS mode: Zero-shot TTS 166 | 167 | Edit mode: Speech editing 168 | 169 | Long TTS mode: Easy TTS on long texts 170 | 171 | 172 | ## Training 173 | To train an VoiceCraft model, you need to prepare the following parts: 174 | 1. utterances and their transcripts 175 | 2. encode the utterances into codes using e.g. Encodec 176 | 3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt) 177 | 4. manifest (i.e. metadata) 178 | 179 | Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where 180 | 1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token) 181 | 2. phoneme sequence and encodec codes are also extracted using the script. 182 | 183 | An example run: 184 | 185 | ```bash 186 | conda activate voicecraft 187 | export CUDA_VISIBLE_DEVICES=0 188 | cd ./data 189 | python phonemize_encodec_encode_hf.py \ 190 | --dataset_size xs \ 191 | --download_to path/to/store_huggingface_downloads \ 192 | --save_dir path/to/store_extracted_codes_and_phonemes \ 193 | --encodec_model_path path/to/encodec_model \ 194 | --mega_batch_size 120 \ 195 | --batch_size 32 \ 196 | --max_len 30000 197 | ``` 198 | where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len. 199 | The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`. 200 | 201 | As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same). 202 | 203 | Now, you are good to start training! 204 | 205 | ```bash 206 | conda activate voicecraft 207 | cd ./z_scripts 208 | bash e830M.sh 209 | ``` 210 | 211 | It's the same procedure to prepare your own custom dataset. Make sure that if 212 | 213 | ## Finetuning 214 | You also need to do step 1-4 as Training, and I recommend to use AdamW for optimization if you finetune a pretrained model for better stability. checkout script `./z_scripts/e830M_ft.sh`. 215 | 216 | If your dataset introduce new phonemes (which is very likely) that doesn't exist in the giga checkpoint, make sure you combine the original phonemes with the phoneme from your data when construction vocab. And you need to adjust `--text_vocab_size` and `--text_pad_token` so that the former is bigger than or equal to you vocab size, and the latter has the same value as `--text_vocab_size` (i.e. `--text_pad_token` is always the last token). Also since the text embedding are now of a different size, make sure you modify the weights loading part so that I won't crash (you could skip loading `text_embedding` or only load the existing part, and randomly initialize the new) 217 | 218 | ## License 219 | The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. 220 | 221 | ## Acknowledgement 222 | We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall-e), and we thank audiocraft team for open-sourcing [encodec](https://github.com/facebookresearch/audiocraft). 223 | 224 | ## Citation 225 | ``` 226 | @article{peng2024voicecraft, 227 | author = {Peng, Puyuan and Huang, Po-Yao and Mohamed, Abdelrahman and Harwath, David}, 228 | title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild}, 229 | journal = {arXiv}, 230 | year = {2024}, 231 | } 232 | ``` 233 | 234 | ## Disclaimer 235 | Any organization or individual is prohibited from using any technology mentioned in this paper to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. 236 | 237 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | system_packages: 7 | - libgl1-mesa-glx 8 | - libglib2.0-0 9 | - ffmpeg 10 | - espeak-ng 11 | python_version: "3.11" 12 | python_packages: 13 | - torch==2.1.0 14 | - torchaudio==2.1.0 15 | - xformers 16 | - phonemizer==3.2.1 17 | - whisperx==3.1.1 18 | - openai-whisper>=20231117 19 | run: 20 | - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft 21 | - pip install "pydantic<2.0.0" 22 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 23 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" 24 | predict: "predict.py:Predictor" 25 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def MyParser(): 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | # general training 7 | parser.add_argument("--seed", type=int, default=1) 8 | parser.add_argument("--precision", type=str, default="float16") 9 | parser.add_argument("--num_workers", type=int, default=8) 10 | parser.add_argument("--resume", action="store_true", default=False) 11 | parser.add_argument("--tb_write_every_n_steps", type=int, default=100) 12 | parser.add_argument("--print_every_n_steps", type=int, default=400) 13 | parser.add_argument("--val_every_n_steps", type=int, default=800) 14 | parser.add_argument("--lr", type=float, default=0.05) 15 | parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size, no matter whether using gradient_accumulation_steps, not used if we specified max_num_tokens") 16 | parser.add_argument("--max_num_tokens", type=int, default=100000, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note this is the final effective batch size per GPU, i.e. gradient accumulated batch size per gpu") 17 | parser.add_argument("--val_max_num_tokens", type=int, default=None, help="FOR validation") 18 | parser.add_argument("--num_buckets", type=int, default=6, help='used for dynamic batching, bucketing the samples based on the number of tokens') 19 | parser.add_argument("--dynamic_batching", type=int, default=0) 20 | parser.add_argument("--weight_decay", type=float, default=1e-2) 21 | parser.add_argument("--warmup_fraction", type=float, default=0.01, help="use linear warmup, the proportion of the training steps that are used for warming up") 22 | parser.add_argument("--num_epochs", type=int, default=10) 23 | parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps") 24 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 25 | parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_(), not used if we use ScaledAdam optimizer") 26 | parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement") 27 | parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps") 28 | 29 | # optimizer focused 30 | parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler") 31 | parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer') 32 | parser.add_argument("--pseudo_epoch_size", type=int, default=3000, help="only use for Eden scheduler.") 33 | parser.add_argument("--reduce_lr_start_epoch", type=int, default=4) 34 | parser.add_argument("--clipping_update_period", type=int, default=600) 35 | 36 | 37 | # path 38 | parser.add_argument("--exp_dir", type=str, default=None, help="will be combined with dataset name") 39 | parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'gigaspeech', they are folder name in the data dir also") 40 | parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file") 41 | parser.add_argument("--phn_folder_name", type=str, default="phonemes", help="for libritts I also have arpa phns, in which case should be phonemes_arpa") 42 | parser.add_argument("--encodec_folder_name", type=str, default="encodec_16khz_4codebooks", help="folder where encodec codes are stored") 43 | parser.add_argument("--manifest_name", type=str, default="manifest", help="metadata filename") 44 | 45 | # data focused 46 | parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0") 47 | parser.add_argument("--audio_max_length", type=float, default=20, help="in second, crop or drop the audio is length is longer than this") 48 | parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this") 49 | parser.add_argument("--text_max_length", type=int, default=400, help='if too long, we crop or drop') 50 | parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop") 51 | parser.add_argument("--encodec_sr", type=int, default=50, help="for my encodec that takes 16kHz audio with a downsample rate of 320, the codec sample rate is 50Hz, i.e. 50 codes (x n_codebooks) per second") 52 | parser.add_argument("--drop_long", type=int, default=0, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping, to reduce hellucination") 53 | 54 | # encodec and token rearrangement 55 | parser.add_argument('--mask_len_min', type=int, default=1, help='Minimum mask length') 56 | parser.add_argument('--mask_len_max', type=int, default=600, help='Maximum mask length') 57 | parser.add_argument("--eos", type=int, default=-1, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4") 58 | parser.add_argument("--reduced_eog", type=int, default=0, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts") 59 | parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens") 60 | parser.add_argument("--n_special", type=int, default=3, help="empty, eog, pad, (eos)") 61 | parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']") 62 | parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion") 63 | parser.add_argument("--max_n_spans", type=int, default=3, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)') 64 | parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first") 65 | parser.add_argument("--mask_sample_dist", type=str, default="poisson1", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks") 66 | parser.add_argument("--min_gap", type=int, default=5, help="after sampled starts, delete later one if it closer to the former start than the min_gap") 67 | parser.add_argument('--n_codebooks', type=int, default=4) 68 | parser.add_argument('--text_vocab_size', type=int, default=100, help='Size of text vocabulary') 69 | parser.add_argument('--text_pad_token', type=int, default=100, help='padding of the text tokens, not attended') 70 | parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary") 71 | parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook") 72 | parser.add_argument('--eog', type=int, default=2049, help='End of generation token') 73 | parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended') 74 | 75 | # model focused 76 | parser.add_argument('--d_model', type=int, default=2048, help='Model dimension') 77 | parser.add_argument('--audio_embedding_dim', type=int, default=2048, help='dimension for encodec continues embedding (before being quantized)') 78 | parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding') 79 | parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding') 80 | parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding') 81 | parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding') 82 | parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer') 83 | parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads') 84 | parser.add_argument('--num_decoder_layers', type=int, default=16, help='Number of decoder layers') 85 | parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume') 86 | return parser -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/data/__init__.py -------------------------------------------------------------------------------- /data/gigaspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import copy 5 | import logging 6 | import shutil 7 | 8 | class dataset(torch.utils.data.Dataset): 9 | def __init__(self, args, split): 10 | super().__init__() 11 | self.args = args 12 | self.split = split 13 | assert self.split in ['train', 'validation', 'test'] 14 | manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt") 15 | 16 | with open(manifest_fn, "r") as rf: 17 | data = [l.strip().split("\t") for l in rf.readlines()] 18 | lengths_list = [int(item[-1]) for item in data] 19 | self.data = [] 20 | self.lengths_list = [] 21 | for d, l in zip(data, lengths_list): 22 | if l >= self.args.encodec_sr*self.args.audio_min_length: 23 | if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length: 24 | continue 25 | self.data.append(d) 26 | self.lengths_list.append(l) 27 | logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}") 28 | 29 | # phoneme vocabulary 30 | vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt") 31 | shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt")) 32 | with open(vocab_fn, "r") as f: 33 | temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0] 34 | self.phn2num = {item[1]:int(item[0]) for item in temp} 35 | 36 | self.symbol_set = set(["", "", "", ""]) 37 | 38 | def __len__(self): 39 | return len(self.lengths_list) 40 | 41 | def _load_phn_enc(self, index): 42 | item = self.data[index] 43 | pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt") 44 | ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt") 45 | try: 46 | with open(pf, "r") as p, open(ef, "r") as e: 47 | phns = [l.strip() for l in p.readlines()] 48 | assert len(phns) == 1, phns 49 | x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["", "", "", ""], as they are not in training set annotation 50 | encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks] 51 | 52 | assert len(encos) == self.args.n_codebooks, ef 53 | if self.args.special_first: 54 | y = [[int(n)+self.args.n_special for n in l] for l in encos] 55 | else: 56 | y = [[int(n) for n in l] for l in encos] 57 | except Exception as e: 58 | logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted") 59 | logging.info(f"error message: {e}") 60 | return [], [[]] 61 | 62 | return x, y 63 | 64 | def __getitem__(self, index): 65 | x, y = self._load_phn_enc(index) 66 | x_len, y_len = len(x), len(y[0]) 67 | 68 | if x_len == 0 or y_len == 0: 69 | return { 70 | "x": None, 71 | "x_len": None, 72 | "y": None, 73 | "y_len": None, 74 | "y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token 75 | "extra_mask_start": None # this is only used in VE1 76 | } 77 | while y_len < self.args.encodec_sr*self.args.audio_min_length: 78 | assert not self.args.dynamic_batching 79 | index = random.choice(range(len(self))) # regenerate an index 80 | x, y = self._load_phn_enc(index) 81 | x_len, y_len = len(x), len(y[0]) 82 | if self.args.drop_long: 83 | while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length: 84 | index = random.choice(range(len(self))) # regenerate an index 85 | x, y = self._load_phn_enc(index) 86 | x_len, y_len = len(x), len(y[0]) 87 | 88 | ### padding and cropping below ### 89 | ### padding and cropping below ### 90 | # adjust the length of encodec codes, pad to max_len or randomly crop 91 | orig_y_len = copy.copy(y_len) 92 | max_len = int(self.args.audio_max_length * self.args.encodec_sr) 93 | if y_len > max_len: 94 | audio_start = random.choice(range(0, y_len-max_len)) 95 | for i in range(len(y)): 96 | y[i] = y[i][audio_start:(audio_start+max_len)] 97 | y_len = max_len 98 | else: 99 | audio_start = 0 100 | if not self.args.dynamic_batching: 101 | pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len) 102 | for i in range(len(y)): 103 | y[i] = y[i] + pad 104 | 105 | # adjust text 106 | # if audio is cropped, and text is longer than max, crop max based on how audio is cropped 107 | if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started 108 | x = x[int(len(x)*audio_start/orig_y_len):] 109 | if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end 110 | x = x[:self.args.text_max_length] 111 | 112 | x_len = len(x) 113 | if x_len > self.args.text_max_length: 114 | text_start = random.choice(range(0, x_len - self.args.text_max_length)) 115 | x = x[text_start:text_start+self.args.text_max_length] 116 | x_len = self.args.text_max_length 117 | elif self.args.pad_x and x_len <= self.args.text_max_length: 118 | pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len) 119 | x = x + pad 120 | ### padding and cropping above ### 121 | ### padding and cropping above ### 122 | 123 | return { 124 | "x": torch.LongTensor(x), 125 | "x_len": x_len, 126 | "y": torch.LongTensor(y), 127 | "y_len": y_len 128 | } 129 | 130 | 131 | def collate(self, batch): 132 | out = {key:[] for key in batch[0]} 133 | for item in batch: 134 | if item['x'] == None: # deal with load failure 135 | continue 136 | for key, val in item.items(): 137 | out[key].append(val) 138 | res = {} 139 | if self.args.pad_x: 140 | res["x"] = torch.stack(out["x"], dim=0) 141 | else: 142 | res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token) 143 | res["x_lens"] = torch.LongTensor(out["x_len"]) 144 | if self.args.dynamic_batching: 145 | if out['y'][0].ndim==2: 146 | res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token) 147 | res['y'] = res['y'].permute(1,2,0) # T B K -> B K T 148 | else: 149 | assert out['y'][0].ndim==1, out['y'][0].shape 150 | res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token) 151 | else: 152 | res['y'] = torch.stack(out['y'], dim=0) 153 | res["y_lens"] = torch.LongTensor(out["y_len"]) 154 | res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1) 155 | res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1) 156 | return res -------------------------------------------------------------------------------- /data/phonemize_encodec_encode_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | def parse_args(): 3 | parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model") 4 | parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging') 5 | parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to") 6 | parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs") 7 | parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") 8 | parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes") 9 | parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading") 10 | parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") 11 | parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') 12 | parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') 13 | parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate') 14 | parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number') 15 | parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine') 16 | return parser.parse_args() 17 | if __name__ == "__main__": 18 | import logging 19 | formatter = ( 20 | "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" 21 | ) 22 | logging.basicConfig(format=formatter, level=logging.INFO) 23 | args = parse_args() 24 | 25 | import os 26 | import numpy as np 27 | import torch 28 | import tqdm 29 | import time 30 | from datasets import load_dataset, DownloadConfig 31 | 32 | from tokenizer import TextTokenizer, tokenize_text 33 | 34 | # get the path 35 | phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes") 36 | codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks") 37 | vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt") 38 | os.makedirs(phn_save_root, exist_ok=True) 39 | os.makedirs(codes_save_root, exist_ok=True) 40 | 41 | 42 | def sort_by_audio_len(lens): 43 | inds = np.argsort(lens).tolist() 44 | logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.") 45 | logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.") 46 | logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.") 47 | logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.") 48 | return inds[::-1] 49 | 50 | def write_array_to_txt_file(array, filename): 51 | with open(filename, 'w') as f: 52 | for a in array[:-1]: 53 | f.write(' '.join(map(str, a))+'\n') 54 | f.write(' '.join(map(str, array[-1]))) 55 | 56 | 57 | ### phonemization 58 | # load tokenizer 59 | # load the encodec model 60 | from audiocraft.solvers import CompressionSolver 61 | model = CompressionSolver.model_from_checkpoint(args.encodec_model_path) 62 | model = model.cuda() 63 | model = model.eval() 64 | text_tokenizer = TextTokenizer() 65 | 66 | 67 | # https://github.com/SpeechColab/GigaSpeech 68 | # there are only four different punctuations 69 | # need to check whether there are other < started strings 70 | punc2sym = {" ": ",", " ": ".", " ": "?", " ": "!"} # note the space in front of each punc name 71 | gar2sym = {"": "#%#", "": "##%", "": "%%#", "":"%#%"} # so that they are savely keep as the original sym when using tokenize_text 72 | punc2sym.update(gar2sym) 73 | 74 | word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": ""} 75 | forbidden_words = set(['#%#', '##%', '%%#', '%#%']) 76 | 77 | dc = DownloadConfig(cache_dir=args.download_to) 78 | stime = time.time() 79 | logging.info("loading the dataset...") 80 | gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc) 81 | logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds") 82 | 83 | splits = ['validation', 'test', 'train'] 84 | 85 | logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}") 86 | logging.info(f"phonemizing...") 87 | phn_vocab = set() 88 | all_lens = [] 89 | 90 | # you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue 91 | for split in tqdm.tqdm(splits): 92 | skip = 0 93 | logging.info(f"now processing split {split}...") 94 | for item in tqdm.tqdm(gs[split]): 95 | save_fn = os.path.join(phn_save_root, item['segment_id']+".txt") 96 | text = item['text'] 97 | if sum(word in forbidden_words for word in text.split(" ")): 98 | logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}") 99 | skip += 1 100 | continue 101 | for k, v in punc2sym.items(): 102 | text = text.replace(k, v) 103 | phn = tokenize_text(text_tokenizer, text) 104 | phn_seq = " ".join(phn) 105 | for k, v in word2sym.items(): 106 | phn_seq = phn_seq.replace(k, v) 107 | phn_vocab.update(phn_seq.split(" ")) 108 | all_lens.append(len(phn_seq.split(" "))) 109 | with open(save_fn, "w") as f: 110 | f.write(phn_seq) 111 | logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words") 112 | 113 | print(f"phn vocab size: {len(list(phn_vocab))}") 114 | print("phn sequence stats: ") 115 | print(f"longest: {max(all_lens)}") 116 | print(f"shortest: {min(all_lens)}") 117 | print(f"median: {np.quantile(all_lens, 0.5)}") 118 | print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}") 119 | print("write vocabulary to ", vocab_fn) 120 | with open(vocab_fn, "w") as f: 121 | for i, phn in enumerate(list(phn_vocab)): 122 | if i < len(list(phn_vocab)) - 1: 123 | f.write(f"{str(i)} {phn}\n") 124 | else: 125 | f.write(f"{str(i)} {phn}") 126 | 127 | class mydataset(torch.utils.data.Dataset): 128 | def __init__(self, split): 129 | super().__init__() 130 | self.data = gs[split] 131 | def __len__(self): 132 | return len(self.data) 133 | def __getitem__(self, ind): 134 | try: 135 | segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time'] 136 | except: 137 | return None, None, None, None, None, None 138 | 139 | return segment_id, audio, sr, text, begin_time, end_time 140 | def collate(self, batch): 141 | res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []} 142 | for item in batch: 143 | if item[0] != None: 144 | res['segment_id'].append(item[0]) 145 | res['audio'].append(item[1]) 146 | res['sr'].append(item[2]) 147 | res['text'].append(item[3]) 148 | res['begin_time'].append(item[4]) 149 | res['end_time'].append(item[5]) 150 | return res 151 | 152 | 153 | ## encodec codes extraction 154 | logging.info("encodec encoding...") 155 | train_dataset = mydataset('train') 156 | train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate) 157 | validation_dataset = mydataset('validation') 158 | validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate) 159 | test_dataset = mydataset('test') 160 | test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate) 161 | splits = ['validation', 'test', 'train'] 162 | loaders = [validation_loader, test_loader, train_loader] 163 | # splits = ['validation'] # for debug 164 | # loaders = [validation_loader] 165 | for split, loader in zip(splits, loaders): 166 | skip = 0 167 | logging.info(f"now processing split {split}...") 168 | mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size)) 169 | logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples") 170 | for m, mega_batch in enumerate(loader): 171 | logging.info(f"====================================") 172 | logging.info(f"====================================") 173 | logging.info(f"now processing mega step {m+1}/{mega_n_steps}") 174 | lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time']) 175 | sorted_inds = sort_by_audio_len(lengths) 176 | for j in range(len(sorted_inds))[::-1]: 177 | if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) 178 | skip += 1 179 | del sorted_inds[j] 180 | 181 | n_steps = int(np.ceil(len(sorted_inds) / args.batch_size)) 182 | for n in tqdm.tqdm(range(n_steps), disable=True): 183 | inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size] 184 | audio_batch = [mega_batch['audio'][id] for id in inds_used] 185 | sr_batch = [mega_batch['sr'][id] for id in inds_used] 186 | segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used] 187 | text_batch = [mega_batch['text'][id] for id in inds_used] 188 | padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] 189 | all_lens = [lengths[id] for id in inds_used] 190 | with torch.no_grad(): 191 | if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes 192 | codes = [] 193 | inwav = padded_wav.cuda() 194 | codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu()) 195 | codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu()) 196 | codes = torch.cat(codes, dim=0) 197 | else: 198 | encoded_frames = model.encode(padded_wav.cuda()) 199 | # logging.info(f"encoded_frames: {encoded_frames[0].shape}") 200 | codes = encoded_frames[0].cpu() 201 | 202 | for i, length in enumerate(all_lens): 203 | save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt") 204 | actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model 205 | cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() 206 | write_array_to_txt_file(cur_code, save_fn) 207 | -------------------------------------------------------------------------------- /data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from dataclasses import asdict, dataclass 18 | from typing import Any, Dict, List, Optional, Pattern, Union 19 | 20 | import numpy as np 21 | import torch 22 | import torchaudio 23 | # from lhotse.features import FeatureExtractor 24 | # from lhotse.utils import Seconds, compute_num_frames 25 | from phonemizer.backend import EspeakBackend 26 | from phonemizer.backend.espeak.language_switch import LanguageSwitch 27 | from phonemizer.backend.espeak.words_mismatch import WordMismatch 28 | from phonemizer.punctuation import Punctuation 29 | from phonemizer.separator import Separator 30 | 31 | 32 | 33 | class TextTokenizer: 34 | """Phonemize Text.""" 35 | 36 | def __init__( 37 | self, 38 | language="en-us", 39 | backend="espeak", 40 | separator=Separator(word="_", syllable="-", phone="|"), 41 | preserve_punctuation=True, 42 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 43 | with_stress: bool = False, 44 | tie: Union[bool, str] = False, 45 | language_switch: LanguageSwitch = "keep-flags", 46 | words_mismatch: WordMismatch = "ignore", 47 | ) -> None: 48 | phonemizer = EspeakBackend( 49 | language, 50 | punctuation_marks=punctuation_marks, 51 | preserve_punctuation=preserve_punctuation, 52 | with_stress=with_stress, 53 | tie=tie, 54 | language_switch=language_switch, 55 | words_mismatch=words_mismatch, 56 | ) 57 | 58 | self.backend = phonemizer 59 | self.separator = separator 60 | 61 | def to_list(self, phonemized: str) -> List[str]: 62 | fields = [] 63 | for word in phonemized.split(self.separator.word): 64 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. 65 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) 66 | fields.extend( 67 | [p for p in pp if p != self.separator.phone] 68 | + [self.separator.word] 69 | ) 70 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( 71 | self.separator.phone 72 | ) 73 | return fields[:-1] 74 | 75 | def __call__(self, text, strip=True) -> List[List[str]]: 76 | if isinstance(text, str): 77 | text = [text] 78 | 79 | phonemized = self.backend.phonemize( 80 | text, separator=self.separator, strip=strip, njobs=1 81 | ) 82 | return [self.to_list(p) for p in phonemized] 83 | 84 | 85 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: 86 | phonemes = tokenizer([text.strip()]) 87 | return phonemes[0] # k2symbols 88 | 89 | def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): 90 | assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." 91 | if target_channels == 1: 92 | wav = wav.mean(0, keepdim=True) 93 | elif target_channels == 2: 94 | *shape, _, length = wav.shape 95 | wav = wav.expand(*shape, target_channels, length) 96 | elif wav.shape[0] == 1: 97 | wav = wav.expand(target_channels, -1) 98 | wav = torchaudio.transforms.Resample(sr, target_sr)(wav) 99 | return wav 100 | 101 | class AudioTokenizer: 102 | """EnCodec audio.""" 103 | 104 | def __init__( 105 | self, 106 | device: Any = None, 107 | signature = None 108 | ) -> None: 109 | from audiocraft.solvers import CompressionSolver 110 | model = CompressionSolver.model_from_checkpoint(signature) 111 | self.sample_rate = model.sample_rate 112 | self.channels = model.channels 113 | 114 | if not device: 115 | device = torch.device("cpu") 116 | if torch.cuda.is_available(): 117 | device = torch.device("cuda:0") 118 | 119 | self._device = device 120 | 121 | self.codec = model.to(device) 122 | 123 | @property 124 | def device(self): 125 | return self._device 126 | 127 | def encode(self, wav: torch.Tensor) -> torch.Tensor: 128 | codes = self.codec.encode(wav.to(self.device)) 129 | return [(codes[0], None)] 130 | 131 | def decode(self, frames: torch.Tensor) -> torch.Tensor: 132 | frames = frames[0][0] # [1,4,T] 133 | return self.codec.decode(frames) 134 | 135 | 136 | 137 | def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): 138 | # Load and pre-process the audio waveform 139 | if offset != -1 and num_frames!=-1: 140 | wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) 141 | else: 142 | wav, sr = torchaudio.load(audio_path) 143 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 144 | wav = wav.unsqueeze(0) 145 | 146 | # Extract discrete codes from EnCodec 147 | with torch.no_grad(): 148 | encoded_frames = tokenizer.encode(wav) 149 | return encoded_frames 150 | -------------------------------------------------------------------------------- /demo/5895_34622_000026_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/demo/5895_34622_000026_000002.wav -------------------------------------------------------------------------------- /demo/84_121550_000074_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/demo/84_121550_000074_000000.wav -------------------------------------------------------------------------------- /demo/pam.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/demo/pam.wav -------------------------------------------------------------------------------- /demo/temp/84_121550_000074_000000.txt: -------------------------------------------------------------------------------- 1 | But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks, -------------------------------------------------------------------------------- /demo/temp/mfa_alignments/5895_34622_000026_000002.csv: -------------------------------------------------------------------------------- 1 | Begin,End,Label,Type,Speaker 2 | 0.04,0.58,gwynplaine,words,temp 3 | 0.58,0.94,had,words,temp 4 | 0.94,1.45,besides,words,temp 5 | 1.45,1.62,for,words,temp 6 | 1.62,1.86,his,words,temp 7 | 1.86,2.16,work,words,temp 8 | 2.16,2.31,and,words,temp 9 | 2.31,2.49,for,words,temp 10 | 2.49,2.71,his,words,temp 11 | 2.71,3.03,feats,words,temp 12 | 3.03,3.12,of,words,temp 13 | 3.12,3.61,strength,words,temp 14 | 3.95,4.25,round,words,temp 15 | 4.25,4.45,his,words,temp 16 | 4.45,4.7,neck,words,temp 17 | 4.7,4.81,and,words,temp 18 | 4.81,5.04,over,words,temp 19 | 5.04,5.22,his,words,temp 20 | 5.22,5.83,shoulders,words,temp 21 | 6.16,6.31,an,words,temp 22 | 6.41,7.15,esclavine,words,temp 23 | 7.15,7.29,of,words,temp 24 | 7.29,7.7,leather,words,temp 25 | 0.04,0.1,G,phones,temp 26 | 0.1,0.13,W,phones,temp 27 | 0.13,0.22,IH1,phones,temp 28 | 0.22,0.3,N,phones,temp 29 | 0.3,0.38,P,phones,temp 30 | 0.38,0.42,L,phones,temp 31 | 0.42,0.53,EY1,phones,temp 32 | 0.53,0.58,N,phones,temp 33 | 0.58,0.71,HH,phones,temp 34 | 0.71,0.86,AE1,phones,temp 35 | 0.86,0.94,D,phones,temp 36 | 0.94,0.97,B,phones,temp 37 | 0.97,1.01,IH0,phones,temp 38 | 1.01,1.14,S,phones,temp 39 | 1.14,1.34,AY1,phones,temp 40 | 1.34,1.4,D,phones,temp 41 | 1.4,1.45,Z,phones,temp 42 | 1.45,1.52,F,phones,temp 43 | 1.52,1.55,AO1,phones,temp 44 | 1.55,1.62,R,phones,temp 45 | 1.62,1.69,HH,phones,temp 46 | 1.69,1.76,IH1,phones,temp 47 | 1.76,1.86,Z,phones,temp 48 | 1.86,1.95,W,phones,temp 49 | 1.95,2.07,ER1,phones,temp 50 | 2.07,2.16,K,phones,temp 51 | 2.16,2.23,AH0,phones,temp 52 | 2.23,2.26,N,phones,temp 53 | 2.26,2.31,D,phones,temp 54 | 2.31,2.38,F,phones,temp 55 | 2.38,2.41,AO1,phones,temp 56 | 2.41,2.49,R,phones,temp 57 | 2.49,2.55,HH,phones,temp 58 | 2.55,2.62,IH1,phones,temp 59 | 2.62,2.71,Z,phones,temp 60 | 2.71,2.8,F,phones,temp 61 | 2.8,2.9,IY1,phones,temp 62 | 2.9,2.98,T,phones,temp 63 | 2.98,3.03,S,phones,temp 64 | 3.03,3.07,AH0,phones,temp 65 | 3.07,3.12,V,phones,temp 66 | 3.12,3.2,S,phones,temp 67 | 3.2,3.26,T,phones,temp 68 | 3.26,3.32,R,phones,temp 69 | 3.32,3.39,EH1,phones,temp 70 | 3.39,3.48,NG,phones,temp 71 | 3.48,3.53,K,phones,temp 72 | 3.53,3.61,TH,phones,temp 73 | 3.95,4.03,R,phones,temp 74 | 4.03,4.16,AW1,phones,temp 75 | 4.16,4.21,N,phones,temp 76 | 4.21,4.25,D,phones,temp 77 | 4.25,4.29,HH,phones,temp 78 | 4.29,4.36,IH1,phones,temp 79 | 4.36,4.45,Z,phones,temp 80 | 4.45,4.53,N,phones,temp 81 | 4.53,4.62,EH1,phones,temp 82 | 4.62,4.7,K,phones,temp 83 | 4.7,4.74,AH0,phones,temp 84 | 4.74,4.77,N,phones,temp 85 | 4.77,4.81,D,phones,temp 86 | 4.81,4.92,OW1,phones,temp 87 | 4.92,4.97,V,phones,temp 88 | 4.97,5.04,ER0,phones,temp 89 | 5.04,5.11,HH,phones,temp 90 | 5.11,5.18,IH1,phones,temp 91 | 5.18,5.22,Z,phones,temp 92 | 5.22,5.34,SH,phones,temp 93 | 5.34,5.47,OW1,phones,temp 94 | 5.47,5.51,L,phones,temp 95 | 5.51,5.58,D,phones,temp 96 | 5.58,5.71,ER0,phones,temp 97 | 5.71,5.83,Z,phones,temp 98 | 6.16,6.23,AE1,phones,temp 99 | 6.23,6.31,N,phones,temp 100 | 6.41,7.15,spn,phones,temp 101 | 7.15,7.21,AH0,phones,temp 102 | 7.21,7.29,V,phones,temp 103 | 7.29,7.36,L,phones,temp 104 | 7.36,7.44,EH1,phones,temp 105 | 7.44,7.49,DH,phones,temp 106 | 7.49,7.7,ER0,phones,temp 107 | -------------------------------------------------------------------------------- /demo/temp/mfa_alignments/84_121550_000074_000000.csv: -------------------------------------------------------------------------------- 1 | Begin,End,Label,Type,Speaker 2 | 0.03,0.18,but,words,temp 3 | 0.18,0.32,when,words,temp 4 | 0.32,0.48,i,words,temp 5 | 0.48,0.64,had,words,temp 6 | 0.64,1.19,approached,words,temp 7 | 1.22,1.58,so,words,temp 8 | 1.58,1.91,near,words,temp 9 | 1.91,2.07,to,words,temp 10 | 2.07,2.42,them,words,temp 11 | 2.53,2.61,the,words,temp 12 | 2.61,3.01,common,words,temp 13 | 3.05,3.62,object,words,temp 14 | 3.68,3.93,which,words,temp 15 | 3.93,4.02,the,words,temp 16 | 4.02,4.34,sense,words,temp 17 | 4.34,4.97,deceives,words,temp 18 | 5.04,5.54,lost,words,temp 19 | 5.54,6.0,not,words,temp 20 | 6.0,6.14,by,words,temp 21 | 6.14,6.67,distance,words,temp 22 | 6.79,7.05,any,words,temp 23 | 7.05,7.18,of,words,temp 24 | 7.18,7.34,its,words,temp 25 | 7.34,7.87,marks,words,temp 26 | 0.03,0.06,B,phones,temp 27 | 0.06,0.09,AH1,phones,temp 28 | 0.09,0.18,T,phones,temp 29 | 0.18,0.23,W,phones,temp 30 | 0.23,0.27,EH1,phones,temp 31 | 0.27,0.32,N,phones,temp 32 | 0.32,0.48,AY1,phones,temp 33 | 0.48,0.49,HH,phones,temp 34 | 0.49,0.6,AE1,phones,temp 35 | 0.6,0.64,D,phones,temp 36 | 0.64,0.7,AH0,phones,temp 37 | 0.7,0.83,P,phones,temp 38 | 0.83,0.88,R,phones,temp 39 | 0.88,0.99,OW1,phones,temp 40 | 0.99,1.12,CH,phones,temp 41 | 1.12,1.19,T,phones,temp 42 | 1.22,1.4,S,phones,temp 43 | 1.4,1.58,OW1,phones,temp 44 | 1.58,1.7,N,phones,temp 45 | 1.7,1.84,IH1,phones,temp 46 | 1.84,1.91,R,phones,temp 47 | 1.91,2.01,T,phones,temp 48 | 2.01,2.07,AH0,phones,temp 49 | 2.07,2.13,DH,phones,temp 50 | 2.13,2.3,EH1,phones,temp 51 | 2.3,2.42,M,phones,temp 52 | 2.53,2.55,DH,phones,temp 53 | 2.55,2.61,AH0,phones,temp 54 | 2.61,2.73,K,phones,temp 55 | 2.73,2.85,AA1,phones,temp 56 | 2.85,2.9,M,phones,temp 57 | 2.9,2.95,AH0,phones,temp 58 | 2.95,3.01,N,phones,temp 59 | 3.05,3.22,AA1,phones,temp 60 | 3.22,3.27,B,phones,temp 61 | 3.27,3.34,JH,phones,temp 62 | 3.34,3.48,EH0,phones,temp 63 | 3.48,3.54,K,phones,temp 64 | 3.54,3.62,T,phones,temp 65 | 3.68,3.69,HH,phones,temp 66 | 3.69,3.76,W,phones,temp 67 | 3.76,3.8,IH1,phones,temp 68 | 3.8,3.93,CH,phones,temp 69 | 3.93,3.95,DH,phones,temp 70 | 3.95,4.02,AH0,phones,temp 71 | 4.02,4.12,S,phones,temp 72 | 4.12,4.21,EH1,phones,temp 73 | 4.21,4.27,N,phones,temp 74 | 4.27,4.34,S,phones,temp 75 | 4.34,4.42,D,phones,temp 76 | 4.42,4.45,IH0,phones,temp 77 | 4.45,4.59,S,phones,temp 78 | 4.59,4.79,IY1,phones,temp 79 | 4.79,4.87,V,phones,temp 80 | 4.87,4.97,Z,phones,temp 81 | 5.04,5.12,L,phones,temp 82 | 5.12,5.33,AO1,phones,temp 83 | 5.33,5.42,S,phones,temp 84 | 5.42,5.54,T,phones,temp 85 | 5.54,5.7,N,phones,temp 86 | 5.7,5.89,AA1,phones,temp 87 | 5.89,6.0,T,phones,temp 88 | 6.0,6.05,B,phones,temp 89 | 6.05,6.14,AY1,phones,temp 90 | 6.14,6.24,D,phones,temp 91 | 6.24,6.3,IH1,phones,temp 92 | 6.3,6.38,S,phones,temp 93 | 6.38,6.45,T,phones,temp 94 | 6.45,6.51,AH0,phones,temp 95 | 6.51,6.57,N,phones,temp 96 | 6.57,6.67,S,phones,temp 97 | 6.79,6.89,EH1,phones,temp 98 | 6.89,6.95,N,phones,temp 99 | 6.95,7.05,IY0,phones,temp 100 | 7.05,7.13,AH0,phones,temp 101 | 7.13,7.18,V,phones,temp 102 | 7.18,7.22,IH0,phones,temp 103 | 7.22,7.29,T,phones,temp 104 | 7.29,7.34,S,phones,temp 105 | 7.34,7.39,M,phones,temp 106 | 7.39,7.5,AA1,phones,temp 107 | 7.5,7.58,R,phones,temp 108 | 7.58,7.7,K,phones,temp 109 | 7.7,7.87,S,phones,temp 110 | -------------------------------------------------------------------------------- /edit_utils.py: -------------------------------------------------------------------------------- 1 | def get_span(orig, new, editType): 2 | orig_list = orig.split(" ") 3 | new_list = new.split(" ") 4 | 5 | flag = False # this indicate whether the actual edit follow the specified editType 6 | if editType == "deletion": 7 | assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}" 8 | diff = len(orig_list) - len(new_list) 9 | for i, (o, n) in enumerate(zip(orig_list, new_list)): 10 | if o != n: # assume the index of the first different word is the starting index of the orig_span 11 | 12 | orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part 13 | new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part 14 | flag = True 15 | break 16 | 17 | 18 | elif editType == "insertion": 19 | assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}" 20 | diff = len(new_list) - len(orig_list) 21 | for i, (o, n) in enumerate(zip(orig_list, new_list)): 22 | if o != n: # insertion is just the opposite of deletion 23 | new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same 24 | orig_span = [i-1, i] 25 | flag = True 26 | break 27 | 28 | elif editType == "substitution": 29 | new_span = [] 30 | orig_span = [] 31 | for i, (o, n) in enumerate(zip(orig_list, new_list)): 32 | if o != n: 33 | new_span = [i] 34 | orig_span = [i] 35 | break 36 | assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}" 37 | for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])): 38 | if o != n: 39 | new_span.append(len(new_list) - j -1) 40 | orig_span.append(len(orig_list) - j - 1) 41 | flag = True 42 | break 43 | else: 44 | raise RuntimeError(f"editType unknown: {editType}") 45 | 46 | if not flag: 47 | raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}") 48 | 49 | return orig_span, new_span -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: voicecraft 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - aom=3.8.2=h59595ed_0 9 | - asttokens=2.4.1=pyhd8ed1ab_0 10 | - atk-1.0=2.38.0=hd4edc92_1 11 | - audioread=3.0.1=py39hf3d152e_1 12 | - backcall=0.2.0=pyh9f0ad1d_0 13 | - baumwelch=0.3.7=h00ab1b0_5 14 | - biopython=1.79=py39hb9d737c_3 15 | - brotli=1.1.0=hd590300_1 16 | - brotli-bin=1.1.0=hd590300_1 17 | - brotli-python=1.1.0=py39h3d6467e_1 18 | - bzip2=1.0.8=hd590300_5 19 | - ca-certificates=2024.2.2=hbcca054_0 20 | - cairo=1.18.0=h3faef2a_0 21 | - certifi=2024.2.2=pyhd8ed1ab_0 22 | - cffi=1.16.0=py39h7a31438_0 23 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 24 | - click=8.1.7=unix_pyh707e725_0 25 | - colorama=0.4.6=pyhd8ed1ab_0 26 | - comm=0.2.2=pyhd8ed1ab_0 27 | - contourpy=1.2.0=py39h7633fee_0 28 | - cycler=0.12.1=pyhd8ed1ab_0 29 | - dataclassy=1.0.1=pyhd8ed1ab_0 30 | - dav1d=1.2.1=hd590300_0 31 | - debugpy=1.8.1=py39h3d6467e_0 32 | - decorator=5.1.1=pyhd8ed1ab_0 33 | - executing=2.0.1=pyhd8ed1ab_0 34 | - expat=2.6.2=h59595ed_0 35 | - ffmpeg=6.1.1=gpl_h38e077a_106 36 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 37 | - font-ttf-inconsolata=3.000=h77eed37_0 38 | - font-ttf-source-code-pro=2.038=h77eed37_0 39 | - font-ttf-ubuntu=0.83=h77eed37_1 40 | - fontconfig=2.14.2=h14ed4e7_0 41 | - fonts-conda-ecosystem=1=0 42 | - fonts-conda-forge=1=0 43 | - fonttools=4.49.0=py39hd1e30aa_0 44 | - freetype=2.12.1=h267a509_2 45 | - fribidi=1.0.10=h36c2ea0_0 46 | - gdk-pixbuf=2.42.10=h829c605_5 47 | - gettext=0.21.1=h27087fc_0 48 | - giflib=5.2.1=h0b41bf4_3 49 | - gmp=6.3.0=h59595ed_1 50 | - gnutls=3.7.9=hb077bed_0 51 | - graphite2=1.3.13=h58526e2_1001 52 | - graphviz=9.0.0=h78e8752_1 53 | - greenlet=3.0.3=py39h3d6467e_0 54 | - gtk2=2.24.33=h280cfa0_4 55 | - gts=0.7.6=h977cf35_4 56 | - harfbuzz=8.3.0=h3d44ed6_0 57 | - hdbscan=0.8.33=py39h44dd56e_4 58 | - icu=73.2=h59595ed_0 59 | - idna=3.6=pyhd8ed1ab_0 60 | - importlib-metadata=7.0.2=pyha770c72_0 61 | - importlib-resources=6.3.0=pyhd8ed1ab_0 62 | - importlib_metadata=7.0.2=hd8ed1ab_0 63 | - importlib_resources=6.3.0=pyhd8ed1ab_0 64 | - ipykernel=6.29.3=pyhd33586a_0 65 | - jedi=0.19.1=pyhd8ed1ab_0 66 | - joblib=1.3.2=pyhd8ed1ab_0 67 | - jupyter_client=8.6.1=pyhd8ed1ab_0 68 | - jupyter_core=5.7.2=py39hf3d152e_0 69 | - kaldi=5.5.1068=cpu_h31769b2_2 70 | - keyutils=1.6.1=h166bdaf_0 71 | - kiwisolver=1.4.5=py39h7633fee_1 72 | - kneed=0.8.5=pyhd8ed1ab_0 73 | - krb5=1.21.2=h659d440_0 74 | - lame=3.100=h166bdaf_1003 75 | - lazy_loader=0.3=pyhd8ed1ab_0 76 | - lcms2=2.16=hb7c19ff_0 77 | - ld_impl_linux-64=2.40=h41732ed_0 78 | - lerc=4.0.0=h27087fc_0 79 | - libabseil=20240116.1=cxx17_h59595ed_2 80 | - libass=0.17.1=h8fe9dca_1 81 | - libblas=3.9.0=21_linux64_openblas 82 | - libbrotlicommon=1.1.0=hd590300_1 83 | - libbrotlidec=1.1.0=hd590300_1 84 | - libbrotlienc=1.1.0=hd590300_1 85 | - libcblas=3.9.0=21_linux64_openblas 86 | - libclang-cpp15=15.0.7=default_hb11cfb5_4 87 | - libdeflate=1.19=hd590300_0 88 | - libdrm=2.4.120=hd590300_0 89 | - libedit=3.1.20191231=he28a2e2_2 90 | - libexpat=2.6.2=h59595ed_0 91 | - libffi=3.4.2=h7f98852_5 92 | - libflac=1.4.3=h59595ed_0 93 | - libgcc-ng=13.2.0=h807b86a_5 94 | - libgd=2.3.3=h119a65a_9 95 | - libgfortran-ng=13.2.0=h69a702a_5 96 | - libgfortran5=13.2.0=ha4646dd_5 97 | - libglib=2.80.0=hf2295e7_0 98 | - libgomp=13.2.0=h807b86a_5 99 | - libhwloc=2.9.3=default_h554bfaf_1009 100 | - libiconv=1.17=hd590300_2 101 | - libidn2=2.3.7=hd590300_0 102 | - libjpeg-turbo=3.0.0=hd590300_1 103 | - liblapack=3.9.0=21_linux64_openblas 104 | - liblapacke=3.9.0=21_linux64_openblas 105 | - libllvm14=14.0.6=hcd5def8_4 106 | - libllvm15=15.0.7=hb3ce162_4 107 | - libllvmspirv15=15.0.0=h0cdce71_1 108 | - libnsl=2.0.1=hd590300_0 109 | - libogg=1.3.4=h7f98852_1 110 | - libopenblas=0.3.26=pthreads_h413a1c8_0 111 | - libopenvino=2024.0.0=h2e90f83_1 112 | - libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1 113 | - libopenvino-auto-plugin=2024.0.0=hd5fc58b_1 114 | - libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1 115 | - libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1 116 | - libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1 117 | - libopenvino-ir-frontend=2024.0.0=h3ecfda7_1 118 | - libopenvino-onnx-frontend=2024.0.0=h757c851_1 119 | - libopenvino-paddle-frontend=2024.0.0=h757c851_1 120 | - libopenvino-pytorch-frontend=2024.0.0=h59595ed_1 121 | - libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1 122 | - libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1 123 | - libopus=1.3.1=h7f98852_1 124 | - libpciaccess=0.18=hd590300_0 125 | - libpng=1.6.43=h2797004_0 126 | - libpq=16.2=h33b98f1_0 127 | - libprotobuf=4.25.3=h08a7969_0 128 | - librosa=0.10.1=pyhd8ed1ab_0 129 | - librsvg=2.56.3=he3f83f7_1 130 | - libsndfile=1.2.2=hc60ed4a_1 131 | - libsodium=1.0.18=h36c2ea0_1 132 | - libsqlite=3.45.2=h2797004_0 133 | - libstdcxx-ng=13.2.0=h7e041cc_5 134 | - libtasn1=4.19.0=h166bdaf_0 135 | - libtiff=4.6.0=ha9c0a0a_2 136 | - libunistring=0.9.10=h7f98852_0 137 | - libuuid=2.38.1=h0b41bf4_0 138 | - libva=2.21.0=hd590300_0 139 | - libvorbis=1.3.7=h9c3ff4c_0 140 | - libvpx=1.14.0=h59595ed_0 141 | - libwebp=1.3.2=h658648e_1 142 | - libwebp-base=1.3.2=hd590300_0 143 | - libxcb=1.15=h0b41bf4_0 144 | - libxcrypt=4.4.36=hd590300_1 145 | - libxml2=2.12.5=h232c23b_0 146 | - libzlib=1.2.13=hd590300_5 147 | - llvm-spirv-15=15.0.0=h0cdce71_1 148 | - mad=0.15.1b=h9c3ff4c_1 149 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 150 | - matplotlib-base=3.8.3=py39he9076e7_0 151 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 152 | - mdurl=0.1.2=pyhd8ed1ab_0 153 | - montreal-forced-aligner=2.2.17=pyhd8ed1ab_0 154 | - mpg123=1.32.4=h59595ed_0 155 | - msgpack-python=1.0.7=py39h7633fee_0 156 | - munkres=1.1.4=pyh9f0ad1d_0 157 | - ncurses=6.4=h59595ed_2 158 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 159 | - nettle=3.9.1=h7ab15ed_0 160 | - ngram=1.3.14=h924138e_2 161 | - numba=0.59.0=py39h615d6bd_1 162 | - numpy=1.26.4=py39h474f0d3_0 163 | - ocl-icd=2.3.2=hd590300_0 164 | - openfst=1.8.2=h924138e_2 165 | - openh264=2.4.1=h59595ed_0 166 | - openjpeg=2.5.2=h488ebb8_0 167 | - openssl=3.2.1=hd590300_0 168 | - p11-kit=0.24.1=hc5aa10d_0 169 | - packaging=24.0=pyhd8ed1ab_0 170 | - pandas=2.2.1=py39hddac248_0 171 | - pango=1.52.1=ha41ecd1_0 172 | - parso=0.8.3=pyhd8ed1ab_0 173 | - patsy=0.5.6=pyhd8ed1ab_0 174 | - pcre2=10.43=hcad00b1_0 175 | - pexpect=4.9.0=pyhd8ed1ab_0 176 | - pgvector-python=0.2.5=pyhe093146_0 177 | - pickleshare=0.7.5=py_1003 178 | - pillow=10.2.0=py39had0adad_0 179 | - pip=24.0=pyhd8ed1ab_0 180 | - pixman=0.43.2=h59595ed_0 181 | - platformdirs=4.2.0=pyhd8ed1ab_0 182 | - pocl=5.0=h03a6ac1_2 183 | - pocl-core=5.0=hdaecddf_2 184 | - pocl-cpu=5.0=he901f76_2 185 | - pocl-cpu-minimal=5.0=h5ccd973_2 186 | - pocl-cuda=5.0=hdaecddf_2 187 | - pocl-remote=5.0=h5ccd973_2 188 | - pooch=1.8.1=pyhd8ed1ab_0 189 | - postgresql=16.2=h7387d8b_0 190 | - prompt-toolkit=3.0.42=pyha770c72_0 191 | - prompt_toolkit=3.0.42=hd8ed1ab_0 192 | - psutil=5.9.8=py39hd1e30aa_0 193 | - psycopg2=2.9.9=py39h89197e3_0 194 | - pthread-stubs=0.4=h36c2ea0_1001 195 | - ptyprocess=0.7.0=pyhd3deb0d_0 196 | - pugixml=1.14=h59595ed_0 197 | - pure_eval=0.2.2=pyhd8ed1ab_0 198 | - pycparser=2.21=pyhd8ed1ab_0 199 | - pygments=2.17.2=pyhd8ed1ab_0 200 | - pyparsing=3.1.2=pyhd8ed1ab_0 201 | - pysocks=1.7.1=pyha2e5f31_6 202 | - pysoundfile=0.12.1=pypyhd8ed1ab_1 203 | - python=3.9.18=h0755675_1_cpython 204 | - python-tzdata=2024.1=pyhd8ed1ab_0 205 | - python_abi=3.9=4_cp39 206 | - pytz=2024.1=pyhd8ed1ab_0 207 | - pyyaml=6.0.1=py39hd1e30aa_1 208 | - pyzmq=25.1.2=py39h8c080ef_0 209 | - readline=8.2=h8228510_1 210 | - requests=2.31.0=pyhd8ed1ab_0 211 | - rich=13.7.1=pyhd8ed1ab_0 212 | - rich-click=1.7.4=pyhd8ed1ab_0 213 | - scikit-learn=1.2.2=py39hc236052_2 214 | - scipy=1.12.0=py39h474f0d3_2 215 | - seaborn=0.13.2=hd8ed1ab_0 216 | - seaborn-base=0.13.2=pyhd8ed1ab_0 217 | - setuptools=69.2.0=pyhd8ed1ab_0 218 | - six=1.16.0=pyh6c4a22f_0 219 | - snappy=1.1.10=h9fff704_0 220 | - sox=14.4.2=ha5cc309_1018 221 | - soxr=0.1.3=h0b41bf4_3 222 | - soxr-python=0.3.7=py39h44dd56e_0 223 | - sqlalchemy=2.0.28=py39hd1e30aa_0 224 | - sqlite=3.45.2=h2c6b66d_0 225 | - stack_data=0.6.2=pyhd8ed1ab_0 226 | - statsmodels=0.14.1=py39h44dd56e_0 227 | - svt-av1=1.8.0=h59595ed_0 228 | - tbb=2021.11.0=h00ab1b0_1 229 | - threadpoolctl=3.3.0=pyhc1e730c_0 230 | - tk=8.6.13=noxft_h4845f30_101 231 | - tornado=6.4=py39hd1e30aa_0 232 | - tqdm=4.66.2=pyhd8ed1ab_0 233 | - traitlets=5.14.2=pyhd8ed1ab_0 234 | - typing-extensions=4.10.0=hd8ed1ab_0 235 | - typing_extensions=4.10.0=pyha770c72_0 236 | - tzcode=2024a=h3f72095_0 237 | - tzdata=2024a=h0c530f3_0 238 | - unicodedata2=15.1.0=py39hd1e30aa_0 239 | - urllib3=2.2.1=pyhd8ed1ab_0 240 | - wcwidth=0.2.13=pyhd8ed1ab_0 241 | - wheel=0.42.0=pyhd8ed1ab_0 242 | - x264=1!164.3095=h166bdaf_2 243 | - x265=3.5=h924138e_3 244 | - xorg-fixesproto=5.0=h7f98852_1002 245 | - xorg-kbproto=1.0.7=h7f98852_1002 246 | - xorg-libice=1.1.1=hd590300_0 247 | - xorg-libsm=1.2.4=h7391055_0 248 | - xorg-libx11=1.8.7=h8ee46fc_0 249 | - xorg-libxau=1.0.11=hd590300_0 250 | - xorg-libxdmcp=1.1.3=h7f98852_0 251 | - xorg-libxext=1.3.4=h0b41bf4_2 252 | - xorg-libxfixes=5.0.3=h7f98852_1004 253 | - xorg-libxrender=0.9.11=hd590300_0 254 | - xorg-renderproto=0.11.1=h7f98852_1002 255 | - xorg-xextproto=7.3.0=h0b41bf4_1003 256 | - xorg-xproto=7.0.31=h7f98852_1007 257 | - xz=5.2.6=h166bdaf_0 258 | - yaml=0.2.5=h7f98852_2 259 | - zeromq=4.3.5=h59595ed_1 260 | - zipp=3.17.0=pyhd8ed1ab_0 261 | - zlib=1.2.13=hd590300_5 262 | - zstd=1.5.5=hfc55251_0 263 | - pip: 264 | - absl-py==2.1.0 265 | - aiofiles==23.2.1 266 | - aiohttp==3.9.3 267 | - aiosignal==1.3.1 268 | - altair==5.2.0 269 | - antlr4-python3-runtime==4.9.3 270 | - anyio==4.3.0 271 | - async-timeout==4.0.3 272 | - attrs==23.2.0 273 | - av==11.0.0 274 | - babel==2.14.0 275 | - beautifulsoup4==4.12.3 276 | - bibtexparser==2.0.0b7 277 | - bleach==6.1.0 278 | - blis==0.7.11 279 | - catalogue==2.0.10 280 | - clldutils==3.22.2 281 | - cloudpickle==3.0.0 282 | - cmake==3.28.3 283 | - colorlog==6.8.2 284 | - confection==0.1.4 285 | - csvw==3.3.0 286 | - cymem==2.0.8 287 | - cython==0.29.37 288 | - datasets==2.16.0 289 | - defusedxml==0.7.1 290 | - demucs==4.0.1 291 | - dill==0.3.6 292 | - dlinfo==1.2.1 293 | - docopt==0.6.2 294 | - dora-search==0.1.12 295 | - einops==0.7.0 296 | - encodec==0.1.1 297 | - exceptiongroup==1.2.0 298 | - fastapi==0.110.0 299 | - fastjsonschema==2.19.1 300 | - ffmpy==0.3.2 301 | - filelock==3.13.1 302 | - flashy==0.0.2 303 | - frozenlist==1.4.1 304 | - fsspec==2023.10.0 305 | - gradio==3.50.2 306 | - gradio-client==0.6.1 307 | - grpcio==1.62.1 308 | - h11==0.14.0 309 | - httpcore==1.0.4 310 | - httpx==0.27.0 311 | - huggingface-hub==0.22.2 312 | - hydra-colorlog==1.2.0 313 | - hydra-core==1.3.2 314 | - ipython==8.12.3 315 | - isodate==0.6.1 316 | - jinja2==3.1.3 317 | - jsonschema==4.21.1 318 | - jsonschema-specifications==2023.12.1 319 | - julius==0.2.7 320 | - jupyterlab-pygments==0.3.0 321 | - lameenc==1.7.0 322 | - langcodes==3.3.0 323 | - language-tags==1.2.0 324 | - lit==18.1.1 325 | - llvmlite==0.42.0 326 | - lxml==5.1.0 327 | - markdown==3.5.2 328 | - markupsafe==2.1.5 329 | - mistune==3.0.2 330 | - mpmath==1.3.0 331 | - msgpack==1.0.8 332 | - multidict==6.0.5 333 | - multiprocess==0.70.14 334 | - murmurhash==1.0.10 335 | - nbclient==0.10.0 336 | - nbconvert==7.16.3 337 | - nbformat==5.10.3 338 | - networkx==3.2.1 339 | - num2words==0.5.13 340 | - nvidia-cublas-cu11==11.10.3.66 341 | - nvidia-cuda-cupti-cu11==11.7.101 342 | - nvidia-cuda-nvrtc-cu11==11.7.99 343 | - nvidia-cuda-runtime-cu11==11.7.99 344 | - nvidia-cudnn-cu11==8.5.0.96 345 | - nvidia-cufft-cu11==10.9.0.58 346 | - nvidia-curand-cu11==10.2.10.91 347 | - nvidia-cusolver-cu11==11.4.0.1 348 | - nvidia-cusparse-cu11==11.7.4.91 349 | - nvidia-nccl-cu11==2.14.3 350 | - nvidia-nvtx-cu11==11.7.91 351 | - omegaconf==2.3.0 352 | - openunmix==1.2.1 353 | - orjson==3.9.15 354 | - pandocfilters==1.5.1 355 | - pathlib-abc==0.1.1 356 | - pathy==0.11.0 357 | - pgvector==0.2.2 358 | - phonemizer==3.2.1 359 | - pipreqs==0.5.0 360 | - praatio==6.2.0 361 | - preshed==3.0.9 362 | - protobuf==4.25.3 363 | - pyarrow==15.0.2 364 | - pyarrow-hotfix==0.6 365 | - pydantic==1.10.14 366 | - pydub==0.25.1 367 | - pylatexenc==2.10 368 | - pynini==2.1.6 369 | - pypinyin==0.48.0 370 | - python-dateutil==2.9.0.post0 371 | - python-multipart==0.0.9 372 | - rdflib==7.0.0 373 | - referencing==0.33.0 374 | - regex==2023.12.25 375 | - responses==0.18.0 376 | - retrying==1.3.4 377 | - rfc3986==1.5.0 378 | - rpds-py==0.18.0 379 | - safetensors==0.4.2 380 | - segments==2.2.1 381 | - semantic-version==2.10.0 382 | - sentencepiece==0.2.0 383 | - smart-open==6.4.0 384 | - sniffio==1.3.1 385 | - soupsieve==2.5 386 | - spacy==3.5.2 387 | - spacy-legacy==3.0.12 388 | - spacy-loggers==1.0.5 389 | - srsly==2.4.8 390 | - starlette==0.36.3 391 | - submitit==1.5.1 392 | - sympy==1.12 393 | - tabulate==0.9.0 394 | - tensorboard==2.16.2 395 | - tensorboard-data-server==0.7.2 396 | - thinc==8.1.12 397 | - tinycss2==1.2.1 398 | - tokenizers==0.15.2 399 | - toolz==0.12.1 400 | - torch==2.0.1 401 | - torchaudio==2.0.2 402 | - torchmetrics==0.11.1 403 | - transformers==4.38.2 404 | - treetable==0.2.5 405 | - triton==2.0.0 406 | - typer==0.7.0 407 | - uritemplate==4.1.1 408 | - uvicorn==0.28.0 409 | - wasabi==1.1.2 410 | - webencodings==0.5.1 411 | - websockets==11.0.3 412 | - werkzeug==3.0.1 413 | - xformers==0.0.22 414 | - xxhash==3.4.1 415 | - yarg==0.1.9 416 | - yarl==1.9.4 417 | prefix: /home/pyp/miniconda3/envs/voicecraft 418 | -------------------------------------------------------------------------------- /gradio_app.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9b6a0c92", 6 | "metadata": {}, 7 | "source": [ 8 | "### Only do the below if you are using docker" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "961faa43", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!source ~/.bashrc && \\\n", 19 | " apt-get update && \\\n", 20 | " apt-get install -y espeak espeak-data libespeak1 libespeak-dev && \\\n", 21 | " apt-get install -y festival* && \\\n", 22 | " apt-get install -y build-essential && \\\n", 23 | " apt-get install -y flac libasound2-dev libsndfile1-dev vorbis-tools && \\\n", 24 | " apt-get install -y libxml2-dev libxslt-dev zlib1g-dev" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "598d75cf", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "!source ~/.bashrc && \\\n", 35 | " conda activate voicecraft && \\\n", 36 | " pip install -r gradio_requirements.txt" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "8b9c4436", 42 | "metadata": {}, 43 | "source": [ 44 | "# STOP\n", 45 | "You have to do this part manually using the mouse/keyboard and the tabs at the top.\n", 46 | "\n", 47 | "* Refresh your browser to make sure it picks up the new kernel.\n", 48 | "* Kernel -> Change Kernel -> Select Kernel -> voicecraft\n", 49 | "* Kernel -> Restart Kernel -> Yes\n", 50 | "\n", 51 | "Now you can run the rest of the notebook and get an audio sample output. It will automatically download more models and such. The next time you use this container, you can just start below here as the dependencies will remain available until you delete the docker container." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "f089aa96", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from gradio_app import app\n", 62 | "app.launch()" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "voicecraft", 69 | "language": "python", 70 | "name": "voicecraft" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.9.19" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 5 87 | } 88 | -------------------------------------------------------------------------------- /gradio_requirements.txt: -------------------------------------------------------------------------------- 1 | gradio==3.50.2 2 | nltk>=3.8.1 3 | openai-whisper>=20231117 4 | aeneas>=1.7.3.0 5 | whisperx>=3.1.1 6 | huggingface_hub==0.22.2 7 | num2words==0.5.13 8 | -------------------------------------------------------------------------------- /inference_speech_editing_scale.py: -------------------------------------------------------------------------------- 1 | import argparse, pickle 2 | import logging 3 | import os, random 4 | import numpy as np 5 | import torch 6 | import torchaudio 7 | 8 | from data.tokenizer import ( 9 | AudioTokenizer, 10 | TextTokenizer, 11 | tokenize_audio, 12 | tokenize_text 13 | ) 14 | 15 | from models import voicecraft 16 | import argparse, time, tqdm 17 | 18 | # this script only works for the musicgen architecture 19 | def get_args(): 20 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") 22 | parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") 23 | parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") 24 | parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary") 25 | parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary") 26 | parser.add_argument("--seed", type=int, default=1) 27 | parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for') 28 | parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes') 29 | parser.add_argument("--top_k", type=int, default=-1, help="sampling param") 30 | parser.add_argument("--top_p", type=float, default=0.8, help="sampling param") 31 | parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") 32 | parser.add_argument("--output_dir", type=str, default=None) 33 | parser.add_argument("--device", type=str, default="cuda") 34 | parser.add_argument("--signature", type=str, default=None, help="path to the encodec model") 35 | parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it") 36 | parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without') 37 | parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") 38 | return parser.parse_args() 39 | 40 | @torch.no_grad() 41 | def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config): 42 | # phonemize 43 | text_tokens = [phn2num[phn] for phn in 44 | tokenize_text( 45 | text_tokenizer, text=target_text.strip() 46 | ) if phn in phn2num 47 | ] 48 | text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) 49 | text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) 50 | 51 | encoded_frames = tokenize_audio(audio_tokenizer, audio_fn) 52 | original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K] 53 | assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape 54 | logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") 55 | 56 | # forward 57 | stime = time.time() 58 | encoded_frames = model.inference( 59 | text_tokens.to(device), 60 | text_tokens_lens.to(device), 61 | original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] 62 | mask_interval=mask_interval.unsqueeze(0).to(device), 63 | top_k=decode_config['top_k'], 64 | top_p=decode_config['top_p'], 65 | temperature=decode_config['temperature'], 66 | stop_repetition=decode_config['stop_repetition'], 67 | kvcache=decode_config['kvcache'], 68 | silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'], 69 | ) # output is [1,K,T] 70 | logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") 71 | if type(encoded_frames) == tuple: 72 | encoded_frames = encoded_frames[0] 73 | logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.") 74 | 75 | 76 | # decode (both original and generated) 77 | original_sample = audio_tokenizer.decode( 78 | [(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T] 79 | ) 80 | generated_sample = audio_tokenizer.decode( 81 | [(encoded_frames, None)] 82 | ) 83 | 84 | return original_sample, generated_sample 85 | 86 | def get_model(exp_dir, device=None): 87 | with open(os.path.join(exp_dir, "args.pkl"), "rb") as f: 88 | model_args = pickle.load(f) 89 | 90 | logging.info("load model weights...") 91 | model = voicecraft.VoiceCraft(model_args) 92 | ckpt_fn = os.path.join(exp_dir, "best_bundle.pth") 93 | ckpt = torch.load(ckpt_fn, map_location='cpu')['model'] 94 | phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num'] 95 | model.load_state_dict(ckpt) 96 | del ckpt 97 | logging.info("done loading weights...") 98 | if device == None: 99 | device = torch.device("cpu") 100 | if torch.cuda.is_available(): 101 | device = torch.device("cuda:0") 102 | model.to(device) 103 | model.eval() 104 | return model, model_args, phn2num 105 | 106 | 107 | def get_mask_interval(ali_fn, word_span_ind, editType): 108 | with open(ali_fn, "r") as rf: 109 | data = [l.strip().split(",") for l in rf.readlines()] 110 | data = data[1:] 111 | tmp = word_span_ind.split(",") 112 | s, e = int(tmp[0]), int(tmp[-1]) 113 | start = None 114 | for j, item in enumerate(data): 115 | if j == s and item[3] == "words": 116 | if editType == 'insertion': 117 | start = float(item[1]) 118 | else: 119 | start = float(item[0]) 120 | if j == e and item[3] == "words": 121 | if editType == 'insertion': 122 | end = float(item[0]) 123 | else: 124 | end = float(item[1]) 125 | assert start != None 126 | break 127 | return (start, end) 128 | 129 | if __name__ == "__main__": 130 | def seed_everything(seed): 131 | os.environ['PYTHONHASHSEED'] = str(seed) 132 | random.seed(seed) 133 | np.random.seed(seed) 134 | torch.manual_seed(seed) 135 | torch.cuda.manual_seed(seed) 136 | torch.backends.cudnn.benchmark = False 137 | torch.backends.cudnn.deterministic = True 138 | formatter = ( 139 | "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" 140 | ) 141 | logging.basicConfig(format=formatter, level=logging.INFO) 142 | args = get_args() 143 | # args.device = 'cpu' 144 | args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens) 145 | seed_everything(args.seed) 146 | 147 | # load model 148 | stime = time.time() 149 | logging.info(f"loading model from {args.exp_dir}") 150 | model, model_args, phn2num = get_model(args.exp_dir) 151 | if not os.path.isfile(model_args.exp_dir): 152 | model_args.exp_dir = args.exp_dir 153 | logging.info(f"loading model done, took {time.time() - stime:.4f} sec") 154 | 155 | # setup text and audio tokenizer 156 | text_tokenizer = TextTokenizer(backend="espeak") 157 | audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu 158 | 159 | with open(args.manifest_fn, "r") as rf: 160 | manifest = [l.strip().split("\t") for l in rf.readlines()] 161 | manifest = manifest[1:] 162 | 163 | # wav_fn txt_fn alingment_fn num_words word_span_ind 164 | audio_fns = [] 165 | target_texts = [] 166 | mask_intervals = [] 167 | edit_types = [] 168 | new_spans = [] 169 | orig_spans = [] 170 | os.makedirs(args.output_dir, exist_ok=True) 171 | if args.crop_concat: 172 | mfa_temp = f"{args.output_dir}/mfa_temp" 173 | os.makedirs(mfa_temp, exist_ok=True) 174 | for item in manifest: 175 | audio_fn = os.path.join(args.audio_root, item[0]) 176 | temp = torchaudio.info(audio_fn) 177 | audio_dur = temp.num_frames/temp.sample_rate 178 | audio_fns.append(audio_fn) 179 | target_text = item[2].split("|")[-1] 180 | edit_types.append(item[5].split("|")) 181 | new_spans.append(item[4].split("|")) 182 | orig_spans.append(item[3].split("|")) 183 | target_texts.append(target_text) # the last transcript is the target 184 | # mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin 185 | mis = [] 186 | all_ind_intervals = item[3].split("|") 187 | editTypes = item[5].split("|") 188 | smaller_indx = [] 189 | alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv")) 190 | if not os.path.isfile(alignment_fn): 191 | alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/") 192 | assert os.path.isfile(alignment_fn), alignment_fn 193 | for ind_inter,editType in zip(all_ind_intervals, editTypes): 194 | # print(ind_inter) 195 | mi = get_mask_interval(alignment_fn, ind_inter, editType) 196 | mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds 197 | mis.append(mi) 198 | smaller_indx.append(mi[0]) 199 | ind = np.argsort(smaller_indx) 200 | mis = [mis[id] for id in ind] 201 | mask_intervals.append(mis) 202 | 203 | 204 | 205 | for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))): 206 | orig_mask_interval = mask_interval 207 | mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval] 208 | # logging.info(f"i: {i}, mask_interval: {mask_interval}") 209 | mask_interval = torch.LongTensor(mask_interval) # [M,2] 210 | orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args)) 211 | 212 | # save segments for comparison 213 | orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu() 214 | # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}") 215 | 216 | save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav" 217 | 218 | torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr) 219 | 220 | save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav" 221 | if not os.path.isfile(save_fn_orig): 222 | orig_audio, orig_sr = torchaudio.load(audio_fn) 223 | if orig_sr != args.codec_audio_sr: 224 | orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio) 225 | torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr) 226 | 227 | -------------------------------------------------------------------------------- /inference_tts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "VoiceCraft Inference Text To Speech Demo\n", 8 | "===" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### Select 'voicecraft' as the kernel" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# import libs\n", 25 | "# if this throws an error, something went wrong installing dependencies or changing the kernel above!\n", 26 | "import os\n", 27 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n", 28 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 29 | "os.environ[\"USER\"] = \"me\" # TODO change this to your username\n", 30 | "\n", 31 | "import torch\n", 32 | "import torchaudio\n", 33 | "import numpy as np\n", 34 | "import random\n", 35 | "from argparse import Namespace\n", 36 | "\n", 37 | "from data.tokenizer import (\n", 38 | " AudioTokenizer,\n", 39 | " TextTokenizer,\n", 40 | ")\n", 41 | "from huggingface_hub import hf_hub_download" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# # install MFA models and dictionaries if you haven't done so already, already done in the dockerfile or envrionment setup\n", 51 | "# !source ~/.bashrc && \\\n", 52 | "# conda activate voicecraft && \\\n", 53 | "# mfa model download dictionary english_us_arpa && \\\n", 54 | "# mfa model download acoustic english_us_arpa" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stderr", 64 | "output_type": "stream", 65 | "text": [ 66 | "Dora directory: /tmp/audiocraft_me\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "# load model, encodec, and phn2num\n", 72 | "# # load model, tokenizer, and other necessary files\n", 73 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 74 | "voicecraft_name=\"830M_TTSEnhanced.pth\" # or giga330M.pth, 330M_TTSEnhanced.pth, giga830M.pth\n", 75 | "\n", 76 | "# the new way of loading the model, with huggingface, recommended\n", 77 | "from models import voicecraft\n", 78 | "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", 79 | "phn2num = model.args.phn2num\n", 80 | "config = vars(model.args)\n", 81 | "model.to(device)\n", 82 | "\n", 83 | "\n", 84 | "# # the old way of loading the model\n", 85 | "# from models import voicecraft\n", 86 | "# filepath = hf_hub_download(repo_id=\"pyp1/VoiceCraft\", filename=voicecraft_name, repo_type=\"model\")\n", 87 | "# ckpt = torch.load(filepath, map_location=\"cpu\")\n", 88 | "# model = voicecraft.VoiceCraft(ckpt[\"config\"])\n", 89 | "# model.load_state_dict(ckpt[\"model\"])\n", 90 | "# config = vars(model.args)\n", 91 | "# phn2num = ckpt[\"phn2num\"]\n", 92 | "# model.to(device)\n", 93 | "# model.eval()\n", 94 | "\n", 95 | "\n", 96 | "encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n", 97 | "if not os.path.exists(encodec_fn):\n", 98 | " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n", 99 | " os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n", 100 | "audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n", 101 | "\n", 102 | "text_tokenizer = TextTokenizer(backend=\"espeak\")\n" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Prepare your audio\n", 112 | "# point to the original audio whose speech you want to clone\n", 113 | "# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n", 114 | "orig_audio = \"./demo/5895_34622_000026_000002.wav\"\n", 115 | "orig_transcript = \"Gwynplaine had, besides, for his work and for his feats of strength, round his neck and over his shoulders, an esclavine of leather.\"\n", 116 | "\n", 117 | "# move the audio and transcript to temp folder\n", 118 | "temp_folder = \"./demo/temp\"\n", 119 | "os.makedirs(temp_folder, exist_ok=True)\n", 120 | "os.system(f\"cp {orig_audio} {temp_folder}\")\n", 121 | "filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n", 122 | "with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n", 123 | " f.write(orig_transcript)\n", 124 | "# run MFA to get the alignment\n", 125 | "align_temp = f\"{temp_folder}/mfa_alignments\"\n", 126 | "!source ~/.bashrc && \\\n", 127 | " conda activate voicecraft && \\\n", 128 | " mfa align -v --clean -j 1 --output_format csv {temp_folder} \\\n", 129 | " english_us_arpa english_us_arpa {align_temp}\n", 130 | "\n", 131 | "# # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n", 132 | "# !source ~/.bashrc && \\\n", 133 | "# conda activate voicecraft && \\\n", 134 | "# mfa align -v --clean -j 1 --output_format csv {temp_folder} \\\n", 135 | "# english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\n", 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt\n", 146 | "cut_off_sec = 3.6 # NOTE: according to forced-alignment file demo/temp/mfa_alignments/5895_34622_000026_000002.wav, the word \"strength\" stop as 3.561 sec, so we use first 3.6 sec as the prompt. this should be different for different audio\n", 147 | "target_transcript = \"Gwynplaine had, besides, for his work and for his feats of strength, I cannot believe that the same model can also do text to speech synthesis too!\"\n", 148 | "# NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec.\n", 149 | "audio_fn = f\"{temp_folder}/{filename}.wav\"\n", 150 | "info = torchaudio.info(audio_fn)\n", 151 | "audio_dur = info.num_frames / info.sample_rate\n", 152 | "\n", 153 | "assert cut_off_sec < audio_dur, f\"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}\"\n", 154 | "prompt_end_frame = int(cut_off_sec * info.sample_rate)\n", 155 | "\n", 156 | "# run the model to get the output\n", 157 | "# hyperparameters for inference\n", 158 | "codec_audio_sr = 16000\n", 159 | "codec_sr = 50\n", 160 | "top_k = 40 # can also try 20, 30, 50\n", 161 | "top_p = 1 # 1 means do not do top-p sampling\n", 162 | "temperature = 1\n", 163 | "silence_tokens=[1388,1898,131]\n", 164 | "kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model\n", 165 | "\n", 166 | "# NOTE adjust the below three arguments if the generation is not as good\n", 167 | "stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1\n", 168 | "sample_batch_size = 3 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 4 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n", 169 | "seed = 1 # change seed if you are still unhappy with the result\n", 170 | "\n", 171 | "def seed_everything(seed):\n", 172 | " os.environ['PYTHONHASHSEED'] = str(seed)\n", 173 | " random.seed(seed)\n", 174 | " np.random.seed(seed)\n", 175 | " torch.manual_seed(seed)\n", 176 | " torch.cuda.manual_seed(seed)\n", 177 | " torch.backends.cudnn.benchmark = False\n", 178 | " torch.backends.cudnn.deterministic = True\n", 179 | "seed_everything(seed)\n", 180 | "\n", 181 | "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n", 182 | "from inference_tts_scale import inference_one_sample\n", 183 | "concated_audio, gen_audio = inference_one_sample(model, Namespace(**config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n", 184 | " \n", 185 | "# save segments for comparison\n", 186 | "concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n", 187 | "# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n", 188 | "\n", 189 | "\n", 190 | "# display the audio\n", 191 | "from IPython.display import Audio\n", 192 | "print(\"concatenate prompt and generated:\")\n", 193 | "display(Audio(concated_audio, rate=codec_audio_sr))\n", 194 | "\n", 195 | "print(\"generated:\")\n", 196 | "display(Audio(gen_audio, rate=codec_audio_sr))\n", 197 | "\n", 198 | "# # save the audio\n", 199 | "# # output_dir\n", 200 | "# output_dir = \"/home/pyp/VoiceCraft/demo/generated_tts\"\n", 201 | "# os.makedirs(output_dir, exist_ok=True)\n", 202 | "# seg_save_fn_gen = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav\"\n", 203 | "# seg_save_fn_concat = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav\" \n", 204 | "\n", 205 | "# torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)\n", 206 | "# torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)\n", 207 | "\n", 208 | "# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "voicecraft", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.9.18" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 4 240 | } 241 | -------------------------------------------------------------------------------- /inference_tts_scale.py: -------------------------------------------------------------------------------- 1 | import argparse, pickle 2 | import logging 3 | import os, random 4 | import numpy as np 5 | import torch 6 | import torchaudio 7 | 8 | from data.tokenizer import ( 9 | AudioTokenizer, 10 | TextTokenizer, 11 | tokenize_audio, 12 | tokenize_text 13 | ) 14 | 15 | from models import voicecraft 16 | import argparse, time, tqdm 17 | 18 | 19 | # this script only works for the musicgen architecture 20 | def get_args(): 21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") 23 | parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") 24 | parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") 25 | parser.add_argument("--seed", type=int, default=1) 26 | parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for') 27 | parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes') 28 | parser.add_argument("--top_k", type=int, default=40, help="sampling param") 29 | parser.add_argument("--top_p", type=float, default=1, help="sampling param") 30 | parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") 31 | parser.add_argument("--output_dir", type=str, default=None) 32 | parser.add_argument("--device", type=str, default="cuda") 33 | parser.add_argument("--signature", type=str, default=None, help="path to the encodec model") 34 | parser.add_argument("--crop_concat", type=int, default=0) 35 | parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it") 36 | parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without') 37 | parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation") 38 | parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") 39 | return parser.parse_args() 40 | 41 | 42 | @torch.no_grad() 43 | def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame): 44 | # phonemize 45 | text_tokens = [phn2num[phn] for phn in 46 | tokenize_text( 47 | text_tokenizer, text=target_text.strip() 48 | ) if phn in phn2num 49 | ] 50 | text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) 51 | text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) 52 | 53 | # encode audio 54 | encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame) 55 | original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K] 56 | assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape 57 | logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") 58 | 59 | # forward 60 | stime = time.time() 61 | if decode_config['sample_batch_size'] <= 1: 62 | logging.info(f"running inference with batch size 1") 63 | concat_frames, gen_frames = model.inference_tts( 64 | text_tokens.to(device), 65 | text_tokens_lens.to(device), 66 | original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] 67 | top_k=decode_config['top_k'], 68 | top_p=decode_config['top_p'], 69 | temperature=decode_config['temperature'], 70 | stop_repetition=decode_config['stop_repetition'], 71 | kvcache=decode_config['kvcache'], 72 | silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] 73 | ) # output is [1,K,T] 74 | else: 75 | logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.") 76 | concat_frames, gen_frames = model.inference_tts_batch( 77 | text_tokens.to(device), 78 | text_tokens_lens.to(device), 79 | original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] 80 | top_k=decode_config['top_k'], 81 | top_p=decode_config['top_p'], 82 | temperature=decode_config['temperature'], 83 | stop_repetition=decode_config['stop_repetition'], 84 | kvcache=decode_config['kvcache'], 85 | batch_size = decode_config['sample_batch_size'], 86 | silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] 87 | ) # output is [1,K,T] 88 | logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") 89 | 90 | logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.") 91 | 92 | # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)): 93 | # logging.info(f"{timestamp}: {codes.tolist()}") 94 | # decode (both original and generated) 95 | concat_sample = audio_tokenizer.decode( 96 | [(concat_frames, None)] # [1,T,8] -> [1,8,T] 97 | ) 98 | gen_sample = audio_tokenizer.decode( 99 | [(gen_frames, None)] 100 | ) 101 | #Empty cuda cache between runs 102 | if torch.cuda.is_available(): 103 | torch.cuda.empty_cache() 104 | # return 105 | return concat_sample, gen_sample 106 | 107 | def get_model(exp_dir, device=None): 108 | with open(os.path.join(exp_dir, "args.pkl"), "rb") as f: 109 | model_args = pickle.load(f) 110 | 111 | logging.info("load model weights...") 112 | model = voicecraft.VoiceCraft(model_args) 113 | ckpt_fn = os.path.join(exp_dir, "best_bundle.pth") 114 | ckpt = torch.load(ckpt_fn, map_location='cpu')['model'] 115 | phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num'] 116 | model.load_state_dict(ckpt) 117 | del ckpt 118 | logging.info("done loading weights...") 119 | if device == None: 120 | device = torch.device("cpu") 121 | if torch.cuda.is_available(): 122 | device = torch.device("cuda:0") 123 | model.to(device) 124 | model.eval() 125 | return model, model_args, phn2num 126 | 127 | if __name__ == "__main__": 128 | def seed_everything(seed): 129 | os.environ['PYTHONHASHSEED'] = str(seed) 130 | random.seed(seed) 131 | np.random.seed(seed) 132 | torch.manual_seed(seed) 133 | torch.cuda.manual_seed(seed) 134 | torch.backends.cudnn.benchmark = False 135 | torch.backends.cudnn.deterministic = True 136 | formatter = ( 137 | "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" 138 | ) 139 | logging.basicConfig(format=formatter, level=logging.INFO) 140 | args = get_args() 141 | # args.device='cpu' 142 | seed_everything(args.seed) 143 | 144 | os.makedirs(args.output_dir, exist_ok=True) 145 | # load model 146 | 147 | with open(args.manifest_fn, "r") as rf: 148 | manifest = [l.strip().split("\t") for l in rf.readlines()] 149 | manifest = manifest[1:] 150 | manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest] 151 | 152 | stime = time.time() 153 | logging.info(f"loading model from {args.exp_dir}") 154 | model, model_args, phn2num = get_model(args.exp_dir) 155 | logging.info(f"loading model done, took {time.time() - stime:.4f} sec") 156 | 157 | # setup text and audio tokenizer 158 | text_tokenizer = TextTokenizer(backend="espeak") 159 | audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu 160 | 161 | audio_fns = [] 162 | texts = [] 163 | prompt_end_frames = [] 164 | new_audio_fns = [] 165 | text_to_syn = [] 166 | 167 | for item in manifest: 168 | audio_fn = os.path.join(args.audio_root, item[0]) 169 | audio_fns.append(audio_fn) 170 | temp = torchaudio.info(audio_fn) 171 | prompt_end_frames.append(round(float(item[2])*temp.sample_rate)) 172 | texts.append(item[1]) 173 | new_audio_fns.append(item[-2]) 174 | all_text = item[1].split(" ") 175 | start_ind = int(item[-1].split(",")[0]) 176 | text_to_syn.append(" ".join(all_text[start_ind:])) 177 | 178 | for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))): 179 | output_expected_sr = args.codec_audio_sr 180 | concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame) 181 | 182 | # save segments for comparison 183 | concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() 184 | if output_expected_sr != args.codec_audio_sr: 185 | gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio) 186 | concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio) 187 | 188 | seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav" 189 | seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav" 190 | 191 | torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr) 192 | torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr) 193 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import pickle 4 | import argparse 5 | import logging 6 | import torch.distributed as dist 7 | from config import MyParser 8 | from steps import trainer 9 | 10 | 11 | if __name__ == "__main__": 12 | formatter = ( 13 | "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" 14 | ) 15 | logging.basicConfig(format=formatter, level=logging.INFO) 16 | 17 | torch.cuda.empty_cache() 18 | args = MyParser().parse_args() 19 | logging.info(args) 20 | exp_dir = Path(args.exp_dir) 21 | exp_dir.mkdir(exist_ok=True, parents=True) 22 | logging.info(f"exp_dir: {str(exp_dir)}") 23 | 24 | if args.resume: 25 | resume = args.resume 26 | assert(bool(args.exp_dir)) 27 | with open("%s/args.pkl" % args.exp_dir, "rb") as f: 28 | old_args = pickle.load(f) 29 | new_args = vars(args) 30 | old_args = vars(old_args) 31 | for key in new_args: 32 | if key not in old_args or old_args[key] != new_args[key]: 33 | old_args[key] = new_args[key] 34 | args = argparse.Namespace(**old_args) 35 | args.resume = resume 36 | else: 37 | with open("%s/args.pkl" % args.exp_dir, "wb") as f: 38 | pickle.dump(args, f) 39 | 40 | dist.init_process_group(backend='nccl', init_method='env://') 41 | rank = dist.get_rank() 42 | world_size = dist.get_world_size() 43 | torch.cuda.set_device(rank) 44 | my_trainer = trainer.Trainer(args, world_size, rank) 45 | my_trainer.train() -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | 22 | class TokenEmbedding(nn.Module): 23 | def __init__( 24 | self, 25 | dim_model: int, 26 | vocab_size: int, 27 | dropout: float = 0.0, 28 | ): 29 | super().__init__() 30 | 31 | self.vocab_size = vocab_size 32 | self.dim_model = dim_model 33 | 34 | self.dropout = torch.nn.Dropout(p=dropout) 35 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 36 | 37 | @property 38 | def weight(self) -> torch.Tensor: 39 | return self.word_embeddings.weight 40 | 41 | def embedding(self, index: int) -> torch.Tensor: 42 | return self.word_embeddings.weight[index : index + 1] 43 | 44 | def forward(self, x: torch.Tensor): 45 | X = self.word_embeddings(x) 46 | X = self.dropout(X) 47 | 48 | return X 49 | 50 | 51 | class SinePositionalEmbedding(nn.Module): 52 | def __init__( 53 | self, 54 | dim_model: int, 55 | dropout: float = 0.0, 56 | scale: bool = False, 57 | alpha: bool = False, 58 | ): 59 | super().__init__() 60 | self.dim_model = dim_model 61 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 62 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 63 | self.dropout = torch.nn.Dropout(p=dropout) 64 | 65 | self.reverse = False 66 | self.pe = None 67 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 68 | 69 | def extend_pe(self, x): 70 | """Reset the positional encodings.""" 71 | if self.pe is not None: 72 | if self.pe.size(1) >= x.size(1): 73 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 74 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 75 | return 76 | pe = torch.zeros(x.size(1), self.dim_model) 77 | if self.reverse: 78 | position = torch.arange( 79 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 80 | ).unsqueeze(1) 81 | else: 82 | position = torch.arange( 83 | 0, x.size(1), dtype=torch.float32 84 | ).unsqueeze(1) 85 | div_term = torch.exp( 86 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 87 | * -(math.log(10000.0) / self.dim_model) 88 | ) 89 | pe[:, 0::2] = torch.sin(position * div_term) 90 | pe[:, 1::2] = torch.cos(position * div_term) 91 | pe = pe.unsqueeze(0) 92 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 93 | 94 | def forward(self, x: torch.Tensor) -> torch.Tensor: 95 | self.extend_pe(x) 96 | output = x.unsqueeze(-1) if x.ndim == 2 else x 97 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 98 | return self.dropout(output) -------------------------------------------------------------------------------- /models/modules/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def top_k_top_p_filtering( 5 | logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 6 | ): 7 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 8 | Args: 9 | logits: logits distribution shape (batch size, vocabulary size) 10 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 11 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 12 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 13 | Make sure we keep at least min_tokens_to_keep per batch example in the output 14 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 15 | """ 16 | if top_k > 0: 17 | top_k = min( 18 | max(top_k, min_tokens_to_keep), logits.size(-1) 19 | ) # Safety check 20 | # Remove all tokens with a probability less than the last token of the top-k 21 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 22 | logits[indices_to_remove] = filter_value 23 | 24 | if top_p < 1.0: 25 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 26 | cumulative_probs = torch.cumsum( 27 | F.softmax(sorted_logits, dim=-1), dim=-1 28 | ) 29 | 30 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 31 | sorted_indices_to_remove = cumulative_probs > top_p 32 | if min_tokens_to_keep > 1: 33 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 34 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 35 | # Shift the indices to the right to keep also the first token above the threshold 36 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ 37 | ..., :-1 38 | ].clone() 39 | sorted_indices_to_remove[..., 0] = 0 40 | 41 | # scatter sorted tensors to original indexing 42 | indices_to_remove = sorted_indices_to_remove.scatter( 43 | 1, sorted_indices, sorted_indices_to_remove 44 | ) 45 | logits[indices_to_remove] = filter_value 46 | return logits 47 | 48 | def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): 49 | # temperature: (`optional`) float 50 | # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. 51 | # top_k: (`optional`) int 52 | # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. 53 | # top_p: (`optional`) float 54 | # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. 55 | 56 | # Temperature (higher temperature => more likely to sample low probability tokens) 57 | if temperature != 1.0: 58 | logits = logits / temperature 59 | # Top-p/top-k filtering 60 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 61 | # Sample 62 | token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) 63 | return token -------------------------------------------------------------------------------- /models/modules/utils.py: -------------------------------------------------------------------------------- 1 | # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2 | import torch 3 | 4 | 5 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 6 | """ 7 | Args: 8 | lengths: 9 | A 1-D tensor containing sentence lengths. 10 | max_len: 11 | The length of masks. 12 | Returns: 13 | Return a 2-D bool tensor, where masked positions 14 | are filled with `True` and non-masked positions are 15 | filled with `False`. 16 | 17 | >>> lengths = torch.tensor([1, 3, 2, 5]) 18 | >>> make_pad_mask(lengths) 19 | tensor([[False, True, True, True, True], 20 | [False, False, False, True, True], 21 | [False, False, True, True, True], 22 | [False, False, False, False, False]]) 23 | """ 24 | assert lengths.ndim == 1, lengths.ndim 25 | max_len = max(max_len, lengths.max()) 26 | n = lengths.size(0) 27 | seq_range = torch.arange(0, max_len, device=lengths.device) 28 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 29 | 30 | return expaned_lengths >= lengths.unsqueeze(-1) 31 | 32 | def generate_partial_autoregressive_mask(sz, start, end): 33 | mask = torch.zeros(sz, sz).bool() 34 | mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1) 35 | mask[:start, start:end] = True 36 | mask[end:, start:end] = True 37 | return mask 38 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import os 5 | import time 6 | import random 7 | import getpass 8 | import shutil 9 | import subprocess 10 | import torch 11 | import numpy as np 12 | import torchaudio 13 | from cog import BasePredictor, Input, Path, BaseModel 14 | 15 | os.environ["USER"] = getpass.getuser() 16 | 17 | from data.tokenizer import ( 18 | AudioTokenizer, 19 | TextTokenizer, 20 | ) 21 | from models import voicecraft 22 | from inference_tts_scale import inference_one_sample 23 | from edit_utils import get_span 24 | from inference_speech_editing_scale import ( 25 | inference_one_sample as inference_one_sample_editing, 26 | ) 27 | 28 | 29 | MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting 30 | MODEL_CACHE = "model_cache" 31 | 32 | 33 | class ModelOutput(BaseModel): 34 | whisper_transcript_orig_audio: str 35 | generated_audio: Path 36 | 37 | 38 | class WhisperxAlignModel: 39 | def __init__(self): 40 | from whisperx import load_align_model 41 | 42 | self.model, self.metadata = load_align_model( 43 | language_code="en", device="cuda:0" 44 | ) 45 | 46 | def align(self, segments, audio_path): 47 | from whisperx import align, load_audio 48 | 49 | audio = load_audio(audio_path) 50 | return align( 51 | segments, 52 | self.model, 53 | self.metadata, 54 | audio, 55 | device="cuda:0", 56 | return_char_alignments=False, 57 | )["segments"] 58 | 59 | 60 | class WhisperxModel: 61 | def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"): 62 | from whisperx import load_model 63 | 64 | # the model weights are cached from Systran/faster-whisper-base.en etc 65 | self.model = load_model( 66 | model_name, 67 | device, 68 | asr_options={ 69 | "suppress_numerals": True, 70 | "max_new_tokens": None, 71 | "clip_timestamps": None, 72 | "hallucination_silence_threshold": None, 73 | }, 74 | ) 75 | self.align_model = align_model 76 | 77 | def transcribe(self, audio_path): 78 | segments = self.model.transcribe(audio_path, language="en", batch_size=8)[ 79 | "segments" 80 | ] 81 | return self.align_model.align(segments, audio_path) 82 | 83 | 84 | def download_weights(url, dest): 85 | start = time.time() 86 | print("downloading url: ", url) 87 | print("downloading to: ", dest) 88 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 89 | print("downloading took: ", time.time() - start) 90 | 91 | 92 | class Predictor(BasePredictor): 93 | def setup(self): 94 | """Load the model into memory to make running multiple predictions efficient""" 95 | self.device = "cuda" 96 | 97 | if not os.path.exists(MODEL_CACHE): 98 | download_weights(MODEL_URL, MODEL_CACHE) 99 | 100 | encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th" 101 | self.models, self.ckpt, self.phn2num = {}, {}, {} 102 | for voicecraft_name in [ 103 | "giga830M.pth", 104 | "giga330M.pth", 105 | "gigaHalfLibri330M_TTSEnhanced_max16s.pth", 106 | ]: 107 | ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}" 108 | 109 | self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu") 110 | self.models[voicecraft_name] = voicecraft.VoiceCraft( 111 | self.ckpt[voicecraft_name]["config"] 112 | ) 113 | self.models[voicecraft_name].load_state_dict( 114 | self.ckpt[voicecraft_name]["model"] 115 | ) 116 | self.models[voicecraft_name].to(self.device) 117 | self.models[voicecraft_name].eval() 118 | 119 | self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"] 120 | 121 | self.text_tokenizer = TextTokenizer(backend="espeak") 122 | self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) 123 | 124 | align_model = WhisperxAlignModel() 125 | self.transcribe_models = { 126 | k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model) 127 | for k in ["base.en", "small.en", "medium.en"] 128 | } 129 | 130 | def predict( 131 | self, 132 | task: str = Input( 133 | description="Choose a task", 134 | choices=[ 135 | "speech_editing-substitution", 136 | "speech_editing-insertion", 137 | "speech_editing-deletion", 138 | "zero-shot text-to-speech", 139 | ], 140 | default="zero-shot text-to-speech", 141 | ), 142 | voicecraft_model: str = Input( 143 | description="Choose a model", 144 | choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"], 145 | default="giga330M_TTSEnhanced.pth", 146 | ), 147 | orig_audio: Path = Input(description="Original audio file"), 148 | orig_transcript: str = Input( 149 | description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing", 150 | default="", 151 | ), 152 | whisperx_model: str = Input( 153 | description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above", 154 | choices=[ 155 | "base.en", 156 | "small.en", 157 | "medium.en", 158 | ], 159 | default="base.en", 160 | ), 161 | target_transcript: str = Input( 162 | description="Transcript of the target audio file", 163 | ), 164 | cut_off_sec: float = Input( 165 | description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", 166 | default=3.01, 167 | ), 168 | kvcache: int = Input( 169 | description="Set to 0 to use less VRAM, but with slower inference", 170 | choices=[0, 1], 171 | default=1, 172 | ), 173 | left_margin: float = Input( 174 | description="Margin to the left of the editing segment", 175 | default=0.08, 176 | ), 177 | right_margin: float = Input( 178 | description="Margin to the right of the editing segment", 179 | default=0.08, 180 | ), 181 | temperature: float = Input( 182 | description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change", 183 | default=1, 184 | ), 185 | top_p: float = Input( 186 | description="Default value for TTS is 0.9, and 0.8 for speech editing", 187 | default=1, 188 | ), 189 | stop_repetition: int = Input( 190 | default=3, 191 | description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4", 192 | ), 193 | sample_batch_size: int = Input( 194 | description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one", 195 | default=4, 196 | ), 197 | seed: int = Input( 198 | description="Random seed. Leave blank to randomize the seed", default=None 199 | ), 200 | ) -> ModelOutput: 201 | """Run a single prediction on the model""" 202 | 203 | if seed is None: 204 | seed = int.from_bytes(os.urandom(2), "big") 205 | print(f"Using seed: {seed}") 206 | 207 | seed_everything(seed) 208 | 209 | segments = self.transcribe_models[whisperx_model].transcribe( 210 | str(orig_audio) 211 | ) 212 | 213 | state = get_transcribe_state(segments) 214 | 215 | whisper_transcript = state["transcript"].strip() 216 | 217 | if len(orig_transcript.strip()) == 0: 218 | orig_transcript = whisper_transcript 219 | 220 | print(f"The transcript from the Whisper model: {whisper_transcript}") 221 | 222 | temp_folder = "exp_dir" 223 | if os.path.exists(temp_folder): 224 | shutil.rmtree(temp_folder) 225 | 226 | os.makedirs(temp_folder) 227 | 228 | filename = "orig_audio" 229 | audio_fn = str(orig_audio) 230 | 231 | info = torchaudio.info(audio_fn) 232 | audio_dur = info.num_frames / info.sample_rate 233 | 234 | # hyperparameters for inference 235 | codec_audio_sr = 16000 236 | codec_sr = 50 237 | top_k = 40 238 | silence_tokens = [1388, 1898, 131] 239 | 240 | if voicecraft_model == "giga330M_TTSEnhanced.pth": 241 | voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth" 242 | 243 | if task == "zero-shot text-to-speech": 244 | assert ( 245 | cut_off_sec < audio_dur 246 | ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" 247 | prompt_end_frame = int(cut_off_sec * info.sample_rate) 248 | 249 | idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec) 250 | orig_transcript_until_cutoff_time = " ".join( 251 | [word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]] 252 | ) 253 | else: 254 | edit_type = task.split("-")[-1] 255 | orig_span, new_span = get_span( 256 | orig_transcript, target_transcript, edit_type 257 | ) 258 | if orig_span[0] > orig_span[1]: 259 | RuntimeError(f"example {audio_fn} failed") 260 | if orig_span[0] == orig_span[1]: 261 | orig_span_save = [orig_span[0]] 262 | else: 263 | orig_span_save = orig_span 264 | if new_span[0] == new_span[1]: 265 | new_span_save = [new_span[0]] 266 | else: 267 | new_span_save = new_span 268 | orig_span_save = ",".join([str(item) for item in orig_span_save]) 269 | new_span_save = ",".join([str(item) for item in new_span_save]) 270 | 271 | start, end = get_mask_interval_from_word_bounds( 272 | state["word_bounds"], orig_span_save, edit_type 273 | ) 274 | 275 | # span in codec frames 276 | morphed_span = ( 277 | max(start - left_margin, 1 / codec_sr), 278 | min(end + right_margin, audio_dur), 279 | ) # in seconds 280 | mask_interval = [ 281 | [round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)] 282 | ] 283 | mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now 284 | 285 | decode_config = { 286 | "top_k": top_k, 287 | "top_p": top_p, 288 | "temperature": temperature, 289 | "stop_repetition": stop_repetition, 290 | "kvcache": kvcache, 291 | "codec_audio_sr": codec_audio_sr, 292 | "codec_sr": codec_sr, 293 | "silence_tokens": silence_tokens, 294 | } 295 | 296 | if task == "zero-shot text-to-speech": 297 | decode_config["sample_batch_size"] = sample_batch_size 298 | _, gen_audio = inference_one_sample( 299 | self.models[voicecraft_model], 300 | self.ckpt[voicecraft_model]["config"], 301 | self.phn2num[voicecraft_model], 302 | self.text_tokenizer, 303 | self.audio_tokenizer, 304 | audio_fn, 305 | orig_transcript_until_cutoff_time.strip() 306 | + " " 307 | + target_transcript.strip(), 308 | self.device, 309 | decode_config, 310 | prompt_end_frame, 311 | ) 312 | else: 313 | _, gen_audio = inference_one_sample_editing( 314 | self.models[voicecraft_model], 315 | self.ckpt[voicecraft_model]["config"], 316 | self.phn2num[voicecraft_model], 317 | self.text_tokenizer, 318 | self.audio_tokenizer, 319 | audio_fn, 320 | target_transcript, 321 | mask_interval, 322 | self.device, 323 | decode_config, 324 | ) 325 | 326 | # save segments for comparison 327 | gen_audio = gen_audio[0].cpu() 328 | 329 | out = "/tmp/out.wav" 330 | torchaudio.save(out, gen_audio, codec_audio_sr) 331 | return ModelOutput( 332 | generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript 333 | ) 334 | 335 | 336 | def seed_everything(seed): 337 | os.environ["PYTHONHASHSEED"] = str(seed) 338 | random.seed(seed) 339 | np.random.seed(seed) 340 | torch.manual_seed(seed) 341 | torch.cuda.manual_seed(seed) 342 | torch.backends.cudnn.benchmark = False 343 | torch.backends.cudnn.deterministic = True 344 | 345 | 346 | def get_transcribe_state(segments): 347 | words_info = [word_info for segment in segments for word_info in segment["words"]] 348 | return { 349 | "transcript": " ".join([segment["text"].strip() for segment in segments]), 350 | "word_bounds": [ 351 | {"word": word["word"], "start": word["start"], "end": word["end"]} 352 | for word in words_info 353 | ], 354 | } 355 | 356 | 357 | def find_closest_cut_off_word(word_bounds, cut_off_sec): 358 | min_distance = float("inf") 359 | 360 | for i, word_bound in enumerate(word_bounds): 361 | distance = abs(word_bound["start"] - cut_off_sec) 362 | 363 | if distance < min_distance: 364 | min_distance = distance 365 | 366 | if word_bound["end"] > cut_off_sec: 367 | break 368 | 369 | return i 370 | 371 | 372 | def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType): 373 | tmp = word_span_ind.split(",") 374 | s, e = int(tmp[0]), int(tmp[-1]) 375 | start = None 376 | for j, item in enumerate(word_bounds): 377 | if j == s: 378 | if editType == "insertion": 379 | start = float(item["end"]) 380 | else: 381 | start = float(item["start"]) 382 | if j == e: 383 | if editType == "insertion": 384 | end = float(item["start"]) 385 | else: 386 | end = float(item["end"]) 387 | assert start is not None 388 | break 389 | return (start, end) 390 | -------------------------------------------------------------------------------- /pretrained_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/pretrained_models/.gitkeep -------------------------------------------------------------------------------- /start-jupyter.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | echo Creating and running the Jupyter container... 4 | 5 | docker run -it -d ^ 6 | --gpus all ^ 7 | -p 8888:8888 ^ 8 | -p 7860:7860 ^ 9 | --name jupyter ^ 10 | --user root ^ 11 | -e NB_USER="%username%" ^ 12 | -e CHOWN_HOME=yes ^ 13 | -e GRANT_SUDO=yes ^ 14 | -e JUPYTER_TOKEN=mytoken ^ 15 | -w "/home/%username%" ^ 16 | -v "%cd%":"/home/%username%/work" ^ 17 | voicecraft 18 | 19 | if %errorlevel% == 0 ( 20 | echo Jupyter container created and running. 21 | 22 | echo Jupyter container is running. 23 | echo To access the Jupyter web UI, please follow these steps: 24 | echo 1. Open your web browser 25 | echo 2. Navigate to http://localhost:8888/?token=mytoken 26 | echo 3. !! The default token is "mytoken" and should be changed. !! 27 | pause 28 | ) else ( 29 | echo Failed to create and run the Jupyter container. 30 | ) 31 | -------------------------------------------------------------------------------- /start-jupyter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ## Assumes you have docker installed with nvidia container container-toolkit 3 | # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.13.5/install-guide.html 4 | # sudo apt-get install -y nvidia-container-toolkit-base || yay -Syu nvidia-container-toolkit || echo etc... 5 | ## Try to start an existing container otherwise create a new one 6 | docker start jupyter 2> /dev/null || \ 7 | docker run -it \ 8 | -d \ 9 | --gpus all \ 10 | -p 8888:8888 \ 11 | -p 7860:7860 \ 12 | --name jupyter \ 13 | --user root \ 14 | -e NB_USER="$USER" \ 15 | -e CHOWN_HOME=yes \ 16 | -e GRANT_SUDO=yes \ 17 | -w "/home/${NB_USER}" \ 18 | -v "$PWD":"/home/$USER/work" \ 19 | voicecraft 20 | 21 | ## `docker logs jupyter` to get the URL link and token e.g. 22 | ## http://127.0.0.1:8888/lab?token=blahblahblahblabhlaabhalbhalbhal 23 | -------------------------------------------------------------------------------- /steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/VoiceCraft/a702dfd2ced6d4fd6b04bdc160c832c6efc8f6c5/steps/__init__.py -------------------------------------------------------------------------------- /steps/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os, random 3 | import torch 4 | import math, pickle 5 | from tqdm import tqdm 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from torch.utils.tensorboard import SummaryWriter 11 | import numpy as np 12 | from torch.utils.data.distributed import DistributedSampler 13 | import logging 14 | from data import gigaspeech 15 | from models import voicecraft 16 | 17 | from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, AverageMeter, print_model_info 18 | from .optim import ScaledAdam, Eden 19 | 20 | 21 | class Trainer: 22 | 23 | def __init__(self, args, world_size, rank): 24 | self.start_time = time.time() 25 | self.args = args 26 | self.world_size, self.rank = world_size, rank 27 | self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 28 | if self.rank == 0: 29 | self.writer = SummaryWriter(args.exp_dir) 30 | self.seed_everything(seed=self.args.seed) 31 | self.meters = self._setup_meters() 32 | 33 | self.progress, self.total_progress = self._setup_progress() 34 | 35 | self.model, self.trainables, self.optim_states, self.scheduler_states = self._setup_models() 36 | 37 | self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() 38 | if self.args.num_steps != None: 39 | self.total_step = self.args.num_steps 40 | self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None 41 | else: 42 | self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs 43 | 44 | self.optimizer, self.scheduler = self._setup_optimizer() 45 | self.scaler = torch.cuda.amp.GradScaler() 46 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank], find_unused_parameters=False) 47 | 48 | if self.rank == 0: 49 | self.early_stop_accu_steps = 0 50 | if self.args.dynamic_batching: 51 | logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}") 52 | else: 53 | logging.info(f"batch size (summed over all GPUs): {self.args.batch_size}") 54 | 55 | def train(self): 56 | flag = True 57 | skip_flag = False 58 | data_start_time = time.time() 59 | while flag: 60 | self.train_sampler.set_epoch(self.progress['epoch']) 61 | for i, batch in enumerate(self.train_loader): 62 | data_end_time = time.time() 63 | self.model.train() 64 | if self.progress['step'] > self.total_step: 65 | flag = False 66 | self.validate_and_save() 67 | if self.rank == 0: 68 | self.writer.close() 69 | break 70 | if isinstance(self.scheduler, Eden): 71 | self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1) 72 | if self.args.optimizer_name == "ScaledAdam": 73 | cur_lr = self.scheduler.get_last_lr()[0] 74 | else: 75 | lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] 76 | assert lrs[0] == lrs[1] 77 | cur_lr = lrs[0] 78 | 79 | if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0: 80 | self.writer.add_scalar("train/lr", cur_lr, self.progress['step']) 81 | 82 | all_inds = list(range(len(batch['y']))) 83 | sum_losses = 0 84 | sum_top10acc = 0 85 | sum_ntoken = 0 86 | sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] 87 | for j in range(self.args.gradient_accumulation_steps): 88 | cur_ind = all_inds[j::self.args.gradient_accumulation_steps] 89 | cur_batch = {key: batch[key][cur_ind] for key in batch} 90 | with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32): 91 | out = self.model(cur_batch) 92 | if out == None: 93 | continue 94 | 95 | record_loss = out['loss'].detach().to(self.rank) 96 | top10acc = out['top10acc'].to(self.rank) 97 | effective_ntoken = out['effective_ntoken'].to(self.rank) 98 | is_nan = torch.tensor(int(torch.isnan(record_loss).any()), dtype=torch.float32, device=self.rank) 99 | 100 | dist.all_reduce(record_loss, op=dist.ReduceOp.SUM) 101 | dist.all_reduce(top10acc, op=dist.ReduceOp.SUM) 102 | dist.all_reduce(effective_ntoken, op=dist.ReduceOp.SUM) 103 | dist.all_reduce(is_nan, op=dist.ReduceOp.SUM) 104 | 105 | # check if loss is nan 106 | if is_nan.item() > 0: 107 | logging.info(f"loss at step {self.progress['step']} is nan, therefore skip this batch") 108 | skip_flag = True 109 | continue 110 | 111 | sum_losses += record_loss.item() 112 | sum_top10acc += top10acc.item() 113 | sum_ntoken += effective_ntoken.item() 114 | 115 | if 'top10acc_by_codebook' in out: 116 | for cb in range(self.args.n_codebooks): 117 | top10acc_cbi = out['top10acc_by_codebook'][cb] 118 | dist.all_reduce(top10acc_cbi, op=dist.ReduceOp.SUM) 119 | sum_top10acc_cbi[cb] += top10acc_cbi.item() 120 | 121 | if self.rank == 0: 122 | average_loss = sum_losses / sum_ntoken 123 | average_top10acc = sum_top10acc / sum_ntoken 124 | self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size) 125 | self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) 126 | self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) 127 | average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)] 128 | for cb in range(self.args.n_codebooks): 129 | self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size) 130 | 131 | if self.progress['step'] % self.args.tb_write_every_n_steps == 0: 132 | self.writer.add_scalar('train/loss', average_loss, self.progress['step']) 133 | self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step']) 134 | self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step']) 135 | for cb in range(self.args.n_codebooks): 136 | self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step']) 137 | 138 | if self.args.optimizer_name == "ScaledAdam": 139 | self.scaler.scale(out['loss']).backward() 140 | else: 141 | self.scaler.scale(out['loss']/out['effective_ntoken']).backward() 142 | 143 | if skip_flag: 144 | self.optimizer.zero_grad() 145 | skip_flag = False 146 | continue 147 | 148 | if self.args.optimizer_name != "ScaledAdam": 149 | self.scaler.unscale_(self.optimizer) 150 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val) 151 | self.scaler.step(self.optimizer) 152 | self.scaler.update() 153 | 154 | self.optimizer.zero_grad() 155 | 156 | if self.args.optimizer_name == "ScaledAdam": 157 | self.scheduler.step_batch(self.progress['step']) 158 | else: 159 | self.scheduler.step() 160 | 161 | if self.rank == 0: 162 | self.meters['data_time'].update(data_end_time - data_start_time) 163 | self.meters['train_time'].update(time.time() - data_end_time) 164 | if self.progress['step'] % self.args.tb_write_every_n_steps == 0: 165 | self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step']) 166 | self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step']) 167 | 168 | 169 | # logging 170 | if self.progress['step'] % self.args.print_every_n_steps == 0: 171 | log_out = {} 172 | log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}" 173 | log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}" 174 | log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}" 175 | log_out['lr'] = f"{cur_lr:.7f}" 176 | log_out['ntokens'] = f"{sum_ntoken}" 177 | for key in self.meters: 178 | if self.meters[key].val != 0 or self.meters[key].avg != 0: 179 | log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}" 180 | logging.info(log_out) 181 | if np.isnan(self.meters['train_loss'].avg): 182 | logging.warning("training diverged...") 183 | raise RuntimeError("training diverged...") 184 | 185 | # validation and save models 186 | if self.progress['step'] % self.args.val_every_n_steps == 0: 187 | dist.barrier() 188 | self.validate_and_save() 189 | 190 | self.progress['step'] += 1 191 | self.progress['cur_step'] += 1 192 | 193 | data_start_time = time.time() 194 | self.progress['epoch'] += 1 195 | self.progress['cur_step'] = 0 # reset cur_step to be 0 196 | dist.destroy_process_group() 197 | 198 | def validate_and_save(self): 199 | self.model.eval() 200 | 201 | score = self.validate(self.valid_loader) 202 | 203 | if self.rank == 0: 204 | if self.args.early_stop_threshold > 0: 205 | if self.progress['best_score'] - score < self.args.early_stop_threshold: 206 | self.early_stop_accu_steps += self.args.val_every_n_steps 207 | if self.early_stop_accu_steps >= self.args.early_stop_step-1: 208 | logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}") 209 | logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}") 210 | dist.destroy_process_group() 211 | raise RuntimeError("early stop") 212 | else: 213 | self.early_stop_accu_steps = 0 214 | 215 | if (score < self.progress['best_score']): 216 | self.progress['best_step'] = self.progress['step'] 217 | self.progress['best_score'] = score 218 | save_path = os.path.join(self.args.exp_dir,"best_bundle.pth") 219 | torch.save( 220 | { 221 | "model": self.model.module.state_dict(), 222 | "optimizer": self.optimizer.state_dict(), 223 | "scheduler": self.scheduler.state_dict(), 224 | "config": self.args, 225 | "phn2num": self.train_loader.dataset.phn2num 226 | },save_path 227 | ) 228 | logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}") 229 | self._save_progress() 230 | save_path = os.path.join(self.args.exp_dir,"bundle.pth") 231 | torch.save( 232 | { 233 | "model": self.model.module.state_dict(), 234 | "optimizer": self.optimizer.state_dict(), 235 | "scheduler": self.scheduler.state_dict(), 236 | "config": self.args, 237 | "phn2num": self.train_loader.dataset.phn2num 238 | },save_path 239 | ) 240 | logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}") 241 | 242 | dist.barrier() 243 | 244 | def validate(self, valid_loader=None, hide_progress=True): 245 | if valid_loader == None: 246 | valid_loader = self.valid_loader 247 | self.model.eval() 248 | 249 | start_val_time = time.time() 250 | sum_losses = 0 251 | sum_top10acc = 0 252 | sum_ntoken = 0 253 | sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] 254 | 255 | with torch.no_grad(): 256 | for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)): 257 | out = self.model(batch) 258 | sum_losses += out['loss'] 259 | sum_top10acc += out['top10acc'] 260 | sum_ntoken += out['effective_ntoken'] 261 | if 'top10acc_by_codebook' in out: 262 | for cb in range(self.args.n_codebooks): 263 | sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb] 264 | 265 | dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM) 266 | dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM) 267 | dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM) 268 | 269 | if 'top10acc_by_codebook' in out: 270 | for cb in range(self.args.n_codebooks): 271 | dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM) 272 | 273 | if self.rank == 0: 274 | val_loss = sum_losses / sum_ntoken 275 | val_top10acc = sum_top10acc / sum_ntoken 276 | # logging 277 | self.meters['val_loss'].update(val_loss) 278 | logging.info(f"val loss: {val_loss:.5f}") 279 | self.writer.add_scalar("val/loss", val_loss, self.progress['step']) 280 | 281 | self.meters['val_top10acc'].update(val_top10acc) 282 | logging.info(f"val top10acc: {val_top10acc:.5f}") 283 | self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step']) 284 | for cb in range(self.args.n_codebooks): 285 | average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks 286 | self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi) 287 | self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step']) 288 | 289 | logging.info(f"validation takes: {time.time() - start_val_time:.2f}s") 290 | logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}") 291 | return val_loss.item() 292 | else: 293 | return None 294 | 295 | def _setup_meters(self): 296 | meters = {} 297 | meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time'] 298 | meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc'] 299 | meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] 300 | meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] 301 | for name in meter_names: 302 | meters[name] = AverageMeter() 303 | return meters 304 | def _setup_progress(self): 305 | progress = {} 306 | progress['best_step'] = 1 307 | progress['best_score'] = np.inf # this records loss value 308 | progress['step'] = 1 309 | progress['epoch'] = 1 310 | progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler 311 | total_progress = [] 312 | # if self.args.resume or self.args.validate: 313 | if self.args.resume: 314 | progress_pkl = "%s/progress.pkl" % self.args.exp_dir 315 | with open(progress_pkl, "rb") as f: 316 | total_progress = pickle.load(f) 317 | progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1] 318 | if self.rank == 0: 319 | logging.info("\nResume training from:") 320 | logging.info(" epoch = %s" % progress['epoch']) 321 | logging.info(" cur_step = %s" % progress['cur_step']) 322 | logging.info(" step = %s" % progress['step']) 323 | logging.info(" best_step = %s" % progress['best_step']) 324 | logging.info(" best_score = %s" % progress['best_score']) 325 | return progress, total_progress 326 | 327 | def _save_progress(self): 328 | self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time]) 329 | with open("%s/progress.pkl" % self.args.exp_dir, "wb") as f: 330 | pickle.dump(self.total_progress, f) 331 | 332 | def _setup_dataloader(self): 333 | assert self.args.dataset == 'gigaspeech', "only gigaspeech is supported for now" 334 | train_dataset, val_dataset = gigaspeech.dataset(self.args, 'train'), gigaspeech.dataset(self.args, 'validation') 335 | 336 | if self.args.dynamic_batching: 337 | train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0) 338 | valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0) 339 | else: 340 | train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True) 341 | valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False) 342 | 343 | if self.progress['step'] > 1: 344 | train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step']) 345 | 346 | if self.args.dynamic_batching: 347 | train_loader = torch.utils.data.DataLoader(train_dataset, 348 | batch_sampler=train_sampler, 349 | num_workers=self.args.num_workers//self.world_size, 350 | collate_fn=train_dataset.collate, persistent_workers=True 351 | ) 352 | valid_loader = torch.utils.data.DataLoader(val_dataset, 353 | batch_sampler=valid_sampler, 354 | num_workers=self.args.num_workers//self.world_size, 355 | collate_fn=val_dataset.collate, persistent_workers=True 356 | ) 357 | else: 358 | train_loader = torch.utils.data.DataLoader(train_dataset, 359 | batch_size=self.args.batch_size//self.world_size, sampler=train_sampler, num_workers=self.args.num_workers//self.world_size, 360 | collate_fn=train_dataset.collate, persistent_workers=True 361 | ) 362 | valid_loader = torch.utils.data.DataLoader(val_dataset, 363 | batch_size=self.args.batch_size//self.world_size, sampler=valid_sampler, 364 | num_workers=self.args.num_workers//self.world_size, 365 | collate_fn=val_dataset.collate, persistent_workers=True 366 | ) 367 | return len(train_dataset), train_sampler, train_loader, valid_loader 368 | 369 | 370 | 371 | def _setup_models(self): 372 | model = voicecraft.VoiceCraft(self.args) 373 | 374 | if self.rank == 0: 375 | logging.info(model) 376 | logging.info("model parameters") 377 | print_model_info(model) 378 | 379 | if self.progress['step'] > 1: 380 | bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu") 381 | model.load_state_dict(bundle['model']) 382 | optim_states = bundle['optimizer'] 383 | scheduler_states = bundle['scheduler'] 384 | if self.rank == 0: 385 | logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step'])) 386 | del bundle['model'] 387 | else: 388 | optim_states = None 389 | scheduler_states = None 390 | 391 | if self.args.load_model_from != None and self.progress['step'] <= 1: 392 | sd = torch.load(self.args.load_model_from, map_location="cpu")['model'] 393 | model.load_state_dict(sd) 394 | del sd 395 | 396 | if self.args.optimizer_name == "ScaledAdam": 397 | trainables = [p for p in model.parameters() if p.requires_grad] 398 | else: 399 | no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"] 400 | optimizer_grouped_parameters = [ 401 | { 402 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 403 | "weight_decay": self.args.weight_decay, 404 | }, 405 | { 406 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 407 | "weight_decay": 0.0, 408 | }, 409 | ] 410 | if len(optimizer_grouped_parameters[1]['params']) == 0: 411 | logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names") 412 | trainables = optimizer_grouped_parameters[0] 413 | else: 414 | trainables = optimizer_grouped_parameters 415 | model.to(self.device) 416 | 417 | return model, trainables, optim_states, scheduler_states 418 | 419 | 420 | def _setup_optimizer(self): 421 | if self.args.optimizer_name == "ScaledAdam": 422 | parameters_names = [] 423 | parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad]) 424 | optimizer = ScaledAdam( 425 | self.trainables, 426 | lr=self.args.lr, 427 | betas=(0.9, 0.95), 428 | clipping_scale=2.0, 429 | parameters_names=parameters_names, 430 | show_dominant_parameters=False, 431 | clipping_update_period=self.args.clipping_update_period, 432 | ) 433 | scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) 434 | 435 | else: 436 | optimizer = AdamW(self.trainables, lr=self.args.lr) 437 | warmup_steps = self.total_step * self.args.warmup_fraction 438 | def lr_lambda(current_step: int): 439 | if current_step < warmup_steps: 440 | return float(current_step) / float(max(1, warmup_steps)) 441 | return max( 442 | 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps)) 443 | ) 444 | 445 | scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) 446 | 447 | # if resume 448 | if self.progress['step'] > 1: 449 | optimizer.load_state_dict(self.optim_states) 450 | for state in optimizer.state.values(): 451 | for k, v in state.items(): 452 | if isinstance(v, torch.Tensor): 453 | state[k] = v.cuda() 454 | del self.optim_states 455 | 456 | scheduler.load_state_dict(self.scheduler_states) 457 | 458 | optimizer.zero_grad() 459 | return optimizer, scheduler 460 | 461 | def seed_everything(self, seed=1): 462 | os.environ['PYTHONHASHSEED'] = str(seed) 463 | random.seed(seed) 464 | np.random.seed(seed) 465 | torch.manual_seed(seed) 466 | torch.cuda.manual_seed(seed) 467 | torch.backends.cudnn.benchmark = False 468 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /tts_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will allow you to run TTS inference with Voicecraft 3 | Before getting started, be sure to follow the environment setup. 4 | """ 5 | 6 | from inference_tts_scale import inference_one_sample 7 | from models import voicecraft 8 | from data.tokenizer import ( 9 | AudioTokenizer, 10 | TextTokenizer, 11 | ) 12 | import argparse 13 | import random 14 | import numpy as np 15 | import torchaudio 16 | import torch 17 | import os 18 | os.environ["USER"] = "me" # TODO change this to your username 19 | 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser( 25 | description="VoiceCraft TTS Inference: see the script for more information on the options") 26 | 27 | parser.add_argument("-m", "--model_name", type=str, default="giga830M", choices=[ 28 | "giga330M", "giga830M", "giga330M_TTSEnhanced", "giga830M_TTSEnhanced"], 29 | help="VoiceCraft model to use") 30 | parser.add_argument("-st", "--silence_tokens", type=int, nargs="*", 31 | default=[1388, 1898, 131], help="Silence token IDs") 32 | parser.add_argument("-casr", "--codec_audio_sr", type=int, 33 | default=16000, help="Codec audio sample rate.") 34 | parser.add_argument("-csr", "--codec_sr", type=int, default=50, 35 | help="Codec sample rate.") 36 | 37 | parser.add_argument("-k", "--top_k", type=float, 38 | default=0, help="Top k value.") 39 | parser.add_argument("-p", "--top_p", type=float, 40 | default=0.8, help="Top p value.") 41 | parser.add_argument("-t", "--temperature", type=float, 42 | default=1, help="Temperature value.") 43 | parser.add_argument("-kv", "--kvcache", type=float, choices=[0, 1], 44 | default=0, help="Kvcache value.") 45 | parser.add_argument("-sr", "--stop_repetition", type=int, 46 | default=-1, help="Stop repetition for generation") 47 | parser.add_argument("--sample_batch_size", type=int, 48 | default=3, help="Batch size for sampling") 49 | parser.add_argument("-s", "--seed", type=int, 50 | default=1, help="Seed value.") 51 | parser.add_argument("-bs", "--beam_size", type=int, default=50, 52 | help="beam size for MFA alignment") 53 | parser.add_argument("-rbs", "--retry_beam_size", type=int, default=200, 54 | help="retry beam size for MFA alignment") 55 | parser.add_argument("--output_dir", type=str, default="./generated_tts", 56 | help="directory to save generated audio") 57 | parser.add_argument("-oa", "--original_audio", type=str, 58 | default="./demo/5895_34622_000026_000002.wav", help="location of audio file") 59 | parser.add_argument("-ot", "--original_transcript", type=str, 60 | default="Gwynplaine had, besides, for his work and for his feats of strength, round his neck and over his shoulders, an esclavine of leather.", 61 | help="original transcript") 62 | parser.add_argument("-tt", "--target_transcript", type=str, 63 | default="I cannot believe that the same model can also do text to speech synthesis too!", 64 | help="target transcript") 65 | parser.add_argument("-co", "--cut_off_sec", type=float, default=3.6, 66 | help="cut off point in seconds for input prompt") 67 | parser.add_argument("-ma", "--margin", type=float, default=0.04, 68 | help="margin in seconds between the end of the cutoff words and the start of the next word. If the next word is not immediately following the cutoff word, the algorithm is more tolerant to word alignment errors") 69 | parser.add_argument("-cuttol", "--cutoff_tolerance", type=float, default=1, help="tolerance in seconds for the cutoff time, if given cut_off_sec plus the tolerance, we still are not able to find the next word, we will use the best cutoff time found, i.e. likely no margin or very small margin between the end of the cutoff word and the start of the next word") 70 | 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | args = parse_arguments() 76 | voicecraft_name = args.model_name 77 | # hyperparameters for inference 78 | codec_audio_sr = args.codec_audio_sr 79 | codec_sr = args.codec_sr 80 | top_k = args.top_k 81 | top_p = args.top_p # defaults to 0.9 can also try 0.8, but 0.9 seems to work better 82 | temperature = args.temperature 83 | silence_tokens = args.silence_tokens 84 | kvcache = args.kvcache # NOTE if OOM, change this to 0, or try the 330M model 85 | 86 | # NOTE adjust the below three arguments if the generation is not as good 87 | # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1 88 | stop_repetition = args.stop_repetition 89 | 90 | # NOTE: if the if there are long silence or unnaturally strecthed words, 91 | # increase sample_batch_size to 4 or higher. What this will do to the model is that the 92 | # model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. 93 | # So if the speech rate of the generated is too fast change it to a smaller number. 94 | sample_batch_size = args.sample_batch_size 95 | seed = args.seed # change seed if you are still unhappy with the result 96 | 97 | # load the model 98 | if voicecraft_name == "330M": 99 | voicecraft_name = "giga330M" 100 | elif voicecraft_name == "830M": 101 | voicecraft_name = "giga830M" 102 | elif voicecraft_name == "330M_TTSEnhanced": 103 | voicecraft_name = "330M_TTSEnhanced" 104 | elif voicecraft_name == "830M_TTSEnhanced": 105 | voicecraft_name = "830M_TTSEnhanced" 106 | model = voicecraft.VoiceCraft.from_pretrained( 107 | f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") 108 | phn2num = model.args.phn2num 109 | config = vars(model.args) 110 | model.to(device) 111 | 112 | encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" 113 | if not os.path.exists(encodec_fn): 114 | os.system( 115 | f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th -O ./pretrained_models/encodec_4cb2048_giga.th") 116 | # will also put the neural codec model on gpu 117 | audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) 118 | 119 | text_tokenizer = TextTokenizer(backend="espeak") 120 | 121 | # Prepare your audio 122 | # point to the original audio whose speech you want to clone 123 | # write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file 124 | orig_audio = args.original_audio 125 | orig_transcript = args.original_transcript 126 | 127 | # move the audio and transcript to temp folder 128 | temp_folder = "./demo/temp" 129 | os.makedirs(temp_folder, exist_ok=True) 130 | os.system(f"cp {orig_audio} {temp_folder}") 131 | filename = os.path.splitext(orig_audio.split("/")[-1])[0] 132 | with open(f"{temp_folder}/{filename}.txt", "w") as f: 133 | f.write(orig_transcript) 134 | # run MFA to get the alignment 135 | align_temp = f"{temp_folder}/mfa_alignments" 136 | beam_size = args.beam_size 137 | retry_beam_size = args.retry_beam_size 138 | alignments = f"{temp_folder}/mfa_alignments/{filename}.csv" 139 | if not os.path.isfile(alignments): 140 | os.system(f"mfa align -v --clean -j 1 --output_format csv {temp_folder} \ 141 | english_us_arpa english_us_arpa {align_temp} --beam {beam_size} --retry_beam {retry_beam_size}") 142 | # if the above fails, it could be because the audio is too hard for the alignment model, 143 | # increasing the beam_size and retry_beam_size usually solves the issue 144 | 145 | def find_closest_word_boundary(alignments, cut_off_sec, margin, cutoff_tolerance = 1): 146 | with open(alignments, 'r') as file: 147 | # skip header 148 | next(file) 149 | cutoff_time = None 150 | cutoff_index = None 151 | cutoff_time_best = None 152 | cutoff_index_best = None 153 | lines = [l for l in file.readlines()] 154 | for i, line in enumerate(lines): 155 | end = float(line.strip().split(',')[1]) 156 | if end >= cut_off_sec and cutoff_time == None: 157 | cutoff_time = end 158 | cutoff_index = i 159 | if end >= cut_off_sec and end < cut_off_sec + cutoff_tolerance and float(lines[i+1].strip().split(',')[0]) - end >= margin: 160 | cutoff_time_best = end + margin * 2 / 3 161 | cutoff_index_best = i 162 | break 163 | if cutoff_time_best != None: 164 | cutoff_time = cutoff_time_best 165 | cutoff_index = cutoff_index_best 166 | return cutoff_time, cutoff_index 167 | 168 | # take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt 169 | # NOTE: according to forced-alignment file demo/temp/mfa_alignments/5895_34622_000026_000002.wav, the word "strength" stop as 3.561 sec, so we use first 3.6 sec as the prompt. this should be different for different audio 170 | cut_off_sec = args.cut_off_sec 171 | margin = args.margin 172 | audio_fn = f"{temp_folder}/{filename}.wav" 173 | 174 | cut_off_sec, cut_off_word_idx = find_closest_word_boundary(alignments, cut_off_sec, margin, args.cutoff_tolerance) 175 | target_transcript = " ".join(orig_transcript.split(" ")[:cut_off_word_idx+1]) + " " + args.target_transcript 176 | # NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec. 177 | info = torchaudio.info(audio_fn) 178 | audio_dur = info.num_frames / info.sample_rate 179 | 180 | assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" 181 | prompt_end_frame = int(cut_off_sec * info.sample_rate) 182 | 183 | 184 | def seed_everything(seed): 185 | os.environ['PYTHONHASHSEED'] = str(seed) 186 | random.seed(seed) 187 | np.random.seed(seed) 188 | torch.manual_seed(seed) 189 | torch.cuda.manual_seed(seed) 190 | torch.backends.cudnn.benchmark = False 191 | torch.backends.cudnn.deterministic = True 192 | 193 | 194 | seed_everything(seed) 195 | 196 | # inference 197 | decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, 198 | "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size} 199 | concated_audio, gen_audio = inference_one_sample(model, argparse.Namespace( 200 | **config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame) 201 | 202 | # save segments for comparison 203 | concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() 204 | # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}") 205 | 206 | # save the audio 207 | # output_dir 208 | output_dir = args.output_dir 209 | os.makedirs(output_dir, exist_ok=True) 210 | seg_save_fn_gen = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav" 211 | seg_save_fn_concat = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav" 212 | 213 | torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr) 214 | torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr) 215 | 216 | # you might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored 217 | -------------------------------------------------------------------------------- /voicecraft-gradio-colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "Y87ixxsUVIhM" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "!git clone https://github.com/jasonppy/VoiceCraft" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "id": "-w3USR91XdxY" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "!pip install tensorboard\n", 33 | "!pip install phonemizer\n", 34 | "!pip install datasets\n", 35 | "!pip install torchmetrics\n", 36 | "\n", 37 | "!apt-get install -y espeak espeak-data libespeak1 libespeak-dev\n", 38 | "!apt-get install -y festival*\n", 39 | "!apt-get install -y build-essential\n", 40 | "!apt-get install -y flac libasound2-dev libsndfile1-dev vorbis-tools\n", 41 | "!apt-get install -y libxml2-dev libxslt-dev zlib1g-dev\n", 42 | "\n", 43 | "!pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft\n", 44 | "\n", 45 | "!pip install -r \"/content/VoiceCraft/gradio_requirements.txt\"\n", 46 | "!pip install typer==0.7.0" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": { 52 | "id": "jNuzjrtmv2n1" 53 | }, 54 | "source": [ 55 | "# Let it restarted, it won't let your entire installation be aborted." 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "AnqGEwZ4NxtJ" 62 | }, 63 | "source": [ 64 | "# Note before launching the `gradio_app.py`\n", 65 | "\n", 66 | "***You will get JSON warning if you move anything beside `sample_batch_size`, `stop_repetition` and `seed`.*** Which for most advanced setting, `kvache` and `temperature` unable to set in different value.\n", 67 | "\n", 68 | "You will download a .file File when you download the output audio for some reason. You will need to **convert the file from .snd to .wav/.mp3 manually**. Or if you enable showing file type in the name in Windows or wherever you are, change the file name to \"xxx.wav\" or \"xxx.mp3\". (know the solution? pull request my repository)\n", 69 | "\n", 70 | "Frequency of VRAM spikes no longer exist as well in April 5 Update.\n", 71 | "* Nevermind, I have observed some weird usage on Colab's GPU Memory Monitor. It can spike up to 13.5GB VRAM even in WhisperX mode. (April 11)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "dE0W76cMN3Si" 78 | }, 79 | "source": [ 80 | "Don't make your `prompt end time` too long, 6-9s is fine. Or else it will **either raise up JSON issue or cut off your generated audio**. This one is due to how VoiceCraft worked (so probably unfixable). It will add those text you want to get audio from at the end of the input audio transcript. It was way too much word for application or code to handle as it added up with original transcript. So please keep it short.\n", 81 | "\n", 82 | "Your total audio length (`prompt end time` + add-up audio) must not exceed 16 or 17s." 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "id": "nnu2cY4t8P6X" 89 | }, 90 | "source": [ 91 | "For voice cloning, I suggest you to probably have a monotone input to feed the voice cloning. Of course you can always try input that have tons of tone variety, but I find that as per April 11 Update, it's much more easy to replicate in monotone rather than audio that have laugh, scream, crying inside.\n", 92 | "\n", 93 | "The inference speed is much stable. With sample text, T4 (Free Tier Colab GPU) can do 6-15s on 6s-8s of `prompt end time`." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "id": "NDt4r4DiXAwG" 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "!python /content/VoiceCraft/gradio_app.py --demo-path=/content/VoiceCraft/demo --tmp-path=/content/VoiceCraft/demo/temp --models-path=/content/VoiceCraft/pretrained_models --share" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "accelerator": "GPU", 110 | "colab": { 111 | "authorship_tag": "ABX9TyPsqFhtOeQ18CXHnRkWAQSk", 112 | "gpuType": "T4", 113 | "include_colab_link": true, 114 | "provenance": [] 115 | }, 116 | "kernelspec": { 117 | "display_name": "Python 3", 118 | "name": "python3" 119 | }, 120 | "language_info": { 121 | "name": "python" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 0 126 | } 127 | -------------------------------------------------------------------------------- /z_scripts/e830M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate voicecraft 4 | export CUDA_VISIBLE_DEVICES=0,1,2,3 5 | export WORLD_SIZE=4 6 | 7 | dataset=gigaspeech 8 | mkdir -p ./logs/${dataset} 9 | 10 | exp_root="path/to/store/exp_results" 11 | exp_name=e830M 12 | dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step 13 | encodec_codes_folder_name="encodec_16khz_4codebooks" 14 | 15 | # export CUDA_LAUNCH_BLOCKING=1 # for debugging 16 | 17 | torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_per_node=${WORLD_SIZE} \ 18 | ../main.py \ 19 | --reduced_eog 1 \ 20 | --drop_long 1 \ 21 | --eos 2051 \ 22 | --n_special 4 \ 23 | --pad_x 0 \ 24 | --codebook_weight "[5,1,0.5,0.1]" \ 25 | --encodec_sr 50 \ 26 | --num_steps 50000 \ 27 | --lr 0.05 \ 28 | --warmup_fraction 0.01 \ 29 | --optimizer_name "ScaledAdam" \ 30 | --pseudo_epoch_size 3000 \ 31 | --reduce_lr_start_step 3000 \ 32 | --reduce_lr_start_epoch 4 \ 33 | --clipping_update_period 1000 \ 34 | --d_model 2048 \ 35 | --audio_embedding_dim 2048 \ 36 | --nhead 16 \ 37 | --num_decoder_layers 16 \ 38 | --max_num_tokens 100000 \ 39 | --gradient_accumulation_steps 26 \ 40 | --val_max_num_tokens 6000 \ 41 | --num_buckets 6 \ 42 | --audio_max_length 20 \ 43 | --audio_min_length 2 \ 44 | --text_max_length 400 \ 45 | --text_min_length 10 \ 46 | --mask_len_min 1 \ 47 | --mask_len_max 600 \ 48 | --tb_write_every_n_steps 10 \ 49 | --print_every_n_steps 400 \ 50 | --val_every_n_steps 1600 \ 51 | --text_vocab_size 100 \ 52 | --text_pad_token 100 \ 53 | --phn_folder_name "phonemes" \ 54 | --manifest_name "manifest" \ 55 | --encodec_folder_name ${encodec_codes_folder_name} \ 56 | --audio_vocab_size 2048 \ 57 | --empty_token 2048 \ 58 | --eog 2049 \ 59 | --audio_pad_token 2050 \ 60 | --n_codebooks 4 \ 61 | --max_n_spans 3 \ 62 | --shuffle_mask_embedding 0 \ 63 | --mask_sample_dist poisson1 \ 64 | --max_mask_portion 0.9 \ 65 | --min_gap 5 \ 66 | --num_workers 8 \ 67 | --dynamic_batching 1 \ 68 | --dataset $dataset \ 69 | --exp_dir "${exp_root}/${dataset}/${exp_name}" \ 70 | --dataset_dir ${dataset_dir} 71 | # >> ./logs/${dataset}/${exp_name}.log 2>&1 -------------------------------------------------------------------------------- /z_scripts/e830M_ft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate voicecraft 4 | export CUDA_VISIBLE_DEVICES=0,1,2,3 5 | export WORLD_SIZE=4 6 | 7 | dataset=gigaspeech 8 | mkdir -p ./logs/${dataset} 9 | 10 | exp_root="path/to/store/exp_results" 11 | exp_name=e830M_ft 12 | dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step 13 | encodec_codes_folder_name="encodec_16khz_4codebooks" 14 | load_model_from="./pretrained_models/giga830M.pth" 15 | 16 | # export CUDA_LAUNCH_BLOCKING=1 # for debugging 17 | 18 | torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_per_node=${WORLD_SIZE} \ 19 | ../main.py \ 20 | --load_model_from ${load_model_from} \ 21 | --reduced_eog 1 \ 22 | --drop_long 1 \ 23 | --eos 2051 \ 24 | --n_special 4 \ 25 | --pad_x 0 \ 26 | --codebook_weight "[3,1,1,1]" \ 27 | --encodec_sr 50 \ 28 | --num_steps 500000 \ 29 | --lr 0.00001 \ 30 | --warmup_fraction 0.1 \ 31 | --optimizer_name "AdamW" \ 32 | --d_model 2048 \ 33 | --audio_embedding_dim 2048 \ 34 | --nhead 16 \ 35 | --num_decoder_layers 16 \ 36 | --max_num_tokens 20000 \ 37 | --gradient_accumulation_steps 12 \ 38 | --val_max_num_tokens 6000 \ 39 | --num_buckets 6 \ 40 | --audio_max_length 20 \ 41 | --audio_min_length 2 \ 42 | --text_max_length 400 \ 43 | --text_min_length 10 \ 44 | --mask_len_min 1 \ 45 | --mask_len_max 600 \ 46 | --tb_write_every_n_steps 10 \ 47 | --print_every_n_steps 400 \ 48 | --val_every_n_steps 1600 \ 49 | --text_vocab_size 100 \ 50 | --text_pad_token 100 \ 51 | --phn_folder_name "phonemes" \ 52 | --manifest_name "manifest" \ 53 | --encodec_folder_name ${encodec_codes_folder_name} \ 54 | --audio_vocab_size 2048 \ 55 | --empty_token 2048 \ 56 | --eog 2049 \ 57 | --audio_pad_token 2050 \ 58 | --n_codebooks 4 \ 59 | --max_n_spans 3 \ 60 | --shuffle_mask_embedding 0 \ 61 | --mask_sample_dist poisson1 \ 62 | --max_mask_portion 0.9 \ 63 | --min_gap 5 \ 64 | --num_workers 8 \ 65 | --dynamic_batching 1 \ 66 | --dataset $dataset \ 67 | --exp_dir "${exp_root}/${dataset}/${exp_name}" \ 68 | --dataset_dir ${dataset_dir} 69 | # >> ./logs/${dataset}/${exp_name}.log 2>&1 --------------------------------------------------------------------------------