├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── demo-preview.png └── github-banner.png ├── cog ├── cog.yaml └── predict.py ├── docker └── Dockerfile ├── docs ├── architecture.md └── assets │ ├── MARS5_Architecture.png │ ├── NAR_inpainting_diagram.png │ ├── example_ref.wav │ ├── intro_vid.mp4 │ ├── mars5_AR_arch.png │ ├── mars5_NAR_arch.png │ └── simplified_diagram.png ├── hubconf.py ├── inference.py ├── mars5 ├── ar_generate.py ├── diffuser.py ├── minbpe │ ├── base.py │ ├── codebook.py │ └── regex.py ├── model.py ├── nn_future.py ├── samplers.py ├── trim.py └── utils.py ├── mars5_demo.ipynb ├── mars5_explainable_AI.ipynb └── requirements.txt /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | ko_fi: cambai 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | example.flac -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, 660 | if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . 661 | 662 | If you would like to obtain a copy of the software under a different license (e.g. Apache), 663 | please send an email to Camb.AI at help@camb.ai indicating that 664 | you would like a copy of the software under a different license. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MARS5: A novel speech model for insane prosody 2 |
3 | 4 | ![MARS5 Banner](assets/github-banner.png) 5 | 6 |

7 | Why MARS5? | 8 | Model Architecture | 9 | Samples | 10 | Camb AI Website

11 | 12 | [![GitHub Repo stars](https://img.shields.io/github/stars/Camb-ai/MARS5-TTS?style=social)](https://github.com/Camb-ai/MARS5-TTS/stargazers) 13 | [![Join our Discord](https://discordapp.com/api/guilds/1107565548864290840/widget.png)](https://discord.gg/FFQNCSKSXX) 14 | [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/CAMB-AI/MARS5-TTS) 15 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Camb-ai/mars5-tts/blob/master/mars5_demo.ipynb) 16 | 17 | 18 |
19 | 20 | # Updates 21 | <> July 5, 2024: Latest AR checkpoint released: higher stability of output. Very big update coming soon! 22 | 23 | 24 | # Approach 25 | 26 | This is the repo for the MARS5 English speech model (TTS) from CAMB.AI. 27 | 28 | The model follows a two-stage AR-NAR pipeline with a distinctively novel NAR component (see more info in the [Architecture](docs/architecture.md)). 29 | 30 | With just 5 seconds of audio and a snippet of text, MARS5 can generate speech even for prosodically hard and diverse scenarios like sports commentary, anime and more. Check out our demo: 31 | 32 | 33 | 34 | 35 | https://github.com/Camb-ai/MARS5-TTS/assets/23717819/3e191508-e03c-4ff9-9b02-d73ae0ebefdd 36 | 37 | 38 | Watch full video here: [![Youtube](https://img.shields.io/badge/YouTube-red)](https://www.youtube.com/watch?v=bmJSLPYrKtE) 39 | 40 | ![Mars 5 simplified diagram](docs/assets/MARS5_Architecture.png) 41 | 42 | **Figure**: The high-level architecture flow of MARS5. Given text and a reference audio, coarse (L0) encodec speech features are obtained through an autoregressive transformer model. Then, the text, reference, and coarse features are refined in a multinomial DDPM model to produce the remaining encodec codebook values. The output of the DDPM is then vocoded to produce the final audio. 43 | 44 | Because the model is trained on raw audio together with byte-pair-encoded text, it can be steered with things like punctuation and capitalization. 45 | E.g. To add a pause, add a comma to that part in the transcript. Or, to emphasize a word, put it in capital letters in the transcript. 46 | This enables a fairly natural way for guiding the prosody of the generated output. 47 | 48 | Speaker identity is specified using an audio reference file between 2-12 seconds, with lengths around 6s giving optimal results. 49 | Further, by providing the transcript of the reference, MARS5 enables one to do a '_deep clone_' which improves the quality of the cloning and output, at the cost of taking a bit longer to produce the audio. 50 | For more details on this and other performance and model details, please see the [docs folder](docs/architecture.md). 51 | 52 | ## Quick links 53 | 54 | - [CAMB.AI website](https://camb.ai/) (access MARS in 140+ languages for TTS and dubbing) 55 | - Technical details and architecture: [in the docs folder](docs/architecture.md) 56 | - Colab quickstart: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Camb-ai/mars5-tts/blob/master/mars5_demo.ipynb) 57 | - Sample page with a few hard prosodic samples: [https://camb-ai.github.io/MARS5-TTS/](https://camb-ai.github.io/MARS5-TTS/) 58 | - Online demo: [here](https://6b1a3a8e53ae.ngrok.app/) 59 | 60 | 61 | ## Quickstart 62 | 63 | 64 | We use `torch.hub` to make loading the model easy -- no cloning of the repo needed. The steps to perform inference are simple: 65 | 66 | 1. **Installation using pip**: 67 | 68 | Requirements: 69 | - Python >= 3.10 70 | - Torch >= 2.0 71 | - Torchaudio 72 | - Librosa 73 | - Vocos 74 | - Encodec 75 | - safetensors 76 | - regex 77 | 78 | ```bash 79 | pip install --upgrade torch torchaudio librosa vocos encodec safetensors regex 80 | ``` 81 | 82 | 2. **Load models**: load the MARS5 AR and NAR model from torch hub: 83 | 84 | ```python 85 | import torch, librosa 86 | 87 | mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True) 88 | # The `mars5` contains the AR and NAR model, as well as inference code. 89 | # The `config_class` contains tunable inference config settings like temperature. 90 | ``` 91 | 92 | (Optional) Load Model from huggingface (make sure repository is cloned) 93 | ```python 94 | from inference import Mars5TTS, InferenceConfig as config_class 95 | import torch, librosa 96 | 97 | mars5 = Mars5TTS.from_pretrained("CAMB-AI/MARS5-TTS") 98 | ``` 99 | 100 | 101 | 3. **Pick a reference** and optionally its transcript: 102 | 103 | ```python 104 | # Load reference audio between 1-12 seconds. 105 | wav, sr = librosa.load('.wav', 106 | sr=mars5.sr, mono=True) 107 | wav = torch.from_numpy(wav) 108 | ref_transcript = "" 109 | ``` 110 | 111 | *Note: The reference transcript is optional. Pass it if you wish to do a deep clone.* 112 | 113 | MARS5 supports 2 kinds of inference: a shallow, fast inference whereby you do not need the transcript of the reference (we call this a _shallow clone_), and a second slower, but typically higher quality way, which we call a _deep clone_. 114 | To use the deep clone, you need the prompt transcript. See the [model architecture](docs/architecture.md) for more info on this. 115 | 116 | 4. **Perform the synthesis**: 117 | 118 | ```python 119 | # Pick whether you want a deep or shallow clone. Set to False if you don't know prompt transcript or want fast inference. Set to True if you know transcript and want highest quality. 120 | deep_clone = True 121 | # Below you can tune other inference settings, like top_k, temperature, top_p, etc... 122 | cfg = config_class(deep_clone=deep_clone, rep_penalty_window=100, 123 | top_k=100, temperature=0.7, freq_penalty=3) 124 | 125 | ar_codes, output_audio = mars5.tts("The quick brown rat.", wav, 126 | ref_transcript, 127 | cfg=cfg) 128 | # output_audio is (T,) shape float tensor corresponding to the 24kHz output audio. 129 | ``` 130 | 131 | **That's it!** These default settings provide pretty good results, but feel free to tune the inference settings to optimize the output for your particular usecase. See the [`InferenceConfig`](inference.py) code or the demo notebook for info and docs on all the different inference settings. 132 | 133 | _Some tips for best quality:_ 134 | - Make sure reference audio is clean and between 1 second and 12 seconds. 135 | - Use deep clone and provide an accurate transcript for the reference. 136 | - Use proper punctuation -- the model can be guided and made better or worse with proper use of punctuation and capitalization. 137 | 138 | ## Or Use Docker 139 | 140 | **Pull from DockerHub** 141 | 142 | You can directly pull the docker image from our [DockerHub page](https://hub.docker.com/r/cambai/mars5ttsimage). 143 | 144 | 145 | **Build On Your Own** 146 | 147 | You can build a custom image from the provided Dockerfile in this repo by running the following command. 148 | 149 | ```bash 150 | cd MARS5-TTS 151 | docker build -t mars5ttsimage ./docker 152 | ``` 153 | 154 | 155 | *Note: This image should be used as a base image on top of which you can add your custom inference script in a Dockerfile or docker-compose. Images that directly generate output will be added to Docker Hub and as Dockerfiles in this repo soon* 156 | 157 | ## Model Details 158 | 159 | **Checkpoints** 160 | 161 | The checkpoints for MARS5 are provided under the releases tab of this github repo. We provide two checkpoints: 162 | 163 | - AR fp16 checkpoint [~750M parameters], along with config embedded in the checkpoint. 164 | - NAR fp16 checkpoint [~450M parameters], along with config embedded in the checkpoint. 165 | - The byte-pair encoding tokenizer used for the L0 encodec codes and the English text is embedded in each checkpoint under the `'vocab'` key, and follows roughly the same format of a saved minbpe tokenizer. 166 | 167 | The checkpoints are provided as both pytorch `.pt` checkpoints, and safetensors `.safetensors` checkpoints. By default, the `torch.hub.load()` loads the safetensors version, but you can specify which version of checkpoint you prefer with the `ckpt_format='safetensors'` or `ckpt_format='pt'` argument the in `torch.hub.load()` call. E.g. to force safetensors format: 168 | 169 | `torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', ckpt_format='safetensors')` 170 | 171 | Or to force pytorch `.pt` format when loading the checkpoints: 172 | 173 | `torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', ckpt_format='pt')` 174 | 175 | **Hardware Requirements**: 176 | 177 | You must be able to store at least 750M+450M params on GPU, and do inference with 750M of active parameters. 178 | 179 | If you do not have the necessary hardware requirements and just want to use MARS5 in your applications, you can use it via our [API](https://docs.camb.ai/). If you need some extra credits to test it for your use case, feel free to reach out to `help@camb.ai`. 180 | 181 | ## Roadmap and tasks 182 | 183 | MARS5 is not perfect at the moment, and we are working on improving its quality, stability, and performance. 184 | Rough areas we are looking to improve, and welcome any contributions in: 185 | 186 | - Improving inference stability and consistency 187 | - Speed/performance optimizations 188 | - Improving reference audio selection when given long references. 189 | - Benchmark performance numbers for MARS5 on standard speech datasets. 190 | 191 | **Specific tasks** 192 | 193 | - [ ] Profile the GPU and CPU memory and runtime speed metrics of the current model, add to readme. 194 | - [ ] Port model operations not supported by MPS to equivalents to speed up apple mac inference. E.g. `site-packages/torch/nn/functional.py:4840: UserWarning: The operator 'aten::col2im' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications`. 195 | - [ ] Cleanly add more performant ODE samplers to DDPM inference code (even just DPM++2M would be great). 196 | - [ ] Make demo/ user-interface program to rapidly collect human preference ratings between two audio samples, one generated by the model, and one ground truth. 197 | - [ ] Implement a way to do long-form generation. E.g. one possibility is to chunk long intput text into smaller pieces and then synthesize the codes each in turn, concatenating them, and vocoding the final result. 198 | - [ ] Perform a search (e.g. beam or grid) on the autoregressive sampling settings to find the setting preset which give the best quality. 199 | 200 | If you would like to contribute any improvement to MARS5, please feel free to contribute (guidelines below). 201 | 202 | 203 | ## Contributions 204 | 205 | We welcome any contributions to improving the model. As you may find when experimenting, it can produce really great results, it can still be further improved to create excellent outputs _consistently_. 206 | We'd also love to see how you used MARS5 in different scenarios, please use the [🙌 Show and tell](https://github.com/Camb-ai/MARS5-TTS/discussions/categories/show-and-tell) category in Discussions to share your examples. 207 | 208 | **Contribution format**: 209 | 210 | The preferred way to contribute to our repo is to fork the [master repository](https://github.com/Camb-ai/mars5-tts) on GitHub: 211 | 212 | 1. Fork the repo on github 213 | 2. Clone the repo, set upstream as this repo: `git remote add upstream git@github.com:Camb-ai/mars5-tts.git` 214 | 3. Make a new local branch and make your changes, commit changes. 215 | 4. Push changes to new upstream branch: `git push --set-upstream origin ` 216 | 5. On github, go to your fork and click 'Pull Request' to begin the PR process. Please make sure to include a description of what you did/fixed. 217 | 218 | ## License 219 | 220 | We are open-sourcing MARS5 in English under GNU AGPL 3.0. For commercial inquiries or to license the closed source version of MARS, please email help@camb.ai 221 | 222 | ## Join Our Team 223 | 224 | We're an ambitious team, globally distributed, with a singular aim of making everyone's voice count. At CAMB.AI, we're a research team of Interspeech-published, Carnegie Mellon, ex-Siri engineers and we're looking for you to join our team. 225 | 226 | We're actively hiring; please drop us an email at ack@camb.ai if you're interested. Visit our [careers page](https://www.camb.ai/careers) for more info. 227 | 228 | 229 | ## Community 230 | 231 | Join CAMB.AI community on [Forum](https://github.com/Camb-ai/MARS5-TTS/discussions) and 232 | [Discord](https://discord.gg/FFQNCSKSXX) to share any suggestions, feedback, or questions with our team. 233 | 234 | 235 | ## Support Camb.ai on Ko-fi ❤️! 236 | [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/cambai) 237 | 238 | 239 | ## Acknowledgements 240 | 241 | Parts of code for this project are adapted from the following repositories -- please make sure to check them out! Thank you to the authors of: 242 | 243 | - AWS: For providing much needed compute resources (NVIDIA H100s) to enable training of the model. 244 | - TransFusion: [https://github.com/RF5/transfusion-asr](https://github.com/RF5/transfusion-asr) 245 | - Multinomial diffusion: [https://github.com/ehoogeboom/multinomial_diffusion](https://github.com/ehoogeboom/multinomial_diffusion) 246 | - Mistral-src: [https://github.com/mistralai/mistral-src](https://github.com/mistralai/mistral-src) 247 | - minbpe: [https://github.com/karpathy/minbpe](https://github.com/karpathy/minbpe) 248 | - gemelo-ai's encodec Vocos: [https://github.com/gemelo-ai/vocos](https://github.com/gemelo-ai/vocos) 249 | - librosa for their `.trim()` code: [https://librosa.org/doc/main/generated/librosa.effects.trim.html](https://librosa.org/doc/main/generated/librosa.effects.trim.html) 250 | -------------------------------------------------------------------------------- /assets/demo-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/assets/demo-preview.png -------------------------------------------------------------------------------- /assets/github-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/assets/github-banner.png -------------------------------------------------------------------------------- /cog/cog.yaml: -------------------------------------------------------------------------------- 1 | 2 | build: 3 | gpu: true 4 | system_packages: 5 | - "ffmpeg" 6 | - "libavcodec-dev" 7 | - "sox" 8 | 9 | python_version: "3.10" 10 | 11 | python_packages: 12 | - numpy 13 | - torch 14 | - torchaudio 15 | - torchvision 16 | - transformers 17 | - regex 18 | - librosa 19 | - vocos 20 | - encodec 21 | run: 22 | - python -c "import torch; torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)" 23 | - pip install pyupload 24 | predict: "predict.py:Predictor" 25 | -------------------------------------------------------------------------------- /cog/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input, Path 2 | from typing import Dict 3 | from pathlib import Path 4 | import tempfile 5 | import torch 6 | import torchaudio 7 | import librosa 8 | import subprocess 9 | import os 10 | import soundfile as sf 11 | from pyupload.main import CatboxUploader 12 | from scipy.io.wavfile import write as write_wav 13 | 14 | SAMPLE_RATE = 16000 15 | 16 | # TODO: Replicate is unable to parse & render if we return plain Audio, hence used this method of temporary output url, see if it can resolve 17 | class Predictor(BasePredictor): 18 | def setup(self): 19 | self.mars5, self.config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True) 20 | print(">>>>> Model Loaded") 21 | 22 | def predict( 23 | self, 24 | text: str = Input(description="Text to synthesize"), 25 | ref_audio_file: Path = Input(description='Reference audio file to clone from <= 10 seconds', default="https://files.catbox.moe/be6df3.wav"), 26 | ref_audio_transcript: str = Input(description='Text in the reference audio file', default="We actually haven't managed to meet demand."), 27 | ) -> str: 28 | 29 | print(f">>>> Ref Audio file: {ref_audio_file}; ref_transcript: {ref_audio_transcript}") 30 | 31 | # Load the reference audio 32 | wav, sr = librosa.load(ref_audio_file, sr=self.mars5.sr, mono=True) 33 | wav = torch.from_numpy(wav) 34 | 35 | # configuration for the TTS model 36 | deep_clone = True 37 | cfg = self.config_class(deep_clone=deep_clone, rep_penalty_window=100, top_k=100, temperature=0.7, freq_penalty=3) 38 | 39 | # Generate the synthesized audio 40 | print(f">>> Running inference") 41 | ar_codes, wav_out = self.mars5.tts(text, wav, ref_audio_transcript, cfg=cfg) 42 | print(f">>>>> Done with inference") 43 | 44 | output_path = "/tmp/aud.mp3" 45 | write_wav(output_path, self.mars5.sr, wav_out.numpy()) 46 | 47 | output_file_url = CatboxUploader(output_path).execute() 48 | print(f">>>> Output file url: {output_file_url}") 49 | return output_file_url 50 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3 2 | 3 | # metainformation 4 | LABEL org.opencontainers.image.source = "https://github.com/Camb-ai/MARS5-TTS" 5 | LABEL org.opencontainers.image.licenses = "AGPL-3.0 license" 6 | 7 | 8 | # enable passwordless ssh 9 | RUN mkdir ~/.ssh && \ 10 | printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \ 11 | ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \ 12 | cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys 13 | 14 | # enable RDMA support: this enables direct access to GPU memory 15 | RUN apt-get update && \ 16 | apt-get install -y infiniband-diags perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 && \ 17 | apt-get clean && \ 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | # install MARS5-TTS 21 | RUN git clone https://github.com/Camb-ai/MARS5-TTS.git \ 22 | && cd ./MARS5-TTS \ 23 | && pip install -r requirements.txt 24 | -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # MARS5 Technical Details 2 | 3 | 4 | ## hubconf object/api 5 | 6 | 7 | After loading the model with `torch.hub.load`, two objects are returned, a Mars5TTS, and the dataclass of the inference config to use when calling the `mars5.tts()` method. 8 | Concretely, the main methods of the mars5 object are: 9 | 10 | ```python 11 | 12 | # The init function, called automatically when you initialize the 13 | # model from torch.hub.load(). If you want, you can pass in your 14 | # own custom checkpoints here to initalize the model with your 15 | # own model, tokenizer, etc... 16 | def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: 17 | # ... initialization code ... 18 | 19 | # Main text-to-speech function, converting text and a reference 20 | # audio to speech. 21 | def tts(self, text: str, ref_audio: Tensor, ref_transcript: str | None, 22 | cfg: InferenceConfig) -> Tensor: 23 | """ Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz) 24 | which has an associated `ref_transcript`. Perform inference using the inference 25 | config given by `cfg`, which controls the temperature, top_p, etc... 26 | Returns: 27 | - `ar_codes`: (seq_len,) long tensor of discrete coarse code outputs from the AR model. 28 | - `out_wav`: (T,) float output audio tensor sampled at 24kHz. 29 | """ 30 | 31 | # Utility function to vocode encodec tokens, if one wishes 32 | # to hear the raw AR model ouput by vocoding the `ar_codes` 33 | # returned above. 34 | def vocode(self, tokens: Tensor) -> Tensor: 35 | """ Vocodes tokens of shape (seq_len, n_q) """ 36 | ``` 37 | 38 | 39 | ## Model Design 40 | 41 | MARS5 follows a two-stage AR-NAR design according to the [diagram](/docs/assets/simplified_diagram.png) on the main page. 42 | 43 | #### AR component 44 | 45 | The AR model follows a Mistral-style decoder-only transformer model to predict Encodec L0 codes (the lowest/most coarse level quantization codes). 46 | Overall, the AR and NAR model is going to predict all 8 codebook entries of the Encodec 6kbps codec. 47 | The AR model design is given below: 48 | 49 | ![Mars 5 AR architecture](/docs/assets/mars5_AR_arch.png) 50 | 51 | **Figure**: Autoregressive component of MARS5. During training, the initial 6kbps encodec tokens of the speech are fed through a small decoder-only transformer, producing a single output vector corresponding to an implicit speaker embedding. 52 | This vector is concatenated with learnt embeddings corresponding to the text tokens, and L0 speech tokens, after byte-pair encoding tokenization. 53 | 54 | 55 | The AR model is trained using the standard next-token prediction task of language models with a cross-entropy loss with the next token, given a smaller weight to text tokens. 56 | During inference, we iteratively sample from the transformer to produce the desiged L0 codes. 57 | When we use a _shallow clone_, then the reference audio is fed into the transcript to make the implicit speaker embedding used in the input sequence. 58 | When we use a _deep clone_, the above is done, but we also concatenate the reference transcript with the desired text, and the reference audio tokens with the input sequence before we start sampling the output. 59 | In pseudocode: 60 | 61 | ``` 62 | speaker_embedding <- speaker_conditioning_transformer(ref audio) 63 | if deep_clone: 64 | prompt = concatenate( speaker embedding, reference text, target text, reference L0 speech codes ) 65 | else: 66 | prompt = concatenate( speaker embedding, target text ) 67 | 68 | ar output <- autoregressively sample from prompt 69 | ``` 70 | 71 | While a deep clone provides a more accurate cloning of the reference speaker identity and prosody, it requires knowledge of the reference transcript and takes longer to do inference. 72 | 73 | #### NAR component 74 | 75 | After the AR model has predicted the L0 encodec codes, we need a way to predict the remaining 7 codebooks of the 6kbps Encodec codec. 76 | This is what the NAR model is trained to do, using a multinomial diffusion framework. 77 | Concretely, the diffusion process is a discrete DDPM, whereby at each timestep in the diffusion process, it takes in a sequence of `(batch size, sequence length, n_codebooks)` and produces an output categorical distribution over each codebook, i.e. an output of shape `(batch size, sequence length, n_codebooks, 1024)`, since each encodec codebook has 1024 possible values. 78 | The architecture of the model looks as follows: 79 | 80 | 81 | ![Mars 5 NAR architecture](/docs/assets/mars5_NAR_arch.png) 82 | 83 | **Figure**: MARS5 non-autoregressive component. It follows an encoder-decoder transformer architecture, whereby the encoder computes an implicit speaker embedding like the AR model, and concatenates that along with the target to form an input sequence to a transformer encoder. The transformer decoder predicts the distribution of all 8 encodec codebook tokens given a partly noised input at some diffusion timestep `t`. 84 | 85 | 86 | The encoder and decoder transformers are simple `nn.Transformer` variants with sinusoidal positional embeddings and SwiGLU activations. 87 | A multinomial diffusion manager controls the forward and reference diffusion processes during inference and training according to a cosine diffusion schedule. 88 | Diffusion is performed independently of the sequence length or codebook index. 89 | 90 | During training and inference, the L0 codebooks of the input at timestep $t$ are overridden (i.e. not noised in the forward diffusion process) with either the ground truth L0 codes (during training) or the AR model's predictions (during inference). 91 | Like the AR model, the NAR model can perform inference in either a _shallow clone_ way or a _deep clone_ way. 92 | And, like the AR model, the difference between the two is, with a _deep clone_, we concatenate the reference text to the input text sequence, and the reference speech codes (the full values for all 8 codebooks) to the decoder input sequence $x$. 93 | During inference, we then treat the portion of $x$ corresponding to the reference codec codes, and all the AR L0 codes, as 'fixed' and effectively perform diffusion inpainting for the remaining missing codec codes. 94 | The figure below explains what the input to the decoder looks like for a deep clone: 95 | 96 | ![NAR decoder input for deep clone](/docs/assets/NAR_inpainting_diagram.png) 97 | 98 | This allows us to use diffusion inpainting techniques like [RePaint](https://arxiv.org/abs/2201.09865) to improve the quality of the output at the cost of more inference time. 99 | We've implemented this in the the diffusion config used in the NAR inference code (see it [here](/mars5/diffuser.py)), and you can simply increase the `jump_len` and `jump_n_sample` to greater than 1 to use RePaint inpainting to improve NAR performance. 100 | 101 | 102 | -------------------------------------------------------------------------------- /docs/assets/MARS5_Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/MARS5_Architecture.png -------------------------------------------------------------------------------- /docs/assets/NAR_inpainting_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/NAR_inpainting_diagram.png -------------------------------------------------------------------------------- /docs/assets/example_ref.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/example_ref.wav -------------------------------------------------------------------------------- /docs/assets/intro_vid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/intro_vid.mp4 -------------------------------------------------------------------------------- /docs/assets/mars5_AR_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/mars5_AR_arch.png -------------------------------------------------------------------------------- /docs/assets/mars5_NAR_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/mars5_NAR_arch.png -------------------------------------------------------------------------------- /docs/assets/simplified_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Camb-ai/MARS5-TTS/84697c8673d6277d329381acfd796114b02d531f/docs/assets/simplified_diagram.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors'] 2 | 3 | import logging 4 | import os 5 | from pathlib import Path 6 | from safetensors import safe_open 7 | 8 | import torch 9 | from inference import Mars5TTS, InferenceConfig 10 | 11 | ar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.pt" 12 | nar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt" 13 | 14 | ar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.safetensors" 15 | nar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors" 16 | 17 | def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', 18 | ar_path=None, nar_path=None) -> Mars5TTS: 19 | """ Load mars5 english model on `device`, optionally show `progress`. """ 20 | if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | 22 | assert ckpt_format in ['safetensors', 'pt'], "checkpoint format must be 'safetensors' or 'pt'" 23 | 24 | logging.info(f"Using device: {device}") 25 | if pretrained == False: raise AssertionError('Only pretrained model currently supported.') 26 | logging.info("Loading AR checkpoint...") 27 | 28 | if ar_path is None: 29 | if ckpt_format == 'safetensors': 30 | ar_ckpt = _load_safetensors_ckpt(ar_sf_url, progress=progress) 31 | elif ckpt_format == 'pt': 32 | ar_ckpt = torch.hub.load_state_dict_from_url( 33 | ar_url, progress=progress, check_hash=False, map_location='cpu' 34 | ) 35 | else: ar_ckpt = torch.load(str(ar_path), map_location='cpu') 36 | 37 | logging.info("Loading NAR checkpoint...") 38 | if nar_path is None: 39 | if ckpt_format == 'safetensors': 40 | nar_ckpt = _load_safetensors_ckpt(nar_sf_url, progress=progress) 41 | elif ckpt_format == 'pt': 42 | nar_ckpt = torch.hub.load_state_dict_from_url( 43 | nar_url, progress=progress, check_hash=False, map_location='cpu' 44 | ) 45 | else: nar_ckpt = torch.load(str(nar_path), map_location='cpu') 46 | logging.info("Initializing modules...") 47 | mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device) 48 | return mars5, InferenceConfig 49 | 50 | 51 | def _load_safetensors_ckpt(url, progress): 52 | """ Loads checkpoint from a safetensors file """ 53 | hub_dir = torch.hub.get_dir() 54 | model_dir = os.path.join(hub_dir, 'checkpoints') 55 | os.makedirs(model_dir, exist_ok=True) 56 | parts = torch.hub.urlparse(url) 57 | filename = os.path.basename(parts.path) 58 | cached_file = os.path.join(model_dir, filename) 59 | if not os.path.exists(cached_file): 60 | # download it 61 | torch.hub.download_url_to_file(url, cached_file, None, progress=progress) 62 | # load checkpoint 63 | ckpt = {} 64 | with safe_open(cached_file, framework='pt', device='cpu') as f: 65 | metadata = f.metadata() 66 | ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']} 67 | ckpt['model'] = {} 68 | for k in f.keys(): ckpt['model'][k] = f.get_tensor(k) 69 | return ckpt 70 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import logging 6 | from typing import Optional, Dict, Type, Union 7 | from pathlib import Path 8 | from dataclasses import dataclass 9 | import io 10 | 11 | from mars5.model import CodecLM, ResidualTransformer 12 | from vocos import Vocos 13 | from encodec import EncodecModel 14 | from mars5.diffuser import MultinomialDiffusion, DSH, perform_simple_inference 15 | from mars5.minbpe.regex import RegexTokenizer, GPT4_SPLIT_PATTERN 16 | from mars5.minbpe.codebook import CodebookTokenizer 17 | from mars5.ar_generate import ar_generate 18 | from mars5.utils import nuke_weight_norm, construct_padding_mask 19 | from mars5.trim import trim 20 | from safetensors import safe_open 21 | from huggingface_hub import ModelHubMixin, hf_hub_download 22 | import logging 23 | 24 | @dataclass 25 | class InferenceConfig(): 26 | """ The defaults configuration variables for TTS inference. """ 27 | 28 | ## >>>> AR CONFIG 29 | # temperature influences probability distribution of logits 30 | # How to set this variable: high temperatures (T>1) favour less probable outputs while low temperatures reduce randomness 31 | temperature: float = 0.7 32 | 33 | # Used for sampling - Keeps tokens with the highest probabilities until a certain number (top_k) is reached 34 | top_k: int = 200 # 0 disables it 35 | # Used for sampling - keep the top tokens with cumulative probability >= top_p 36 | top_p: float = 0.2 # 1.0 disables it 37 | 38 | typical_p: float = 1.0 39 | freq_penalty: float = 3 # increasing it would penalize the model more for repetitions 40 | presence_penalty: float = 0.4 # increasing it would increase token diversity 41 | rep_penalty_window: int = 80 # how far in the past to consider when penalizing repetitions. Equates to 5s 42 | 43 | eos_penalty_decay: float = 0.5 # how much to penalize 44 | eos_penalty_factor: float = 1 # overal penalty weighting 45 | eos_estimated_gen_length_factor: float = 1.0 # multiple of len(text_phones) to assume an approximate output length is 46 | 47 | ## >>>> NAR CONFIG 48 | # defaults, that can be overridden with user specified inputs 49 | timesteps: int = 200 50 | x_0_temp: float = 0.7 51 | q0_override_steps: int = 20 # number of diffusion steps where NAR L0 predictions overrides AR L0 predictions. 52 | nar_guidance_w: float = 3 53 | 54 | max_prompt_dur: float = 12 # maximum length prompt is allowed, in seconds. 55 | 56 | # Maximum AR codes to generate in 1 inference. 57 | # Default of -1 leaves it same as training time max AR tokens. 58 | # Typical values up to ~2x training time can be tolerated, 59 | # with ~1.5x trianing time tokens having still mostly ok performance. 60 | generate_max_len_override: int = -1 61 | 62 | # Whether to deep clone from the reference. 63 | # Pros: improves intelligibility and speaker cloning performance. 64 | # Cons: requires reference transcript, and inference takes a bit longer. 65 | deep_clone: bool = True 66 | 67 | # kv caching helps with optimizing inference speed. 68 | # disabling/enabling kv caching won't affect output quality 69 | use_kv_cache: bool = True 70 | 71 | 72 | # Leading and trailing silences will be trimmed from final output 73 | # Trim_db is the threshold (in decibels) below reference to consider as silence 74 | trim_db: float = 27 75 | beam_width: int = 1 # only beam width of 1 is currently supported 76 | 77 | ref_audio_pad: float = 0 78 | 79 | class Mars5TTS(nn.Module, ModelHubMixin): 80 | def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: 81 | super().__init__() 82 | 83 | if device is None: 84 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | self.device = torch.device(device) 86 | 87 | self.codec = EncodecModel.encodec_model_24khz().to(device).eval() 88 | self.codec.set_target_bandwidth(6.0) 89 | 90 | # save and load text tokenize 91 | self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN) 92 | texttok_data = io.BytesIO(ar_ckpt['vocab']['texttok.model'].encode('utf-8')) 93 | self.texttok.load(texttok_data) 94 | 95 | # save and load speech tokenizer 96 | self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN) 97 | speechtok_data = io.BytesIO(ar_ckpt['vocab']['speechtok.model'].encode('utf-8')) 98 | self.speechtok.load(speechtok_data) 99 | 100 | # keep track of tokenization things. 101 | self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab) 102 | self.n_text_vocab = len(self.texttok.vocab) + 1 103 | self.diffusion_n_classes: int = 1025 # 1 for padding idx 104 | # load AR model 105 | self.codeclm = CodecLM(n_vocab=self.n_vocab, dim=1536, dim_ff_scale=7/3) 106 | self.codeclm.load_state_dict(ar_ckpt['model']) 107 | self.codeclm = self.codeclm.to(self.device).eval() 108 | # load NAR model 109 | self.codecnar = ResidualTransformer(n_text_vocab=self.n_text_vocab, n_quant=self.diffusion_n_classes, 110 | p_cond_drop=0, dropout=0) 111 | self.codecnar.load_state_dict(nar_ckpt['model']) 112 | self.codecnar = self.codecnar.to(self.device).eval() 113 | self.default_T = 200 114 | 115 | self.sr = 24000 116 | self.latent_sr = 75 117 | 118 | # load vocoder 119 | self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval() 120 | nuke_weight_norm(self.codec) 121 | nuke_weight_norm(self.vocos) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls: Type["Mars5TTS"], 126 | *, 127 | model_id: str, 128 | revision: Optional[str], 129 | cache_dir: Optional[Union[str, Path]], 130 | force_download: bool, 131 | proxies: Optional[Dict], 132 | local_files_only: bool, 133 | token: Optional[Union[str, bool]], 134 | device: str = None, 135 | **model_kwargs, 136 | ) -> "Mars5TTS": 137 | # Download files from Hub 138 | logging.info(f">>>>> Downloading AR model") 139 | ar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_ar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token) 140 | logging.info(f">>>>> Downloading NAR model") 141 | nar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_nar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token) 142 | 143 | ar_ckpt = {} 144 | with safe_open(ar_ckpt_path, framework='pt', device='cpu') as f: 145 | metadata = f.metadata() 146 | ar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']} 147 | ar_ckpt['model'] = {} 148 | for k in f.keys(): ar_ckpt['model'][k] = f.get_tensor(k) 149 | nar_ckpt = {} 150 | with safe_open(nar_ckpt_path, framework='pt', device='cpu') as f: 151 | metadata = f.metadata() 152 | nar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']} 153 | nar_ckpt['model'] = {} 154 | for k in f.keys(): nar_ckpt['model'][k] = f.get_tensor(k) 155 | 156 | 157 | # Init 158 | return cls(ar_ckpt=ar_ckpt, nar_ckpt=nar_ckpt, device=device) 159 | 160 | @torch.inference_mode 161 | def vocode(self, tokens: Tensor) -> Tensor: 162 | """ Vocodes tokens of shape (seq_len, n_q) """ 163 | tokens = tokens.T.to(self.device) 164 | features = self.vocos.codes_to_features(tokens) 165 | # A cool hidden feature of vocos vocoding: 166 | # setting the bandwidth below to 1 (corresponding to 3 kbps) 167 | # actually still works on 6kbps input tokens, but *smooths* the output 168 | # audio a bit, which can help improve quality if its a bit noisy. 169 | # Hence we use [1] and not [2] below. 170 | bandwidth_id = torch.tensor([1], device=self.device) # 6 kbps 171 | wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id) 172 | return wav_diffusion.cpu().squeeze()[None] 173 | 174 | @torch.inference_mode 175 | def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor: 176 | """ Given `ref_audio` (bs, T) audio tensor, compute the implicit speakre embedding of shape (bs, dim). """ 177 | if ref_audio.dim() == 1: ref_audio = ref_audio[None] 178 | spk_reference = self.codec.encode(ref_audio[None].to(self.device))[0][0] 179 | spk_reference = spk_reference.permute(0, 2, 1) 180 | bs = spk_reference.shape[0] 181 | if bs != 1: 182 | raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.") 183 | spk_seq = self.codeclm.ref_chunked_emb(spk_reference) # (bs, sl, dim) 184 | spk_ref_emb = self.codeclm.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) 185 | 186 | spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) 187 | # add pos encoding 188 | spk_seq = self.codeclm.pos_embedding(spk_seq) 189 | # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry) 190 | src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) 191 | src_key_padding_mask = torch.cat(( 192 | # append a zero here since we DO want to attend to initial position. 193 | torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), 194 | src_key_padding_mask 195 | ), 196 | dim=1) 197 | # pass through transformer 198 | res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). 199 | return res.squeeze(1) 200 | 201 | @torch.inference_mode 202 | def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None, 203 | cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor: 204 | """ Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz) 205 | which has an associated `ref_transcript`. Perform inference using the inference 206 | config given by `cfg`, which controls the temperature, top_p, etc... 207 | Returns: 208 | - `ar_codes`: (seq_len,) long tensor of discrete coarse code outputs from the AR model. 209 | - `out_wav`: (T,) float output audio tensor sampled at 24kHz. 210 | """ 211 | 212 | if cfg.deep_clone and ref_transcript is None: 213 | raise AssertionError( 214 | ("Inference config deep clone is set to true, but reference transcript not specified! " 215 | "Please specify the transcript of the prompt, or set deep_clone=False in the inference `cfg` argument." 216 | )) 217 | ref_dur = ref_audio.shape[-1]/self.sr 218 | if ref_dur > cfg.max_prompt_dur: 219 | logging.warning((f"Reference audio duration is {ref_dur:.2f} > max suggested ref audio. " 220 | f"Expect quality degradations. We recommend you trim prompt to be shorter than max prompt length.")) 221 | 222 | # get text codes. 223 | text_tokens = self.texttok.encode("<|startoftext|>"+text.strip()+"<|endoftext|>", 224 | allowed_special='all') 225 | 226 | text_tokens_full = self.texttok.encode("<|startoftext|>"+ ref_transcript + ' ' + str(text).strip()+"<|endoftext|>", 227 | allowed_special='all') 228 | 229 | if ref_audio.dim() == 1: ref_audio = ref_audio[None] 230 | if ref_audio.shape[0] != 1: ref_audio = ref_audio.mean(dim=0, keepdim=True) 231 | ref_audio = F.pad(ref_audio, (int(self.sr*cfg.ref_audio_pad), 0)) 232 | # get reference audio codec tokens 233 | prompt_codec = self.codec.encode(ref_audio[None].to(self.device))[0][0] # (bs, n_q, seq_len) 234 | 235 | n_speech_inp = 0 236 | n_start_skip = 0 237 | q0_str = ' '.join([str(t) for t in prompt_codec[0, 0].tolist()]) 238 | # Note, in the below, we do NOT want to encode the token as a part of it, since we will be continuing it!!! 239 | speech_tokens = self.speechtok.encode(q0_str.strip()) # + "<|endofspeech|>", allowed_special='all') 240 | spk_ref_codec = prompt_codec[0, :, :].T # (seq_len, n_q) 241 | 242 | raw_prompt_acoustic_len = len(prompt_codec[0,0].squeeze()) 243 | offset_speech_codes = [p+len(self.texttok.vocab) for p in speech_tokens] 244 | if not cfg.deep_clone: 245 | # shallow clone, so 246 | # 1. clip existing speech codes to be empty (n_speech_inp = 0) 247 | offset_speech_codes = offset_speech_codes[:n_speech_inp] 248 | else: 249 | # Deep clone, so 250 | # 1. set text to be text of prompt + target text 251 | text_tokens = text_tokens_full 252 | # 2. update n_speech_inp to be length of prompt, so we only display from ths `n_speech_inp` onwards in the final output. 253 | n_speech_inp = len(offset_speech_codes) 254 | prompt = torch.tensor(text_tokens + offset_speech_codes, dtype=torch.long, device=self.device) 255 | first_codec_idx = prompt.shape[-1] - n_speech_inp + 1 256 | 257 | # ---> perform AR code generation 258 | logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}") 259 | 260 | ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm, 261 | prompt, spk_ref_codec, first_codec_idx, 262 | max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000, 263 | fp16=True if torch.cuda.is_available() else False, 264 | temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p, 265 | alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window, 266 | eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor, 267 | beam_width=cfg.beam_width, beam_length_penalty=1, 268 | n_phones_gen=round(cfg.eos_estimated_gen_length_factor*len(text)), 269 | vocode=False, use_kv_cache=cfg.use_kv_cache) 270 | 271 | # Parse AR output 272 | output_tokens = ar_codes - len(self.texttok.vocab) 273 | output_tokens = output_tokens.clamp(min=0).squeeze()[first_codec_idx:].cpu().tolist() 274 | gen_codes_decoded = self.speechtok.decode_int(output_tokens) 275 | gen_codes_decoded = torch.tensor([s for s in gen_codes_decoded if type(s) == int], dtype=torch.long, device=self.device) 276 | 277 | c_text = torch.tensor(text_tokens, dtype=torch.long, device=self.device)[None] 278 | c_codes = prompt_codec.permute(0, 2, 1) 279 | c_texts_lengths = torch.tensor([len(text_tokens)], dtype=torch.long, device=self.device) 280 | c_codes_lengths = torch.tensor([c_codes.shape[1],], dtype=torch.long, device=self.device) 281 | 282 | _x = gen_codes_decoded[None, n_start_skip:, None].repeat(1, 1, 8) # (seq_len) -> (1, seq_len, 8) 283 | x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device) 284 | 285 | # ---> perform DDPM NAR inference 286 | T = self.default_T 287 | diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device) 288 | 289 | dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp, 290 | guidance_w=cfg.nar_guidance_w, 291 | deep_clone=cfg.deep_clone, jump_len=1, jump_n_sample=1, 292 | q0_override_steps=cfg.q0_override_steps, 293 | enable_kevin_scaled_inference=True, # see TransFusion ASR for explanation of this 294 | progress=False) 295 | 296 | final_output = perform_simple_inference(self.codecnar,( 297 | c_text, c_codes, c_texts_lengths, c_codes_lengths, _x, x_padding_mask 298 | ), diff, diff.num_timesteps, torch.float16, dsh=dsh_cfg, retain_quant0=True) # (bs, seq_len, n_quant) 299 | 300 | skip_front = raw_prompt_acoustic_len if cfg.deep_clone else 0 301 | final_output = final_output[0, skip_front:].to(self.device) # (seq_len, n_quant) 302 | 303 | # vocode final output and trim silences 304 | final_audio = self.vocode(final_output).squeeze() 305 | final_audio, _ = trim(final_audio.cpu(), top_db=cfg.trim_db) 306 | 307 | return gen_codes_decoded, final_audio 308 | -------------------------------------------------------------------------------- /mars5/ar_generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchaudio 4 | import copy 5 | from torch import Tensor, nn 6 | import logging 7 | from .model import length_to_mask 8 | from .samplers import (apply_typical_p, early_eos_penalty, 9 | top_k_top_p_filtering, freq_rep_penalty) 10 | from .nn_future import RotatingBufferCache 11 | from .minbpe.codebook import CodebookTokenizer 12 | from .minbpe.regex import RegexTokenizer 13 | 14 | 15 | @torch.inference_mode() 16 | def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer, 17 | codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int, 18 | max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None, 19 | top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100, 20 | typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True, 21 | beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]: 22 | """ Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1 23 | indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes). 24 | Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep 25 | track of top `beam_width` outcomes, selecting the top one among them. 26 | 27 | - Optionally vocode if `vocode` (default True). 28 | - See `InferenceConfig` for other inference docs. 29 | """ 30 | assert xx.dim() == 1, "Only batch size of 1 is currently supported." 31 | assert beam_width == 1, "Only beam size of 1 is currently supported." 32 | # internally our batch size will be the beam width 33 | bs = beam_width 34 | x_inp = xx[None].repeat(bs, 1) # (bs, seq_len) 35 | ss_gen = ss_gen[None].repeat(bs, 1, 1) 36 | # We must subtract 1 in the line below so that we match the train-time conditions of having a 37 | # False padding value for the token position. This is needed so that we correctly use the 38 | # _acoustic_ and not the linguistic language embedding for the token. 39 | offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device) 40 | valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000'] 41 | valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000'] 42 | # Make mask that is True where we have valid outputs, False otherwise (where we have text outputs). 43 | # logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device) 44 | # logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True 45 | # logit_mask[vocab['s2i']['']] = True 46 | cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device) 47 | eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>'] 48 | n_vocab = len(texttok.vocab) + len(speechtok.vocab) 49 | 50 | logging.info(f"Starting beam decoding with beam_width={beam_width}") 51 | 52 | prev_ids = [[] for _ in range(bs)] 53 | 54 | cache = None 55 | if use_kv_cache: 56 | # Initialise kv cache 57 | cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len) 58 | cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim) 59 | cache.to(device=x_inp.device, dtype=torch.float16 if fp16 else torch.float32) 60 | 61 | counter = 0 62 | while x_inp.shape[-1] < max_len: 63 | counter += 1 64 | gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device) 65 | padding_mask = length_to_mask(gen_length, offsets) 66 | 67 | with torch.autocast('cuda', enabled=fp16): 68 | logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter) 69 | logits = logits.float() 70 | 71 | logits = logits[:, -1] # select last index, now (bs, logit_dim) 72 | 73 | # <---------------------- logit filtering ----------------------> 74 | filtered_logits = logits.clone() 75 | 76 | # apply repetition penalty before logit mask if any item in the beam has more than 1 prior token. 77 | if len(prev_ids[0]) > 1: 78 | filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long), 79 | alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, 80 | penalty_window=penalty_window) 81 | 82 | filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') 83 | filtered_logits[..., valid_logit_idx_end:] = float('-inf') 84 | 85 | if n_phones_gen is not None: 86 | # apply eos penalty 87 | filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen, 88 | eos_penalty_decay, eos_penalty_factor, 89 | eos_index=eos_idx) 90 | 91 | filtered_logits = filtered_logits / temperature 92 | filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p) 93 | filtered_logits = apply_typical_p(filtered_logits, mass=typical_p) 94 | 95 | # mask out anything that isn't first quantizer output codes 96 | filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') 97 | filtered_logits[..., valid_logit_idx_end:] = float('-inf') 98 | logits = filtered_logits 99 | 100 | # <---------------------- next frame prediction ---------------------> 101 | 102 | logprobs = logits.log_softmax(dim=-1) 103 | 104 | # update assignments: if any beam ended in last step, it MUST also end in this step. 105 | # so, below we multiply the logits with a True/False mask, setting to 106 | for j in range(bs): 107 | if x_inp[j, -1] == eos_idx: 108 | # do not add any additional probability to it, keeping it the same for all vocab idxs 109 | logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos 110 | logprobs[j, eos_idx] = 0 # probability=1 of after 111 | 112 | candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab) 113 | 114 | logp_flat = logprobs.flatten() 115 | candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,) 116 | # Ravel it up: 117 | beam_idxs = candidates // n_vocab # (bs,) 118 | tok_inds_in_each_beam = candidates % n_vocab # (bs,) 119 | 120 | # check for breaks 121 | if torch.all(tok_inds_in_each_beam == eos_idx): 122 | # apply length penalty: 123 | non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks 124 | gen_length = non_eos_toks - first_codex_idx 125 | penalties = (gen_length**beam_length_penalty) 126 | penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None] 127 | 128 | eos_avg_logps = penalized_cum_tok_logp[:, eos_idx] 129 | best_beam_idx = eos_avg_logps.argmax() 130 | best_avg_logp = eos_avg_logps[best_beam_idx] 131 | best_beam = x_inp[best_beam_idx] 132 | logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. " 133 | f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | " 134 | f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}")) 135 | break 136 | 137 | # update beam histories: 138 | x_inp = x_inp[beam_idxs] 139 | # update next token 140 | next_sample = tok_inds_in_each_beam 141 | # update cum logprob 142 | cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam] 143 | # update prior inds to point to correct beam 144 | prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs] 145 | # add new tokens to previous ids 146 | for j in range(bs): 147 | prev_ids[j].append(tok_inds_in_each_beam[j].item()) 148 | 149 | logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu()) 150 | 151 | # update cache with beam indexes 152 | if cache is not None: 153 | cache.cache_k = cache.cache_k[:, beam_idxs] 154 | cache.cache_v = cache.cache_v[:, beam_idxs] 155 | 156 | # add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim. 157 | x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1) 158 | 159 | 160 | if x_inp.shape[-1] >= max_len - 1: 161 | logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!") 162 | best_beam = x_inp[0] 163 | 164 | if not vocode: return best_beam # (seq_len,) 165 | else: raise AssertionError() 166 | -------------------------------------------------------------------------------- /mars5/diffuser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr, 3 | which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion. 4 | 5 | Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full 6 | details on how multinomial diffusion works -- thanks to the original authors! 7 | """ 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch.functional import F 12 | import numpy as np 13 | from dataclasses import dataclass 14 | from typing import Union 15 | 16 | # -------------- Multinomial utility functions ----------- 17 | 18 | MIN_LOG_ARG = 1e-7 # originally was 1e-40 19 | 20 | def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30)) 21 | 22 | def log_add_exp(a, b): 23 | maximum = torch.max(a, b) 24 | return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) 25 | 26 | def extract(a: Tensor, t, x_shape): 27 | """ Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then 28 | broadcast it to number of dims in `x_shape`. 29 | """ 30 | b, *_ = t.shape 31 | out = a.gather(-1, t) 32 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 33 | 34 | def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32): 35 | """ Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """ 36 | assert x.max().item() < num_classes, \ 37 | f'Error: {x.max().item()} >= {num_classes}' 38 | x_onehot = F.one_hot(x, num_classes) 39 | if dim == 1: 40 | permute_order = (0, -1) + tuple(range(1, len(x.size()))) 41 | x_onehot = x_onehot.permute(permute_order) 42 | else: 43 | pass 44 | 45 | log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) # so min(log_x) will be -30 46 | 47 | return log_x 48 | 49 | def sum_except_batch(x: Tensor, num_dims=1) -> Tensor: 50 | ''' 51 | Sums all dimensions except the first. 52 | Args: 53 | x: Tensor, shape (batch_size, ...) 54 | num_dims: int, number of batch dims (default=1) 55 | Returns: 56 | x_sum: Tensor, shape (batch_size,) 57 | ''' 58 | return x.reshape(*x.shape[:num_dims], -1).sum(-1) 59 | 60 | # -------------- Multinomial diffusion class ------------- 61 | 62 | class MultinomialDiffusion(): 63 | def __init__(self, num_classes, timesteps=100, diffusion_s=0.008, 64 | loss_type='vb_stochastic', parametrization='x0', 65 | dtype=torch.float32, 66 | device='cpu'): 67 | super(MultinomialDiffusion, self).__init__() 68 | assert loss_type in ('vb_stochastic',) 69 | assert parametrization in ('x0', 'direct') 70 | 71 | self.num_classes = num_classes 72 | self.loss_type = loss_type 73 | self.num_timesteps = timesteps 74 | self.parametrization = parametrization 75 | 76 | alphas = self.cosine_beta_schedule(timesteps, diffusion_s) 77 | 78 | alphas = alphas.to(torch.float64) 79 | log_alpha = alphas.log() 80 | log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1) 81 | 82 | log_1_min_alpha = log_1_min_a(log_alpha) # = log(betas) 83 | 84 | log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) # = log(1- \bar{a}) 85 | a = log_add_exp(log_alpha, log_1_min_alpha) # log(1-beta + beta) = log(1) = 0 86 | 87 | assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 88 | assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 89 | assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5 90 | 91 | # Convert to float32 and register buffers. 92 | self.log_alpha = log_alpha.to(dtype).to(device) 93 | self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device) 94 | self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device) 95 | self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device) 96 | 97 | @staticmethod 98 | def cosine_beta_schedule(timesteps, s=0.008) -> Tensor: 99 | """ 100 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672 . 101 | Returns alpha parameters, NOT Beta 102 | """ 103 | steps = timesteps + 1 104 | x = torch.linspace(0, timesteps, steps) 105 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 106 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 107 | alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) 108 | alphas = torch.clamp(alphas, 0.001, 1.0) 109 | return torch.sqrt(alphas) 110 | 111 | def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor: 112 | """ Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`. 113 | Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones) 114 | """ 115 | kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim) 116 | return kl 117 | 118 | def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: 119 | """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain 120 | given `log_x_t` as log one-hot encoding of x_t. 121 | 122 | Recall due to symmetry property we can compute 123 | this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) 124 | """ 125 | dt = log_x_t.dtype 126 | log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) 127 | log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) 128 | 129 | # alpha_t * E[xt] + (1 - alpha_t) 1 / K 130 | log_probs = log_add_exp( 131 | log_x_t + log_alpha_t, 132 | log_1_min_alpha_t - np.log(self.num_classes) 133 | ) 134 | return log_probs 135 | 136 | def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor: 137 | """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain 138 | given `log_x_t` as log one-hot encoding of x_t. 139 | 140 | Recall due to symmetry property we can compute 141 | this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) 142 | """ 143 | dt = log_x_t.dtype 144 | log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) 145 | log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) 146 | 147 | # Magic 148 | xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device) 149 | aa=log_x_t.shape[1]*(c/jump_len) 150 | sig = 1/(1+torch.exp(-(xax-aa+20)/8)) 151 | log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0) 152 | log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t 153 | 154 | # alpha_t * E[xt] + (1 - alpha_t) 1 / K 155 | log_probs = log_add_exp( 156 | log_x_t + log_alpha_t, 157 | log_1_min_alpha_t - np.log(self.num_classes) 158 | ) 159 | return log_probs 160 | 161 | def q_pred(self, log_x_start: Tensor, t) -> Tensor: 162 | """ Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain, 163 | given `log_x_start` of log probs of x_0. 164 | """ 165 | dt = log_x_start.dtype 166 | log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt) 167 | log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt) 168 | 169 | log_probs = log_add_exp( 170 | log_x_start + log_cumprod_alpha_t, 171 | log_1_min_cumprod_alpha - np.log(self.num_classes) 172 | ) 173 | 174 | return log_probs 175 | 176 | def q_posterior(self, log_x_start, log_x_t, t): 177 | """ Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)` 178 | where q(xt | xt-1, x0) = q(xt | xt-1). 179 | """ 180 | # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) 181 | # where q(xt | xt-1, x0) = q(xt | xt-1). 182 | 183 | t_minus_1 = t - 1 184 | # Remove negative values, will not be used anyway for final decoder 185 | t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) 186 | log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) # log( q(x_{t-1} | x_0) ) 187 | # if t == 0, then log( q(x_0 | x_0) ) = log( one_hot(x_0) ), not even random at that point. 188 | # so, where t == 0 189 | num_axes = (1,) * (len(log_x_start.size()) - 1) 190 | t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) # broadcast to non-batch axes 191 | log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) 192 | # where it is zero, replace 193 | # with log one-hot encoding of x0. 194 | 195 | # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! 196 | # Not very easy to see why this is true. But it is :) 197 | # log_EV_qxtmin_x0 ~ q(x_{t-1} | x_0) 198 | # q_pred_one_timestep(log_x_t, t) ~ q(x_t | x_{t-1}) (which due to symmetry can be computed using x_t) 199 | unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) # numerator of bayes 200 | 201 | # approximate denominator with just a normalizing sum. 202 | log_EV_xtmin_given_xt_given_xstart = \ 203 | unnormed_logprobs \ 204 | - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) 205 | 206 | return log_EV_xtmin_given_xt_given_xstart 207 | 208 | def p_pred(self, log_x_t, t, log_x0_pred): 209 | """ Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by 210 | log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by 211 | `log_x_t` of shape `(bs, ..., K)` 212 | """ 213 | # log_x_recon = self.predict_start(log_x, t=t) # model itself predicts x_0 214 | # log_x0_pred 215 | log_model_pred = self.q_posterior( 216 | log_x_start=log_x0_pred, log_x_t=log_x_t, t=t) 217 | return log_model_pred 218 | 219 | def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor: 220 | """ Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified 221 | by `dim`. 222 | 223 | Returns sampled long indices of shape `(bs, ...)` 224 | """ 225 | uniform = torch.rand_like(logprobs) 226 | gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG)) 227 | sample = (gumbel_noise + logprobs).argmax(dim=dim) 228 | return sample 229 | 230 | def q_sample(self, log_x_start, t): 231 | """ Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """ 232 | log_EV_qxt_x0 = self.q_pred(log_x_start, t) 233 | sample = self.log_sample_categorical(log_EV_qxt_x0) 234 | # log_sample = index_to_log_onehot(sample, self.num_classes) 235 | 236 | return sample #log_sample 237 | 238 | def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t, 239 | detach_mean=False, include_kl_prior=True): 240 | """ Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`. 241 | Parameters: 242 | - `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K) 243 | - `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K) 244 | - `t`: diffusion timestep (bs,) 245 | - `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}. 246 | - `include_kl_prior`: add last two terms to model loss (does not change optimization problem). 247 | """ 248 | dtype = log_x_start.dtype 249 | log_true_prob = self.q_posterior( 250 | log_x_start=log_x_start, log_x_t=log_x_t, t=t) 251 | 252 | log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred) 253 | 254 | if detach_mean: 255 | log_model_prob = log_model_prob.detach() 256 | 257 | kl = self.multinomial_kl(log_true_prob, log_model_prob) 258 | kl = sum_except_batch(kl) 259 | 260 | # Add L_0, -log(p(x_0 | x_1)) 261 | decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1) 262 | decoder_nll = sum_except_batch(decoder_nll) 263 | 264 | mask = (t == torch.zeros_like(t)).to(dtype) 265 | loss = mask * decoder_nll + (1. - mask) * kl # only add L0 if t == 0. 266 | 267 | if include_kl_prior: 268 | pt = torch.ones_like(t, dtype=dtype) 269 | kl_prior = self.kl_prior(log_x_start) 270 | loss = (kl) + kl_prior 271 | 272 | return loss 273 | 274 | def kl_prior(self, log_x_start: Tensor) -> Tensor: 275 | """ This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which 276 | by some math (see wiki for KL div relation to conditional entropy). 277 | So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution. 278 | 279 | Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,) 280 | """ 281 | b = log_x_start.size(0) 282 | device = log_x_start.device 283 | ones = torch.ones(b, device=device, dtype=torch.long) 284 | 285 | log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) # q(x_T | x_0) 286 | log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) # log(1/K), broadcast to q(x_T|x_0) shape 287 | 288 | kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) 289 | return sum_except_batch(kl_prior) 290 | 291 | 292 | def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32): 293 | x = F.one_hot(x, num_classes=vocab_size).to(dtype) 294 | x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1) 295 | return x 296 | 297 | 298 | # ------------------------------ 299 | # Functions adapted from the full 300 | 301 | 302 | @dataclass 303 | class DSH(): 304 | # Diffusion Sampling Hyperparameters [DSH] (Section 4) 305 | jump_len: int = 1 # j in RePaint paper [default 10] (Section 4.1) 306 | jump_n_sample: int = 1 # r in RePaint paper [default 10] (Section 4.1) 307 | last_greedy: bool = False # whether to not sample at t=0, but take argmax prediction. [default False] 308 | x_0_temp: float = 1.0 # reweight temp for model prediction of x0 309 | guidance_w: float = 1.0 # classifier free guidance weight [default 1.5] (Section 4.3) 310 | enable_kevin_scaled_inference: bool = True # sequentially progressive diffusion [default True] (Section 4.2) 311 | T_override: Union[None, int] = None # allow variable transcription sizes during inference (Section 4.4) 312 | 313 | deep_clone: bool = False # whether to do deep clone. 314 | q0_override_steps: int = 0 # number of steps that we allow overriding the input quant level 0 inputs. 315 | progress: bool = False # whether to show progress bar 316 | 317 | 318 | def get_schedule(t_T, jump_len=10, jump_n_sample=10): 319 | jumps = {} 320 | for j in range(0, t_T - jump_len, jump_len): 321 | jumps[j] = jump_n_sample - 1 322 | t = t_T 323 | ts = [] 324 | while t >= 1: 325 | t = t-1 326 | ts.append(t) 327 | if jumps.get(t, 0) > 0: 328 | jumps[t] = jumps[t] - 1 329 | for _ in range(jump_len): 330 | t = t + 1 331 | ts.append(t) 332 | ts.append(-1) 333 | return ts 334 | 335 | 336 | def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH): 337 | """Simple forward diffusion process p""" 338 | log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype) 339 | if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len) 340 | else: x = diff.q_pred_one_timestep(log_x_t, t) 341 | x = diff.log_sample_categorical(x) 342 | return x 343 | 344 | 345 | def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None, 346 | last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH): 347 | """Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output 348 | for t=0, but rather use the greedy argmax with `last_greedy`. 349 | """ 350 | x = batch[4] 351 | t = batch[-1] 352 | if x_known is None: x_known = torch.zeros_like(x) 353 | if m is None: m = torch.zeros_like(x) 354 | 355 | # Equation 8b 356 | # for b in batch: 357 | # print(f"{b.shape}: {b}") 358 | x_0_pred = model(*batch) # (bs, seq_len, logit_dim, n_quant) 359 | x_0_pred = x_0_pred.permute(0, 1, 3, 2) # (bs, seq_len, n_quant, dim) 360 | 361 | if dsh.guidance_w != 1: 362 | uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True) 363 | uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2) 364 | x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred 365 | 366 | x_0_pred = x_0_pred / temperature 367 | log_x_0_pred = F.log_softmax(x_0_pred, dim=-1) 368 | log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype) 369 | 370 | # print("PRE: ", log_x_t.shape, t.shape, log_x_0_pred.shape) 371 | log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) # p(x_{t-1} | x_{t}) 372 | 373 | a_t = alphas[t[0]] if alphas is not None else 0 374 | mat = torch.eye(ensemble_size, device=x.device)*(1-a_t) 375 | mat += 1/ensemble_size * a_t 376 | mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size))) 377 | log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None]) 378 | log_model_pred = torch.logsumexp(log_model_pred, dim=1) 379 | 380 | if (t==0).all() and last_greedy: # Do not sample at t=0 381 | x_tm1_unknown = log_model_pred.argmax(dim=-1) 382 | else: 383 | x_tm1_unknown = diff.log_sample_categorical(log_model_pred) 384 | 385 | # Equation 8a 386 | x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype) 387 | if (t==0).all(): # Do not sample at t=0 388 | x_tm1_known = x_known 389 | else: 390 | x_tm1_known = diff.q_sample(x_known_log, t) 391 | 392 | # Equation 8c 393 | x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long()) 394 | return x_tm1, x_0_pred 395 | 396 | 397 | 398 | @torch.inference_mode() 399 | def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16, 400 | retain_quant0: bool = True, dsh=DSH): 401 | """ If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """ 402 | 403 | # (bs=1, N), (bs, seq_len2, 8), (bs,) 404 | c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch 405 | 406 | device = c_text.device 407 | bs = c_text.shape[0] 408 | x_quant0 = x[..., 0].clone() # (bs, seq_len) 0th quant level 409 | x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device) 410 | # CRITICAL LINE: override quantization level 0 with provided quant0 level. 411 | x[..., 0] = x_quant0 412 | 413 | # RePaint paper resample scheduling 414 | times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len) 415 | 416 | x_known = torch.zeros_like(x) 417 | x_known[..., 0] = x[..., 0] # override L0 codes 418 | m = torch.zeros_like(x).bool() 419 | # (bs, seq_len, 8) 420 | m[..., 0] = True 421 | 422 | offset = 0 423 | if dsh.deep_clone: 424 | print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.", 425 | "Also assuming no padded indices in `c_codes`.") 426 | prompt = c_codes 427 | x = torch.cat((prompt, x), dim=1) # (bs=1, sl1 + sl2, 8) 428 | x_known = torch.cat((prompt, x_known), dim=1) 429 | x_padding_mask = torch.cat(( 430 | torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device), 431 | x_padding_mask), dim=-1 432 | ) 433 | # (bs=1, :up to prompt duration, all 8 codebooks) = True/masked. 434 | m = torch.cat((torch.ones_like(prompt), m), dim=1) 435 | x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1) 436 | offset = c_codes_lengths[0] 437 | 438 | print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}") 439 | 440 | c = 0 # sequentially progressive diffusion offset (Section 4.2) 441 | 442 | # ensemble bs (not in paper) 443 | alphas = torch.linspace(1, 0, T).to(device) 444 | 445 | pb = zip(times[:-1], times[1:]) 446 | if dsh.progress: 447 | from fastprogress import progress_bar 448 | pb = progress_bar(pb, total=len(times)-1) 449 | 450 | # See RePaint paper algorithm 451 | for t_last, t_cur in pb: 452 | 453 | t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last) 454 | if t_cur < t_last: 455 | if c > dsh.jump_n_sample: 456 | c = 0 457 | c += 1/dsh.jump_len 458 | 459 | # Reverse diffusion: q 460 | cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t) 461 | x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh) 462 | else: 463 | # Forward diffusion: p 464 | if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh) 465 | else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh) 466 | 467 | if retain_quant0 and dsh.q0_override_steps < t_last: 468 | x[..., 0] = x_quant0 469 | 470 | # crop offset: 471 | x = x[:, offset:] 472 | return x 473 | -------------------------------------------------------------------------------- /mars5/minbpe/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the base Tokenizer class and a few common helper functions. 3 | The base class also contains the (common) save/load functionality. 4 | It would be possible to be a lot more strict about the interface and 5 | e.g. isolating all regex/pattern parts to the RegexTokenizer, but 6 | some concessions are made for simplicity. 7 | """ 8 | import unicodedata 9 | import io 10 | 11 | # ----------------------------------------------------------------------------- 12 | # a few helper functions useful for both BasicTokenizer and RegexTokenizer 13 | 14 | def get_stats(ids, counts=None): 15 | """ 16 | Given a list of integers, return a dictionary of counts of consecutive pairs 17 | Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} 18 | Optionally allows to update an existing dictionary of counts 19 | """ 20 | counts = {} if counts is None else counts 21 | for pair in zip(ids, ids[1:]): # iterate consecutive elements 22 | counts[pair] = counts.get(pair, 0) + 1 23 | return counts 24 | 25 | 26 | def merge(ids, pair, idx): 27 | """ 28 | In the list of integers (ids), replace all consecutive occurrences 29 | of pair with the new integer token idx 30 | Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] 31 | """ 32 | newids = [] 33 | i = 0 34 | while i < len(ids): 35 | # if not at the very last position AND the pair matches, replace it 36 | if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: 37 | newids.append(idx) 38 | i += 2 39 | else: 40 | newids.append(ids[i]) 41 | i += 1 42 | return newids 43 | 44 | # first two helper functions... 45 | def replace_control_characters(s: str) -> str: 46 | # we don't want to print control characters 47 | # which distort the output (e.g. \n or much worse) 48 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 49 | # http://www.unicode.org/reports/tr44/#GC_Values_Table 50 | chars = [] 51 | for ch in s: 52 | if unicodedata.category(ch)[0] != "C": 53 | chars.append(ch) # this character is ok 54 | else: 55 | chars.append(f"\\u{ord(ch):04x}") # escape 56 | return "".join(chars) 57 | 58 | def render_token(t: bytes) -> str: 59 | # pretty print a token, escaping control characters 60 | s = t.decode('utf-8', errors='replace') 61 | s = replace_control_characters(s) 62 | return s 63 | 64 | # ----------------------------------------------------------------------------- 65 | # the base Tokenizer class 66 | 67 | class Tokenizer: 68 | """Base class for Tokenizers""" 69 | 70 | def __init__(self): 71 | # default: vocab size of 256 (all bytes), no merges, no patterns 72 | self.merges = {} # (int, int) -> int 73 | self.pattern = "" # str 74 | self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257} 75 | self.vocab = self._build_vocab() # int -> bytes 76 | 77 | def train(self, text, vocab_size, verbose=False): 78 | # Tokenizer can train a vocabulary of size vocab_size from text 79 | raise NotImplementedError 80 | 81 | def encode(self, text): 82 | # Tokenizer can encode a string into a list of integers 83 | raise NotImplementedError 84 | 85 | def decode(self, ids): 86 | # Tokenizer can decode a list of integers into a string 87 | raise NotImplementedError 88 | 89 | def _build_vocab(self): 90 | # vocab is simply and deterministically derived from merges 91 | vocab = {idx: bytes([idx]) for idx in range(256)} 92 | for (p0, p1), idx in self.merges.items(): 93 | vocab[idx] = vocab[p0] + vocab[p1] 94 | for special, idx in self.special_tokens.items(): 95 | vocab[idx] = special.encode("utf-8") 96 | return vocab 97 | 98 | def save(self, file_prefix): 99 | """ 100 | Saves two files: file_prefix.vocab and file_prefix.model 101 | This is inspired (but not equivalent to!) sentencepiece's model saving: 102 | - model file is the critical one, intended for load() 103 | - vocab file is just a pretty printed version for human inspection only 104 | """ 105 | # write the model: to be used in load() later 106 | model_file = file_prefix + ".model" 107 | with open(model_file, 'w') as f: 108 | # write the version, pattern and merges, that's all that's needed 109 | f.write("minbpe v1\n") 110 | f.write(f"{self.pattern}\n") 111 | # write the special tokens, first the number of them, then each one 112 | f.write(f"{len(self.special_tokens)}\n") 113 | for special, idx in self.special_tokens.items(): 114 | f.write(f"{special} {idx}\n") 115 | # the merges dict 116 | for idx1, idx2 in self.merges: 117 | f.write(f"{idx1} {idx2}\n") 118 | # write the vocab: for the human to look at 119 | vocab_file = file_prefix + ".vocab" 120 | inverted_merges = {idx: pair for pair, idx in self.merges.items()} 121 | with open(vocab_file, "w", encoding="utf-8") as f: 122 | for idx, token in self.vocab.items(): 123 | # note: many tokens may be partial utf-8 sequences 124 | # and cannot be decoded into valid strings. Here we're using 125 | # errors='replace' to replace them with the replacement char �. 126 | # this also means that we couldn't possibly use .vocab in load() 127 | # because decoding in this way is a lossy operation! 128 | s = render_token(token) 129 | # find the children of this token, if any 130 | if idx in inverted_merges: 131 | # if this token has children, render it nicely as a merge 132 | idx0, idx1 = inverted_merges[idx] 133 | s0 = render_token(self.vocab[idx0]) 134 | s1 = render_token(self.vocab[idx1]) 135 | f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") 136 | else: 137 | # otherwise this is leaf token, just print it 138 | # (this should just be the first 256 tokens, the bytes) 139 | f.write(f"[{s}] {idx}\n") 140 | 141 | def load(self, model_file): 142 | """Inverse of save() but only for the model file""" 143 | if type(model_file) == io.BytesIO: open_fn = io.TextIOWrapper 144 | else: 145 | open_fn = open 146 | model_file = str(model_file) 147 | assert model_file.endswith(".model") 148 | # read the model file 149 | merges = {} 150 | special_tokens = {} 151 | idx = 256 152 | with open_fn(model_file, encoding="utf-8") as f: 153 | # read the version 154 | version = f.readline().strip() 155 | assert version == "minbpe v1" 156 | # read the pattern 157 | self.pattern = f.readline().strip() 158 | # read the special tokens 159 | num_special = int(f.readline().strip()) 160 | for _ in range(num_special): 161 | special, special_idx = f.readline().strip().split() 162 | special_tokens[special] = int(special_idx) 163 | # read the merges 164 | for line in f: 165 | idx1, idx2 = map(int, line.split()) 166 | merges[(idx1, idx2)] = idx 167 | idx += 1 168 | self.merges = merges 169 | self.special_tokens = special_tokens 170 | self.vocab = self._build_vocab() 171 | -------------------------------------------------------------------------------- /mars5/minbpe/codebook.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal (byte-level) Byte Pair Encoding tokenizer. 3 | 4 | Unlike RegexTokenizer: 5 | - Operates on integer codes from an encodec codebook. 6 | """ 7 | 8 | import regex as re 9 | from .base import Tokenizer, get_stats, merge 10 | import io 11 | 12 | 13 | class CodebookTokenizer(Tokenizer): 14 | 15 | def __init__(self, pattern=None, codebook_size=1024): 16 | """ 17 | - pattern: optional string to override the default (GPT-4 split pattern) 18 | - special_tokens: str -> int dictionary of special tokens 19 | example: {'<|endoftext|>': 100257} 20 | """ 21 | self.merges = {} # (int, int) -> int 22 | self.pattern = pattern 23 | self.compiled_pattern = re.compile(self.pattern) 24 | self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257} 25 | self.inverse_special_tokens = {} 26 | self.codebook_size = codebook_size 27 | self.vocab = self._build_vocab() # int -> bytes 28 | 29 | def train(self, text, vocab_size, verbose=False): 30 | assert vocab_size >= self.codebook_size 31 | num_merges = vocab_size - self.codebook_size 32 | 33 | # split the text up into text chunks 34 | # text is a continuous signal, there is no splitting it up. 35 | text_chunks = [text,] # re.findall(self.compiled_pattern, text) 36 | 37 | # input text preprocessing 38 | ids = [[int(idx) for idx in ch.split(' ')] for ch in text_chunks] 39 | 40 | # iteratively merge the most common pairs to create new tokens 41 | merges = {} # (int, int) -> int 42 | # vocab = {idx: bytes([idx]) for idx in range(self.codebook_size)} # idx -> bytes 43 | vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} # idx -> bytes 44 | 45 | for i in range(num_merges): 46 | # count the number of times every consecutive pair appears 47 | stats = {} 48 | for chunk_ids in ids: 49 | # passing in stats will update it in place, adding up counts 50 | get_stats(chunk_ids, stats) 51 | # find the pair with the highest count 52 | pair = max(stats, key=stats.get) 53 | # mint a new token: assign it the next available id 54 | idx = self.codebook_size + i 55 | # replace all occurrences of pair in ids with idx 56 | ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] 57 | # save the merge 58 | merges[pair] = idx 59 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 60 | # prints 61 | if verbose: 62 | print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 63 | 64 | # save class variables 65 | self.merges = merges # used in encode() 66 | self.vocab = vocab # used in decode() 67 | 68 | def register_special_tokens(self, special_tokens): 69 | # special_tokens is a dictionary of str -> int 70 | # example: {"<|endoftext|>": 100257} 71 | self.special_tokens = special_tokens 72 | self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} 73 | 74 | def decode(self, ids): 75 | # given ids (list of integers), return Python string 76 | part_bytes = [] 77 | for idx in ids: 78 | if idx in self.vocab: 79 | part_bytes.append(self.vocab[idx]) 80 | elif idx in self.inverse_special_tokens: 81 | part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) 82 | else: 83 | raise ValueError(f"invalid token id: {idx}") 84 | text_bytes = b"".join(part_bytes) 85 | text = text_bytes.decode("utf-8", errors="replace") 86 | return text 87 | 88 | def decode_int(self, ids) -> list[int]: 89 | ret: str = self.decode(ids) 90 | for s in self.special_tokens: 91 | ret = ret.replace(s, ' ' + s + ' ') 92 | ret = ret.strip() 93 | ret = [int(t) if t[0].isnumeric() else t for t in ret.split(' ') if len(t) > 0] 94 | return ret 95 | 96 | def _encode_chunk(self, text_bytes): 97 | # return the token ids 98 | # let's begin. first, convert all bytes to integers in range 0..255 99 | ids = list(text_bytes) 100 | while len(ids) >= 2: 101 | # find the pair with the lowest merge index 102 | stats = get_stats(ids) 103 | pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) 104 | # subtle: if there are no more merges available, the key will 105 | # result in an inf for every single pair, and the min will be 106 | # just the first pair in the list, arbitrarily 107 | # we can detect this terminating case by a membership check 108 | if pair not in self.merges: 109 | break # nothing else can be merged anymore 110 | # otherwise let's merge the best pair (lowest merge index) 111 | idx = self.merges[pair] 112 | ids = merge(ids, pair, idx) 113 | return ids 114 | 115 | def encode_ordinary(self, text): 116 | """Encoding that ignores any special tokens.""" 117 | # split text into chunks of text by categories defined in regex pattern 118 | text_chunks = [text,] #re.findall(self.compiled_pattern, text) 119 | # all chunks of text are encoded separately, then results are joined 120 | ids = [] 121 | for chunk in text_chunks: 122 | # chunk_bytes = chunk.encode("utf-8") # raw bytes 123 | chunk_ids = [int(idx) for idx in chunk.split(' ')] 124 | chunk_ids = self._encode_chunk(chunk_ids) 125 | ids.extend(chunk_ids) 126 | return ids 127 | 128 | def encode(self, text, allowed_special="none_raise"): 129 | """ 130 | Unlike encode_ordinary, this function handles special tokens. 131 | allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens 132 | if none_raise, then an error is raised if any special token is encountered in text 133 | this is the default tiktoken behavior right now as well 134 | any other behavior is either annoying, or a major footgun 135 | """ 136 | # decode the user desire w.r.t. handling of special tokens 137 | special = None 138 | if allowed_special == "all": 139 | special = self.special_tokens 140 | elif allowed_special == "none": 141 | special = {} 142 | elif allowed_special == "none_raise": 143 | special = {} 144 | assert all(token not in text for token in self.special_tokens) 145 | elif isinstance(allowed_special, set): 146 | special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} 147 | else: 148 | raise ValueError(f"allowed_special={allowed_special} not understood") 149 | if not special: 150 | # shortcut: if no special tokens, just use the ordinary encoding 151 | return self.encode_ordinary(text) 152 | # otherwise, we have to be careful with potential special tokens in text 153 | # we handle special tokens by splitting the text 154 | # based on the occurrence of any exact match with any of the special tokens 155 | # we can use re.split for this. note that surrounding the pattern with () 156 | # makes it into a capturing group, so the special tokens will be included 157 | special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" 158 | special_chunks = re.split(special_pattern, text) 159 | # now all the special characters are separated from the rest of the text 160 | # all chunks of text are encoded separately, then results are joined 161 | ids = [] 162 | for part in special_chunks: 163 | part = part.strip() 164 | if len(part) == 0: continue 165 | if part in special: 166 | # this is a special token, encode it separately as a special case 167 | ids.append(special[part]) 168 | else: 169 | # this is an ordinary sequence, encode it normally 170 | ids.extend(self.encode_ordinary(part)) 171 | return ids 172 | 173 | 174 | def load(self, model_file): 175 | """Inverse of save() but only for the model file""" 176 | 177 | if type(model_file) == io.BytesIO: 178 | open_fn = io.TextIOWrapper 179 | else: 180 | open_fn = open 181 | model_file = str(model_file) 182 | assert model_file.endswith(".model") 183 | # read the model file 184 | merges = {} 185 | special_tokens = {} 186 | idx = self.codebook_size 187 | with open_fn(model_file, encoding="utf-8") as f: 188 | # read the version 189 | version = f.readline().strip() 190 | assert version == "minbpe v1" 191 | # read the pattern 192 | self.pattern = f.readline().strip() 193 | # read the special tokens 194 | num_special = int(f.readline().strip()) 195 | for _ in range(num_special): 196 | special, special_idx = f.readline().strip().split() 197 | special_tokens[special] = int(special_idx) 198 | # read the merges 199 | for line in f: 200 | # print(line) 201 | idx1, idx2 = map(int, line.split()) 202 | merges[(idx1, idx2)] = idx 203 | idx += 1 204 | self.merges = merges 205 | self.special_tokens = special_tokens 206 | self.vocab = self._build_vocab() 207 | 208 | 209 | def _build_vocab(self): 210 | # vocab is simply and deterministically derived from merges 211 | vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} 212 | for (p0, p1), idx in self.merges.items(): 213 | vocab[idx] = vocab[p0] + vocab[p1] 214 | for special, idx in self.special_tokens.items(): 215 | vocab[idx] = special.encode("utf-8") 216 | return vocab -------------------------------------------------------------------------------- /mars5/minbpe/regex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal (byte-level) Byte Pair Encoding tokenizer. 3 | 4 | Algorithmically follows along the GPT tokenizer: 5 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 6 | 7 | Unlike BasicTokenizer: 8 | - RegexTokenizer handles an optional regex splitting pattern. 9 | - RegexTokenizer handles optional special tokens. 10 | """ 11 | 12 | import regex as re 13 | from .base import Tokenizer, get_stats, merge 14 | 15 | 16 | # the main GPT text split patterns, see 17 | # https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py 18 | GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 19 | GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 20 | 21 | 22 | class RegexTokenizer(Tokenizer): 23 | 24 | def __init__(self, pattern=None): 25 | """ 26 | - pattern: optional string to override the default (GPT-4 split pattern) 27 | - special_tokens: str -> int dictionary of special tokens 28 | example: {'<|endoftext|>': 100257} 29 | """ 30 | super().__init__() 31 | self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern 32 | self.compiled_pattern = re.compile(self.pattern) 33 | self.special_tokens = {} 34 | self.inverse_special_tokens = {} 35 | 36 | def train(self, text, vocab_size, verbose=False): 37 | assert vocab_size >= 256 38 | num_merges = vocab_size - 256 39 | 40 | # split the text up into text chunks 41 | text_chunks = re.findall(self.compiled_pattern, text) 42 | 43 | # input text preprocessing 44 | ids = [list(ch.encode("utf-8")) for ch in text_chunks] 45 | 46 | # iteratively merge the most common pairs to create new tokens 47 | merges = {} # (int, int) -> int 48 | vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes 49 | for i in range(num_merges): 50 | # count the number of times every consecutive pair appears 51 | stats = {} 52 | for chunk_ids in ids: 53 | # passing in stats will update it in place, adding up counts 54 | get_stats(chunk_ids, stats) 55 | # find the pair with the highest count 56 | pair = max(stats, key=stats.get) 57 | # mint a new token: assign it the next available id 58 | idx = 256 + i 59 | # replace all occurrences of pair in ids with idx 60 | ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] 61 | # save the merge 62 | merges[pair] = idx 63 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 64 | # prints 65 | if verbose: 66 | print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 67 | 68 | # save class variables 69 | self.merges = merges # used in encode() 70 | self.vocab = vocab # used in decode() 71 | 72 | def register_special_tokens(self, special_tokens): 73 | # special_tokens is a dictionary of str -> int 74 | # example: {"<|endoftext|>": 100257} 75 | self.special_tokens = special_tokens 76 | self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} 77 | 78 | def decode(self, ids): 79 | # given ids (list of integers), return Python string 80 | part_bytes = [] 81 | for idx in ids: 82 | if idx in self.vocab: 83 | part_bytes.append(self.vocab[idx]) 84 | elif idx in self.inverse_special_tokens: 85 | part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) 86 | else: 87 | raise ValueError(f"invalid token id: {idx}") 88 | text_bytes = b"".join(part_bytes) 89 | text = text_bytes.decode("utf-8", errors="replace") 90 | return text 91 | 92 | def _encode_chunk(self, text_bytes): 93 | # return the token ids 94 | # let's begin. first, convert all bytes to integers in range 0..255 95 | ids = list(text_bytes) 96 | while len(ids) >= 2: 97 | # find the pair with the lowest merge index 98 | stats = get_stats(ids) 99 | pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) 100 | # subtle: if there are no more merges available, the key will 101 | # result in an inf for every single pair, and the min will be 102 | # just the first pair in the list, arbitrarily 103 | # we can detect this terminating case by a membership check 104 | if pair not in self.merges: 105 | break # nothing else can be merged anymore 106 | # otherwise let's merge the best pair (lowest merge index) 107 | idx = self.merges[pair] 108 | ids = merge(ids, pair, idx) 109 | return ids 110 | 111 | def encode_ordinary(self, text): 112 | """Encoding that ignores any special tokens.""" 113 | # split text into chunks of text by categories defined in regex pattern 114 | text_chunks = re.findall(self.compiled_pattern, text) 115 | # all chunks of text are encoded separately, then results are joined 116 | ids = [] 117 | for chunk in text_chunks: 118 | chunk_bytes = chunk.encode("utf-8") # raw bytes 119 | chunk_ids = self._encode_chunk(chunk_bytes) 120 | ids.extend(chunk_ids) 121 | return ids 122 | 123 | def encode(self, text, allowed_special="none_raise"): 124 | """ 125 | Unlike encode_ordinary, this function handles special tokens. 126 | allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens 127 | if none_raise, then an error is raised if any special token is encountered in text 128 | this is the default tiktoken behavior right now as well 129 | any other behavior is either annoying, or a major footgun 130 | """ 131 | # decode the user desire w.r.t. handling of special tokens 132 | special = None 133 | if allowed_special == "all": 134 | special = self.special_tokens 135 | elif allowed_special == "none": 136 | special = {} 137 | elif allowed_special == "none_raise": 138 | special = {} 139 | assert all(token not in text for token in self.special_tokens) 140 | elif isinstance(allowed_special, set): 141 | special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} 142 | else: 143 | raise ValueError(f"allowed_special={allowed_special} not understood") 144 | if not special: 145 | # shortcut: if no special tokens, just use the ordinary encoding 146 | return self.encode_ordinary(text) 147 | # otherwise, we have to be careful with potential special tokens in text 148 | # we handle special tokens by splitting the text 149 | # based on the occurrence of any exact match with any of the special tokens 150 | # we can use re.split for this. note that surrounding the pattern with () 151 | # makes it into a capturing group, so the special tokens will be included 152 | special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" 153 | special_chunks = re.split(special_pattern, text) 154 | # now all the special characters are separated from the rest of the text 155 | # all chunks of text are encoded separately, then results are joined 156 | ids = [] 157 | for part in special_chunks: 158 | if part in special: 159 | # this is a special token, encode it separately as a special case 160 | ids.append(special[part]) 161 | else: 162 | # this is an ordinary sequence, encode it normally 163 | ids.extend(self.encode_ordinary(part)) 164 | return ids 165 | -------------------------------------------------------------------------------- /mars5/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from .nn_future import (FNNSwiGLU, MistralTransformer, ModelArgs, 10 | RotatingBufferCache, SinePositionalEmbedding) 11 | from .utils import construct_padding_mask, length_to_mask 12 | 13 | LAYERNORM_EPS = 4e-5 14 | 15 | # ------------------------ 16 | # Code adapted from OpenAI guided diffusion repo 17 | 18 | def timestep_embedding(timesteps, dim, max_period=10000, dtype=torch.float32): 19 | """ 20 | Create sinusoidal timestep embeddings. 21 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 22 | These may be fractional. 23 | :param dim: the dimension of the output. 24 | :param max_period: controls the minimum frequency of the embeddings. 25 | :return: an [N x dim] Tensor of positional embeddings. 26 | """ 27 | half = dim // 2 28 | freqs = torch.exp( 29 | -math.log(max_period) * torch.arange(start=0, end=half) / half 30 | ).to(device=timesteps.device) 31 | args = timesteps[:, None].float() * freqs[None] 32 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) 33 | if dim % 2: 34 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 35 | return embedding 36 | 37 | 38 | # -------------------------------- 39 | # autoregressive codec language model 40 | 41 | 42 | class CodecLM(nn.Module): 43 | 44 | def __init__(self, n_vocab, dim=1536, nhead=24, n_layers=26, n_spk_layers=2, dim_ff_scale=None, sliding_window=3000) -> None: 45 | super().__init__() 46 | 47 | if dim_ff_scale is None: hidden_dim = int(dim*4*(3/4)) 48 | else: hidden_dim = int(dim*dim_ff_scale) 49 | 50 | self.cfg = ModelArgs(n_vocab, dim=dim, n_layers=n_layers, n_heads=nhead, n_kv_heads=nhead, hidden_dim=hidden_dim, sliding_window=sliding_window) 51 | self.ar = MistralTransformer(self.cfg) 52 | 53 | self.embed = nn.Embedding(n_vocab, dim) 54 | 55 | # --- spk embedding network 56 | dim_ff = int(dim*4*(3/4)) 57 | self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) 58 | self.ref_chunked_emb = ChunkedEmbedding(1024 + 1, 8, dim) # add 1 for pad idx 59 | self.spk_identity_emb = nn.Embedding(1, dim) 60 | # define custom decoder 61 | encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, 62 | activation=FNNSwiGLU(dim, dim_ff), dropout=0, 63 | batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS) 64 | encoder_layer.linear1 = nn.Identity() 65 | self.spk_encoder = nn.TransformerEncoder(encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS)) 66 | # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder 67 | for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) 68 | 69 | 70 | @torch.inference_mode 71 | def get_spk_embedding(self, spk_reference, c_codes_lengths=None) -> Tensor: 72 | """ Gets speaker reference embeddings using `spk_reference` codes of shape (bs, seq_len, n_codebooks). """ 73 | bs = spk_reference.shape[0] 74 | if bs != 1: 75 | raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.") 76 | spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim) 77 | spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) 78 | 79 | spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) 80 | # add pos encoding 81 | spk_seq = self.pos_embedding(spk_seq) 82 | # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry) 83 | src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) 84 | src_key_padding_mask = torch.cat(( 85 | # append a zero here since we DO want to attend to initial position. 86 | torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), 87 | src_key_padding_mask 88 | ), 89 | dim=1) 90 | # pass through transformer 91 | res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). 92 | return res.squeeze(1) 93 | 94 | 95 | def forward(self, x: Tensor, x_padding_mask: Optional[Tensor] = None, spk_reference: Optional[Tensor] = None, 96 | cache: Optional[RotatingBufferCache] = None, counter: int = 0) -> Tensor: 97 | """ Inputs: 98 | - `x`: (bs, seq_len, vocab_size) 99 | - `x_padding_mask`: (bs, seq_len) mask for each input, True for positions to *ignore*, False otherwise. 100 | Note that since this is an autoregressive model, this doesn't actually matter for infernece, so it is ignored at inference. 101 | - `spk_reference`: (bs, seq_len, n_codebooks) corresponding to the speaker reference to clone from. 102 | - `cache` and `counter`: used for kv caching, optional. 103 | 104 | Returns `x` of same shape (bs, seq_len, dim) 105 | """ 106 | x = self.embed(x) 107 | 108 | # --- speaker reference/embedding 109 | if spk_reference is not None: 110 | # compute ref 111 | bs = spk_reference.shape[0] 112 | spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim) 113 | spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) 114 | 115 | spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) 116 | # add pos encoding 117 | spk_seq = self.pos_embedding(spk_seq) 118 | # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry) 119 | src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) 120 | src_key_padding_mask = torch.cat(( 121 | # append a zero here since we DO want to attend to initial position. 122 | torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), 123 | src_key_padding_mask 124 | ), 125 | dim=1) 126 | # pass through transformer 127 | res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). 128 | 129 | x = torch.cat([res, x], dim=1) 130 | 131 | positions = torch.arange(0, x.shape[1], device=x.device, dtype=torch.long) 132 | if cache is not None and counter != 1: 133 | # using only the last token to predict the next one 134 | x = x[:,-1,:].unsqueeze(1) 135 | positions = positions[-1:] 136 | 137 | x = self.ar(x, positions, cache) # (bs, seq_len, vocab) 138 | if spk_reference is not None and (cache is None or counter == 1): 139 | x = x[:, 1:] # strip out the first output token corresponding to the speaker embedding token. 140 | 141 | return x 142 | 143 | 144 | # ------------------------- 145 | # residual discrete diffusion model 146 | 147 | class ChunkedEmbedding(nn.Module): 148 | 149 | def __init__(self, codebook_size: int, n_quantizer: int, dim: int) -> None: 150 | super().__init__() 151 | assert dim % n_quantizer == 0, f"ChunkedEmbedding output dim ({dim}) must be divisible by n_quant {n_quantizer}" 152 | self.embs = nn.ModuleList([nn.Embedding(codebook_size, dim//n_quantizer) for _ in range(n_quantizer)]) 153 | 154 | def forward(self, x: Tensor) -> Tensor: 155 | """ Embeds each codebook index in `x` (bs, seq_len, n_quantizer) to an embedding vector, concatenating results. 156 | Returns output of shape (bs, seq_len, dim) 157 | """ 158 | y = torch.cat([self.embs[i](x[..., i]) for i in range(x.shape[-1])], dim=-1) 159 | return y 160 | 161 | 162 | 163 | class ResidualTransformer(nn.Module): 164 | 165 | def __init__(self, n_text_vocab, n_quant=1024, dim=1024, nhead=16, 166 | enc_layers=8, dec_layers=16, n_spk_layers=3, 167 | c_quant_levels=8, pred_quant_levels=8, 168 | t_emb_dim=1024, norm_first=True, p_cond_drop=0.1, dropout=0) -> None: 169 | super().__init__() 170 | 171 | self.cond_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) 172 | self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) 173 | 174 | # *4 from heuristic, *2/3 from swiglu, since there are 3 linear matrices not 2. 175 | # so we must keep # params the same. 176 | dim_ff = int(dim*4*(3/4)) 177 | 178 | # define custom encoder 179 | encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, 180 | activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, 181 | batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS) 182 | encoder_layer.linear1 = nn.Identity() 183 | encoder = nn.TransformerEncoder(encoder_layer, enc_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None) 184 | 185 | # define custom decoder 186 | decoder_layer = nn.TransformerDecoderLayer(dim, nhead, dim_ff, 187 | activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, 188 | batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS) 189 | decoder_layer.linear1 = nn.Identity() 190 | decoder = nn.TransformerDecoder(decoder_layer, dec_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None) 191 | 192 | # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder 193 | for l in decoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) 194 | 195 | self.tfm = nn.Transformer(dim, nhead, dim_feedforward=dim_ff, batch_first=True, 196 | norm_first=norm_first, 197 | num_encoder_layers=enc_layers, 198 | num_decoder_layers=dec_layers, 199 | custom_encoder=encoder, 200 | custom_decoder=decoder, 201 | layer_norm_eps=LAYERNORM_EPS, 202 | dropout=dropout 203 | ) 204 | # Timestep embedding network 205 | self.t_emb_dim = t_emb_dim 206 | self.timestep_encoder_emb = nn.Sequential( 207 | nn.Linear(t_emb_dim, dim), 208 | nn.SiLU(), 209 | nn.Linear(dim, dim) 210 | ) 211 | self.timestep_decoder_emb = nn.Sequential( 212 | nn.Linear(t_emb_dim, dim), 213 | nn.SiLU(), 214 | nn.Linear(dim, dim) 215 | ) 216 | 217 | self.text_embed = nn.Embedding(n_text_vocab, dim) 218 | 219 | ## ----> reference / conditioning encoder: 220 | self.ref_embedder = ChunkedEmbedding(n_quant, c_quant_levels, dim) 221 | self.ref_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) 222 | self.spk_identity_emb = nn.Embedding(1, dim) 223 | spk_encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, 224 | activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, 225 | batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS) 226 | spk_encoder_layer.linear1 = nn.Identity() 227 | self.spk_encoder = nn.TransformerEncoder(spk_encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS)) 228 | # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder 229 | for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) 230 | # ----> end speaker encoder network 231 | 232 | # self.residual_encoder = nn.Embedding(n_quant, dim) # only encode first quantization level of decoder input. 233 | self.residual_encoder = ChunkedEmbedding(n_quant, c_quant_levels, dim) 234 | 235 | self.residual_decoder = nn.ModuleList([ 236 | nn.Sequential( 237 | nn.LayerNorm(dim), 238 | nn.Linear(dim, n_quant) 239 | ) for i in range(pred_quant_levels) 240 | ]) 241 | self.n_quantizer = pred_quant_levels 242 | self.p_cond_drop = p_cond_drop 243 | 244 | 245 | @torch.inference_mode 246 | def get_spk_embedding(self, c_codes, c_codes_length) -> Tensor: 247 | """ Obtain speaker embedding vectors using `c_codes` from reference encodec sequences, and `c_codes_length` of lengths for each sequence """ 248 | bs = c_codes.shape[0] 249 | spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim) 250 | spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) 251 | spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) 252 | # add pos encoding 253 | spk_seq = self.ref_pos_embedding(spk_seq) 254 | 255 | # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. 256 | src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1]) 257 | src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device) 258 | 259 | # pass through transformer 260 | res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). 261 | return res.squeeze(1) 262 | 263 | 264 | def forward(self, c_text: Tensor, c_codes: Tensor, c_texts_length: Tensor, c_codes_length: Tensor, 265 | x: Tensor, x_padding_mask: Tensor, t: Tensor, drop_cond=False): 266 | """ Input: 267 | - `c_text`: (bs, seq_len1) the prompt text (BPE encoded) 268 | - `c_codes`: (bs, seq_len2, n_quant) the full tokenized codes of the reference speech 269 | - `c_texts_length`: (bs, ) the length of the codes in the text prompt 270 | - `c_codes_length`: (bs, ) the length of the prompt acoustic token codes in `c_codes`. 271 | - `x`: (bs, seq_len3) L0 residual codes 272 | - `x`: (bs, seq_len3, n_quant) L0 residual codes 273 | - `x_padding_mask`: (bs, seq_len3) masking for residual codes 274 | - `t`: (bs) timestep 275 | - `drop_cond`: bool, whether or not to forcibly drop the conditioning information. 276 | Returns: 277 | - outs: (bs, seq_len, n_quantizer, codebook_size) 278 | """ 279 | 280 | c_text = self.text_embed(c_text) # (bs, seq_len1, dim) 281 | 282 | ## ----> reference / conditioning encoder: 283 | bs = c_codes.shape[0] 284 | 285 | 286 | if self.training: 287 | zero_cond_inds = torch.rand_like(t, dtype=c_text.dtype) < self.p_cond_drop 288 | else: 289 | # never randomly zero when in eval mode 290 | zero_cond_inds = torch.zeros_like(t, dtype=torch.bool) 291 | if drop_cond: 292 | # force drop conditioning 293 | zero_cond_inds = torch.ones_like(t, dtype=torch.bool) 294 | 295 | c_codes_length[zero_cond_inds] = 0 296 | c_codes[zero_cond_inds] = 1024 297 | 298 | spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim) 299 | spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) 300 | spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) 301 | # add pos encoding 302 | spk_seq = self.ref_pos_embedding(spk_seq) 303 | 304 | # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. 305 | src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1]) 306 | src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device) 307 | 308 | # pass through transformer 309 | res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). 310 | c_codes = res # (bs, 1, dim) 311 | c_codes_lengths_extract = torch.ones_like(c_codes_length) # manually override all the code lengths to equal 1, since we only have 1 spk embedding. 312 | ## ----> end reference / conditioning encoder: 313 | 314 | ## ----> timestep embeddings and parsing 315 | t_emb = timestep_embedding(t, self.t_emb_dim, dtype=c_text.dtype) 316 | t_emb_encoder = self.timestep_encoder_emb(t_emb) # (bs, t_dim) 317 | t_emb_decoder = self.timestep_decoder_emb(t_emb) 318 | 319 | ## ----> concatenating text/phone inputs and implicit speaker embedding. 320 | c_phones_unpacked = nn.utils.rnn.unpad_sequence(c_text, c_texts_length.cpu(), batch_first=True) 321 | c_codes_unpacked = nn.utils.rnn.unpad_sequence(c_codes, c_codes_lengths_extract.cpu(), batch_first=True) 322 | # >>> Concat [speaker codes, text codes] 323 | assert all(b.shape[0] == 1 for b in c_codes_unpacked) 324 | c_joined = [torch.cat((b, a), dim=0) for a, b in zip(c_phones_unpacked, c_codes_unpacked)] 325 | 326 | c = nn.utils.rnn.pad_sequence(c_joined, batch_first=True) 327 | c_joined_lengths = torch.tensor([p.shape[0] for p in c_joined], device=c.device, dtype=torch.long) 328 | c_padding_mask = length_to_mask(c_joined_lengths, torch.zeros_like(c_joined_lengths)) 329 | c = self.cond_pos_embedding(c) 330 | 331 | ## Format input: 332 | x = self.residual_encoder(x) # (bs, seq_len3, dim) 333 | 334 | x = self.pos_embedding(x) 335 | 336 | x = x + t_emb_decoder[:, None] 337 | c = c + t_emb_encoder[:, None] 338 | ## Perform prediction: 339 | output = self.tfm(c, x, src_key_padding_mask=c_padding_mask, 340 | tgt_key_padding_mask=x_padding_mask, 341 | memory_key_padding_mask=c_padding_mask) # (bs, seq_len, dim) 342 | outs = torch.stack([self.residual_decoder[i](output) for i in range(self.n_quantizer)], dim=-1) # (bs, seq_len, logit_dim, n_quant) 343 | return outs 344 | 345 | -------------------------------------------------------------------------------- /mars5/nn_future.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | import math 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | 10 | # -------------------------- 11 | # activation functions 12 | 13 | class FNNSwiGLU(nn.Module): 14 | 15 | def __init__(self, dim, dim_ff) -> None: 16 | super().__init__() 17 | 18 | # we will receive in xW 19 | self.V = nn.Linear(dim, dim_ff, bias=False) 20 | self.W = nn.Linear(dim, dim_ff, bias=False) 21 | 22 | 23 | def forward(self, x: Tensor) -> Tensor: 24 | """ Compute SwiGLU output of x, the output of the first linear layer. i.e. 25 | FFNSwiGLU(x, W, V, W2) = (Swish1(xW) ⊗ xV )W2. 26 | NOTE: the transformer linear1 layer must be overwritten to identity. This layer only applies 27 | the Swish(xW) * xV. The W2 multiplication is done in the main transformer layer 28 | """ 29 | return F.silu(self.W(x)) * self.V(x) 30 | 31 | 32 | # --------------------------------- 33 | # padding and position layers 34 | 35 | class SinePositionalEmbedding(nn.Module): 36 | def __init__( 37 | self, 38 | dim_model: int, 39 | dropout: float = 0.0, 40 | scale: bool = False, 41 | alpha: bool = False, 42 | ): 43 | super().__init__() 44 | self.dim_model = dim_model 45 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 46 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 47 | self.dropout = torch.nn.Dropout(p=dropout) 48 | 49 | self.reverse = False 50 | self.pe = None 51 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 52 | 53 | def extend_pe(self, x): 54 | """Reset the positional encodings.""" 55 | if self.pe is not None: 56 | if self.pe.size(1) >= x.size(1): 57 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 58 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 59 | return 60 | pe = torch.zeros(x.size(1), self.dim_model) 61 | if self.reverse: 62 | position = torch.arange( 63 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 64 | ).unsqueeze(1) 65 | else: 66 | position = torch.arange( 67 | 0, x.size(1), dtype=torch.float32 68 | ).unsqueeze(1) 69 | div_term = torch.exp( 70 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 71 | * -(math.log(10000.0) / self.dim_model) 72 | ) 73 | pe[:, 0::2] = torch.sin(position * div_term) 74 | pe[:, 1::2] = torch.cos(position * div_term) 75 | pe = pe.unsqueeze(0) 76 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 77 | 78 | def forward(self, x: torch.Tensor) -> torch.Tensor: 79 | """ Assumes x of shape (bs, seq_len, dim) """ 80 | self.extend_pe(x) 81 | output = x.unsqueeze(-1) if x.ndim == 2 else x 82 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 83 | return self.dropout(output) 84 | 85 | 86 | # -------------------------------- 87 | # kv cache blocks 88 | 89 | class CacheView: 90 | def __init__(self, cache_k: torch.Tensor, cache_v: torch.Tensor): 91 | self.cache_k = cache_k 92 | self.cache_v = cache_v 93 | 94 | @property 95 | def sliding_window(self): 96 | return self.cache_k.shape[1] 97 | 98 | class RotatingBufferCache: 99 | """ 100 | This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences. 101 | Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms) 102 | """ 103 | def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int): 104 | 105 | self.sliding_window = sliding_window 106 | self.n_kv_heads = n_kv_heads 107 | self.head_dim = head_dim 108 | 109 | self.cache_k = torch.empty(( 110 | n_layers, 111 | max_batch_size, 112 | sliding_window, 113 | n_kv_heads, 114 | head_dim 115 | )) 116 | self.cache_v = torch.empty(( 117 | n_layers, 118 | max_batch_size, 119 | sliding_window, 120 | n_kv_heads, 121 | head_dim 122 | )) 123 | 124 | def get_view(self, layer_id: int) -> CacheView: 125 | return CacheView(self.cache_k[layer_id], self.cache_v[layer_id]) 126 | 127 | @property 128 | def device(self): 129 | return self.cache_k.device 130 | 131 | def to(self, device: torch.device, dtype: torch.dtype): 132 | self.cache_k = self.cache_k.to(device=device, dtype=dtype) 133 | self.cache_v = self.cache_v.to(device=device, dtype=dtype) 134 | return self 135 | 136 | 137 | # -------------------------------- 138 | # Mistral transformer blocks 139 | # Code for the follow blocks are adapted from 140 | # https://github.com/mistralai/mistral-src 141 | # Thank you Mistral team! 142 | 143 | @dataclass 144 | class ModelArgs: 145 | vocab_size: int 146 | 147 | dim: int = 1152 # default for mars3 and before: 1024 148 | n_layers: int = 24 149 | head_dim: int = 64 # = dim/n_heads 150 | hidden_dim: int = 3584 151 | n_heads: int = 16 152 | n_kv_heads: int = 16 # default: 8 153 | sliding_window: int = 1792 154 | norm_eps: float = 1e-5 155 | 156 | max_batch_size: int = 256 157 | 158 | 159 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): 160 | if repeats == 1: return keys, values 161 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) 162 | values = torch.repeat_interleave(values, repeats=repeats, dim=2) 163 | return keys, values 164 | 165 | 166 | def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 167 | """ 168 | freqs_cis: complex - (seq_len, head_dim / 2) 169 | x: complex - (bsz, seq_len, head_dim / 2) 170 | """ 171 | ndim = x.ndim 172 | assert 1 < ndim 173 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( 174 | freqs_cis.shape, 175 | (x.shape[1], x.shape[-1]), 176 | ) 177 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 178 | return freqs_cis.view(*shape) 179 | 180 | 181 | def apply_rotary_emb( 182 | xq: torch.Tensor, 183 | xk: torch.Tensor, 184 | freqs_cis: torch.Tensor, 185 | ) -> tuple[torch.Tensor, torch.Tensor]: 186 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 187 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 188 | freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) 189 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 190 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 191 | return xq_out.type_as(xq), xk_out.type_as(xk) 192 | 193 | 194 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: 195 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 196 | t = torch.arange(end, device=freqs.device) # type: ignore 197 | freqs = torch.outer(t, freqs).float() # type: ignore 198 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 199 | 200 | 201 | class Attention(nn.Module): 202 | def __init__(self, args: ModelArgs): 203 | super().__init__() 204 | self.args = args 205 | 206 | self.n_heads: int = args.n_heads 207 | self.n_kv_heads: int = args.n_kv_heads 208 | 209 | self.repeats = self.n_heads // self.n_kv_heads 210 | self.sliding_window = self.args.sliding_window 211 | 212 | self.scale = self.args.head_dim**-0.5 213 | 214 | self.wq = nn.Linear( 215 | args.dim, 216 | args.n_heads * args.head_dim, 217 | bias=False 218 | ) 219 | self.wk = nn.Linear( 220 | args.dim, 221 | args.n_kv_heads * args.head_dim, 222 | bias=False 223 | ) 224 | self.wv = nn.Linear( 225 | args.dim, 226 | args.n_kv_heads * args.head_dim, 227 | bias=False 228 | ) 229 | self.wo = nn.Linear( 230 | args.n_heads * args.head_dim, 231 | args.dim, 232 | bias=False 233 | ) 234 | 235 | def forward( 236 | self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor], cache: Optional[CacheView] 237 | ) -> torch.Tensor: 238 | 239 | bsz, seqlen, _ = x.shape 240 | 241 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 242 | xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim) 243 | xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) 244 | xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) 245 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 246 | 247 | # The cache is a rotating buffer 248 | if cache is not None: 249 | scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None] 250 | scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim) 251 | cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:]) 252 | cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:]) 253 | 254 | if positions.shape[0] > 1: 255 | # prefill 256 | key, value = repeat_kv(xk, xv, self.repeats) 257 | else: 258 | cur_pos = positions[-1].item() + 1 259 | key, value = repeat_kv(cache.cache_k[:bsz, :cur_pos, ...], cache.cache_v[:bsz, :cur_pos, ...], self.repeats) 260 | 261 | # print(f"Internal: {xq.shape}, key: {key.shape}, mask: {mask.shape} | {mask.dtype} | xq: {xq.dtype} | mask: {mask} ") 262 | # if mask is not None: 263 | # mask = mask[None, None, ...].expand(bsz, self.n_heads, -1, -1) 264 | # mask = mask.to(key.dtype) 265 | 266 | query = xq.transpose(1, 2) 267 | key = key.transpose(1, 2) 268 | value = value.transpose(1, 2) 269 | # # scores : [bsz, n_heads, seqlen | 1, seqlen] 270 | # scores = torch.matmul(query, key.transpose(2, 3)) * self.scale 271 | 272 | output = F.scaled_dot_product_attention(query, key, value, mask) # (bs, n_local_heads, slen, head_dim) 273 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 274 | return self.wo(output) 275 | 276 | 277 | class FeedForward(nn.Module): 278 | def __init__(self, args: ModelArgs): 279 | super().__init__() 280 | 281 | self.w1 = nn.Linear( 282 | args.dim, 283 | args.hidden_dim, 284 | bias=False 285 | ) 286 | self.w2 = nn.Linear( 287 | args.hidden_dim, 288 | args.dim, 289 | bias=False 290 | ) 291 | self.w3 = nn.Linear( 292 | args.dim, 293 | args.hidden_dim, 294 | bias=False 295 | ) 296 | 297 | def forward(self, x) -> torch.Tensor: 298 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) 299 | 300 | 301 | class RMSNorm(torch.nn.Module): 302 | def __init__(self, dim: int, eps: float = 1e-6): 303 | super().__init__() 304 | self.eps = eps 305 | self.weight = nn.Parameter(torch.ones(dim)) 306 | 307 | def _norm(self, x): 308 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 309 | 310 | def forward(self, x): 311 | output = self._norm(x.float()).type_as(x) 312 | return output * self.weight 313 | 314 | 315 | class TransformerBlock(nn.Module): 316 | def __init__(self, args: ModelArgs): 317 | super().__init__() 318 | self.n_heads = args.n_heads 319 | self.dim = args.dim 320 | self.attention = Attention(args) 321 | self.feed_forward = FeedForward(args=args) 322 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 323 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 324 | self.args = args 325 | 326 | def forward( 327 | self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor], cache: Optional[CacheView] 328 | ) -> torch.Tensor: 329 | r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask, cache) 330 | h = x + r 331 | r = self.feed_forward.forward(self.ffn_norm(h)) 332 | out = h + r 333 | return out 334 | 335 | 336 | class MistralTransformer(nn.Module): 337 | def __init__(self, args: ModelArgs): 338 | super().__init__() 339 | self.args = args 340 | self.vocab_size = args.vocab_size 341 | self.n_layers = args.n_layers 342 | assert self.vocab_size > 0 343 | 344 | # self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 345 | 346 | self.layers = torch.nn.ModuleList( 347 | [TransformerBlock(args=args) for _ in range(args.n_layers)] 348 | ) 349 | 350 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 351 | 352 | self.output = nn.Linear( 353 | args.dim, 354 | args.vocab_size, 355 | bias=False 356 | ) 357 | 358 | # self.freqs_cis 359 | self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000) 360 | 361 | @property 362 | def dtype(self) -> torch.dtype: 363 | return self.tok_embeddings.weight.dtype 364 | 365 | @property 366 | def device(self) -> torch.device: 367 | return self.tok_embeddings.weight.device 368 | 369 | def forward( 370 | self, 371 | input_ids: torch.Tensor, 372 | positions: torch.Tensor, 373 | cache: Optional[RotatingBufferCache] 374 | ): 375 | h = input_ids 376 | if self.freqs_cis.device != h.device: 377 | self.freqs_cis = self.freqs_cis.to(h.device) 378 | freqs_cis = self.freqs_cis[positions] 379 | 380 | mask: Optional[torch.Tensor] = None 381 | if input_ids.shape[1] > 1: 382 | seqlen = input_ids.shape[1] 383 | tensor = torch.full( 384 | (seqlen, seqlen), 385 | dtype=h.dtype, 386 | fill_value=1, 387 | device=h.device, 388 | ) 389 | mask = torch.tril(tensor, diagonal=0).to(h.dtype) 390 | # make the mask banded to account for sliding window 391 | mask = torch.triu(mask, diagonal=-self.args.sliding_window) 392 | mask = torch.log(mask) 393 | 394 | for layer_id, layer in enumerate(self.layers): 395 | cache_view = None if cache is None else cache.get_view(layer_id) 396 | h = layer(h, freqs_cis, positions, mask, cache_view) 397 | 398 | return self.output(self.norm(h)) 399 | 400 | 401 | -------------------------------------------------------------------------------- /mars5/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for modifying categorical distributions to improve quality of sampling. 3 | 4 | Adapted from: 5 | - https://github.com/e-c-k-e-r/vall-e/blob/master/vall_e/samplers.py 6 | - Mirosoft UniLM 7 | - Matthew Baas's typical sampling code. 8 | - https://github.com/LostRuins/koboldcpp 9 | """ 10 | 11 | import math 12 | import torch 13 | import torch.nn.functional as F 14 | import numpy as np 15 | import logging 16 | 17 | from torch import Tensor, nn 18 | 19 | 20 | def freq_rep_penalty(logits: Tensor, previous: Tensor, alpha_frequency: float, alpha_presence: float, penalty_window: int = 100) -> Tensor: 21 | """ Apply frequency and presence penalty according to openai's formuation. 22 | Concretely: given `logits` (bs, vocab_size) and `previous` (bs, seq_len,) 23 | 24 | Modified to support batched inference. 25 | 26 | See: https://platform.openai.com/docs/guides/text-generation/parameter-details 27 | """ 28 | bs = logits.shape[0] 29 | previous = previous[..., -penalty_window:] 30 | c = torch.zeros_like(logits, device=logits.device, dtype=torch.long) # (1, vocab_size) 31 | for i in range(bs): 32 | vals, cnts = previous[i].unique(return_counts=True) 33 | c[i, vals] = cnts.to(c.device) 34 | 35 | logits = logits - c * alpha_frequency - (c > 0).to(logits.dtype) * alpha_presence 36 | return logits 37 | 38 | 39 | def early_eos_penalty(logits: Tensor, n_generated: int, estimated_gen_length: int, decay: float, factor: float = 1, eos_index: int = 0) -> Tensor: 40 | """ Penalize the `eos_index` of `logits` (bs, vocab_size) up to `estimated_gen_length`, 41 | whereby we reduce the logit value by `factor`*(expected_length - current_length)^decay, 42 | `n_generated` is the current number of generated samples. `decay` anneals the penalty relative to the distance. 43 | 44 | Good values for decay are between 0 and 1. 0 = hard always apply penalty of 1, 1 = linearly scale penalty relative to distance. 45 | Setting factor = 0 disabled penatly. Increasing factor increases penalty. 46 | """ 47 | if n_generated > estimated_gen_length: return logits 48 | penalty = max(estimated_gen_length - n_generated, 1) 49 | 50 | bigger = logits[:, eos_index] > 0 51 | 52 | modifier = factor*(penalty ** decay) 53 | # logits[bigger, eos_index] /= modifier 54 | # logits[~bigger, eos_index] *= modifier 55 | logits[:, eos_index] -= modifier 56 | return logits 57 | 58 | 59 | # Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / 60 | # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 61 | def top_k_top_p_filtering( logits: Tensor, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ) -> Tensor: 62 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 63 | Args: 64 | logits: logits distribution shape (batch size, vocabulary size) 65 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 66 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 67 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 68 | Make sure we keep at least min_tokens per batch example in the output 69 | """ 70 | if top_k > 0: 71 | top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check 72 | # Remove all tokens with a probability less than the last token of the top-k 73 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 74 | logits[indices_to_remove] = filter_value 75 | 76 | if top_p < 1.0: 77 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 78 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 79 | 80 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 81 | sorted_indices_to_remove = cumulative_probs > top_p 82 | if min_tokens > 1: 83 | # Keep at least min_tokens (set to min_tokens-1 because we add the first one below) 84 | sorted_indices_to_remove[..., :min_tokens] = 0 85 | # Shift the indices to the right to keep also the first token above the threshold 86 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 87 | sorted_indices_to_remove[..., 0] = 0 88 | 89 | # scatter sorted tensors to original indexing 90 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 91 | logits[indices_to_remove] = filter_value 92 | 93 | return logits 94 | 95 | 96 | def apply_typical_p(logprobs: Tensor, mass: float) -> Tensor: 97 | """ Warp categorical logprobs associated with `x` to be in line with `mass`. Last dimension is the bin dimension. 98 | `mass` corresponds to `tau` in the paper. 99 | """ 100 | if mass > 0.999: return logprobs 101 | # see: https://arxiv.org/abs/2202.00666 102 | # calculate entropy 103 | # normalized = logprobs #torch.nn.functional.log_softmax(scores, dim=-1) 104 | normalized = torch.nn.functional.log_softmax(logprobs, dim=-1) 105 | p = torch.exp(normalized) 106 | ent = -(normalized * p).nansum(-1, keepdim=True) 107 | 108 | # shift and sort 109 | shifted_scores = torch.abs((-normalized) - ent) 110 | sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) 111 | sorted_logits = logprobs.gather(-1, sorted_indices) 112 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 113 | 114 | # Remove tokens with cumulative mass above the threshold 115 | last_ind = (cumulative_probs < mass).sum(dim=1) 116 | last_ind[last_ind < 0] = 0 117 | sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) 118 | 119 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 120 | 121 | scores = logprobs.masked_fill(indices_to_remove, -float('Inf')) 122 | return scores -------------------------------------------------------------------------------- /mars5/trim.py: -------------------------------------------------------------------------------- 1 | """ Custom port of librosa trim code, to remove numba dependency. 2 | This allows us to use librosa.trim effect without the librosa or numba dependancy. 3 | 4 | All code below adapted from librosa open source github: 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import warnings 11 | 12 | 13 | def amplitude_to_db(S, ref=1.0, amin=1e-5, top_db=80.0): 14 | """Convert an amplitude spectrogram to dB-scaled spectrogram. 15 | 16 | This is equivalent to ``power_to_db(S**2)``, but is provided for convenience. 17 | 18 | Parameters 19 | ---------- 20 | S : np.ndarray 21 | input amplitude 22 | 23 | ref : scalar or callable 24 | If scalar, the amplitude ``abs(S)`` is scaled relative to ``ref``: 25 | ``20 * log10(S / ref)``. 26 | Zeros in the output correspond to positions where ``S == ref``. 27 | 28 | If callable, the reference value is computed as ``ref(S)``. 29 | 30 | amin : float > 0 [scalar] 31 | minimum threshold for ``S`` and ``ref`` 32 | 33 | top_db : float >= 0 [scalar] 34 | threshold the output at ``top_db`` below the peak: 35 | ``max(20 * log10(S)) - top_db`` 36 | 37 | 38 | Returns 39 | ------- 40 | S_db : np.ndarray 41 | ``S`` measured in dB 42 | 43 | See Also 44 | -------- 45 | power_to_db, db_to_amplitude 46 | 47 | Notes 48 | ----- 49 | This function caches at level 30. 50 | """ 51 | 52 | # S = np.asarray(S) 53 | S = torch.asarray(S) 54 | 55 | 56 | magnitude = S.abs() 57 | 58 | if callable(ref): 59 | # User supplied a function to calculate reference power 60 | ref_value = ref(magnitude) 61 | else: 62 | ref_value = torch.abs(ref) 63 | 64 | power = torch.square(magnitude, out=magnitude) 65 | 66 | return power_to_db(power, ref=ref_value ** 2, amin=amin ** 2, top_db=top_db) 67 | 68 | 69 | def _signal_to_frame_nonsilent( 70 | y, frame_length=2048, hop_length=512, top_db=60, ref=torch.max 71 | ): 72 | """Frame-wise non-silent indicator for audio input. 73 | 74 | This is a helper function for `trim` and `split`. 75 | 76 | Parameters 77 | ---------- 78 | y : np.ndarray, shape=(n,) or (2,n) 79 | Audio signal, mono or stereo 80 | 81 | frame_length : int > 0 82 | The number of samples per frame 83 | 84 | hop_length : int > 0 85 | The number of samples between frames 86 | 87 | top_db : number > 0 88 | The threshold (in decibels) below reference to consider as 89 | silence 90 | 91 | ref : callable or float 92 | The reference power 93 | 94 | Returns 95 | ------- 96 | non_silent : np.ndarray, shape=(m,), dtype=bool 97 | Indicator of non-silent frames 98 | """ 99 | # Convert to mono 100 | if y.ndim > 1: 101 | y_mono = torch.mean(y, dim=0) 102 | else: y_mono = y 103 | 104 | # Compute the MSE for the signal 105 | mse = rms(y=y_mono, frame_length=frame_length, hop_length=hop_length) ** 2 106 | 107 | return power_to_db(mse.squeeze(), ref=ref, top_db=None) > -top_db 108 | 109 | 110 | def trim(y, top_db=60, ref=torch.max, frame_length=2048, hop_length=512): 111 | """Trim leading and trailing silence from an audio signal. 112 | 113 | Parameters 114 | ---------- 115 | y : np.ndarray, shape=(n,) or (2,n) 116 | Audio signal, can be mono or stereo 117 | 118 | top_db : number > 0 119 | The threshold (in decibels) below reference to consider as 120 | silence 121 | 122 | ref : number or callable 123 | The reference power. By default, it uses `np.max` and compares 124 | to the peak power in the signal. 125 | 126 | frame_length : int > 0 127 | The number of samples per analysis frame 128 | 129 | hop_length : int > 0 130 | The number of samples between analysis frames 131 | 132 | Returns 133 | ------- 134 | y_trimmed : np.ndarray, shape=(m,) or (2, m) 135 | The trimmed signal 136 | 137 | index : np.ndarray, shape=(2,) 138 | the interval of ``y`` corresponding to the non-silent region: 139 | ``y_trimmed = y[index[0]:index[1]]`` (for mono) or 140 | ``y_trimmed = y[:, index[0]:index[1]]`` (for stereo). 141 | 142 | 143 | Examples 144 | -------- 145 | >>> # Load some audio 146 | >>> y, sr = librosa.load(librosa.ex('choice')) 147 | >>> # Trim the beginning and ending silence 148 | >>> yt, index = librosa.effects.trim(y) 149 | >>> # Print the durations 150 | >>> print(librosa.get_duration(y), librosa.get_duration(yt)) 151 | 25.025986394557822 25.007891156462584 152 | """ 153 | 154 | non_silent = _signal_to_frame_nonsilent( 155 | y, frame_length=frame_length, hop_length=hop_length, ref=ref, top_db=top_db 156 | ) 157 | 158 | # nonzero = np.flatnonzero(non_silent) 159 | nonzero = torch.nonzero(torch.ravel(non_silent)).squeeze()#[0] 160 | 161 | if nonzero.numel() > 0: 162 | # Compute the start and end positions 163 | # End position goes one frame past the last non-zero 164 | start = int(frames_to_samples(nonzero[0], hop_length)) 165 | end = min(y.shape[-1], int(frames_to_samples(nonzero[-1] + 1, hop_length))) 166 | else: 167 | # The signal only contains zeros 168 | start, end = 0, 0 169 | 170 | # Build the mono/stereo index 171 | full_index = [slice(None)] * y.ndim 172 | full_index[-1] = slice(start, end) 173 | 174 | # print(non_silent) 175 | # print(non_silent.shape, nonzero.shape) 176 | 177 | return y[tuple(full_index)], torch.asarray([start, end]) 178 | 179 | 180 | def rms( 181 | y=None, S=None, frame_length=2048, hop_length=512, center=True, pad_mode="reflect" 182 | ): 183 | """Compute root-mean-square (RMS) value for each frame, either from the 184 | audio samples ``y`` or from a spectrogram ``S``. 185 | 186 | Computing the RMS value from audio samples is faster as it doesn't require 187 | a STFT calculation. However, using a spectrogram will give a more accurate 188 | representation of energy over time because its frames can be windowed, 189 | thus prefer using ``S`` if it's already available. 190 | 191 | 192 | Parameters 193 | ---------- 194 | y : np.ndarray [shape=(n,)] or None 195 | (optional) audio time series. Required if ``S`` is not input. 196 | 197 | S : np.ndarray [shape=(d, t)] or None 198 | (optional) spectrogram magnitude. Required if ``y`` is not input. 199 | 200 | frame_length : int > 0 [scalar] 201 | length of analysis frame (in samples) for energy calculation 202 | 203 | hop_length : int > 0 [scalar] 204 | hop length for STFT. See `librosa.stft` for details. 205 | 206 | center : bool 207 | If `True` and operating on time-domain input (``y``), pad the signal 208 | by ``frame_length//2`` on either side. 209 | 210 | If operating on spectrogram input, this has no effect. 211 | 212 | pad_mode : str 213 | Padding mode for centered analysis. See `numpy.pad` for valid 214 | values. 215 | 216 | Returns 217 | ------- 218 | rms : np.ndarray [shape=(1, t)] 219 | RMS value for each frame 220 | 221 | 222 | Examples 223 | -------- 224 | >>> y, sr = librosa.load(librosa.ex('trumpet')) 225 | >>> librosa.feature.rms(y=y) 226 | array([[1.248e-01, 1.259e-01, ..., 1.845e-05, 1.796e-05]], 227 | dtype=float32) 228 | 229 | Or from spectrogram input 230 | 231 | >>> S, phase = librosa.magphase(librosa.stft(y)) 232 | >>> rms = librosa.feature.rms(S=S) 233 | 234 | >>> import matplotlib.pyplot as plt 235 | >>> fig, ax = plt.subplots(nrows=2, sharex=True) 236 | >>> times = librosa.times_like(rms) 237 | >>> ax[0].semilogy(times, rms[0], label='RMS Energy') 238 | >>> ax[0].set(xticks=[]) 239 | >>> ax[0].legend() 240 | >>> ax[0].label_outer() 241 | >>> librosa.display.specshow(librosa.amplitude_to_db(S, ref=np.max), 242 | ... y_axis='log', x_axis='time', ax=ax[1]) 243 | >>> ax[1].set(title='log Power spectrogram') 244 | 245 | Use a STFT window of constant ones and no frame centering to get consistent 246 | results with the RMS computed from the audio samples ``y`` 247 | 248 | >>> S = librosa.magphase(librosa.stft(y, window=np.ones, center=False))[0] 249 | >>> librosa.feature.rms(S=S) 250 | >>> plt.show() 251 | 252 | """ 253 | if y is not None: 254 | if y.dim() > 1: 255 | y = torch.mean(y, dim=0) 256 | 257 | if center: 258 | y = F.pad(y[None, None], (int(frame_length//2), int(frame_length//2)), mode=pad_mode)[0, 0] 259 | # y = np.pad(y, int(frame_length // 2), mode=pad_mode) 260 | 261 | x = frame(y, frame_length=frame_length, hop_length=hop_length) 262 | # print(y.shape, x.shape, x) 263 | # Calculate power 264 | power = torch.mean(x.abs() ** 2, dim=0, keepdim=True) 265 | elif S is not None: 266 | # Check the frame length 267 | if S.shape[0] != frame_length // 2 + 1: 268 | raise AssertionError( 269 | "Since S.shape[0] is {}, " 270 | "frame_length is expected to be {} or {}; " 271 | "found {}".format( 272 | S.shape[0], S.shape[0] * 2 - 2, S.shape[0] * 2 - 1, frame_length 273 | ) 274 | ) 275 | 276 | # power spectrogram 277 | x = torch.abs(S) ** 2 278 | 279 | # Adjust the DC and sr/2 component 280 | x[0] *= 0.5 281 | if frame_length % 2 == 0: 282 | x[-1] *= 0.5 283 | 284 | # Calculate power 285 | power = 2 * torch.sum(x, dim=0, keepdim=True) / frame_length ** 2 286 | else: 287 | raise AssertionError("Either `y` or `S` must be input.") 288 | 289 | return torch.sqrt(power) 290 | 291 | 292 | def frame(x, frame_length, hop_length, axis=-1): 293 | """Slice a data array into (overlapping) frames. 294 | 295 | This implementation uses low-level stride manipulation to avoid 296 | making a copy of the data. The resulting frame representation 297 | is a new view of the same input data. 298 | 299 | However, if the input data is not contiguous in memory, a warning 300 | will be issued and the output will be a full copy, rather than 301 | a view of the input data. 302 | 303 | For example, a one-dimensional input ``x = [0, 1, 2, 3, 4, 5, 6]`` 304 | can be framed with frame length 3 and hop length 2 in two ways. 305 | The first (``axis=-1``), results in the array ``x_frames``:: 306 | 307 | [[0, 2, 4], 308 | [1, 3, 5], 309 | [2, 4, 6]] 310 | 311 | where each column ``x_frames[:, i]`` contains a contiguous slice of 312 | the input ``x[i * hop_length : i * hop_length + frame_length]``. 313 | 314 | The second way (``axis=0``) results in the array ``x_frames``:: 315 | 316 | [[0, 1, 2], 317 | [2, 3, 4], 318 | [4, 5, 6]] 319 | 320 | where each row ``x_frames[i]`` contains a contiguous slice of the input. 321 | 322 | This generalizes to higher dimensional inputs, as shown in the examples below. 323 | In general, the framing operation increments by 1 the number of dimensions, 324 | adding a new "frame axis" either to the end of the array (``axis=-1``) 325 | or the beginning of the array (``axis=0``). 326 | 327 | 328 | Parameters 329 | ---------- 330 | x : np.ndarray 331 | Array to frame 332 | 333 | frame_length : int > 0 [scalar] 334 | Length of the frame 335 | 336 | hop_length : int > 0 [scalar] 337 | Number of steps to advance between frames 338 | 339 | axis : 0 or -1 340 | The axis along which to frame. 341 | 342 | If ``axis=-1`` (the default), then ``x`` is framed along its last dimension. 343 | ``x`` must be "F-contiguous" in this case. 344 | 345 | If ``axis=0``, then ``x`` is framed along its first dimension. 346 | ``x`` must be "C-contiguous" in this case. 347 | 348 | Returns 349 | ------- 350 | x_frames : np.ndarray [shape=(..., frame_length, N_FRAMES) or (N_FRAMES, frame_length, ...)] 351 | A framed view of ``x``, for example with ``axis=-1`` (framing on the last dimension):: 352 | 353 | x_frames[..., j] == x[..., j * hop_length : j * hop_length + frame_length] 354 | 355 | If ``axis=0`` (framing on the first dimension), then:: 356 | 357 | x_frames[j] = x[j * hop_length : j * hop_length + frame_length] 358 | 359 | Raises 360 | ------ 361 | ParameterError 362 | If ``x`` is not an `np.ndarray`. 363 | 364 | If ``x.shape[axis] < frame_length``, there is not enough data to fill one frame. 365 | 366 | If ``hop_length < 1``, frames cannot advance. 367 | 368 | If ``axis`` is not 0 or -1. Framing is only supported along the first or last axis. 369 | 370 | 371 | See Also 372 | -------- 373 | numpy.asfortranarray : Convert data to F-contiguous representation 374 | numpy.ascontiguousarray : Convert data to C-contiguous representation 375 | numpy.ndarray.flags : information about the memory layout of a numpy `ndarray`. 376 | 377 | Examples 378 | -------- 379 | Extract 2048-sample frames from monophonic signal with a hop of 64 samples per frame 380 | 381 | >>> y, sr = librosa.load(librosa.ex('trumpet')) 382 | >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64) 383 | >>> frames 384 | array([[-1.407e-03, -2.604e-02, ..., -1.795e-05, -8.108e-06], 385 | [-4.461e-04, -3.721e-02, ..., -1.573e-05, -1.652e-05], 386 | ..., 387 | [ 7.960e-02, -2.335e-01, ..., -6.815e-06, 1.266e-05], 388 | [ 9.568e-02, -1.252e-01, ..., 7.397e-06, -1.921e-05]], 389 | dtype=float32) 390 | >>> y.shape 391 | (117601,) 392 | 393 | >>> frames.shape 394 | (2048, 1806) 395 | 396 | Or frame along the first axis instead of the last: 397 | 398 | >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64, axis=0) 399 | >>> frames.shape 400 | (1806, 2048) 401 | 402 | Frame a stereo signal: 403 | 404 | >>> y, sr = librosa.load(librosa.ex('trumpet', hq=True), mono=False) 405 | >>> y.shape 406 | (2, 117601) 407 | >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64) 408 | (2, 2048, 1806) 409 | 410 | Carve an STFT into fixed-length patches of 32 frames with 50% overlap 411 | 412 | >>> y, sr = librosa.load(librosa.ex('trumpet')) 413 | >>> S = np.abs(librosa.stft(y)) 414 | >>> S.shape 415 | (1025, 230) 416 | >>> S_patch = librosa.util.frame(S, frame_length=32, hop_length=16) 417 | >>> S_patch.shape 418 | (1025, 32, 13) 419 | >>> # The first patch contains the first 32 frames of S 420 | >>> np.allclose(S_patch[:, :, 0], S[:, :32]) 421 | True 422 | >>> # The second patch contains frames 16 to 16+32=48, and so on 423 | >>> np.allclose(S_patch[:, :, 1], S[:, 16:48]) 424 | True 425 | """ 426 | 427 | # if not isinstance(x, np.ndarray): 428 | # raise AssertionError( 429 | # "Input must be of type numpy.ndarray, " "given type(x)={}".format(type(x)) 430 | # ) 431 | x: torch.Tensor = x 432 | 433 | if x.shape[axis] < frame_length: 434 | raise AssertionError( 435 | "Input is too short (n={:d})" 436 | " for frame_length={:d}".format(x.shape[axis], frame_length) 437 | ) 438 | 439 | if hop_length < 1: 440 | raise AssertionError("Invalid hop_length: {:d}".format(hop_length)) 441 | 442 | if axis == -1 and not x.is_contiguous(): 443 | warnings.warn( 444 | "librosa.util.frame called with axis={} " 445 | "on a non-contiguous input. This will result in a copy.".format(axis) 446 | ) 447 | x = x.contiguous() 448 | elif axis == 0 and not x.is_contiguous(): 449 | warnings.warn( 450 | "librosa.util.frame called with axis={} " 451 | "on a non-contiguous input. This will result in a copy.".format(axis) 452 | ) 453 | x = x.contiguous() 454 | 455 | n_frames = 1 + (x.shape[axis] - frame_length) // hop_length 456 | strides = torch.asarray(x.numpy().strides) 457 | # print(strides, x) 458 | new_stride = torch.prod(strides[strides > 0] // x.itemsize) * x.itemsize 459 | 460 | if axis == -1: 461 | shape = list(x.shape)[:-1] + [frame_length, n_frames] 462 | strides = list(strides) + [hop_length * new_stride] 463 | 464 | elif axis == 0: 465 | shape = [n_frames, frame_length] + list(x.shape)[1:] 466 | strides = [hop_length * new_stride] + list(strides) 467 | 468 | else: 469 | raise AssertionError("Frame axis={} must be either 0 or -1".format(axis)) 470 | 471 | return torch.from_numpy(as_strided(x, shape=shape, strides=strides)) 472 | # return x.as_strided(size=shape, stride=strides) 473 | 474 | 475 | 476 | class DummyArray: 477 | """Dummy object that just exists to hang __array_interface__ dictionaries 478 | and possibly keep alive a reference to a base array. 479 | """ 480 | 481 | def __init__(self, interface, base=None): 482 | self.__array_interface__ = interface 483 | self.base = base 484 | 485 | 486 | 487 | def as_strided(x, shape=None, strides=None, subok=False, writeable=True): 488 | """ 489 | Create a view into the array with the given shape and strides. 490 | 491 | .. warning:: This function has to be used with extreme care, see notes. 492 | 493 | Parameters 494 | ---------- 495 | x : ndarray 496 | Array to create a new. 497 | shape : sequence of int, optional 498 | The shape of the new array. Defaults to ``x.shape``. 499 | strides : sequence of int, optional 500 | The strides of the new array. Defaults to ``x.strides``. 501 | subok : bool, optional 502 | .. versionadded:: 1.10 503 | 504 | If True, subclasses are preserved. 505 | writeable : bool, optional 506 | .. versionadded:: 1.12 507 | 508 | If set to False, the returned array will always be readonly. 509 | Otherwise it will be writable if the original array was. It 510 | is advisable to set this to False if possible (see Notes). 511 | 512 | Returns 513 | ------- 514 | view : ndarray 515 | 516 | See also 517 | -------- 518 | broadcast_to : broadcast an array to a given shape. 519 | reshape : reshape an array. 520 | lib.stride_tricks.sliding_window_view : 521 | userfriendly and safe function for the creation of sliding window views. 522 | 523 | Notes 524 | ----- 525 | ``as_strided`` creates a view into the array given the exact strides 526 | and shape. This means it manipulates the internal data structure of 527 | ndarray and, if done incorrectly, the array elements can point to 528 | invalid memory and can corrupt results or crash your program. 529 | It is advisable to always use the original ``x.strides`` when 530 | calculating new strides to avoid reliance on a contiguous memory 531 | layout. 532 | 533 | Furthermore, arrays created with this function often contain self 534 | overlapping memory, so that two elements are identical. 535 | Vectorized write operations on such arrays will typically be 536 | unpredictable. They may even give different results for small, large, 537 | or transposed arrays. 538 | Since writing to these arrays has to be tested and done with great 539 | care, you may want to use ``writeable=False`` to avoid accidental write 540 | operations. 541 | 542 | For these reasons it is advisable to avoid ``as_strided`` when 543 | possible. 544 | """ 545 | # first convert input to array, possibly keeping subclass 546 | x = np.array(x, copy=False, subok=subok) 547 | interface = dict(x.__array_interface__) 548 | if shape is not None: 549 | interface['shape'] = tuple(shape) 550 | if strides is not None: 551 | interface['strides'] = tuple(strides) 552 | 553 | array = np.asarray(DummyArray(interface, base=x)) 554 | # The route via `__interface__` does not preserve structured 555 | # dtypes. Since dtype should remain unchanged, we set it explicitly. 556 | array.dtype = x.dtype 557 | 558 | view = _maybe_view_as_subclass(x, array) 559 | 560 | if view.flags.writeable and not writeable: 561 | view.flags.writeable = False 562 | 563 | return view 564 | 565 | 566 | def _maybe_view_as_subclass(original_array, new_array): 567 | if type(original_array) is not type(new_array): 568 | # if input was an ndarray subclass and subclasses were OK, 569 | # then view the result as that subclass. 570 | new_array = new_array.view(type=type(original_array)) 571 | # Since we have done something akin to a view from original_array, we 572 | # should let the subclass finalize (if it has it implemented, i.e., is 573 | # not None). 574 | if new_array.__array_finalize__: 575 | new_array.__array_finalize__(original_array) 576 | return new_array 577 | 578 | 579 | def power_to_db(S, ref=1.0, amin=1e-10, top_db=80.0): 580 | """Convert a power spectrogram (amplitude squared) to decibel (dB) units 581 | 582 | This computes the scaling ``10 * log10(S / ref)`` in a numerically 583 | stable way. 584 | 585 | Parameters 586 | ---------- 587 | S : np.ndarray 588 | input power 589 | 590 | ref : scalar or callable 591 | If scalar, the amplitude ``abs(S)`` is scaled relative to ``ref``:: 592 | 593 | 10 * log10(S / ref) 594 | 595 | Zeros in the output correspond to positions where ``S == ref``. 596 | 597 | If callable, the reference value is computed as ``ref(S)``. 598 | 599 | amin : float > 0 [scalar] 600 | minimum threshold for ``abs(S)`` and ``ref`` 601 | 602 | top_db : float >= 0 [scalar] 603 | threshold the output at ``top_db`` below the peak: 604 | ``max(10 * log10(S)) - top_db`` 605 | 606 | Returns 607 | ------- 608 | S_db : np.ndarray 609 | ``S_db ~= 10 * log10(S) - 10 * log10(ref)`` 610 | 611 | See Also 612 | -------- 613 | perceptual_weighting 614 | db_to_power 615 | amplitude_to_db 616 | db_to_amplitude 617 | 618 | Notes 619 | ----- 620 | This function caches at level 30. 621 | 622 | 623 | Examples 624 | -------- 625 | Get a power spectrogram from a waveform ``y`` 626 | 627 | >>> y, sr = librosa.load(librosa.ex('trumpet')) 628 | >>> S = np.abs(librosa.stft(y)) 629 | >>> librosa.power_to_db(S**2) 630 | array([[-41.809, -41.809, ..., -41.809, -41.809], 631 | [-41.809, -41.809, ..., -41.809, -41.809], 632 | ..., 633 | [-41.809, -41.809, ..., -41.809, -41.809], 634 | [-41.809, -41.809, ..., -41.809, -41.809]], dtype=float32) 635 | 636 | Compute dB relative to peak power 637 | 638 | >>> librosa.power_to_db(S**2, ref=np.max) 639 | array([[-80., -80., ..., -80., -80.], 640 | [-80., -80., ..., -80., -80.], 641 | ..., 642 | [-80., -80., ..., -80., -80.], 643 | [-80., -80., ..., -80., -80.]], dtype=float32) 644 | 645 | Or compare to median power 646 | 647 | >>> librosa.power_to_db(S**2, ref=np.median) 648 | array([[16.578, 16.578, ..., 16.578, 16.578], 649 | [16.578, 16.578, ..., 16.578, 16.578], 650 | ..., 651 | [16.578, 16.578, ..., 16.578, 16.578], 652 | [16.578, 16.578, ..., 16.578, 16.578]], dtype=float32) 653 | 654 | 655 | And plot the results 656 | 657 | >>> import matplotlib.pyplot as plt 658 | >>> fig, ax = plt.subplots(nrows=2, sharex=True, sharey=True) 659 | >>> imgpow = librosa.display.specshow(S**2, sr=sr, y_axis='log', x_axis='time', 660 | ... ax=ax[0]) 661 | >>> ax[0].set(title='Power spectrogram') 662 | >>> ax[0].label_outer() 663 | >>> imgdb = librosa.display.specshow(librosa.power_to_db(S**2, ref=np.max), 664 | ... sr=sr, y_axis='log', x_axis='time', ax=ax[1]) 665 | >>> ax[1].set(title='Log-Power spectrogram') 666 | >>> fig.colorbar(imgpow, ax=ax[0]) 667 | >>> fig.colorbar(imgdb, ax=ax[1], format="%+2.0f dB") 668 | """ 669 | 670 | S = torch.asarray(S) 671 | 672 | if amin <= 0: 673 | raise AssertionError("amin must be strictly positive") 674 | 675 | # if np.issubdtype(S.dtype, np.complexfloating): 676 | # warnings.warn( 677 | # "power_to_db was called on complex input so phase " 678 | # "information will be discarded. To suppress this warning, " 679 | # "call power_to_db(np.abs(D)**2) instead." 680 | # ) 681 | # magnitude = np.abs(S) 682 | # else: 683 | magnitude = S 684 | 685 | if callable(ref): 686 | # User supplied a function to calculate reference power 687 | ref_value = ref(magnitude) 688 | else: 689 | ref_value = torch.abs(ref) 690 | 691 | log_spec = 10.0 * torch.log10(torch.maximum(torch.tensor(amin), magnitude)) 692 | log_spec -= 10.0 * torch.log10(torch.maximum(torch.tensor(amin), ref_value)) 693 | 694 | if top_db is not None: 695 | if top_db < 0: 696 | raise AssertionError("top_db must be non-negative") 697 | log_spec = torch.maximum(log_spec, log_spec.max() - top_db) 698 | 699 | return log_spec 700 | 701 | 702 | def frames_to_samples(frames, hop_length=512, n_fft=None): 703 | """Converts frame indices to audio sample indices. 704 | 705 | Parameters 706 | ---------- 707 | frames : number or np.ndarray [shape=(n,)] 708 | frame index or vector of frame indices 709 | 710 | hop_length : int > 0 [scalar] 711 | number of samples between successive frames 712 | 713 | n_fft : None or int > 0 [scalar] 714 | Optional: length of the FFT window. 715 | If given, time conversion will include an offset of ``n_fft // 2`` 716 | to counteract windowing effects when using a non-centered STFT. 717 | 718 | Returns 719 | ------- 720 | times : number or np.ndarray 721 | time (in samples) of each given frame number:: 722 | 723 | times[i] = frames[i] * hop_length 724 | 725 | See Also 726 | -------- 727 | frames_to_time : convert frame indices to time values 728 | samples_to_frames : convert sample indices to frame indices 729 | 730 | Examples 731 | -------- 732 | >>> y, sr = librosa.load(librosa.ex('choice')) 733 | >>> tempo, beats = librosa.beat.beat_track(y, sr=sr) 734 | >>> beat_samples = librosa.frames_to_samples(beats) 735 | """ 736 | 737 | offset = 0 738 | if n_fft is not None: 739 | offset = int(n_fft // 2) 740 | 741 | return (torch.asarray(frames) * hop_length + offset).to(torch.int) -------------------------------------------------------------------------------- /mars5/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | def length_to_mask(length, offsets, max_len=None): 5 | """ 6 | Convert tensor of lengths into a mask. 7 | 8 | Args: 9 | length (Tensor): a tensor of lengths, shape = (batch_size,) 10 | offsets (Tensor): a tensor of offsets, shape = (batch_size,) 11 | max_len (int, optional): maximum length to be considered 12 | 13 | Returns: 14 | mask (Tensor): a mask tensor, shape = (batch_size, max_len), 15 | True in masked positions, False otherwise. 16 | """ 17 | # get the batch size 18 | batch_size = length.size(0) 19 | 20 | # if maximum length is not provided, then compute it from the 'length' tensor. 21 | if max_len is None: 22 | max_len = length.max().item() 23 | 24 | # Create a tensor of size `(batch_size, max_len)` filled with `True`. 25 | mask = torch.ones(size=(batch_size, max_len), dtype=torch.bool, device=length.device) 26 | 27 | # Create a tensor with consecutive numbers. 28 | range_tensor = torch.arange(max_len, device=length.device) 29 | 30 | # Expand the dim of 'length' tensor and 'offset' tensor to make it `(batch_size, max_len)`. 31 | # The added dimension will be used for broadcasting. 32 | length_exp = length.unsqueeze(-1) 33 | offsets_exp = offsets.unsqueeze(-1) 34 | 35 | # Create a boolean mask where `False` represents valid positions and `True` represents padding. 36 | mask = (range_tensor < offsets_exp) | (~(range_tensor < length_exp)) 37 | 38 | return mask 39 | 40 | 41 | def construct_padding_mask(input_tensor, pad_token): 42 | return (input_tensor == pad_token).cumsum(dim=1) > 0 43 | 44 | 45 | def nuke_weight_norm(module): 46 | """ 47 | Recursively remove weight normalization from a module and its children. 48 | 49 | Args: 50 | module (torch.nn.Module): The module from which to remove weight normalization. 51 | """ 52 | # Remove weight norm from current module if it exists 53 | try: 54 | torch.nn.utils.remove_weight_norm(module) 55 | logging.debug(f"Removed weight norm from {module.__class__.__name__}") 56 | except ValueError: 57 | # Ignore if the module does not have weight norm applied. 58 | pass 59 | 60 | # Recursively call the function on children modules 61 | for child in module.children(): 62 | nuke_weight_norm(child) 63 | -------------------------------------------------------------------------------- /mars5_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip install --upgrade vocos encodec librosa safetensors" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 16, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pprint\n", 19 | "import IPython.display as ipd\n", 20 | "import torch\n", 21 | "import librosa" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# load model\n", 31 | "mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "Now that the model is loaded, pick a reference audio to clone from. If you want to use deep clone, also specify its transcript. " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# download example ref audio\n", 48 | "!wget -O example.wav https://github.com/Camb-ai/mars5-tts/raw/master/docs/assets/example_ref.wav " 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "wav, sr = librosa.load('./example.wav', \n", 58 | " sr=mars5.sr, mono=True)\n", 59 | "wav = torch.from_numpy(wav)\n", 60 | "ref_transcript = \"We actually haven't managed to meet demand.\"\n", 61 | "print(\"Reference audio:\")\n", 62 | "ipd.display(ipd.Audio(wav.numpy(), rate=mars5.sr))\n", 63 | "print(f\"Reference transcript: {ref_transcript}\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "deep_clone = True # set to False if you don't know prompt transcript or want fast inference.\n", 73 | "# Below you can tune other inference settings, like top_k, temperature, top_p, etc...\n", 74 | "cfg = config_class(deep_clone=deep_clone, rep_penalty_window=100,\n", 75 | " top_k=100, temperature=0.7, freq_penalty=3)\n", 76 | "\n", 77 | "ar_codes, wav_out = mars5.tts(\"The quick brown rat.\", wav, \n", 78 | " ref_transcript,\n", 79 | " cfg=cfg)\n", 80 | "\n", 81 | "print('Synthesized output audio:')\n", 82 | "ipd.Audio(wav_out.numpy(), rate=mars5.sr)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "You can see all the inference settings available to tune in the inference config here:" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "pprint.pprint(config_class())" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "You can also listen to the vocoded raw coarse codes, for debugging purposes:" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "ar_wav = mars5.vocode(ar_codes.cpu()[:, None])\n", 115 | "ipd.Audio(ar_wav.numpy(), rate=mars5.sr)" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "matt-py311", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.11.9" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | numpy==1.26.4 5 | regex 6 | librosa 7 | vocos 8 | encodec 9 | safetensors 10 | --------------------------------------------------------------------------------