├── .bumpversion.cfg ├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── export.py ├── musiclang_predict ├── __init__.py ├── c │ ├── Makefile │ ├── run.c │ └── run.h ├── chelpers.py ├── corpus.py ├── corpus │ ├── bach_847.mid │ ├── bob_marley_jammin.mid │ ├── boney_m_ma_baker.mid │ ├── mozart_alla_turca.mid │ └── white_stripes_seven_nation_army.mid ├── predict.py └── tokenizers │ ├── __init__.py │ ├── bpe_iterator.py │ ├── bpe_tokenizer.py │ ├── chord_tokenizer.py │ ├── template_extractor.py │ └── tokenizer.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py └── test_tokenizer.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.1.1 3 | commit = True 4 | message = Bump version: {current_version} → {new_version} [skip ci] 5 | tag = True 6 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)? 7 | serialize = 8 | {major}.{minor}.{patch} 9 | {major}.{minor}.{patch} 10 | 11 | [bumpversion:part:release] 12 | optional_value = prod 13 | first_value = dev 14 | values = 15 | dev 16 | prod 17 | 18 | [bumpversion:file:pyproject.toml] 19 | 20 | [bumpversion:file:setup.py] 21 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # CircleCI configuration file 2 | version: 2.1 3 | 4 | 5 | orbs: 6 | python: circleci/python@2.1.1 7 | 8 | jobs: 9 | test: # this can be any name you choose 10 | executor: python/default 11 | steps: 12 | - checkout # checkout source code 13 | - python/install-packages: 14 | pkg-manager: pip 15 | - run: pip install pytest 16 | - run: 17 | name: Run tests 18 | command: python -m pytest 19 | - persist_to_workspace: 20 | root: ~/project 21 | paths: 22 | - . 23 | deploy-on-pypi: 24 | executor: python/default 25 | steps: 26 | - checkout 27 | - run: git config user.email $EMAIL 28 | - run: git config user.name "MusicLang bot" 29 | - run: pip install bumpversion 30 | - run: bumpversion patch 31 | - run: git push 32 | - python/dist 33 | - run: ls dist/ 34 | - run: pip install twine && twine upload dist/* --non-interactive --verbose 35 | 36 | deploy-on-pypi-minor: 37 | executor: python/default 38 | steps: 39 | - checkout 40 | - run: git config user.email $EMAIL 41 | - run: git config user.name "MusicLang bot" 42 | - run: pip install bumpversion 43 | - run: bumpversion minor 44 | - run: git push 45 | - python/dist 46 | - run: ls dist/ 47 | - run: pip install twine && twine upload dist/* --non-interactive --verbose 48 | 49 | 50 | deploy-on-pypi-major: 51 | executor: python/default 52 | steps: 53 | - checkout 54 | - run: git config user.email $EMAIL 55 | - run: git config user.name "MusicLang bot" 56 | - run: pip install bumpversion 57 | - run: bumpversion major 58 | - run: git push 59 | - python/dist 60 | - run: ls dist/ 61 | - run: pip install twine && twine upload dist/* --non-interactive --verbose 62 | 63 | 64 | workflows: 65 | version: 2 66 | build-test-and-deploy: 67 | jobs: 68 | - test 69 | - approve-deployment: 70 | requires: 71 | - test 72 | type: approval 73 | filters: 74 | branches: 75 | only: main 76 | - approve-deployment-minor: 77 | requires: 78 | - test 79 | type: approval 80 | filters: 81 | branches: 82 | only: main 83 | 84 | - approve-deployment-major: 85 | requires: 86 | - test 87 | type: approval 88 | filters: 89 | branches: 90 | only: main 91 | 92 | - deploy-on-pypi: 93 | requires: 94 | - approve-deployment 95 | filters: 96 | branches: 97 | only: main 98 | 99 | - deploy-on-pypi-minor: 100 | requires: 101 | - approve-deployment-minor 102 | filters: 103 | branches: 104 | only: main 105 | 106 | - deploy-on-pypi-major: 107 | requires: 108 | - approve-deployment-major 109 | filters: 110 | branches: 111 | only: main 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | *.idea 3 | .ipynb_checkpoints/ 4 | *.ipynb 5 | *__pycache__* 6 | *.band 7 | *.band* 8 | *pytest_cache* 9 | .pytest_cache/ 10 | *.ogg 11 | *.flac 12 | *.so 13 | musiclang_predict.egg-info/* 14 | main.py 15 | *.mid 16 | !musiclang_predict/corpus/*.mid 17 | musiclang_predict.egg-info/ 18 | *.midi 19 | venv/ 20 | **.DS_Store 21 | notebooks/data 22 | data 23 | locals/ 24 | *.pyc 25 | build/ 26 | musiclang.egg-info/ 27 | documentation/build/ 28 | documentation/_build/ 29 | **.pickle 30 | dist/ 31 | documentation/auto_examples/ 32 | *.ipynb 33 | main*.py 34 | build.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![MusicLang logo](https://github.com/MusicLang/musiclang/blob/main/documentation/images/MusicLang.png?raw=true "MusicLang") 2 | 3 |

MusicLang Predict, your controllable music copilot.

4 | 5 |

6 | 🤗 HuggingFace | 7 | Discord | 8 | Follow us! 9 |

10 |
11 | 12 |

☞ You want to generate music that you can export to your favourite DAW in MIDI ?

13 |

☞ You want to control the chord progression of the generated music ?

14 |

☞ You need to run it fast on your laptop without a gpu ?

15 | 16 |
17 |

MusicLang is the contraction of “Music” & “language”: we bring advanced controllability features over music generation by manipulating symbolic music.

18 |
19 | 20 | Table of contents 21 | 22 | - Quickstart 🚀 23 | - Try it quickly 📙 24 | - Install MusicLang ♫ 25 | - Examples 🎹 26 | - 1. Generate your first music 🕺 27 | - 2. Controlling chord progression generation 🪩 28 | - 3. Generation from an existing music 💃 29 | - What's coming next at musiclang? 👀 30 | - How does MusicLang work? 🔬 31 | - 1. Annotate chords and scales progression of MIDIs using MusicLang analysis 32 | - 2. The MusicLang tokenizer : Toward controllable symbolic music generation 33 | - 3. Examples of sound made with MusicLang ❤️ 34 | - Contributing & spread the word 🤝 35 | - License ⚖️ 36 | 37 |

Quickstart 🚀

38 |

Try it quickly 📙

39 |
40 | 41 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MA2mek826c05BjbWk2nRkVv2rW7kIU_S?usp=sharing) 42 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/musiclang/musiclang-predict) 43 | 44 | Go to our ♾Colab, or to our 🤗HuggingFace space, we have a lot of cool examples, from generating creative musical ideas to continuing a song with a specified chord progression. 45 |
46 | 47 |

Install MusicLang ♫

48 |
49 | 50 | Install the `musiclang-predict` package : 51 | 52 | ```bash 53 | pip install musiclang_predict 54 | ``` 55 |

Examples 🎹

56 | 57 |

1. Generate your first music 🕺

58 |
59 | 60 | Open your favourite notebook and start generating music in a few lines : 61 | 62 | ```python 63 | from musiclang_predict import MusicLangPredictor 64 | nb_tokens = 1024 65 | temperature = 0.9 # Don't go over 1.0, at your own risks ! 66 | top_p = 1.0 # <=1.0, Usually 1 best to get not too much repetitive music 67 | seed = 16 # change here to change result, or set to 0 to unset seed 68 | 69 | ml = MusicLangPredictor('musiclang/musiclang-v2') # Only available model for now 70 | 71 | score = ml.predict( 72 | nb_tokens=nb_tokens, # 1024 tokens ~ 25s of music (depending of the number of instruments generated) 73 | temperature=temperature, 74 | topp=top_p, 75 | rng_seed=seed # change here to change result, or set to 0 to unset seed 76 | ) 77 | score.to_midi('test.mid') # Open that file in your favourite DAW, score editor or even in VLC 78 | ``` 79 | 80 |

2. Controlling chord progression generation 🪩

81 |
82 | 83 | You had a specific harmony in mind, right ? MusicLang allows fine control over the chord progression of the generated music. 84 | Just specify it as a string like below, choose a time signature and let the magic happen. 85 | 86 | ```python 87 | from musiclang_predict import MusicLangPredictor 88 | 89 | # Control the chord progression 90 | # Chord qualities available : M, m, 7, m7b5, sus2, sus4, m7, M7, dim, dim0. 91 | # You can also specify the bass if it belongs to the chord (eg : Bm/D) 92 | chord_progression = "Am CM Dm E7 Am" # 1 chord = 1 bar 93 | time_signature = (4, 4) # 4/4 time signature, don't be too crazy here 94 | nb_tokens = 1024 95 | temperature = 0.8 96 | top_p = 1.0 97 | seed = 42 98 | 99 | ml = MusicLangPredictor('musiclang/musiclang-v2') 100 | 101 | score = ml.predict_chords( 102 | chord_progression, 103 | time_signature=time_signature, 104 | temperature=temperature, 105 | topp=top_p, 106 | rng_seed=seed # set to 0 to unset seed 107 | ) 108 | score.to_midi('test.mid', tempo=120, time_signature=(4, 4)) 109 | ``` 110 | 111 | > Disclaimer : The chord progression is not guaranteed to be exactly the same as the one you specified. It's a generative model after all. This may occur more frequently when using an exotic chord progression or setting a high temperature. 112 | 113 |

3. Generation from an existing music 💃

114 |
115 | 116 | What if I want to use MusicLang from an existing music ? Don't worry, we got you covered. You can use your music as a template to generate new music. 117 | Let's continue with some Bach music and explore a chord progression he might have used: 118 | ```python 119 | from musiclang_predict import MusicLangPredictor 120 | from musiclang_predict import corpus 121 | 122 | song_name = 'bach_847' # corpus.list_corpus() to get the list of available songs 123 | chord_progression = "Cm C7/E Fm F#dim G7 Cm" 124 | nb_tokens = 1024 125 | temperature = 0.8 126 | top_p = 1.0 127 | seed = 3666 128 | 129 | ml = MusicLangPredictor('musiclang/musiclang-v2') 130 | 131 | score = ml.predict_chords( 132 | chord_progression, 133 | score=corpus.get_midi_path_from_corpus(song_name), 134 | time_signature=(4, 4), 135 | nb_tokens=1024, 136 | prompt_chord_range=(0,4), 137 | temperature=temperature, 138 | topp=top_p, 139 | rng_seed=seed # set to 0 to unset seed 140 | ) 141 | 142 | score.to_midi('test.mid', tempo=110, time_signature=(4, 4)) 143 | ``` 144 | 145 |

What's coming next at MusicLang? 👀

146 |
147 | 148 | We are working on a lot of cool features, some are already encoded in the model : 149 | - A control over the instruments used in each bar and their properties (note density, pitch range, average velocity); 150 | - Some performances improvements over the inference C script; 151 | - A faster distilled model for real-time generation that can be embedded in plugins or mobile applications; 152 | - An integration into a DAW as a plugin; 153 | - Some specialized smaller models depending on our user's needs; 154 | - And more to come! 😎 155 | 156 |

How does MusicLang work? 🔬

157 |
158 | 159 | If you want to learn more about how we are moving toward symbolic music generation, go to our [technical blog](https://musiclang.github.io/). The tokenization, the model are described in great details: 160 | 161 |

1. Annotate chords and scales progression of MIDIs using MusicLang analysis

162 |

2. The MusicLang tokenizer : Toward controllable symbolic music generation

163 |

3. Examples of sound made with MusicLang ❤️

164 |
165 | 166 | We are using a LLAMA2 architecture (many thanks to Andrej Karpathy's awesome [llama2.c](https://github.com/karpathy/llama2.c)), trained on a large dataset of midi files (The CC0 licensed [LAKH](https://colinraffel.com/projects/lmd/)). 167 | We heavily rely on preprocessing the midi files to get an enriched tokenization that describe chords & scale for each bar. 168 | The is also helpful for normalizing melodies relative to the current chord/scale. 169 | 170 | 171 |

Contributing & spread the word 🤝

172 |
173 | 174 | We are looking for contributors to help us improve the model, the tokenization, the performances and the documentation. 175 | If you are interested in this project, open an issue, a pull request, or even [contact us directly](https://www.musiclang.io/contact). 176 | 177 | Whether you're contributing code or just saying hello, we'd love to hear about the work you are creating with MusicLang. Here's how you can reach out to us: 178 | * Join our [Discord](https://discord.gg/2g7eA5vP) to ask your questions and get support 179 | * Follow us on [Linkedin](https://www.linkedin.com/company/musiclang/) 180 | * Add your star on [GitHub](https://github.com/musiclang/musiclang_predict?tab=readme-ov-file) or [HuggingFace](https://huggingface.co/musiclang/musiclang-4k) 181 | 182 |

License ⚖️

183 |
184 | 185 | MusicLang Predict (This package) is licensed under the GPL-3.0 License. 186 | However please note that specific licenses applies to our models. If you would like to use the model in your commercial product, please 187 | [contact us](https://www.musiclang.io/contact). We are looking forward to hearing from you ! 188 | 189 | The MusicLang base language package on which the model rely ([musiclang package](https://github.com/musiclang/musiclang)) is licensed under the BSD 3-Clause License. 190 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script has functions and utilties for model export. 3 | Basically, we have a bunch of versions of the model, and we 4 | want to export them to .bin files to be read from and inferenced in C. 5 | 6 | Among the "input" versions of PyTorch files/models: 7 | - Official Llama 2 weights released by Meta 8 | - Huggingface weights available on the hub 9 | - llama2.c (this repo) trained models 10 | 11 | Among the "output" versions of .bin files: 12 | - v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED) 13 | - v1-vN: Improved .bin files with a proper header, cache alignment, etc. 14 | 15 | This script aspires to provide all of these conversions. 16 | """ 17 | import os 18 | import gzip 19 | import shutil 20 | import struct 21 | import argparse 22 | import json 23 | from pathlib import Path 24 | 25 | import numpy as np 26 | import torch 27 | from torch import nn 28 | 29 | from model import ModelArgs, Transformer 30 | 31 | # ----------------------------------------------------------------------------- 32 | # common utilities 33 | 34 | def serialize_fp32(file, tensor): 35 | """ writes one fp32 tensor to file that is open in wb mode """ 36 | d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() 37 | b = struct.pack(f'{len(d)}f', *d) 38 | file.write(b) 39 | 40 | def serialize_int8(file, tensor): 41 | """ writes one int8 tensor to file that is open in wb mode """ 42 | d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) 43 | b = struct.pack(f'{len(d)}b', *d) 44 | file.write(b) 45 | 46 | def quantize_q80(w, group_size): 47 | """ 48 | takes a tensor and returns the Q8_0 quantized version 49 | i.e. symmetric quantization into int8, range [-127,127] 50 | """ 51 | assert w.numel() % group_size == 0 52 | ori_shape = w.shape 53 | w = w.float() # convert to float32 54 | w = w.reshape(-1, group_size) 55 | # find the max in each group 56 | wmax = torch.abs(w).max(dim=1).values 57 | # calculate the scaling factor such that float = quant * scale 58 | scale = wmax / 127.0 59 | # scale into range [-127, 127] 60 | quant = w / scale[:,None] 61 | # round to nearest integer 62 | int8val = torch.round(quant).to(torch.int8) 63 | # dequantize by rescaling 64 | fp32val = (int8val.float() * scale[:,None]).view(-1) 65 | fp32valr = fp32val.reshape(-1, group_size) 66 | # calculate the max error in each group 67 | err = torch.abs(fp32valr - w).max(dim=1).values 68 | # find the max error across all groups 69 | maxerr = err.max().item() 70 | return int8val, scale, maxerr 71 | 72 | # ----------------------------------------------------------------------------- 73 | # legacy 74 | 75 | def legacy_export(model, filepath): 76 | """ Original export of llama2.c bin files, i.e. version v0 """ 77 | out_file = open(filepath, 'wb') 78 | 79 | # first write out the header 80 | hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] 81 | p = model.params 82 | shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) 83 | # legacy format uses negative/positive vocab size as a shared classifier flag 84 | if not shared_classifier: 85 | p.vocab_size = -p.vocab_size 86 | n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads 87 | header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, 88 | n_kv_heads, p.vocab_size, p.max_seq_len) 89 | out_file.write(header) 90 | 91 | # next write out the embedding weights 92 | serialize_fp32(out_file, model.tok_embeddings.weight) 93 | 94 | # now all the layers 95 | # attention weights 96 | for layer in model.layers: 97 | serialize_fp32(out_file, layer.attention_norm.weight) 98 | for layer in model.layers: 99 | serialize_fp32(out_file, layer.attention.wq.weight) 100 | for layer in model.layers: 101 | serialize_fp32(out_file, layer.attention.wk.weight) 102 | for layer in model.layers: 103 | serialize_fp32(out_file, layer.attention.wv.weight) 104 | for layer in model.layers: 105 | serialize_fp32(out_file, layer.attention.wo.weight) 106 | # ffn weights 107 | for layer in model.layers: 108 | serialize_fp32(out_file, layer.ffn_norm.weight) 109 | for layer in model.layers: 110 | serialize_fp32(out_file, layer.feed_forward.w1.weight) 111 | for layer in model.layers: 112 | serialize_fp32(out_file, layer.feed_forward.w2.weight) 113 | for layer in model.layers: 114 | serialize_fp32(out_file, layer.feed_forward.w3.weight) 115 | # final rmsnorm 116 | serialize_fp32(out_file, model.norm.weight) 117 | # freqs_cis 118 | serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) 119 | serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) 120 | 121 | # final classifier weights 122 | if not shared_classifier: 123 | serialize_fp32(out_file, model.output.weight) 124 | 125 | # write to binary file 126 | out_file.close() 127 | print(f"wrote {filepath}") 128 | 129 | # ----------------------------------------------------------------------------- 130 | # new version 131 | 132 | def version1_export(model, filepath): 133 | """ 134 | Export the model weights in full float32 .bin file to be read from C. 135 | This is same as legacy_export, but with a proper header. 136 | """ 137 | version = 1 138 | 139 | out_file = open(filepath, 'wb') 140 | # first write out the header. the header will be 256 bytes 141 | # 1) write magic, which will be uint32 of "ak42" in ASCII 142 | out_file.write(struct.pack('I', 0x616b3432)) 143 | # 2) write version, which will be int 144 | out_file.write(struct.pack('i', version)) 145 | # 3) write the params, which will be 7 ints 146 | p = model.params 147 | hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] 148 | n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads 149 | header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, 150 | n_kv_heads, p.vocab_size, p.max_seq_len) 151 | out_file.write(header) 152 | # 4) write some other flags 153 | shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) 154 | out_file.write(struct.pack('B', int(shared_classifier))) 155 | pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos 156 | assert pad >= 0 157 | out_file.write(b'\0' * pad) 158 | 159 | # now let's write out all the params 160 | weights = [ 161 | *[layer.attention_norm.weight for layer in model.layers], 162 | *[layer.ffn_norm.weight for layer in model.layers], 163 | model.norm.weight, 164 | model.tok_embeddings.weight, 165 | *[layer.attention.wq.weight for layer in model.layers], 166 | *[layer.attention.wk.weight for layer in model.layers], 167 | *[layer.attention.wv.weight for layer in model.layers], 168 | *[layer.attention.wo.weight for layer in model.layers], 169 | *[layer.feed_forward.w1.weight for layer in model.layers], 170 | *[layer.feed_forward.w2.weight for layer in model.layers], 171 | *[layer.feed_forward.w3.weight for layer in model.layers], 172 | ] 173 | if not shared_classifier: 174 | weights.append(model.output.weight) 175 | for w in weights: 176 | serialize_fp32(out_file, w) 177 | 178 | # write to binary file 179 | out_file.close() 180 | print(f"wrote {filepath}") 181 | 182 | def version2_export(model, filepath, group_size=64): 183 | """ 184 | Export the model weights in Q8_0 into .bin file to be read from C. 185 | That is: 186 | - quantize all weights to symmetric int8, in range [-127, 127] 187 | - all other tensors (the rmsnorm params) are kept and exported in fp32 188 | - quantization is done in groups of group_size to reduce the effects of any outliers 189 | """ 190 | version = 2 191 | 192 | # let's first do some validation for this export type 193 | while model.params.dim % group_size != 0: 194 | group_size //= 2 195 | print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim") 196 | weights = [ 197 | model.tok_embeddings.weight, 198 | *[layer.attention.wq.weight for layer in model.layers], 199 | *[layer.attention.wk.weight for layer in model.layers], 200 | *[layer.attention.wv.weight for layer in model.layers], 201 | *[layer.attention.wo.weight for layer in model.layers], 202 | *[layer.feed_forward.w1.weight for layer in model.layers], 203 | *[layer.feed_forward.w2.weight for layer in model.layers], 204 | *[layer.feed_forward.w3.weight for layer in model.layers], 205 | ] 206 | shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) 207 | if not shared_classifier: 208 | weights.append(model.output.weight) 209 | for w in weights: 210 | assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" 211 | 212 | # write 213 | out_file = open(filepath, 'wb') 214 | # first write out the header. the header will be 256 bytes 215 | # 1) write magic, which will be uint32 of "ak42" in ASCII 216 | out_file.write(struct.pack('I', 0x616b3432)) 217 | # 2) write version, which will be int 218 | out_file.write(struct.pack('i', version)) 219 | # 3) write the params, which will be 7 ints 220 | p = model.params 221 | hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] 222 | n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads 223 | header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, 224 | n_kv_heads, p.vocab_size, p.max_seq_len) 225 | out_file.write(header) 226 | # 4) write some other flags 227 | out_file.write(struct.pack('B', int(shared_classifier))) 228 | out_file.write(struct.pack('i', group_size)) # group size used for quantization 229 | pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos 230 | assert pad >= 0 231 | out_file.write(b'\0' * pad) 232 | # now that the header is done, let's write out the model 233 | 234 | # first let's write out all the params that we are keeping in fp32: the norms 235 | for layer in model.layers: # attention norms 236 | serialize_fp32(out_file, layer.attention_norm.weight) 237 | for layer in model.layers: # MLP norms 238 | serialize_fp32(out_file, layer.ffn_norm.weight) 239 | serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm 240 | 241 | # now let's write out all the params that we are quantizing to Q8_0 242 | # note we skip classifier weights, which are shared with the embedding 243 | ew = [] 244 | for i, w in enumerate(weights): 245 | # quantize this weight 246 | q, s, err = quantize_q80(w, group_size) 247 | # save the int8 weights to file 248 | serialize_int8(out_file, q) # save the tensor in int8 249 | serialize_fp32(out_file, s) # save scale factors 250 | # logging 251 | ew.append((err, w.shape)) 252 | print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") 253 | 254 | # print the highest error across all weights, should be very small, e.g. O(~0.001) 255 | ew.sort(reverse=True) 256 | print(f"max quantization group error across all weights: {ew[0][0]}") 257 | 258 | # write to binary file 259 | out_file.close() 260 | print(f"wrote {filepath}") 261 | 262 | def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32): 263 | """ Generate the pytorch_model.bin state_dict and config.json for HuggingFace """ 264 | 265 | try: 266 | from transformers.models.llama.configuration_llama import LlamaConfig 267 | except ImportError: 268 | print("Error: transformers package is required to load huggingface models") 269 | print("Please run `pip install transformers` to install it") 270 | return None 271 | 272 | # Generate LlamaModel state_dict 273 | hf_state_dict = {} 274 | 275 | # Sometimes we have repeated key values for the heads 276 | dim = llama_model.params.dim 277 | num_key_value_heads = llama_model.params.n_kv_heads 278 | n_rep = llama_model.params.n_heads // num_key_value_heads 279 | key_value_dim = dim // n_rep 280 | 281 | # HuggingFace needs the weights permuted. 282 | # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122 283 | def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim): 284 | return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) 285 | 286 | # Transfer weights from llama model to the HF state dictionary format 287 | hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype) 288 | hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype) 289 | 290 | # Add each layer's weights to the HF state dictionary 291 | for i, layer in enumerate(llama_model.layers): 292 | layer_id = layer.layer_id 293 | hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype) 294 | hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype) 295 | hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype) 296 | hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype) 297 | hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype) 298 | hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype) 299 | hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype) 300 | hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype) 301 | hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype) 302 | 303 | # llama2.c usually uses tied weights -> reference the embed_tokens.weights instead 304 | hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight'] 305 | 306 | # We check that the embeddings are tied, else use manual output weights 307 | _embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight) 308 | if not _embeddings_are_tied: 309 | hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype) 310 | 311 | 312 | # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama) 313 | 314 | # Extract necessary attributes from llama.c model 315 | vocab_size = llama_model.params.vocab_size 316 | hidden_size = llama_model.params.dim 317 | intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0] 318 | num_hidden_layers = llama_model.params.n_layers 319 | num_attention_heads = llama_model.params.n_heads 320 | num_key_value_heads = llama_model.params.n_kv_heads 321 | max_position_embeddings = llama_model.params.max_seq_len 322 | rms_norm_eps = llama_model.params.norm_eps 323 | 324 | # TODO check values for: 325 | # pretraining_tp, initializer_range, use_cache, 326 | # rope_theta, and rope_scaling. 327 | 328 | config = LlamaConfig( 329 | vocab_size=vocab_size, 330 | hidden_size=hidden_size, 331 | intermediate_size=intermediate_size, 332 | num_hidden_layers=num_hidden_layers, 333 | num_attention_heads=num_attention_heads, 334 | num_key_value_heads=num_key_value_heads, 335 | max_position_embeddings=max_position_embeddings, 336 | rms_norm_eps=rms_norm_eps, 337 | tie_word_embeddings=_embeddings_are_tied, 338 | # Manual 339 | architectures=["LlamaForCausalLM"], 340 | hidden_act="silu", 341 | ) 342 | 343 | 344 | # Save files in directory filepath 345 | # First make the directory if it doesn't exist 346 | os.makedirs(filepath, exist_ok=True) 347 | 348 | # Save the state dictionary in .bin format, and config as .json 349 | torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin")) 350 | config.save_pretrained(filepath) 351 | 352 | 353 | # ----------------------------------------------------------------------------- 354 | # Load / import functions 355 | 356 | def load_checkpoint(checkpoint): 357 | 358 | # load the provided model checkpoint 359 | checkpoint_dict = torch.load(checkpoint, map_location='cpu') 360 | gptconf = ModelArgs(**checkpoint_dict['model_args']) 361 | model = Transformer(gptconf) 362 | state_dict = checkpoint_dict['model'] 363 | unwanted_prefix = '_orig_mod.' 364 | for k,v in list(state_dict.items()): 365 | if k.startswith(unwanted_prefix): 366 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 367 | model.load_state_dict(state_dict, strict=False) 368 | model.eval() 369 | return model 370 | 371 | def load_meta_model(model_path): 372 | params_path = os.path.join(model_path, 'params.json') 373 | with open(params_path) as f: 374 | params = json.load(f) 375 | print(params) 376 | 377 | model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) 378 | models = [torch.load(p, map_location='cpu') for p in model_paths] 379 | 380 | def concat_weights(models): 381 | state_dict = {} 382 | for name in list(models[0]): 383 | tensors = [model[name] for model in models] 384 | if len(tensors) == 1 or len(tensors[0].shape) == 1: 385 | state_dict[name] = tensors[0] 386 | continue 387 | is_axis_1 = ( 388 | name.startswith('tok_embeddings.') 389 | or name.endswith('.attention.wo.weight') 390 | or name.endswith('.feed_forward.w2.weight') 391 | ) 392 | axis = 1 if is_axis_1 else 0 393 | state_dict[name] = torch.cat(tensors, dim=axis) 394 | for model in models: 395 | del model[name] 396 | return state_dict 397 | 398 | state_dict = concat_weights(models) 399 | del models 400 | 401 | # set ModelArgs 402 | config = ModelArgs() 403 | config.dim = params["dim"] 404 | config.n_layers = params["n_layers"] 405 | config.n_heads = params["n_heads"] 406 | config.n_kv_heads = params.get('n_kv_heads') or params['n_heads'] 407 | config.multiple_of = params["multiple_of"] 408 | config.norm_eps = params["norm_eps"] 409 | 410 | config.vocab_size = state_dict['tok_embeddings.weight'].shape[0] 411 | config.max_seq_len = 2048 412 | 413 | 414 | # create a new Transformer object and set weights 415 | model = Transformer(config) 416 | 417 | model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight']) 418 | model.norm.weight = nn.Parameter(state_dict['norm.weight']) 419 | 420 | for layer in model.layers: 421 | i = layer.layer_id 422 | layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight']) 423 | layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight']) 424 | layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight']) 425 | layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight']) 426 | layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight']) 427 | layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight']) 428 | layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight']) 429 | layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight']) 430 | layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight']) 431 | 432 | # final classifier 433 | model.output.weight = nn.Parameter(state_dict['output.weight']) 434 | model.eval() 435 | return model 436 | 437 | def load_hf_model(model_path): 438 | 439 | try: 440 | from transformers import AutoModelForCausalLM 441 | except ImportError: 442 | print("Error: transformers package is required to load huggingface models") 443 | print("Please run `pip install transformers` to install it") 444 | return None 445 | 446 | # load HF model 447 | hf_model = AutoModelForCausalLM.from_pretrained(model_path) 448 | hf_dict = hf_model.state_dict() 449 | 450 | # convert LlamaConfig to ModelArgs 451 | config = ModelArgs() 452 | config.dim = hf_model.config.hidden_size 453 | config.n_layers = hf_model.config.num_hidden_layers 454 | config.n_heads = hf_model.config.num_attention_heads 455 | config.n_kv_heads = hf_model.config.num_attention_heads 456 | config.vocab_size = hf_model.config.vocab_size 457 | config.hidden_dim = hf_model.config.intermediate_size 458 | config.norm_eps = hf_model.config.rms_norm_eps 459 | config.max_seq_len = hf_model.config.max_position_embeddings 460 | 461 | # create a new Transformer object and set weights 462 | model = Transformer(config) 463 | 464 | model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) 465 | model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) 466 | 467 | # huggingface permutes WQ and WK, this function reverses it 468 | def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim): 469 | return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) 470 | 471 | for layer in model.layers: 472 | i = layer.layer_id 473 | layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) 474 | layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])) 475 | layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])) 476 | layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) 477 | layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) 478 | layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) 479 | layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight']) 480 | layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight']) 481 | layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight']) 482 | 483 | # final classifier 484 | model.output.weight = nn.Parameter(hf_dict['lm_head.weight']) 485 | model.eval() 486 | return model 487 | 488 | 489 | # ----------------------------------------------------------------------------- 490 | # API entrypoint 491 | 492 | def model_export(model, filepath, version, dtype=torch.float32): 493 | """ 494 | Versions docs: 495 | v-1:huggingface export, i.e. intended for use outside of this repo, in HF 496 | v0: legacy llama2.c float format, DEPRECATED 497 | v1: float32 export 498 | v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups 499 | # TODO: add dtype export support for other versions (?) 500 | """ 501 | if version == 0: 502 | legacy_export(model, filepath) 503 | elif version == 1: 504 | version1_export(model, filepath) 505 | elif version == 2: 506 | version2_export(model, filepath) 507 | elif version == -1: 508 | hf_export(model, filepath, dtype) 509 | else: 510 | raise ValueError(f"unknown version {version}") 511 | 512 | def torchscript_export(model, filepath, zero_params=False, gzip_output=False): 513 | """ 514 | (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now) 515 | Saves the model as a TorchScript. 516 | The resulting file can be loaded in C++ code and then used for training or 517 | inference with: 518 | #include 519 | torch::jit::Module module = torch::jit::load("model.pt") 520 | Note that the serialized model includes the initial parameters and with the default 521 | ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute 522 | the model parameters separately you can zero out the parameters before saving it and 523 | it will gzip down to 780K. 524 | """ 525 | 526 | # If requested zero params before saving the model. This is useful in 527 | # conjunction with gzip_output. 528 | if zero_params: 529 | for p in model.parameters(): 530 | p.detach().zero_() 531 | 532 | torch.jit.save(torch.jit.script(model), filepath) 533 | 534 | if gzip_output: 535 | with open(filepath, "rb") as f_in: 536 | with gzip.open(f"{filepath}.gz", "wb") as f_out: 537 | shutil.copyfileobj(f_in, f_out) 538 | os.unlink(filepath) 539 | 540 | # ----------------------------------------------------------------------------- 541 | # CLI entrypoint 542 | 543 | if __name__ == "__main__": 544 | 545 | parser = argparse.ArgumentParser() 546 | parser.add_argument("filepath", type=str, help="the output filepath") 547 | parser.add_argument("--version", default=0, type=int, help="the version to export with") 548 | parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32") 549 | group = parser.add_mutually_exclusive_group(required=True) 550 | group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") 551 | group.add_argument("--meta-llama", type=str, help="meta llama model path") 552 | group.add_argument("--hf", type=str, help="huggingface model path") 553 | args = parser.parse_args() 554 | dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] 555 | 556 | if args.checkpoint: 557 | model = load_checkpoint(args.checkpoint) 558 | elif args.meta_llama: 559 | model = load_meta_model(args.meta_llama) 560 | elif args.hf: 561 | model = load_hf_model(args.hf) 562 | 563 | if model is None: 564 | parser.error("Can't load input model!") 565 | 566 | # export 567 | model_export(model, args.filepath, args.version, args.dtype) -------------------------------------------------------------------------------- /musiclang_predict/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizers import MusicLangTokenizer, midi_file_to_template, score_to_template, MusicLangBPETokenizer 2 | from .predict import MusicLangPredictor 3 | 4 | __all__ = [ 5 | 'MusicLangTokenizer', 'predict', 'midi_file_to_template', 'score_to_template', 6 | 'MusicLangBPETokenizer', 'MusicLangPredictor' 7 | ] -------------------------------------------------------------------------------- /musiclang_predict/c/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for compiling C code into a shared library for use with Python 2 | 3 | # Compiler settings - Can be changed to clang if desired 4 | CC = gcc 5 | 6 | # Compiler flags: 7 | # -Ofast for aggressive optimizations 8 | # -fPIC for position-independent code (needed for shared library) 9 | # -shared for creating a shared library 10 | CFLAGS = -Ofast -fPIC -shared 11 | 12 | # Target library name 13 | TARGET_LIB_LINUX = librun.so 14 | TARGET_LIB_WIN = run.dll 15 | 16 | # Source files - Automatically finds all .c files 17 | SRC = $(wildcard *.c) 18 | 19 | # Object files 20 | OBJ = $(SRC:.c=.o) 21 | 22 | # OS specific part 23 | UNAME_S := $(shell uname -s) 24 | ifeq ($(UNAME_S),Linux) 25 | TARGET_LIB = $(TARGET_LIB_LINUX) 26 | endif 27 | ifeq ($(UNAME_S),Darwin) 28 | TARGET_LIB = $(TARGET_LIB_LINUX) 29 | endif 30 | ifeq ($(UNAME_S),Windows_NT) 31 | TARGET_LIB = $(TARGET_LIB_WIN) 32 | CC = x86_64-w64-mingw32-gcc 33 | endif 34 | 35 | # Rule to make the shared library 36 | all: $(TARGET_LIB) 37 | 38 | $(TARGET_LIB): $(SRC) 39 | $(CC) $(CFLAGS) -o $@ $^ -lm 40 | 41 | # Rule for cleaning up 42 | clean: 43 | rm -f $(TARGET_LIB) $(OBJ) 44 | 45 | # Rule for installing (optional, might require additional actions for Python usage) 46 | install: 47 | cp $(TARGET_LIB) /path/to/your/python/project 48 | -------------------------------------------------------------------------------- /musiclang_predict/c/run.c: -------------------------------------------------------------------------------- 1 | /* Inference for Llama-2 Transformer model in pure C */ 2 | 3 | #include 4 | #include 5 | #include // Ensure this is included for the bool type 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "run.h" 12 | #if defined _WIN32 13 | #include "win.h" 14 | #else 15 | #include 16 | #include 17 | #endif 18 | // ---------------------------------------------------------------------------- 19 | // Transformer model 20 | 21 | 22 | 23 | void malloc_run_state(RunState* s, Config* p) { 24 | // we calloc instead of malloc to keep valgrind happy 25 | int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; 26 | s->x = calloc(p->dim, sizeof(float)); 27 | s->xb = calloc(p->dim, sizeof(float)); 28 | s->xb2 = calloc(p->dim, sizeof(float)); 29 | s->hb = calloc(p->hidden_dim, sizeof(float)); 30 | s->hb2 = calloc(p->hidden_dim, sizeof(float)); 31 | s->q = calloc(p->dim, sizeof(float)); 32 | s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 33 | s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 34 | s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); 35 | s->logits = calloc(p->vocab_size, sizeof(float)); 36 | // ensure all mallocs went fine 37 | if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q 38 | || !s->key_cache || !s->value_cache || !s->att || !s->logits) { 39 | fprintf(stderr, "malloc failed!\n"); 40 | exit(EXIT_FAILURE); 41 | } 42 | } 43 | 44 | void free_run_state(RunState* s) { 45 | free(s->x); 46 | free(s->xb); 47 | free(s->xb2); 48 | free(s->hb); 49 | free(s->hb2); 50 | free(s->q); 51 | free(s->att); 52 | free(s->logits); 53 | free(s->key_cache); 54 | free(s->value_cache); 55 | } 56 | 57 | void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { 58 | int head_size = p->dim / p->n_heads; 59 | // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models 60 | unsigned long long n_layers = p->n_layers; 61 | w->token_embedding_table = ptr; 62 | ptr += p->vocab_size * p->dim; 63 | w->rms_att_weight = ptr; 64 | ptr += n_layers * p->dim; 65 | w->wq = ptr; 66 | ptr += n_layers * p->dim * (p->n_heads * head_size); 67 | w->wk = ptr; 68 | ptr += n_layers * p->dim * (p->n_kv_heads * head_size); 69 | w->wv = ptr; 70 | ptr += n_layers * p->dim * (p->n_kv_heads * head_size); 71 | w->wo = ptr; 72 | ptr += n_layers * (p->n_heads * head_size) * p->dim; 73 | w->rms_ffn_weight = ptr; 74 | ptr += n_layers * p->dim; 75 | w->w1 = ptr; 76 | ptr += n_layers * p->dim * p->hidden_dim; 77 | w->w2 = ptr; 78 | ptr += n_layers * p->hidden_dim * p->dim; 79 | w->w3 = ptr; 80 | ptr += n_layers * p->dim * p->hidden_dim; 81 | w->rms_final_weight = ptr; 82 | ptr += p->dim; 83 | ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE) 84 | ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE) 85 | w->wcls = shared_weights ? w->token_embedding_table : ptr; 86 | } 87 | 88 | void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, 89 | int* fd, float** data, ssize_t* file_size) { 90 | FILE *file = fopen(checkpoint, "rb"); 91 | if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } 92 | // read in the config header 93 | if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } 94 | // negative vocab size is hacky way of signaling unshared weights. bit yikes. 95 | int shared_weights = config->vocab_size > 0 ? 1 : 0; 96 | config->vocab_size = abs(config->vocab_size); 97 | // figure out the file size 98 | fseek(file, 0, SEEK_END); // move file pointer to end of file 99 | *file_size = ftell(file); // get the file size, in bytes 100 | fclose(file); 101 | // memory map the Transformer weights into the data pointer 102 | *fd = open(checkpoint, O_RDONLY); // open in read only mode 103 | if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); } 104 | *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); 105 | if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } 106 | float* weights_ptr = *data + sizeof(Config)/sizeof(float); 107 | memory_map_weights(weights, config, weights_ptr, shared_weights); 108 | } 109 | 110 | void build_transformer(Transformer *t, char* checkpoint_path) { 111 | // read in the Config and the Weights from the checkpoint 112 | read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); 113 | // allocate the RunState buffers 114 | malloc_run_state(&t->state, &t->config); 115 | } 116 | 117 | void free_transformer(Transformer* t) { 118 | // close the memory mapping 119 | if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } 120 | if (t->fd != -1) { close(t->fd); } 121 | // free the RunState buffers 122 | free_run_state(&t->state); 123 | } 124 | 125 | // ---------------------------------------------------------------------------- 126 | // neural net blocks; the dynamics of the Transformer 127 | 128 | void rmsnorm(float* o, float* x, float* weight, int size) { 129 | // calculate sum of squares 130 | float ss = 0.0f; 131 | for (int j = 0; j < size; j++) { 132 | ss += x[j] * x[j]; 133 | } 134 | ss /= size; 135 | ss += 1e-5f; 136 | ss = 1.0f / sqrtf(ss); 137 | // normalize and scale 138 | for (int j = 0; j < size; j++) { 139 | o[j] = weight[j] * (ss * x[j]); 140 | } 141 | } 142 | 143 | void softmax(float* x, int size) { 144 | // find max value (for numerical stability) 145 | float max_val = x[0]; 146 | for (int i = 1; i < size; i++) { 147 | if (x[i] > max_val) { 148 | max_val = x[i]; 149 | } 150 | } 151 | // exp and sum 152 | float sum = 0.0f; 153 | for (int i = 0; i < size; i++) { 154 | x[i] = expf(x[i] - max_val); 155 | sum += x[i]; 156 | } 157 | // normalize 158 | for (int i = 0; i < size; i++) { 159 | x[i] /= sum; 160 | } 161 | } 162 | 163 | void matmul(float* xout, float* x, float* w, int n, int d) { 164 | // W (d,n) @ x (n,) -> xout (d,) 165 | // by far the most amount of time is spent inside this little function 166 | int i; 167 | #pragma omp parallel for private(i) 168 | for (i = 0; i < d; i++) { 169 | float val = 0.0f; 170 | for (int j = 0; j < n; j++) { 171 | val += w[i * n + j] * x[j]; 172 | } 173 | xout[i] = val; 174 | } 175 | } 176 | 177 | float* forward(Transformer* transformer, int token, int pos) { 178 | 179 | // a few convenience variables 180 | Config* p = &transformer->config; 181 | TransformerWeights* w = &transformer->weights; 182 | RunState* s = &transformer->state; 183 | float *x = s->x; 184 | int dim = p->dim; 185 | int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; 186 | int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery 187 | int hidden_dim = p->hidden_dim; 188 | int head_size = dim / p->n_heads; 189 | 190 | // copy the token embedding into x 191 | float* content_row = w->token_embedding_table + token * dim; 192 | memcpy(x, content_row, dim*sizeof(*x)); 193 | 194 | // forward all the layers 195 | for(unsigned long long l = 0; l < p->n_layers; l++) { 196 | 197 | // attention rmsnorm 198 | rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); 199 | 200 | // key and value point to the kv cache 201 | int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience 202 | s->k = s->key_cache + loff + pos * kv_dim; 203 | s->v = s->value_cache + loff + pos * kv_dim; 204 | 205 | // qkv matmuls for this position 206 | matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); 207 | matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); 208 | matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); 209 | 210 | // RoPE relative positional encoding: complex-valued rotate q and k in each head 211 | for (int i = 0; i < dim; i+=2) { 212 | int head_dim = i % head_size; 213 | float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); 214 | float val = pos * freq; 215 | float fcr = cosf(val); 216 | float fci = sinf(val); 217 | int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only 218 | for (int v = 0; v < rotn; v++) { 219 | float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) 220 | float v0 = vec[i]; 221 | float v1 = vec[i+1]; 222 | vec[i] = v0 * fcr - v1 * fci; 223 | vec[i+1] = v0 * fci + v1 * fcr; 224 | } 225 | } 226 | 227 | // multihead attention. iterate over all heads 228 | int h; 229 | #pragma omp parallel for private(h) 230 | for (h = 0; h < p->n_heads; h++) { 231 | // get the query vector for this head 232 | float* q = s->q + h * head_size; 233 | // attention scores for this head 234 | float* att = s->att + h * p->seq_len; 235 | // iterate over all timesteps, including the current one 236 | for (int t = 0; t <= pos; t++) { 237 | // get the key vector for this head and at this timestep 238 | float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 239 | // calculate the attention score as the dot product of q and k 240 | float score = 0.0f; 241 | for (int i = 0; i < head_size; i++) { 242 | score += q[i] * k[i]; 243 | } 244 | score /= sqrtf(head_size); 245 | // save the score to the attention buffer 246 | att[t] = score; 247 | } 248 | 249 | // softmax the scores to get attention weights, from 0..pos inclusively 250 | softmax(att, pos + 1); 251 | 252 | // weighted sum of the values, store back into xb 253 | float* xb = s->xb + h * head_size; 254 | memset(xb, 0, head_size * sizeof(float)); 255 | for (int t = 0; t <= pos; t++) { 256 | // get the value vector for this head and at this timestep 257 | float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 258 | // get the attention weight for this timestep 259 | float a = att[t]; 260 | // accumulate the weighted value into xb 261 | for (int i = 0; i < head_size; i++) { 262 | xb[i] += a * v[i]; 263 | } 264 | } 265 | } 266 | 267 | // final matmul to get the output of the attention 268 | matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim); 269 | 270 | // residual connection back into x 271 | for (int i = 0; i < dim; i++) { 272 | x[i] += s->xb2[i]; 273 | } 274 | 275 | // ffn rmsnorm 276 | rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); 277 | 278 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) 279 | // first calculate self.w1(x) and self.w3(x) 280 | matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); 281 | matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); 282 | 283 | // SwiGLU non-linearity 284 | for (int i = 0; i < hidden_dim; i++) { 285 | float val = s->hb[i]; 286 | // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid 287 | val *= (1.0f / (1.0f + expf(-val))); 288 | // elementwise multiply with w3(x) 289 | val *= s->hb2[i]; 290 | s->hb[i] = val; 291 | } 292 | 293 | // final matmul to get the output of the ffn 294 | matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim); 295 | 296 | // residual connection 297 | for (int i = 0; i < dim; i++) { 298 | x[i] += s->xb[i]; 299 | } 300 | } 301 | 302 | // final rmsnorm 303 | rmsnorm(x, x, w->rms_final_weight, dim); 304 | 305 | // classifier into logits 306 | matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); 307 | return s->logits; 308 | } 309 | 310 | // ---------------------------------------------------------------------------- 311 | // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens 312 | 313 | 314 | int compare_tokens(const void *a, const void *b) { 315 | return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); 316 | } 317 | 318 | void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { 319 | // i should have written the vocab_size into the tokenizer file... sigh 320 | t->vocab_size = vocab_size; 321 | // malloc space to hold the scores and the strings 322 | t->vocab = (char**)malloc(vocab_size * sizeof(char*)); 323 | t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); 324 | t->sorted_vocab = NULL; // initialized lazily 325 | for (int i = 0; i < 256; i++) { 326 | t->byte_pieces[i * 2] = (unsigned char)i; 327 | t->byte_pieces[i * 2 + 1] = '\0'; 328 | } 329 | // read in the file 330 | FILE *file = fopen(tokenizer_path, "rb"); 331 | if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } 332 | if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } 333 | int len; 334 | for (int i = 0; i < vocab_size; i++) { 335 | if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} 336 | if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } 337 | t->vocab[i] = (char *)malloc(len + 1); 338 | if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } 339 | t->vocab[i][len] = '\0'; // add the string terminating token 340 | } 341 | fclose(file); 342 | } 343 | 344 | void free_tokenizer(Tokenizer* t) { 345 | for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } 346 | free(t->vocab); 347 | free(t->vocab_scores); 348 | free(t->sorted_vocab); 349 | } 350 | 351 | char* decode(Tokenizer* t, int prev_token, int token) { 352 | char *piece = t->vocab[token]; 353 | // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) 354 | if (prev_token == 1 && piece[0] == ' ') { piece++; } 355 | // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' 356 | // parse this and convert and return the actual byte 357 | unsigned char byte_val; 358 | if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { 359 | piece = (char*)t->byte_pieces + byte_val * 2; 360 | } 361 | return piece; 362 | } 363 | 364 | void safe_printf(char *piece) { 365 | // piece might be a raw byte token, and we only want to print printable chars or whitespace 366 | // because some of the other bytes can be various control codes, backspace, etc. 367 | if (piece == NULL) { return; } 368 | if (piece[0] == '\0') { return; } 369 | if (piece[1] == '\0') { 370 | unsigned char byte_val = piece[0]; 371 | if (!(isprint(byte_val) || isspace(byte_val))) { 372 | return; // bad byte, don't print it 373 | } 374 | } 375 | printf("%s", piece); 376 | } 377 | 378 | int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { 379 | // efficiently find the perfect match for str in vocab, return its index or -1 if not found 380 | TokenIndex tok = { .str = str }; // acts as the key to search for 381 | TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); 382 | return res != NULL ? res->id : -1; 383 | } 384 | 385 | void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { 386 | // encode the string text (input) into an upper-bound preallocated tokens[] array 387 | // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) 388 | if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } 389 | 390 | if (t->sorted_vocab == NULL) { 391 | // lazily malloc and sort the vocabulary 392 | t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); 393 | for (int i = 0; i < t->vocab_size; i++) { 394 | t->sorted_vocab[i].str = t->vocab[i]; 395 | t->sorted_vocab[i].id = i; 396 | } 397 | qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); 398 | } 399 | 400 | // create a temporary buffer that will store merge candidates of always two consecutive tokens 401 | // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) 402 | char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); 403 | size_t str_len = 0; 404 | 405 | // start at 0 tokens 406 | *n_tokens = 0; 407 | 408 | // add optional BOS (=1) token, if desired 409 | if (bos) tokens[(*n_tokens)++] = 1; 410 | 411 | // add_dummy_prefix is true by default 412 | // so prepend a dummy prefix token to the input string, but only if text != "" 413 | // TODO: pretty sure this isn't correct in the general case but I don't have the 414 | // energy to read more of the sentencepiece code to figure out what it's doing 415 | if (text[0] != '\0') { 416 | int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); 417 | tokens[(*n_tokens)++] = dummy_prefix; 418 | } 419 | 420 | // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: 421 | // Code point ↔ UTF-8 conversion 422 | // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 423 | // U+0000 U+007F 0xxxxxxx 424 | // U+0080 U+07FF 110xxxxx 10xxxxxx 425 | // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx 426 | // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 427 | 428 | // process the raw (UTF-8) byte sequence of the input string 429 | for (char *c = text; *c != '\0'; c++) { 430 | 431 | // reset buffer if the current byte is ASCII or a leading byte 432 | // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest 433 | // 0x80 is 10000000 434 | // in UTF-8, all continuation bytes start with "10" in first two bits 435 | // so in English this is: "if this byte is not a continuation byte" 436 | if ((*c & 0xC0) != 0x80) { 437 | // this byte must be either a leading byte (11...) or an ASCII char (0x...) 438 | // => reset our location, as we're starting a new UTF-8 codepoint 439 | str_len = 0; 440 | } 441 | 442 | // append the current byte to the buffer 443 | str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line 444 | str_buffer[str_len] = '\0'; 445 | 446 | // while the next character is a continuation byte, continue appending 447 | // but if there are too many of them, just stop to avoid overruning str_buffer size. 448 | if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { 449 | continue; 450 | } 451 | 452 | // ok c+1 is not a continuation byte, so we've read in a full codepoint 453 | int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); 454 | 455 | if (id != -1) { 456 | // we found this codepoint in vocab, add it as a token 457 | tokens[(*n_tokens)++] = id; 458 | } else { 459 | // byte_fallback encoding: just encode each byte as a token 460 | // +3 is here because the first 3 vocab elements are , , 461 | // so the individual bytes only start at index 3 462 | for (int i=0; i < str_len; i++) { 463 | tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; 464 | } 465 | } 466 | str_len = 0; // protect against a sequence of stray UTF8 continuation bytes 467 | } 468 | 469 | // merge the best consecutive pair each iteration, according the scores in vocab_scores 470 | while (1) { 471 | float best_score = -1e10; 472 | int best_id = -1; 473 | int best_idx = -1; 474 | 475 | for (int i=0; i < (*n_tokens-1); i++) { 476 | // check if we can merge the pair (tokens[i], tokens[i+1]) 477 | sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); 478 | int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); 479 | if (id != -1 && t->vocab_scores[id] > best_score) { 480 | // this merge pair exists in vocab! record its score and position 481 | best_score = t->vocab_scores[id]; 482 | best_id = id; 483 | best_idx = i; 484 | } 485 | } 486 | 487 | if (best_idx == -1) { 488 | break; // we couldn't find any more pairs to merge, so we're done 489 | } 490 | 491 | // merge the consecutive pair (best_idx, best_idx+1) into new token best_id 492 | tokens[best_idx] = best_id; 493 | // delete token at position best_idx+1, shift the entire sequence back 1 494 | for (int i = best_idx+1; i < (*n_tokens-1); i++) { 495 | tokens[i] = tokens[i+1]; 496 | } 497 | (*n_tokens)--; // token length decreased 498 | } 499 | 500 | // add optional EOS (=2) token, if desired 501 | if (eos) tokens[(*n_tokens)++] = 2; 502 | 503 | free(str_buffer); 504 | } 505 | 506 | // ---------------------------------------------------------------------------- 507 | // The Sampler, which takes logits and returns a sampled token 508 | // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling 509 | 510 | 511 | int sample_argmax(float* probabilities, int n) { 512 | // return the index that has the highest probability 513 | int max_i = 0; 514 | float max_p = probabilities[0]; 515 | for (int i = 1; i < n; i++) { 516 | if (probabilities[i] > max_p) { 517 | max_i = i; 518 | max_p = probabilities[i]; 519 | } 520 | } 521 | return max_i; 522 | } 523 | 524 | int sample_mult(float* probabilities, int n, float coin) { 525 | // sample index from probabilities (they must sum to 1!) 526 | // coin is a random number in [0, 1), usually from random_f32() 527 | float cdf = 0.0f; 528 | for (int i = 0; i < n; i++) { 529 | cdf += probabilities[i]; 530 | if (coin < cdf) { 531 | return i; 532 | } 533 | } 534 | return n - 1; // in case of rounding errors 535 | } 536 | 537 | int compare(const void* a, const void* b) { 538 | ProbIndex* a_ = (ProbIndex*) a; 539 | ProbIndex* b_ = (ProbIndex*) b; 540 | if (a_->prob > b_->prob) return -1; 541 | if (a_->prob < b_->prob) return 1; 542 | return 0; 543 | } 544 | 545 | int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { 546 | // top-p sampling (or "nucleus sampling") samples from the smallest set of 547 | // tokens that exceed probability topp. This way we never sample tokens that 548 | // have very low probabilities and are less likely to go "off the rails". 549 | // coin is a random number in [0, 1), usually from random_f32() 550 | 551 | int n0 = 0; 552 | // quicksort indices in descending order of probabilities 553 | // values smaller than (1 - topp) / (n - 1) cannot be part of the result 554 | // so for efficiency we crop these out as candidates before sorting 555 | const float cutoff = (1.0f - topp) / (n - 1); 556 | for (int i = 0; i < n; i++) { 557 | if (probabilities[i] >= cutoff) { 558 | probindex[n0].index = i; 559 | probindex[n0].prob = probabilities[i]; 560 | n0++; 561 | } 562 | } 563 | qsort(probindex, n0, sizeof(ProbIndex), compare); 564 | 565 | // truncate the list where cumulative probability exceeds topp 566 | float cumulative_prob = 0.0f; 567 | int last_idx = n0 - 1; // in case of rounding errors consider all elements 568 | for (int i = 0; i < n0; i++) { 569 | cumulative_prob += probindex[i].prob; 570 | if (cumulative_prob > topp) { 571 | last_idx = i; 572 | break; // we've exceeded topp by including last_idx 573 | } 574 | } 575 | 576 | // sample from the truncated list 577 | float r = coin * cumulative_prob; 578 | float cdf = 0.0f; 579 | for (int i = 0; i <= last_idx; i++) { 580 | cdf += probindex[i].prob; 581 | if (r < cdf) { 582 | return probindex[i].index; 583 | } 584 | } 585 | return probindex[last_idx].index; // in case of rounding errors 586 | } 587 | 588 | void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { 589 | sampler->vocab_size = vocab_size; 590 | sampler->temperature = temperature; 591 | sampler->topp = topp; 592 | sampler->rng_state = rng_seed; 593 | // buffer only used with nucleus sampling; may not need but it's ~small 594 | sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex)); 595 | } 596 | 597 | void free_sampler(Sampler* sampler) { 598 | free(sampler->probindex); 599 | } 600 | 601 | unsigned int random_u32(unsigned long long *state) { 602 | // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A 603 | *state ^= *state >> 12; 604 | *state ^= *state << 25; 605 | *state ^= *state >> 27; 606 | return (*state * 0x2545F4914F6CDD1Dull) >> 32; 607 | } 608 | float random_f32(unsigned long long *state) { // random float32 in [0,1) 609 | return (random_u32(state) >> 8) / 16777216.0f; 610 | } 611 | 612 | int sample(Sampler* sampler, float* logits) { 613 | // sample the token given the logits and some hyperparameters 614 | int next; 615 | if (sampler->temperature == 0.0f) { 616 | // greedy argmax sampling: take the token with the highest probability 617 | next = sample_argmax(logits, sampler->vocab_size); 618 | } else { 619 | // apply the temperature to the logits 620 | for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } 621 | // apply softmax to the logits to get the probabilities for next token 622 | softmax(logits, sampler->vocab_size); 623 | // flip a (float) coin (this is our source of entropy for sampling) 624 | float coin = random_f32(&sampler->rng_state); 625 | // we sample from this distribution to get the next token 626 | if (sampler->topp <= 0 || sampler->topp >= 1) { 627 | // simply sample from the predicted probability distribution 628 | next = sample_mult(logits, sampler->vocab_size, coin); 629 | } else { 630 | // top-p (nucleus) sampling, clamping the least likely tokens to zero 631 | next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); 632 | } 633 | } 634 | return next; 635 | } 636 | 637 | // ---------------------------------------------------------------------------- 638 | // utilities: time 639 | 640 | long time_in_ms() { 641 | // return time in milliseconds, for benchmarking the model speed 642 | struct timespec time; 643 | clock_gettime(CLOCK_REALTIME, &time); 644 | return time.tv_sec * 1000 + time.tv_nsec / 1000000; 645 | } 646 | 647 | // ---------------------------------------------------------------------------- 648 | // generation loop 649 | 650 | 651 | // Update to use attention_already_generated_index 652 | char* generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, char *post_prompt, int steps, char stop_char, bool attention_already_generated) { 653 | // generate a sequence of tokens from the transformer, using the sampler 654 | char *result = malloc(1); // Start with an empty string 655 | size_t result_len = 0; // Keep track of the length 656 | char *empty_prompt = ""; 657 | if (prompt == NULL) { prompt = empty_prompt; } 658 | if (post_prompt == NULL) { post_prompt = empty_prompt; } 659 | 660 | // encode the (string) prompt into tokens sequence 661 | int num_prompt_tokens = 0; 662 | int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS 663 | 664 | encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); 665 | if (num_prompt_tokens < 1) { 666 | fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); 667 | exit(EXIT_FAILURE); 668 | } 669 | 670 | // Encode post prompt separately without BOS 671 | int num_post_prompt_tokens = 0; 672 | int* post_prompt_tokens = (int*)malloc((strlen(post_prompt)+2) * sizeof(int)); // +2 for '\0', ?EOS 673 | encode(tokenizer, post_prompt, 0, 0, post_prompt_tokens, &num_post_prompt_tokens); 674 | if (num_post_prompt_tokens > 0) { 675 | // Update prompt_tokens 676 | prompt_tokens = realloc(prompt_tokens, (num_prompt_tokens + num_post_prompt_tokens) * sizeof(int)); 677 | memcpy(prompt_tokens + num_prompt_tokens, post_prompt_tokens, num_post_prompt_tokens * sizeof(int)); 678 | num_prompt_tokens += num_post_prompt_tokens; 679 | } 680 | 681 | 682 | // Initialize start pos because we already calculated the attention for the prompt 683 | int start_pos ; 684 | int token; 685 | if(attention_already_generated){ 686 | start_pos = num_prompt_tokens - 1; 687 | // Update result and result_len 688 | for (int i = 0; i < num_prompt_tokens - 1; i++) { 689 | char* piece = decode(tokenizer, prompt_tokens[i], prompt_tokens[i + 1]); 690 | size_t piece_len = strlen(piece); 691 | result = realloc(result, result_len + piece_len + 1); // +1 for null terminator 692 | memcpy(result + result_len, piece, piece_len); 693 | result_len += piece_len; 694 | result[result_len] = '\0'; // Ensure null-termination 695 | } 696 | token = prompt_tokens[num_prompt_tokens - 1]; 697 | } else { 698 | start_pos = 0; 699 | token = prompt_tokens[0]; // kick off with the first token in the prompt 700 | } 701 | 702 | // start the main loop 703 | long start = 0; // used to time our code, only initialized after first iteration 704 | int next; // will store the next token in the sequence 705 | //int token = prompt_tokens[0]; // kick off with the first token in the prompt 706 | int prompt_len = num_prompt_tokens - 1; // length of the prompt, excluding BOS 707 | int pos = start_pos; // position in the sequence, excluding BOS 708 | //fprintf(stderr, "pos %d steps %d start_pos %d \n", pos, steps, start_pos); 709 | while (pos < steps + start_pos) { 710 | 711 | // forward the transformer to get logits for the next token 712 | float* logits = forward(transformer, token, pos); 713 | 714 | // advance the state machine 715 | if (pos < num_prompt_tokens - 1) { 716 | // if we are still processing the input prompt, force the next prompt token 717 | next = prompt_tokens[pos + 1]; 718 | } else { 719 | // otherwise sample the next token from the logits 720 | next = sample(sampler, logits); 721 | } 722 | pos++; 723 | 724 | // data-dependent terminating condition: the BOS (=1) token delimits sequences 725 | // Print prompt len, next, pos, stop_token 726 | if (next == 1) { 727 | // Print 728 | //fprintf(stderr, "Prompt len: %d, Next: %d, Pos: %d, Stop Token: %d\n", prompt_len, next, pos, stop_char); 729 | break; 730 | } 731 | 732 | // print the token as string, decode it with the Tokenizer object 733 | char* piece = decode(tokenizer, token, next); 734 | 735 | // If stop_char in piece, break 736 | if (pos > prompt_len && strchr(piece, stop_char) != NULL) { 737 | // Cut the piece at the stop_char 738 | char* stop = strchr(piece, stop_char) + 1; 739 | *stop = '\0'; 740 | // Add the piece to the result 741 | size_t piece_len = strlen(piece); 742 | result = realloc(result, result_len + piece_len + 1); // +1 for null terminator 743 | memcpy(result + result_len, piece, piece_len); 744 | result_len += piece_len; 745 | result[result_len] = '\0'; // Ensure null-termination 746 | token = next; 747 | break; 748 | } 749 | 750 | size_t piece_len = strlen(piece); 751 | result = realloc(result, result_len + piece_len + 1); // +1 for null terminator 752 | memcpy(result + result_len, piece, piece_len); 753 | result_len += piece_len; 754 | result[result_len] = '\0'; // Ensure null-termination 755 | token = next; 756 | 757 | // init the timer here because the first iteration can be slower 758 | if (start == 0) { start = time_in_ms(); } 759 | } 760 | 761 | // report achieved tok/s (pos-1 because the timer starts after first iteration) 762 | if (pos > 1) { 763 | long end = time_in_ms(); 764 | //fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); 765 | } 766 | 767 | free(prompt_tokens); 768 | 769 | return result; // return the generated string 770 | } 771 | 772 | 773 | 774 | void read_stdin(const char* guide, char* buffer, size_t bufsize) { 775 | // read a line from stdin, up to but not including \n 776 | //printf("%s", guide); 777 | if (fgets(buffer, bufsize, stdin) != NULL) { 778 | size_t len = strlen(buffer); 779 | if (len > 0 && buffer[len - 1] == '\n') { 780 | buffer[len - 1] = '\0'; // strip newline 781 | } 782 | } 783 | } 784 | 785 | 786 | // ---------------------------------------------------------------------------- 787 | // CLI, include only if not testing 788 | #ifndef TESTING 789 | 790 | void error_usage() { 791 | fprintf(stderr, "Usage: run [options]\n"); 792 | fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n"); 793 | fprintf(stderr, "Options:\n"); 794 | fprintf(stderr, " -t temperature in [0,inf], default 1.0\n"); 795 | fprintf(stderr, " -p p value in top-p (nucleus) sampling in [0,1] default 0.9\n"); 796 | fprintf(stderr, " -s random seed, default time(NULL)\n"); 797 | fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); 798 | fprintf(stderr, " -i input prompt\n"); 799 | fprintf(stderr, " -z optional path to custom tokenizer\n"); 800 | fprintf(stderr, " -m mode: generate|chat, default: generate\n"); 801 | fprintf(stderr, " -y (optional) system prompt in chat mode\n"); 802 | exit(EXIT_FAILURE); 803 | } 804 | void free_returned_string(char* str) { 805 | free(str); 806 | } 807 | 808 | // Initializes a Transformer and returns a pointer to it 809 | Transformer* create_transformer(char* checkpoint_path) { 810 | Transformer* transformer = malloc(sizeof(Transformer)); 811 | build_transformer(transformer, checkpoint_path); 812 | return transformer; 813 | } 814 | 815 | // Frees a Transformer 816 | void free_transformer_external(Transformer* transformer) { 817 | free_transformer(transformer); 818 | free(transformer); 819 | } 820 | 821 | 822 | char* run_model(Transformer* transformer, bool attention_already_generated, char *tokenizer_path, char stop_char, float temperature, float topp, unsigned long long rng_seed, int steps, char *prompt, char *post_prompt, char *mode, char *system_prompt) { 823 | // parameter validation/overrides 824 | if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL); 825 | if (temperature < 0.0) temperature = 0.0; 826 | if (topp < 0.0 || topp > 1.0) topp = 0.9; 827 | if (steps < 0) steps = 0; 828 | 829 | //printf("running with temperature=%f, topp=%f, seed=%llu, steps=%d\n", temperature, topp, rng_seed, steps); 830 | 831 | Tokenizer tokenizer; 832 | build_tokenizer(&tokenizer, tokenizer_path, transformer->config.vocab_size); 833 | 834 | Sampler sampler; 835 | build_sampler(&sampler, transformer->config.vocab_size, temperature, topp, rng_seed); 836 | 837 | char* generated_text = generate(transformer, &tokenizer, &sampler, prompt, post_prompt, steps, stop_char, attention_already_generated); 838 | 839 | free_sampler(&sampler); 840 | free_tokenizer(&tokenizer); 841 | // Don't free transformer here 842 | return generated_text; 843 | } 844 | #endif 845 | -------------------------------------------------------------------------------- /musiclang_predict/c/run.h: -------------------------------------------------------------------------------- 1 | // run.h 2 | #ifndef RUN_H 3 | #define RUN_H 4 | 5 | #include 6 | #include 7 | 8 | typedef struct { 9 | int dim; 10 | int hidden_dim; 11 | int n_layers; 12 | int n_heads; 13 | int n_kv_heads; 14 | int vocab_size; 15 | int seq_len; 16 | } Config; 17 | 18 | typedef struct { 19 | float* token_embedding_table; 20 | float* rms_att_weight; 21 | float* rms_ffn_weight; 22 | float* wq; 23 | float* wk; 24 | float* wv; 25 | float* wo; 26 | float* w1; 27 | float* w2; 28 | float* w3; 29 | float* rms_final_weight; 30 | float* wcls; 31 | } TransformerWeights; 32 | 33 | typedef struct { 34 | float *x; 35 | float *xb; 36 | float *xb2; 37 | float *hb; 38 | float *hb2; 39 | float *q; 40 | float *k; 41 | float *v; 42 | float *att; 43 | float *logits; 44 | float* key_cache; 45 | float* value_cache; 46 | } RunState; 47 | 48 | typedef struct { 49 | Config config; 50 | TransformerWeights weights; 51 | RunState state; 52 | int fd; 53 | float* data; 54 | ssize_t file_size; 55 | } Transformer; 56 | 57 | typedef struct { 58 | char *str; 59 | int id; 60 | } TokenIndex; 61 | 62 | typedef struct { 63 | char** vocab; 64 | float* vocab_scores; 65 | TokenIndex *sorted_vocab; 66 | int vocab_size; 67 | unsigned int max_token_length; 68 | unsigned char byte_pieces[512]; 69 | } Tokenizer; 70 | 71 | typedef struct { 72 | float prob; 73 | int index; 74 | } ProbIndex; 75 | 76 | typedef struct { 77 | int vocab_size; 78 | ProbIndex* probindex; 79 | float temperature; 80 | float topp; 81 | unsigned long long rng_state; 82 | } Sampler; 83 | 84 | void malloc_run_state(RunState* s, Config* p); 85 | void free_run_state(RunState* s); 86 | void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights); 87 | void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, int* fd, float** data, ssize_t* file_size); 88 | void build_transformer(Transformer *t, char* checkpoint_path); 89 | void free_transformer(Transformer* t); 90 | void rmsnorm(float* o, float* x, float* weight, int size); 91 | void softmax(float* x, int size); 92 | void matmul(float* xout, float* x, float* w, int n, int d); 93 | float* forward(Transformer* transformer, int token, int pos); 94 | void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size); 95 | void free_tokenizer(Tokenizer* t); 96 | char* decode(Tokenizer* t, int prev_token, int token); 97 | void safe_printf(char *piece); 98 | int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size); 99 | void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens); 100 | void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed); 101 | void free_sampler(Sampler* sampler); 102 | //char* generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps, char stop_char, bool attention_already_generated); 103 | void read_stdin(const char* guide, char* buffer, size_t bufsize); 104 | //char* run_model(Transformer* transformer, bool attention_already_generated, char *tokenizer_path, char stop_char, float temperature, float topp, unsigned long long rng_seed, int steps, char *prompt, char *mode, char *system_prompt); 105 | #endif // RUN_H 106 | -------------------------------------------------------------------------------- /musiclang_predict/chelpers.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | from ctypes import c_void_p, c_char_p, c_float, c_ulonglong, c_int, c_char, c_bool 4 | 5 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | 8 | def load_library(): 9 | # Assuming your package structure follows a certain pattern 10 | lib_name = "librun.so" if os.name != "nt" else "run.dll" 11 | lib_path = os.path.join(current_file_path, "c", lib_name) 12 | return ctypes.CDLL(lib_path) 13 | 14 | lib = load_library() 15 | 16 | # Set up function prototypes correctly 17 | lib.create_transformer.restype = c_void_p 18 | lib.create_transformer.argtypes = [c_char_p] 19 | 20 | lib.free_transformer_external.restype = None 21 | lib.free_transformer_external.argtypes = [c_void_p] 22 | 23 | lib.run_model.restype = c_void_p # Corrected to reflect using a transformer pointer 24 | lib.run_model.argtypes = [ 25 | c_void_p, c_bool, c_char_p, c_char, c_float, c_float, 26 | c_ulonglong, c_int, c_char_p, c_char_p, c_char_p, c_char_p 27 | ] 28 | 29 | lib.free_returned_string.restype = None 30 | lib.free_returned_string.argtypes = [c_void_p] 31 | 32 | 33 | 34 | # Function to wrap create_transformer 35 | def create_transformer(checkpoint_path): 36 | trans = lib.create_transformer(checkpoint_path.encode('utf-8')) 37 | return trans 38 | 39 | def free_transformer_external(transformer_ptr): 40 | lib.free_transformer_external(transformer_ptr) 41 | 42 | 43 | def run_transformer_model(transformer_ptr, attention_already_generated, tokenizer_path, temperature=1.0, topp=1.0, rng_seed=0, steps=256, prompt=None, post_prompt=None, mode="generate", system_prompt=None, stop_char="_"): 44 | # Convert string parameters to bytes, if they are not None and are of type str 45 | if prompt is not None and isinstance(prompt, str): 46 | prompt = prompt.encode('utf-8') 47 | if post_prompt is not None and isinstance(post_prompt, str): 48 | post_prompt = post_prompt.encode('utf-8') 49 | if mode is not None and isinstance(mode, str): 50 | mode = mode.encode('utf-8') 51 | if system_prompt is not None and isinstance(system_prompt, str): 52 | system_prompt = system_prompt.encode('utf-8') 53 | stop_char = stop_char.encode('utf-8')[0] if stop_char is not None and isinstance(stop_char, str) else ord(" ") 54 | 55 | # Call the modified C function with attention_already_generated 56 | result_ptr = lib.run_model( 57 | transformer_ptr, attention_already_generated, tokenizer_path.encode('utf-8'), stop_char, temperature, topp, rng_seed, 58 | steps, prompt, post_prompt, mode, system_prompt 59 | ) 60 | 61 | # Convert the result to a Python string, free the returned string 62 | result_str = ctypes.cast(result_ptr, ctypes.c_char_p).value.decode('utf-8') 63 | lib.free_returned_string(result_ptr) 64 | 65 | return result_str 66 | 67 | 68 | 69 | def run_for_n_bars(n_bars, checkpoint_path, tokenizer_path, temperature=1.0, topp=1.0, rng_seed=0, steps=256, 70 | prompt=None, mode="generate", system_prompt=None, 71 | stop_char="_"): 72 | transformer_ptr = create_transformer(checkpoint_path) 73 | 74 | for i in range(n_bars): 75 | generated_text = run_transformer_model( 76 | transformer_ptr=transformer_ptr, 77 | checkpoint_path=checkpoint_path, 78 | tokenizer_path=tokenizer_path, 79 | temperature=temperature, 80 | topp=topp, 81 | rng_seed=rng_seed, 82 | steps=steps, 83 | prompt=prompt, 84 | mode=mode, 85 | system_prompt=system_prompt, 86 | stop_char=stop_char 87 | ) 88 | print(generated_text) 89 | prompt = generated_text 90 | 91 | free_transformer_external(transformer_ptr) 92 | return prompt 93 | 94 | 95 | -------------------------------------------------------------------------------- /musiclang_predict/corpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | current_file_path = os.path.abspath(__file__) 5 | current_file_dir = os.path.dirname(current_file_path) 6 | midi_files = glob.glob(os.path.join(current_file_dir, "corpus", "*.mid")) 7 | 8 | 9 | def list_corpus(): 10 | return [os.path.basename(m).split('.')[0] for m in midi_files] 11 | 12 | 13 | def get_midi_path_from_corpus(name): 14 | return os.path.join(current_file_dir, "corpus", f"{name}.mid") 15 | 16 | -------------------------------------------------------------------------------- /musiclang_predict/corpus/bach_847.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/musiclang_predict/corpus/bach_847.mid -------------------------------------------------------------------------------- /musiclang_predict/corpus/bob_marley_jammin.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/musiclang_predict/corpus/bob_marley_jammin.mid -------------------------------------------------------------------------------- /musiclang_predict/corpus/boney_m_ma_baker.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/musiclang_predict/corpus/boney_m_ma_baker.mid -------------------------------------------------------------------------------- /musiclang_predict/corpus/mozart_alla_turca.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/musiclang_predict/corpus/mozart_alla_turca.mid -------------------------------------------------------------------------------- /musiclang_predict/corpus/white_stripes_seven_nation_army.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/musiclang_predict/corpus/white_stripes_seven_nation_army.mid -------------------------------------------------------------------------------- /musiclang_predict/predict.py: -------------------------------------------------------------------------------- 1 | from musiclang import Score 2 | from musiclang.library import * 3 | 4 | from musiclang_predict.chelpers import run_for_n_bars, run_transformer_model, create_transformer, free_transformer_external 5 | import os 6 | import huggingface_hub 7 | from huggingface_hub import hf_hub_download 8 | from musiclang_predict import MusicLangTokenizer 9 | 10 | STOP_CHAR = None 11 | 12 | 13 | TEST_CHORD = (I % I.M) 14 | 15 | 16 | def get_nb_tokens_chord(tokenizer): 17 | return len(tokenizer.tokenize_to_bytes(TEST_CHORD)) 18 | 19 | MIDI_EXTENSIONS = ['mid', 'midi', 'MID', 'MIDI'] 20 | XML_EXTENSIONS = ['xml', 'mxl', 'musicxml', 'XML', 'MXL', 'MUSICXML'] 21 | 22 | 23 | 24 | class MusicLangPredictor: 25 | 26 | CHORD_CHANGE_CHAR = "_" # FIXME : It should be generic to the tokenizer 27 | MELODY_END_CHAR = "0" 28 | 29 | def __init__(self, path, tokenizer_file="tokenizer.bin", model_file="model.bin"): 30 | self.path = path 31 | self.tokenizer_path = hf_hub_download(repo_id=self.path, filename=tokenizer_file) 32 | self.model_path = hf_hub_download(repo_id=self.path, filename=model_file) 33 | self.pretokenizer = MusicLangTokenizer(self.path) 34 | self.CHORD_CHANGE_CHAR = self.pretokenizer.tokens_to_bytes('CHORD_CHANGE') 35 | self.MELODY_END_CHAR = self.pretokenizer.tokens_to_bytes('MELODY_END') 36 | 37 | self._nb_tokens_chord = get_nb_tokens_chord(self.pretokenizer) 38 | 39 | def parse_score(self, score, prompt_chord_range=None): 40 | # Tokenize the score to bytes 41 | 42 | if isinstance(score, str) and score.split('.')[-1] in MIDI_EXTENSIONS: 43 | # Load score 44 | from musiclang import Score 45 | score = Score.from_midi(score, chord_range=prompt_chord_range) 46 | elif isinstance(score, str) and score.split('.')[-1] in XML_EXTENSIONS: 47 | # Load score 48 | from musiclang import Score 49 | if prompt_chord_range is not None: 50 | raise ValueError("Sorry ... Chord range is not supported yet for XML files, convert it to midi first") 51 | score = Score.from_xml(score) 52 | score = self.pretokenizer.tokenize_to_bytes(score, self.pretokenizer) + self.CHORD_CHANGE_CHAR 53 | return score 54 | 55 | def predict(self, score=None, prompt_chord_range=None, nb_chords=None, nb_tokens: int = 256, temperature=0.9, topp=1.0, rng_seed=0): 56 | """ 57 | Generate a score from a prompt 58 | :param score: (Optional) MusicLang Score, midi or xml file, default None 59 | The prompt used to continue the generation on 60 | :param prompt_chord_range: (Optional) tuple (int, int), default None 61 | Chord range to use for the prompt 62 | :param nb_tokens: (Optional) int, default 256 63 | Number of tokens to generate 64 | :param temperature: (Optional) float, default 0.9 65 | Temperature to use for the generation 66 | :param topp: (Optional) float, default 1.0 67 | Top-p to use for the generation 68 | :param rng_seed: (Optional) int, default 0 69 | Random seed to use for the generation. Use 0 for no seed 70 | :return: MusicLang Score 71 | The generated score 72 | """ 73 | if score is not None: 74 | score = self.parse_score(score, prompt_chord_range) 75 | 76 | transformer_ptr = create_transformer(self.model_path) 77 | if nb_chords is None: 78 | score = run_transformer_model( 79 | transformer_ptr=transformer_ptr, 80 | attention_already_generated=False, 81 | tokenizer_path=self.tokenizer_path, 82 | temperature=temperature, 83 | topp=topp, 84 | rng_seed=rng_seed, 85 | steps=nb_tokens, 86 | prompt=score, 87 | post_prompt=None, 88 | mode="generate", 89 | system_prompt=None, 90 | stop_char=None) 91 | else: 92 | attention_already_generated = False 93 | if score is None: 94 | score = self.CHORD_CHANGE_CHAR 95 | for i in range(nb_chords): 96 | score = run_transformer_model( 97 | transformer_ptr=transformer_ptr, 98 | attention_already_generated=attention_already_generated, 99 | tokenizer_path=self.tokenizer_path, 100 | temperature=temperature, 101 | topp=topp, 102 | rng_seed=rng_seed, 103 | steps=nb_tokens, 104 | prompt=score, 105 | post_prompt=None, 106 | mode="generate", 107 | system_prompt=None, 108 | stop_char=self.CHORD_CHANGE_CHAR) 109 | attention_already_generated = True 110 | # Untokenize to score 111 | score = self.pretokenizer.untokenize_from_bytes(score) 112 | free_transformer_external(transformer_ptr) 113 | return score 114 | 115 | def ts_to_duration(self, ts): 116 | from fractions import Fraction 117 | return Fraction(4 * ts[0], ts[1]) 118 | 119 | @property 120 | def nb_tokens_chord(self): 121 | return self._nb_tokens_chord 122 | 123 | def chord_to_tokens(self, chord): 124 | return self.pretokenizer.tokenize_to_bytes(chord, self.pretokenizer)[1:self.nb_tokens_chord] 125 | 126 | def instrument_to_tokens(self, instrument, voice): 127 | token_instrument = self.pretokenizer.INSTRUMENT_NAME + '__' + instrument 128 | token_voice = self.pretokenizer.INSTRUMENT_PART + '__' + str(voice) 129 | return self.pretokenizer.tokens_to_bytes(" ".join([token_instrument])) 130 | 131 | def fix_instrument_name(self, instrument): 132 | instrument = instrument.lower() 133 | instrument = instrument.replace(" ", "_") 134 | if instrument == "drums": 135 | return "drums_0" 136 | return instrument 137 | 138 | 139 | def midi_to_template(self, midi_file, chord_range=None): 140 | """ 141 | Convert a midi file to a template usable for the predict_chords_and_instruments method 142 | :param midi_file: 143 | :return: 144 | """ 145 | score = Score.from_midi(midi_file, chord_range=chord_range) 146 | time_signature = score.config['time_signature'][1], score.config['time_signature'][2] 147 | chords = score.to_chord_repr().split(' ') 148 | template = [] 149 | for idx, chord in enumerate(chords): 150 | instruments = score[idx].instruments 151 | instruments = [ins.split('__')[0] for ins in instruments] 152 | template.append((chord, instruments)) 153 | 154 | return template, time_signature 155 | 156 | 157 | def predict_chords_and_instruments(self, template, time_signature=(4, 4), score=None, prompt_chord_range=None, nb_tokens: int = 4096, temperature=0.9, topp=1.0, rng_seed=0): 158 | """ 159 | Template is a list of tuple (chords, instrument) 160 | eg : [("EM", ['piano', 'violin']), ("EM", ['piano', 'violin', 'drums_0'])] 161 | 162 | :param template: 163 | :param instruments: 164 | :param time_signature: 165 | :param score: 166 | :param prompt_chord_range: 167 | :param nb_tokens: 168 | :param temperature: 169 | :param topp: 170 | :param rng_seed: 171 | :return: 172 | """ 173 | 174 | transformer_ptr = create_transformer(self.model_path) 175 | chord_duration = self.ts_to_duration(time_signature) 176 | chords = " ".join([chord for chord, instr in template]) 177 | chords = Score.from_chord_repr(chords) 178 | chords = chords.set_duration(chord_duration) 179 | chord_tokens = [self.chord_to_tokens(chord) for chord in chords] 180 | 181 | attention_already_generated = False 182 | if score is not None: 183 | score = self.parse_score(score, prompt_chord_range) 184 | else: 185 | score = self.CHORD_CHANGE_CHAR 186 | for idx in range(len(template)): 187 | chord = chord_tokens[idx] 188 | instruments = template[idx][1] 189 | idx_instruments = {} 190 | for idx_inst, inst in enumerate(instruments): 191 | voice_index = idx_instruments.get(inst, 0) 192 | idx_instruments[inst] = voice_index + 1 193 | post_prompt = "" 194 | if idx_inst == 0: 195 | post_prompt = chord 196 | post_prompt = post_prompt + self.instrument_to_tokens(inst, voice_index) 197 | generated_text = run_transformer_model( 198 | transformer_ptr=transformer_ptr, 199 | attention_already_generated=attention_already_generated, 200 | tokenizer_path=self.tokenizer_path, 201 | temperature=temperature, 202 | topp=topp, 203 | rng_seed=rng_seed, 204 | steps=nb_tokens, 205 | prompt=score, 206 | post_prompt=post_prompt, 207 | mode="generate", 208 | system_prompt=None, 209 | stop_char=self.MELODY_END_CHAR) 210 | 211 | score = generated_text 212 | attention_already_generated = True 213 | 214 | score = self.pretokenizer.untokenize_from_bytes(score) 215 | free_transformer_external(transformer_ptr) 216 | 217 | return score 218 | 219 | def predict_chords(self, chords: str, time_signature=(4, 4), score=None, prompt_chord_range=None, nb_tokens: int = 4096, temperature=0.9, topp=1.0, rng_seed=0): 220 | 221 | transformer_ptr = create_transformer(self.model_path) 222 | chord_duration = self.ts_to_duration(time_signature) 223 | chords = Score.from_chord_repr(chords) 224 | chords = chords.set_duration(chord_duration) 225 | chord_tokens = [self.chord_to_tokens(chord) for chord in chords] 226 | attention_already_generated = False 227 | if score is not None: 228 | score = self.parse_score(score, prompt_chord_range) 229 | else: 230 | score = self.CHORD_CHANGE_CHAR 231 | for chord in chord_tokens: 232 | generated_text = run_transformer_model( 233 | transformer_ptr=transformer_ptr, 234 | attention_already_generated=attention_already_generated, 235 | tokenizer_path=self.tokenizer_path, 236 | temperature=temperature, 237 | topp=topp, 238 | rng_seed=rng_seed, 239 | steps=nb_tokens, 240 | prompt=score, 241 | post_prompt=chord, 242 | mode="generate", 243 | system_prompt=None, 244 | stop_char=self.CHORD_CHANGE_CHAR) 245 | 246 | score = generated_text 247 | attention_already_generated = True 248 | 249 | score = self.pretokenizer.untokenize_from_bytes(score) 250 | free_transformer_external(transformer_ptr) 251 | 252 | return score 253 | 254 | -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .chord_tokenizer import ChordTokenizer, ChordDetokenizer 2 | from .tokenizer import MusicLangTokenizer 3 | from .bpe_tokenizer import MusicLangBPETokenizer 4 | from .template_extractor import midi_file_to_template, score_to_template 5 | -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/bpe_iterator.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class BPEIterator: 5 | r""" 6 | An iterable class to be used when training a musiclang tokenizer with BPE. 7 | 8 | It loads tokens text files to be used with the Hugging Face 9 | tokenizers library to build a vocabulary with BPE. 10 | 11 | It splits the tokens into different sequences using all the control tokens as separator 12 | Eg : CHORD_CHANGE INSTRUMENT_NAME__piano INSTRUMENT_PART__0 NOTE_TYPE_s NOTE_VAL__0 NOTE_OCTAVE__0 ... will be splitted as 13 | ['CHORD_CHANGE', 'INSTRUMENT_NAME__piano', 'INSTRUMENT_PART__0', 'NOTE_TYPE_s NOTE_VAL__0 NOTE_OCTAVE__0 ...'] 14 | separating note tokens from control tokens. 15 | 16 | """ 17 | 18 | def __init__(self, tokenizer, files_paths, control_tokens=[]) -> None: 19 | self.tokenizer = tokenizer 20 | self.files_paths = files_paths 21 | self.control_tokens = control_tokens 22 | self.__iter_count = 0 23 | 24 | def load_file(self, path): 25 | """ 26 | Load a MIDI file and convert it to its byte representation. 27 | 28 | :param path: path to the file to load. 29 | :return: the byte representation of the file. 30 | """ 31 | with open(path, 'r') as f: 32 | text = f.read() 33 | 34 | # list of str (bytes) 35 | bytes_ = self.tokenizer.tokens_to_bytes(text) 36 | bytes_ = bytes_[:8000] 37 | 38 | # Split 39 | split_pattern = '|'.join([re.escape(token) for token in self.control_tokens]) 40 | bytes_ = re.split(f'({split_pattern})', bytes_) 41 | bytes_ = [b for b in bytes_ if len(b) > 0] 42 | return bytes_ 43 | 44 | def __len__(self): 45 | """ 46 | Return the number of files in the training corpus. 47 | 48 | :return: number of files in the training corpus. 49 | """ 50 | return len(self.files_paths) 51 | 52 | def __getitem__(self, idx: int): 53 | """ 54 | Convert the ``idx``th file to its byte representation. 55 | 56 | :param idx: idx of the file to convert. 57 | :return: byte representation of the file. 58 | """ 59 | return self.load_file(self.files_paths[idx]) 60 | 61 | def __iter__(self): # noqa:D105 62 | return self 63 | 64 | def __next__(self) : # noqa:D105 65 | if self.__iter_count >= len(self): 66 | self.__iter_count = 0 67 | raise StopIteration 68 | 69 | self.__iter_count += 1 70 | return self[self.__iter_count - 1] 71 | 72 | def __str__(self): 73 | """ 74 | Return the ``str`` representation of the iterator. 75 | 76 | :return: string description. 77 | """ 78 | return f"{self.tokenizer} - {len(self)} files" 79 | 80 | -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | 4 | 5 | class MusicLangBPETokenizer: 6 | 7 | def __init__(self, tokenizer_path, pretokenizer_path): 8 | from musiclang_predict import MusicLangTokenizer 9 | 10 | self.pretokenizer = MusicLangTokenizer(pretokenizer_path) 11 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 12 | 13 | def tokens_to_bytes(self, tokens): 14 | """ 15 | Convert a list of tokens to a string of bytes. 16 | :param tokens: 17 | :return: 18 | """ 19 | return self.pretokenizer.tokens_to_bytes(tokens) 20 | 21 | def __call__(self, score, **kwargs): 22 | """ 23 | Tokenize the input text. 24 | :param text: input text. 25 | :return: list of tokens. 26 | """ 27 | tokens = self.pretokenizer.tokenize(score) 28 | bytes_ = self.tokens_to_bytes(tokens) 29 | 30 | return self.tokenizer(bytes_, **kwargs).input_ids 31 | 32 | def ids_to_text(self, ids): 33 | bytes_ = self.tokenizer.decode(ids).replace(' ', '') 34 | text = self.pretokenizer.bytes_to_tokens(bytes_) 35 | return text 36 | 37 | 38 | def untokenize(self, ids): 39 | return self.ids_to_score(ids) 40 | 41 | def ids_to_score(self, ids): 42 | """ 43 | Convert a list of tokens to a string of bytes. 44 | :param tokens: 45 | :return: 46 | """ 47 | bytes_ = self.tokenizer.decode(ids).replace(' ', '') 48 | text = self.pretokenizer.bytes_to_tokens(bytes_) 49 | return self.pretokenizer.untokenize(text) -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/chord_tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import itertools 3 | 4 | 5 | CHORD_REGEX = re.compile(r'\([^)]+%[^)]+\)') 6 | 7 | 8 | class ChordTokenizer: 9 | """Chord tokenizer transforms a score in text format to a list of tokens""" 10 | 11 | def tokenize(self, text): 12 | """Tokenize a text (or score) to extract chord tokens 13 | 14 | Parameters 15 | ---------- 16 | text : 17 | Score or str, if Score first convert it to its string representation 18 | 19 | Returns 20 | ------- 21 | tokens: List[str] 22 | List of tokens 23 | 24 | """ 25 | chords = re.findall(CHORD_REGEX, str(text)) 26 | # Deduplicate the chords, because we don't care about rythm here 27 | deduplicated_chords = [k for k, _g in itertools.groupby(chords)] 28 | return deduplicated_chords 29 | 30 | 31 | def tokenize_all(self, texts): 32 | """Tokenize all texts 33 | 34 | Parameters 35 | ---------- 36 | texts : 37 | return: List of List of tokens 38 | 39 | Returns 40 | ------- 41 | tokens_list: List[List[str]] 42 | A list of token per text 43 | 44 | """ 45 | data = [] 46 | for text in texts: 47 | data.append(self.tokenize(text)) 48 | 49 | return data 50 | 51 | def tokenize_file(self, file): 52 | """Tokenize one file 53 | 54 | Parameters 55 | ---------- 56 | file : str 57 | Filepath where to get the tokens 58 | 59 | Returns 60 | ------- 61 | tokens: List[str] 62 | List of tokens 63 | 64 | """ 65 | with open(file, 'r') as f: 66 | return self.tokenize(f.read()) 67 | 68 | def tokenize_directory(self, directory): 69 | """ 70 | Tokenize a directory, call the tokenize_file method for each text in the directory. 71 | 72 | Parameters 73 | ---------- 74 | directory : str 75 | Directory to tokenize 76 | 77 | 78 | Returns 79 | ------- 80 | 81 | data: List[List[tokens]] 82 | A list of token per text 83 | 84 | """ 85 | import os 86 | files = [os.path.join(dirpath, file) 87 | for (dirpath, dirnames, filenames) in os.walk(directory) 88 | for file in filenames] 89 | 90 | data = [] 91 | for file in files: 92 | data.append(self.tokenize_file(file)) 93 | 94 | return data 95 | 96 | 97 | class ChordDetokenizer: 98 | """Convert tokens from chord tokenizer to chords""" 99 | 100 | def detokenize(self, tokens): 101 | """ 102 | Convert a tokens list to a musiclang score. 103 | 104 | Parameters 105 | ---------- 106 | tokens : 107 | 108 | 109 | Returns 110 | ------- 111 | score: musiclang.Score 112 | Score detokenized (only chords) 113 | 114 | """ 115 | from musiclang.write.library import I, II, III, IV, V, VI, VII 116 | return sum(eval('+'.join(tokens)), None) -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/template_extractor.py: -------------------------------------------------------------------------------- 1 | from musiclang import Score 2 | 3 | def midi_file_to_template(midi_file, chord_range=None, max_instruments=8, quantization=(4, 3)): 4 | """ 5 | Extract a song template from a midi file. It will extract the chord progression, the orchestration, 6 | The average density, the average amplitude, the average octave for each instrument of each bar. 7 | 8 | It will also extract metadata about the soundtrack like the tonality, the tempo and the time signature. 9 | :param midi_file: str, path to midi file 10 | :param chord_range: tuple, range of chords to extract (start, end) (default=None) 11 | :param max_instruments: int, maximum number of instruments to extract (default=8) 12 | :return: dict, template 13 | """ 14 | score_prompt = Score.from_midi(midi_file, quantization=quantization, chord_range=chord_range) 15 | return score_to_template(score_prompt, max_instruments=max_instruments) 16 | 17 | 18 | def score_to_template(score, max_instruments=8): 19 | """ 20 | Extract a song template from a musiclang score. It will extract the chord progression, the orchestration, 21 | The average density, the average amplitude, the average octave for each instrument of each bar. 22 | 23 | It will also extract metadata about the soundtrack like the tonality, the tempo and the time signature. 24 | :param score: str, path to midi file 25 | :param chord_range: tuple, range of chords to extract (start, end) (default=None) 26 | :param max_instruments: int, maximum number of instruments to extract (default=8) 27 | :return: dict, template 28 | """ 29 | score_prompt = score.to_score() 30 | densities_per_chords = [chord.to_score().extract_densities() for chord in score_prompt] 31 | amplitudes_per_chords = [chord.to_score().extract_mean_amplitudes() for chord in score_prompt] 32 | octaves_per_chords = [chord.to_score().extract_mean_octaves() for chord in score_prompt] 33 | tonality, chord_list = score_prompt.to_romantext_chord_list() 34 | data_chords = [ 35 | { 36 | "orchestration": [{'instrument_name': key.split('__')[0], 37 | 'instrument_voice': int(key.split('__')[1]), 38 | 'amplitude': amplitudes_per_chords[idx][key], 39 | 'octave': octaves_per_chords[idx][key], 40 | 'density': float(val)} 41 | for key, val in densities_per_chord.items()], 42 | "chord": chord_list[idx] 43 | } 44 | for idx, densities_per_chord in enumerate(densities_per_chords) 45 | 46 | ] 47 | 48 | for data_chord in data_chords: 49 | data_chord['orchestration'] = data_chord['orchestration'][:max_instruments] 50 | 51 | ts = score_prompt.config['time_signature'] 52 | if len(ts) == 4: 53 | ts = ts[1], ts[2] 54 | data = { 55 | 'tonality': tonality, 56 | 'tempo': score_prompt.config['tempo'], 57 | 'time_signature': ts, 58 | 'chords': data_chords 59 | } 60 | return data 61 | -------------------------------------------------------------------------------- /musiclang_predict/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from musiclang import Score, Chord, Note, Melody, Tonality 4 | import os 5 | import tempfile 6 | import os 7 | 8 | import json 9 | from fractions import Fraction as frac 10 | import joblib 11 | import numpy as np 12 | from multiprocessing import Pool 13 | import functools 14 | from huggingface_hub import hf_hub_download 15 | from tqdm import tqdm 16 | import gc 17 | from fractions import Fraction as frac 18 | 19 | from tokenizers import Tokenizer, models, trainers 20 | 21 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast 22 | from tokenizers import Tokenizer 23 | from tokenizers.models import WordLevel 24 | from tokenizers.pre_tokenizers import WhitespaceSplit 25 | from tokenizers.trainers import WordLevelTrainer 26 | 27 | 28 | from musiclang_predict.tokenizers.bpe_iterator import BPEIterator 29 | 30 | NOTE_DURATION_MAX_DENOMINATOR = 8 31 | CHORD_DURATION_MAX_DENOMINATOR = 4 32 | BASE_CHAR_ID = 33 33 | TOKENIZER_CONFIG_BASE = { 34 | "model_max_len": 4096, 35 | "model_max_length": 4096, 36 | "pad_token": "[PAD]", 37 | "pad_token_id": 3, 38 | "eos_token": "[END]", 39 | "eos_token_id": 45, 40 | "bos_token": "[START]", 41 | "bos_token_id": 5, 42 | "mask_token": "[MASK]", 43 | "mask_token_id": 4, 44 | "unk_token": "[UNK]", 45 | "unk_token_id": 0, 46 | "cls_token": "[CLS]", 47 | "cls_token_id": 1, 48 | "sep_token": "[SEP]", 49 | "seo_token_id": 2, 50 | "add_prefix_space": False, 51 | "tokenizer_class": "PreTrainedTokenizer", 52 | "type": "Split" 53 | } 54 | 55 | 56 | default_options = { 57 | 'chord_change_token': True, 58 | 'melody_end_token': True, 59 | 'chord_duration_token': True, 60 | 'density_token': True, 61 | 'chord_extension_token': True, 62 | 'next_chord_token': True, 63 | 'next_chord_duration_token': True, 64 | 'will_end_token': True, 65 | 'dissonance_token': True, 66 | 'amplitude_token': True, 67 | 'average_octave_token': True, 68 | 'voice_token': False, 69 | 'random_instrument_permutation': True, 70 | 'end_token': True, 71 | 'silence_continuation_eco': True 72 | } 73 | class MusicLangTokenizer: 74 | """ 75 | Convert a score into a list of tokens 76 | """ 77 | 78 | SCORE_START = 'SCORE_START' 79 | UNKNOWN = 'UNKNOWN' 80 | 81 | CHORD_DEGREE = 'CHORD_DEGREE' 82 | TONALITY_DEGREE = 'TONALITY_DEGREE' 83 | TONALITY_MODE = 'TONALITY_MODE' 84 | CHORD_OCTAVE = 'CHORD_OCTAVE' 85 | CHORD_DURATION_NUM = 'CHORD_DURATION_NUM' 86 | CHORD_DURATION_DEN = 'CHORD_DURATION_DEN' 87 | 88 | CHORD_EXTENSION = 'CHORD_EXTENSION' 89 | 90 | NEXT_CHORD_DEGREE = 'NEXT_CHORD_DEGREE' 91 | NEXT_TONALITY_DEGREE = 'NEXT_TONALITY_DEGREE' 92 | NEXT_TONALITY_MODE = 'NEXT_TONALITY_MODE' 93 | NEXT_CHORD_OCTAVE = 'NEXT_CHORD_OCTAVE' 94 | NEXT_CHORD_EXTENSION = 'NEXT_CHORD_EXTENSION' 95 | NEXT_CHORD_DURATION_NUM = 'NEXT_CHORD_DURATION_NUM' 96 | NEXT_CHORD_DURATION_DEN = 'NEXT_CHORD_DURATION_DEN' 97 | 98 | CHORD_CHANGE = 'CHORD_CHANGE' 99 | MELODY_END = 'MELODY_END' 100 | WILL_END = 'WILL_END' 101 | DISSONANCE = 'DISSONANCE' 102 | AMPLITUDE = 'AMPLITUDE' 103 | 104 | INSTRUMENT_NAME = 'INSTRUMENT_NAME' 105 | INSTRUMENT_PART = 'INSTRUMENT_PART' 106 | DENSITY = 'DENSITY' 107 | AVERAGE_OCTAVE = 'AVERAGE_OCTAVE' 108 | 109 | NOTE_TYPE = 'NOTE_TYPE' 110 | NOTE_VAL = 'NOTE_VAL' 111 | NOTE_OCTAVE = 'NOTE_OCTAVE' 112 | NOTE_AMP = 'NOTE_AMP' 113 | NOTE_DURATION_NUM = 'NOTE_DURATION_NUM' 114 | NOTE_DURATION_DEN = 'NOTE_DURATION_DEN' 115 | END = 'END' 116 | 117 | def __init__(self, tokenizer_path=None, options=None, hub_tokenizer_path='tokenizer-base.json'): 118 | self.dict = {} 119 | self.tokenizer = None 120 | self.denoms = [i for i in range(1, NOTE_DURATION_MAX_DENOMINATOR + 1)] 121 | if tokenizer_path is None: 122 | import warnings 123 | warnings.warn("No tokenizer_path provided. Using a new tokenizer. You probably should train it using 'train_tokenizer_from_token_files' method.") 124 | else: 125 | try: 126 | with open(tokenizer_path, 'r') as f: 127 | self.dict = json.load(f) 128 | except Exception as e: 129 | tokenizer_path_hub = hf_hub_download(repo_id=tokenizer_path, filename=hub_tokenizer_path) 130 | with open(tokenizer_path_hub, 'r') as f: 131 | self.dict = json.load(f) 132 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 133 | self.denoms = self.get_avalaible_denominators() 134 | # Replace str to int for keys of id_to_token 135 | if options is not None: 136 | self.dict['options'] = options 137 | elif 'options' not in self.dict: 138 | self.dict['options'] = default_options 139 | 140 | @property 141 | def vocab_size(self): 142 | return len(self.tokenizer.vocab_size) 143 | 144 | def __getitem__(self, item): 145 | if isinstance(item, str): 146 | return self.tokenizer.vocab[item] 147 | elif isinstance(item, int): 148 | return self.tokenizer.decode(item) 149 | else: 150 | raise ValueError(f"Invalid type {type(item)} for item {item}") 151 | 152 | def tokenize_midi_file(self, midi_file, quantization=8, fast_chord_inference=True, chord_range=None, interceptors=None): 153 | """ 154 | Tokenize a single midi file and returns list of tokens 155 | Parameters 156 | ---------- 157 | midi_file 158 | quantization 159 | fast_chord_inference 160 | chord_range 161 | 162 | Returns 163 | ------- 164 | 165 | """ 166 | score = Score.from_midi(midi_file, quantization=quantization, fast_chord_inference=fast_chord_inference, chord_range=chord_range) 167 | if interceptors is not None: 168 | for interceptor in interceptors: 169 | score = interceptor(score) 170 | tokens = self.tokenize(score) 171 | return tokens 172 | 173 | def tokenize_midi_files(self, midi_files, quantization=8, fast_chord_inference=True, chord_range=None): 174 | all_tokens = [] 175 | for midi_file in midi_files: 176 | tokens = self.tokenize_midi_file(midi_file, quantization=quantization, fast_chord_inference=fast_chord_inference, chord_range=chord_range) 177 | all_tokens.append(tokens) 178 | return all_tokens 179 | 180 | 181 | def calculate_tokens_to_ids_dict(self, token_files): 182 | 183 | unique_tokens = set() 184 | for token_file in token_files: 185 | with open(token_file, 'r') as f: 186 | tokens = f.read().split('\n') 187 | unique_tokens.update(tokens) 188 | unique_tokens = list(sorted(list(unique_tokens))) 189 | dict = {token: idx for idx, token in enumerate(unique_tokens)} 190 | inv_dict = {idx: token for idx, token in enumerate(unique_tokens)} 191 | self.dict = {'options': self.dict['options']} 192 | 193 | 194 | def tokenize_from_file(self, path): 195 | """ 196 | Tokenize a file and returns a list of tokens 197 | Parameters 198 | ---------- 199 | path: str, path to file 200 | 201 | Returns 202 | ------- 203 | tokens: List[str], list of tokens 204 | """ 205 | with open(path, 'r') as f: 206 | tokens = f.read().split() 207 | return tokens 208 | 209 | def tokenize_sequence(self, seq): 210 | """ 211 | Tokenize a sequence of tokens (A pandas dataframe with appropriate columns 212 | Parameters 213 | ---------- 214 | seq: pd.DataFrame 215 | 216 | Returns 217 | ------- 218 | tokens: str 219 | """ 220 | 221 | score = Score.from_sequence(seq) 222 | return self.tokenize(score) 223 | 224 | 225 | 226 | def train_tokenizer_from_token_files(self, token_files, output=None, hub_output=None, **kwargs): 227 | """ 228 | Train a tokenizer from a list of token files. 229 | It will also save a tokenizer-base.json file with the options used to train the tokenizer 230 | (needed for tokenization from musiclang language) 231 | 232 | Make sure you have logged in to huggingface using `huggingface-cli login` before training and pushing to the hub 233 | Parameters 234 | ---------- 235 | token_files: 236 | output_tokenizer: Path to save the tokenizer 237 | output: Path to save the tokenizer and the config file (Either output or hub_output must be not None) 238 | hub_output: Path to save the tokenizer and the config file in the huggingface hub 239 | 240 | Returns 241 | ------- 242 | tokenizer: Tokenizer 243 | """ 244 | 245 | 246 | def train_tokenizer(data_files, vocab_size=30_000, 247 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[START]", "[END]"]): 248 | # Create a tokenizer 249 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 250 | # Pre-tokenizer to split the text into words 251 | tokenizer.pre_tokenizer = WhitespaceSplit() 252 | # tokenizer.pre_tokenizer = PreTokenizer.custom("whitespace_split", "regex", pattern=r"\s+") 253 | # Special tokens 254 | tokenizer.add_special_tokens(special_tokens) 255 | # Trainer 256 | trainer = WordLevelTrainer( 257 | # vocab_size=vocab_size, 258 | special_tokens=special_tokens, 259 | ) 260 | 261 | # Train the tokenizer 262 | tokenizer.train(files=data_files, trainer=trainer) 263 | return tokenizer 264 | 265 | # Example usage 266 | data_files = token_files # Replace with the path to your text file 267 | tokenizer = train_tokenizer(data_files) 268 | # Save the tokenizer to a temp directory 269 | with tempfile.TemporaryDirectory() as tmpdirname: 270 | tokenizer.save(os.path.join(tmpdirname, "tokenizer.json")) 271 | 272 | # Save TOKENIZER_BASE as tokenizer_config.json in directory 273 | options = copy.deepcopy(TOKENIZER_CONFIG_BASE) 274 | for key, value in kwargs.items(): 275 | options[key] = value 276 | 277 | with open(os.path.join(tmpdirname, 'tokenizer_config.json'), 'w') as f: 278 | json.dump(options, f, indent=4) 279 | # Reload tokenizer to push it or save it 280 | tokenizer = AutoTokenizer.from_pretrained(tmpdirname) 281 | if output is not None: 282 | tokenizer.save_pretrained(output) 283 | # Save base options 284 | with open(os.path.join(output, 'tokenizer-base.json'), 'w') as f: 285 | json.dump({"options": self.dict['options']}, f, indent=4) 286 | if hub_output is not None: 287 | tokenizer.push_to_hub(hub_output) 288 | # Push base options to hub (1. save to a tempfile, 2. then use hf api to push to hub) 289 | with tempfile.TemporaryDirectory() as tmpdirname: 290 | with open(os.path.join(tmpdirname, 'tokenizer-base.json'), 'w') as f: 291 | json.dump({"options": self.dict['options']}, f, indent=4) 292 | from huggingface_hub import HfApi 293 | hf_api = HfApi() 294 | hf_api.upload_file( 295 | path_or_fileobj=os.path.join(tmpdirname, 'tokenizer-base.json'), 296 | path_in_repo="tokenizer-base.json", 297 | repo_id=hub_output, 298 | repo_type="model", 299 | ) 300 | 301 | if output is None and hub_output is None: 302 | # Raise ONLY a warning 303 | print("WARNING: hub_output is None, tokenizer not pushed to hub") 304 | 305 | return tokenizer 306 | 307 | def __call__(self, score, **kwargs): 308 | tokens = " ".join(self.tokenize(score)) 309 | return self.tokenizer(tokens, **kwargs) 310 | 311 | def decode(self, *args, **kwargs): 312 | return self.tokenizer.decode(*args, **kwargs) 313 | 314 | def tokenize_chords(self, score): 315 | return self.tokenize(score, only_chords=True) 316 | 317 | def ids_to_tokens(self, ids): 318 | """ 319 | Convert a list of token ids to a list of tokens 320 | Parameters 321 | ---------- 322 | ids 323 | 324 | Returns 325 | ------- 326 | 327 | """ 328 | return self.tokenizer.convert_ids_to_tokens(ids) 329 | 330 | def ids_to_score(self, ids): 331 | """ 332 | Convert a list of token ids to a score 333 | Parameters 334 | ---------- 335 | ids 336 | 337 | Returns 338 | ------- 339 | 340 | """ 341 | tokens = self.ids_to_tokens(ids) 342 | return self.untokenize(tokens) 343 | 344 | def file_ids_to_score(self, path): 345 | """ 346 | Convert a file with ids to a score 347 | Parameters 348 | ---------- 349 | path: Path to numpy array with ids 350 | 351 | Returns 352 | ------- 353 | 354 | """ 355 | tokens_ids = joblib.load(path) 356 | tokens = self.ids_to_tokens(tokens_ids) 357 | 358 | return self.untokenize(tokens) 359 | 360 | 361 | def file_tokens_to_score(self, path): 362 | """ 363 | Convert a file with tokens to a score 364 | Parameters 365 | ---------- 366 | path: Path to file with tokens 367 | 368 | Returns 369 | ------- 370 | 371 | """ 372 | with open(path, 'r') as f: 373 | tokens = f.read().split('\n') 374 | return self.untokenize(tokens) 375 | 376 | def tokenize_to_ids(self, score, include_end=True): 377 | """ 378 | Tokenize a score and returns a list of token ids 379 | Parameters 380 | ---------- 381 | score: MusicLang.Score, score to tokenize 382 | include_end: bool, if True add the END token at the end of the list (default=True) 383 | 384 | Returns 385 | ------- 386 | ids: List[int], list of token ids 387 | 388 | """ 389 | tokens = self.tokenize(score) 390 | ids = self.tokens_to_ids(tokens) 391 | if not include_end: 392 | ids = ids[:-1] 393 | return ids 394 | 395 | def tokens_to_ids(self, tokens): 396 | """ 397 | Convert a list of tokens to a list of token ids 398 | Parameters 399 | ---------- 400 | tokens 401 | 402 | Returns 403 | ------- 404 | 405 | """ 406 | if self.tokenizer is None: 407 | raise ValueError("No tokens to ids available. You must train the tokenizer first using train_tokenizer_from_token_files method.") 408 | res_with_tokenizer = self.tokenizer(" ".join(tokens)) 409 | 410 | return res_with_tokenizer['input_ids'] 411 | 412 | 413 | def train_bpe(self, files_paths, output_dir=None, hub_path=None, vocab_size=30_000, type='sentence_piece'): 414 | """ 415 | :param files_paths: list of str 416 | Tokens text files paths 417 | :param output_dir: Output of the BPE model, if None, a temporary directory will be created 418 | :param hub_path: str 419 | Path to the hub where to push the tokenizer 420 | :param vocab_size: int 421 | Number of tokens in the BPE model 422 | :param type: str (sentence_piece or bpe) 423 | :return: 424 | """ 425 | 426 | 427 | from tokenizers import SentencePieceBPETokenizer 428 | from tokenizers import normalizers 429 | # Initialize the BPEIterator with your custom tokenizer and list of file paths 430 | control_tokens = self.get_control_tokens_bytes() 431 | bpe_iterator = BPEIterator(self, files_paths, control_tokens=control_tokens) 432 | 433 | # Initialize the Hugging Face tokenizer with a BPE model 434 | if type == 'sentence_piece': 435 | tokenizer = SentencePieceBPETokenizer(add_prefix_space = False) 436 | tokenizer.enable_padding(pad_id=1, pad_token="") 437 | tokenizer.normalizer = normalizers.Sequence([]) 438 | tokenizer.train_from_iterator(bpe_iterator, special_tokens=['', '', '', '', ''], vocab_size=vocab_size, show_progress=True) 439 | elif type == 'bpe': 440 | tokenizer = Tokenizer(models.BPE()) 441 | trainer = trainers.BpeTrainer(vocab_size=vocab_size, show_progress=True, max_token_length=32) 442 | tokenizer.train_from_iterator(bpe_iterator, trainer=trainer) 443 | else: 444 | raise ValueError(f"Invalid type {type}, must be 'sentence_piece' or 'bpe'") 445 | 446 | # Save the trained tokenizer 447 | file = "tokenizer.json" 448 | delete_temp_dir = False 449 | if output_dir is None: 450 | output_dir = tempfile.mkdtemp() 451 | delete_temp_dir = True 452 | 453 | tokenizer_file = os.path.join(output_dir, file) 454 | tokenizer.save(tokenizer_file) 455 | tokenizer = Tokenizer.from_file(tokenizer_file) 456 | tokenizer.model.save(output_dir) 457 | 458 | if hub_path is not None: 459 | tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file, model_max_length=self.tokenizer.model_max_length) 460 | tokenizer.push_to_hub(hub_path) 461 | 462 | # Delete the temporary directory 463 | if delete_temp_dir: 464 | import shutil 465 | shutil.rmtree(output_dir) 466 | 467 | # Push to hub 468 | 469 | return tokenizer 470 | 471 | 472 | def tokens_to_bytes(self, tokens, as_one_str=True): 473 | if not isinstance(tokens, str): 474 | tokens = " ".join(tokens) 475 | ids = self.tokenizer(tokens)['input_ids'] 476 | bytes_ids = [chr(i + BASE_CHAR_ID) for i in ids] 477 | if as_one_str: 478 | return ''.join(bytes_ids) 479 | return bytes_ids 480 | 481 | 482 | def get_control_tokens(self): 483 | """ 484 | Get all the tokens that are control tokens and that should not be part of BPE merges 485 | :return: 486 | """ 487 | 488 | # Control tokens are all tokens that are not note tokens 489 | control_tokens = [voc for voc in self.tokenizer.vocab.keys() if not voc.startswith('NOTE_')] 490 | return control_tokens 491 | 492 | def get_control_tokens_bytes(self): 493 | """ 494 | Get all the tokens that looks like control tokens 495 | :return: 496 | """ 497 | # Control tokens are all tokens that are not note tokens 498 | control_tokens = [chr(i + BASE_CHAR_ID) for i in self.tokenizer.vocab.values() if not self.tokenizer.decode(i).startswith('NOTE_')] 499 | return control_tokens 500 | 501 | def bytes_to_tokens(self, bytes, to_str=True): 502 | ids = [ord(b) - BASE_CHAR_ID for b in bytes] 503 | ids = [id for id in ids if id >= 0] 504 | result = self.ids_to_tokens(ids) 505 | if to_str: 506 | return " ".join(result) 507 | return result 508 | 509 | def save(self, filepath): 510 | with open(filepath, 'w') as f: 511 | json.dump(self.dict, f, indent=4) 512 | 513 | def tokenize_to_files_single(self, midi_files, output_dir, quantization=8, 514 | fast_chord_inference=True, chord_range=None, sep="\n", 515 | allow_error=True 516 | ): 517 | """ 518 | Tokenize a list of midi files and save the tokens to files 519 | Parameters 520 | ---------- 521 | midi_files 522 | output_dir 523 | quantization 524 | fast_chord_inference 525 | chord_range 526 | 527 | Returns 528 | ------- 529 | 530 | """ 531 | for midi_file in tqdm(midi_files, 'Tokenizing files'): 532 | try: 533 | tokens = self.tokenize_midi_file(midi_file, quantization=quantization, fast_chord_inference=fast_chord_inference, chord_range=chord_range) 534 | filename = os.path.basename(midi_file).replace('.mid', '.txt') 535 | output_file = os.path.join(output_dir, filename) 536 | with open(output_file, 'w') as f: 537 | f.write(sep.join(tokens)) 538 | except Exception as e: 539 | if allow_error: 540 | print(f"Error while tokenizing {midi_file}: {e}") 541 | else: 542 | raise e 543 | 544 | 545 | def worker(self, args): 546 | gc.collect() # Force garbage collection at the start of the worker 547 | midi_file, output_dir, quantization, fast_chord_inference, interceptors, chord_range, sep, allow_error = args 548 | try: 549 | tokens = self.tokenize_midi_file(midi_file, quantization=quantization, 550 | fast_chord_inference=fast_chord_inference, 551 | chord_range=chord_range, 552 | interceptors=interceptors 553 | ) 554 | filename = os.path.basename(midi_file).replace('.mid', '.txt') 555 | output_file = os.path.join(output_dir, filename) 556 | with open(output_file, 'w') as f: 557 | f.write(sep.join(tokens)) 558 | except Exception as e: 559 | if allow_error: 560 | print(f"Error while tokenizing {midi_file}: {e}") 561 | else: 562 | raise e 563 | finally: 564 | gc.collect() # Force garbage collection at the end of the worker 565 | 566 | def tokenize_to_files(self, midi_files, output_dir, quantization=8, 567 | fast_chord_inference=True, chord_range=None, sep="\n", 568 | allow_error=True, interceptors=None, num_processes=4 569 | ): 570 | args = [(midi_file, output_dir, quantization, fast_chord_inference, interceptors, chord_range, sep, allow_error) 571 | for midi_file in midi_files] 572 | 573 | with Pool(num_processes, maxtasksperchild=1) as pool: 574 | for _ in tqdm(pool.imap_unordered(self.worker, args), total=len(midi_files), desc='Tokenizing files'): 575 | pass 576 | 577 | 578 | def density_to_density_str(self, density): 579 | # densities = {'low': (0, 0.55), 'medium': (0.55, 1.5), 'high': (1, 2), 'very_high': (2, 10)} 580 | if density < 0.55: 581 | return 'low' 582 | elif density < 1.9: 583 | return 'medium' 584 | elif density < 3: 585 | return 'high' 586 | else: 587 | return 'very_high' 588 | 589 | def tokenize_to_bytes(self, score, as_one_str=True): 590 | tokens = self.tokenize(score) 591 | return self.tokens_to_bytes(tokens, as_one_str=as_one_str) 592 | 593 | def untokenize_from_bytes(self, bytes): 594 | tokens = self.bytes_to_tokens(bytes, to_str=False) 595 | return self.untokenize(tokens) 596 | 597 | 598 | def tokenize(self, score, only_chords=False, for_prompt=False): 599 | if self.dict['options'].get('random_instrument_permutation', False): 600 | import random 601 | instruments = score.instruments 602 | random.shuffle(instruments) 603 | sort_dict = {ins: idx for idx, ins in enumerate(instruments)} 604 | 605 | if isinstance(score, Chord): 606 | score = Score(chords=[score]) 607 | tokens = [] 608 | for idx, chord in enumerate(score): 609 | densities = chord.to_score().extract_densities() 610 | octave_means = chord.to_score().extract_mean_octaves() 611 | amplitude_means = chord.to_score().extract_mean_amplitudes() 612 | 613 | tokens_chord = self.tokenize_chord(chord, only_chords=only_chords) 614 | tokens += tokens_chord 615 | if self.dict['options'].get('next_chord_token', False) and not only_chords: 616 | # Last element 617 | if idx < len(score.chords) - 1: 618 | next_chord = score.chords[idx + 1] 619 | tokens_next_chord = self.tokenize_next_chord(next_chord) 620 | tokens += tokens_next_chord 621 | else: 622 | if self.dict['options'].get('will_end_token', True): 623 | tokens += [self.WILL_END] 624 | 625 | if not only_chords: 626 | 627 | if self.dict['options'].get('random_instrument_permutation', False): 628 | items = sorted(chord.score.items(), key=lambda x: sort_dict[x[0]]) 629 | else: 630 | items = chord.score.items() 631 | 632 | for ins, melody in items: 633 | ins_name, ins_part = ins.split('__') 634 | tokens_ins_name = self.INSTRUMENT_NAME + '__' + ins_name 635 | tokens.append(tokens_ins_name) 636 | if self.dict['options'].get('voice_token', True): 637 | tokens_ins_part = self.INSTRUMENT_PART + '__' + ins_part 638 | tokens.append(tokens_ins_part) 639 | if self.dict['options'].get('density_token', False): 640 | instrument_density = densities[ins] 641 | density_str = self.density_to_density_str(instrument_density) 642 | tokens_density = self.DENSITY + '__' + density_str 643 | tokens += [tokens_density] 644 | if self.dict['options'].get('average_octave_token', False): 645 | tokens_average_octave = self.AVERAGE_OCTAVE + '__' + str(octave_means[ins]) 646 | tokens += [tokens_average_octave] 647 | if self.dict['options'].get('amplitude_token', False): 648 | tokens_amplitude = self.AMPLITUDE + '__' + str(amplitude_means[ins]) 649 | tokens += [tokens_amplitude] 650 | for note in melody: 651 | tokens_note = self.tokenize_note(note) 652 | tokens += tokens_note 653 | if self.dict['options'].get('melody_end_token', False): 654 | tokens.append(self.MELODY_END) 655 | 656 | if self.dict['options'].get('end_token', True): 657 | tokens.append(self.END) 658 | 659 | return tokens 660 | 661 | def tokenize_chord_duration(self, chord_duration): 662 | token_chord_duration_num = self.CHORD_DURATION_NUM + '__' + str(chord_duration.numerator) 663 | token_chord_duration_den = self.CHORD_DURATION_DEN + '__' + str(chord_duration.denominator) 664 | return [token_chord_duration_num, token_chord_duration_den] 665 | 666 | 667 | def get_avalaible_denominators(self): 668 | if self.tokenizer is not None: 669 | words = self.tokenizer.vocab 670 | self.denoms = [int(w.split('__')[1]) for w in words if 'NOTE_DURATION_DEN' in w] 671 | return self.denoms 672 | 673 | def tokenize_note(self, note): 674 | note_type = self.NOTE_TYPE + '__' + note.type 675 | note_degree = self.NOTE_VAL + '__' + str(note.val) 676 | note_octave = self.NOTE_OCTAVE + '__' + str(note.octave) 677 | note_amp = self.NOTE_AMP + '__' + note.amp_figure 678 | 679 | # Limit denominator of duration to 4 680 | note_duration = frac(note.duration).limit_denominator(NOTE_DURATION_MAX_DENOMINATOR) 681 | available_denominators = self.denoms 682 | note_duration_den = min(available_denominators, key=lambda x: abs(note_duration.denominator - x)) 683 | if note_duration.numerator == 0: 684 | return [] 685 | 686 | # Find best approximation using denominator 687 | note_duration = frac(note_duration.numerator, note_duration_den) 688 | note_duration_num = self.NOTE_DURATION_NUM + '__' + str(note_duration.numerator) 689 | note_duration_den = self.NOTE_DURATION_DEN + '__' + str(note_duration.denominator) 690 | # Create the list 691 | tokens = [note_type] 692 | if (note.type not in ['r', 'l']) or not self.dict['options'].get('silence_continuation_eco', False): 693 | tokens += [note_degree, note_octave, note_amp] 694 | 695 | tokens += [note_duration_num, note_duration_den] 696 | 697 | return tokens 698 | 699 | def tokenize_chord(self, chord, only_chords=False): 700 | tokens = [] 701 | if self.dict['options']['chord_change_token']: 702 | tokens.append(self.CHORD_CHANGE) 703 | 704 | chord_degree = self.CHORD_DEGREE + '__' + str(chord.element) 705 | chord_octave = self.CHORD_OCTAVE + '__' + str(chord.full_octave) 706 | chord_extension = self.CHORD_EXTENSION + '__' + str(chord.extension) 707 | tonality_degree = self.TONALITY_DEGREE + '__' + str(chord.tonality.degree) 708 | tonality_mode = self.TONALITY_MODE + '__' + chord.tonality.mode 709 | 710 | # Create the list 711 | tokens += [chord_degree, tonality_degree, tonality_mode, chord_octave] 712 | 713 | if self.dict['options'].get('chord_extension_token', False): 714 | tokens += [chord_extension] 715 | 716 | if self.dict['options'].get('chord_duration_token', False) and not only_chords: 717 | chord_duration = frac(chord.duration).limit_denominator(CHORD_DURATION_MAX_DENOMINATOR) 718 | chord_duration_num = self.CHORD_DURATION_NUM + '__' + str(chord_duration.numerator) 719 | chord_duration_den = self.CHORD_DURATION_DEN + '__' + str(chord_duration.denominator) 720 | tokens += [chord_duration_num, chord_duration_den] 721 | 722 | return tokens 723 | 724 | def tokenize_next_chord(self, chord): 725 | tokens = [] 726 | 727 | chord_degree = self.NEXT_CHORD_DEGREE + '__' + str(chord.element) 728 | chord_octave = self.NEXT_CHORD_OCTAVE + '__' + str(chord.octave) 729 | tonality_degree = self.NEXT_TONALITY_DEGREE + '__' + str(chord.tonality.degree) 730 | tonality_mode = self.NEXT_TONALITY_MODE + '__' + chord.tonality.mode 731 | 732 | tokens += [chord_degree, tonality_degree, tonality_mode, chord_octave] 733 | 734 | if self.dict['options'].get('next_chord_duration_token', True): 735 | chord_duration = frac(chord.duration).limit_denominator(CHORD_DURATION_MAX_DENOMINATOR) 736 | chord_duration_num = self.NEXT_CHORD_DURATION_NUM + '__' + str(chord_duration.numerator) 737 | chord_duration_den = self.NEXT_CHORD_DURATION_DEN + '__' + str(chord_duration.denominator) 738 | tokens += [chord_duration_num, chord_duration_den] 739 | 740 | if self.dict['options']['chord_extension_token']: 741 | chord_extension = self.NEXT_CHORD_EXTENSION + '__' + str(chord.extension) 742 | tokens += [chord_extension] 743 | 744 | return tokens 745 | 746 | 747 | 748 | 749 | 750 | def untokenize(self, tokens): 751 | if isinstance(tokens, str): 752 | # Split by \n or whitespace 753 | tokens = tokens.split() 754 | 755 | score = Score() 756 | current_chord = None 757 | current_melody = None 758 | current_instrument_name = None 759 | current_instrument_part = None 760 | 761 | current_chord_duration_num = None 762 | current_chord_duration_den = None 763 | chord_duration = None 764 | current_melody_duration = 0 765 | note_duration_num = 0 766 | note_duration_den = 0 767 | note_val = 0 768 | note_type = 'r' 769 | note_octave = 0 770 | note_amp = 'mf' 771 | 772 | current_instrument_idx = {} 773 | 774 | for token in tokens: 775 | # Split token into key and value 776 | if token in [self.END, self.CHORD_CHANGE, self.WILL_END, self.SCORE_START]: 777 | continue 778 | 779 | if token == self.MELODY_END: 780 | # Check if melody duration is equal to chord duration, else add a rest 781 | if current_melody_duration < chord_duration: 782 | delta = chord_duration - current_melody_duration 783 | note_duration_num = int(delta.numerator) 784 | note_duration_den = int(delta.denominator) 785 | note_duration = frac(note_duration_num, note_duration_den) 786 | if note_duration > 0: 787 | note = Note(type='r', val=0, octave=0, duration=note_duration) 788 | current_melody.notes.append(note) 789 | # Then continue 790 | current_chord.score[current_instrument_name + '__' + current_instrument_part] = current_melody 791 | continue 792 | 793 | try: 794 | key, value = token.split('__') 795 | except: 796 | raise ValueError(f"Invalid token {token}") 797 | 798 | if key == self.CHORD_DEGREE: 799 | if current_chord is not None: 800 | score.chords.append(current_chord) 801 | current_chord = Chord(element=int(value), tonality=Tonality(0, 'M')) 802 | current_instrument_idx = {} 803 | 804 | elif key == self.CHORD_OCTAVE: 805 | current_chord.octave = int(value) 806 | 807 | elif key == self.CHORD_DURATION_NUM: 808 | current_chord_duration_num = int(value) 809 | 810 | elif key == self.CHORD_DURATION_DEN: 811 | current_chord_duration_den = int(value) 812 | chord_duration = frac(current_chord_duration_num, current_chord_duration_den) 813 | 814 | elif key == self.TONALITY_DEGREE: 815 | current_chord.tonality.degree = int(value) # Assuming Tonality can be constructed from a string 816 | 817 | elif key == self.TONALITY_MODE: 818 | current_chord.tonality.mode = value 819 | 820 | elif key == self.CHORD_EXTENSION: 821 | current_chord.extension = value 822 | 823 | elif key == self.INSTRUMENT_NAME: 824 | current_instrument_name = value 825 | current_melody_duration = 0 826 | note_duration_num = 0 827 | note_duration_den = 0 828 | current_melody = Melody(notes=[]) 829 | 830 | if current_instrument_name not in current_instrument_idx: 831 | current_instrument_idx[current_instrument_name] = 0 832 | else: 833 | current_instrument_idx[current_instrument_name] += 1 834 | 835 | if not self.dict['options'].get('voice_token', False): 836 | current_instrument_part = str(current_instrument_idx[current_instrument_name]) 837 | 838 | elif key == self.INSTRUMENT_PART: 839 | # Assuming that instrument part is not used directly in Melody 840 | current_instrument_part = value 841 | 842 | elif key == self.NOTE_TYPE: 843 | note_type = value 844 | 845 | elif key == self.NOTE_VAL: 846 | note_val = int(value) 847 | 848 | elif key == self.NOTE_OCTAVE: 849 | note_octave = int(value) 850 | 851 | elif key == self.NOTE_AMP: 852 | note_amp = value 853 | 854 | elif key == self.NOTE_DURATION_NUM: 855 | note_duration_num = int(value) 856 | 857 | elif key == self.NOTE_DURATION_DEN: 858 | note_duration_den = int(value) 859 | current_note_duration = frac(note_duration_num, note_duration_den) 860 | if current_melody_duration + current_note_duration > chord_duration: 861 | delta = current_melody_duration + current_note_duration - chord_duration 862 | note_duration_num = int(delta.numerator) 863 | note_duration_den = int(delta.denominator) 864 | note_duration = current_note_duration - frac(note_duration_num, note_duration_den) 865 | if note_duration > 0: 866 | note = Note(type=note_type, val=note_val, octave=note_octave, duration=note_duration) 867 | note = note.set_amp(note_amp) 868 | current_melody.notes.append(note) 869 | current_melody_duration += note_duration 870 | current_chord.score[current_instrument_name + '__' + current_instrument_part] = current_melody 871 | else: 872 | note = Note(type=note_type, val=note_val, octave=note_octave, duration=current_note_duration) 873 | note = note.set_amp(note_amp) 874 | current_melody_duration += current_note_duration 875 | current_melody.notes.append(note) 876 | 877 | current_instrument = current_instrument_name + '__' + current_instrument_part 878 | current_chord.score[current_instrument] = current_melody 879 | 880 | # Add the last chord to the score 881 | if current_chord is not None: 882 | score.chords.append(current_chord) 883 | 884 | return score -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | musiclang>=0.23 2 | sentencepiece 3 | torch 4 | transformers 5 | tokenizers 6 | torchtoolkit 7 | accelerate 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools import setup, Extension, find_packages 3 | from setuptools.command.build_py import build_py 4 | import subprocess 5 | from pathlib import Path 6 | import os 7 | 8 | def custom_command(): 9 | subprocess.check_call(['make', '-C', 'musiclang_predict/c/']) 10 | 11 | 12 | class CustomInstallBuildPy(build_py): 13 | """Custom build command that runs a Makefile.""" 14 | 15 | def run(self): 16 | print('is called install') 17 | custom_command() 18 | # Call the superclass methods to handle Python extension building, if any 19 | build_py.run(self) 20 | 21 | 22 | module = Extension('musiclang_predict.c', 23 | sources=['musiclang_predict/c/run.c'], 24 | include_dirs=[], 25 | extra_compile_args=['-Ofast', '-fPIC', '-shared']) 26 | 27 | 28 | this_directory = Path(__file__).parent 29 | long_description = (this_directory / "README.md").read_text(encoding='utf-8') 30 | 31 | setuptools.setup( 32 | name="musiclang-predict", 33 | version="1.2.0", 34 | author="Florian GARDIN", 35 | author_email="fgardin.pro@gmail.com", 36 | description=("Controllable symbolic music generation with generative AI" 37 | ), 38 | cmdclass={ 39 | 'build_py': CustomInstallBuildPy, 40 | }, 41 | #ext_modules=[module], 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | project_urls={ 45 | 'Documentation': 'https://github.com/MusicLang/musiclang_predict', 46 | 'Source': 'https://github.com/MusicLang/musiclang_predict', 47 | 'Tracker': 'https://github.com/MusicLang/musiclang_predict/issues', 48 | }, 49 | classifiers=[ 50 | "Programming Language :: Python :: 3", 51 | "License :: OSI Approved :: BSD License", 52 | "Operating System :: OS Independent", 53 | ], 54 | install_requires=[ 55 | "musiclang>=0.25", 56 | "torch", 57 | "transformers", 58 | "tokenizers", 59 | "torchtoolkit", 60 | "accelerate" 61 | ], 62 | packages=setuptools.find_packages(include='*'), 63 | package_data={'musiclang_predict': ['c/*.h', 'c/*.c', 'c/*.so', 'c/*.dll', 'c/Makefile', 'corpus/*.mid'], 64 | }, 65 | include_package_data=True, 66 | python_requires=">=3.6", 67 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MusicLang/musiclang_predict/505e8c2fed5aa834619e62abdac16658fc2df711/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from musiclang_predict import MusicLangTokenizer 2 | from musiclang.library import * 3 | 4 | 5 | default_options = { 6 | 'chord_change_token': True, 7 | 'melody_end_token': True, 8 | 'chord_duration_token': True, 9 | 'density_token': True, 10 | 'chord_extension_token': True, 11 | 'next_chord_token': True, 12 | 'next_chord_duration_token': True, 13 | 'will_end_token': True, 14 | 'dissonance_token': True, 15 | 'amplitude_token': True, 16 | 'average_octave_token': True, 17 | 'voice_token': True, 18 | 'random_instrument_permutation': False 19 | } 20 | 21 | def test_tokenizer(): 22 | 23 | pass 24 | 25 | 26 | 27 | def test_cut_bar(): 28 | pass 29 | 30 | def test_add_bar(): 31 | pass --------------------------------------------------------------------------------