├── .gitignore ├── LICENSE ├── README.md ├── download_tokenizer_weights.sh ├── images ├── .gitignore ├── a.jpg ├── b.jpg └── c.jpg ├── imagetokenizer ├── model │ ├── __init__.py │ ├── magvit2.py │ ├── modules │ │ ├── maskgit_vqgan.py │ │ ├── omni_codebook.py │ │ ├── omni_transformer.py │ │ ├── titok_transformer.py │ │ └── vae.py │ ├── omnitokenizer.py │ └── titok.py ├── quantize │ ├── lookup_free_quantize.py │ └── vector_quantize.py ├── utils │ └── omnitokenizer_utils.py └── version.py ├── ps.sh ├── setup.py ├── test_image_tokenizer.py └── upload_pypi.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | build/ 3 | alfred_py.egg-info/ 4 | alfred.egg-info/ 5 | dist/ 6 | build/ 7 | .vscode/ 8 | vendor/ 9 | 10 | *.pyc 11 | a.py 12 | __pycache__/vendor/ 13 | upload_tpi.sh 14 | __pycache__/ 15 | *.egg-info/ 16 | checkpoints/ 17 | results/ 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | alfred Copyright (C) 2021 Lucas Jin 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageTokenizer: Unified Image and Video Tokenization 2 | 3 | Welcome to the **ImageTokenizer** repository! 🎉 This Python package is designed to simplify the process of image and video tokenization, a crucial step for various applications such as image/video generation and understanding. We provide a variety of popular tokenizers with a simple and unified interface, making your coding experience seamless and efficient. 🛠️ 4 | 5 | > ⚠️💡 Note that this project is still in its early stages of development. We welcome any contributions from the community to help us improve and expand the package. Please make sure **star** and **fork** the repository if you find it useful. We are tacking on some awesome applications with `imagetokenizer` such as image/video generation and understanding. Stay tuned! 6 | 7 | 8 | ## Features 9 | 10 | - **Unified Interface**: A consistent API for all supported tokenizers. 11 | - **Extensive Support**: Covers a range of popular image and video tokenizers. 12 | - **Easy Integration**: Quick setup and integration with your projects. 13 | - **Different ImageTokenizers**: Support Magvit2, OmniTokenizer, Titok etc. 14 | 15 | 16 | ## Updates 17 | 18 | - 🔥**2024.06.22**: Titok were supported now! **This most minimal tokens num tokenizer as for now**; 19 | - 🔥**2024.06.22**: OmniTokenizer supported now! 20 | 21 | 22 | ## Supported Tokenizers 23 | 24 | Here's a list of the current supported image tokenizers: 25 | 26 | - **OmniTokenizer**: Versatile tokenizer capable of handling both images and videos. 27 | - **OpenMagvit2**: An open-source version of Magvit2, renowned for its excellent results. 28 | 29 | ## Getting Started 30 | 31 | To get started with ImageTokenizer, follow these simple steps: 32 | 33 | ### Installation 34 | 35 | You can install ImageTokenizer using pip: 36 | 37 | ```bash 38 | pip install imagetokenizer 39 | ``` 40 | 41 | ### Usage 42 | 43 | Here's a quick example of how to use OmniTokenizer: 44 | 45 | ```python 46 | from imagetokenizer import Magvit2Tokenizer 47 | 48 | # Initialize the tokenizer 49 | image_tokenizer = Magvit2Tokenizer() 50 | 51 | # Tokenize an image 52 | quants, embedding, codebook_indices = image_tokenizer.encode("path_to_your_image.jpg") 53 | 54 | # Print the tokens 55 | print(image_tokens) 56 | 57 | image = image_tokenizer.decode(quants) 58 | ``` 59 | 60 | ### Documentation 61 | 62 | For more detailed information and examples, please refer to our [official documentation](#). 63 | 64 | ## Contributing 65 | 66 | We welcome contributions! If you have an idea for a new tokenizer or want to improve existing ones, feel free to submit a pull request or create an issue. 🔧 67 | 68 | ## License 69 | 70 | ImageTokenizer is open-source and available under the [MIT License](LICENSE). 71 | 72 | ## Community 73 | 74 | - Join our [Slack Channel](#) to discuss and collaborate. 75 | - Follow us on [Twitter](#) for updates and news. 76 | 77 | ## Acknowledgements 78 | 79 | We would like to thank all the contributors and the community for their support and feedback. 🙏 80 | -------------------------------------------------------------------------------- /download_tokenizer_weights.sh: -------------------------------------------------------------------------------- 1 | export HF_ENDPOINT=https://hf-mirror.com 2 | 3 | mkdir checkpoints 4 | cd checkpoints 5 | 6 | # download tokenizer weights 7 | huggingface-cli download TencentARC/Open-MAGVIT2 --local-dir magvit2 8 | huggingface-cli download fun-research/TiTok --local-dir titok 9 | 10 | wget $HF_ENDPOINT/Daniel0724/OmniTokenizer/resolve/main/imagenet_sthv2.ckpt -o omni_imagenet_sthv2.ckpt -------------------------------------------------------------------------------- /images/.gitignore: -------------------------------------------------------------------------------- 1 | *_constructed*.png 2 | -------------------------------------------------------------------------------- /images/a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/a.jpg -------------------------------------------------------------------------------- /images/b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/b.jpg -------------------------------------------------------------------------------- /images/c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/c.jpg -------------------------------------------------------------------------------- /imagetokenizer/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .magvit2 import Magvit2Tokenizer 2 | from .omnitokenizer import OmniTokenizer 3 | from .titok import TiTok 4 | -------------------------------------------------------------------------------- /imagetokenizer/model/magvit2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ 5 | for inference only 6 | """ 7 | from collections import OrderedDict 8 | from torch import nn 9 | import torch 10 | from ..quantize.lookup_free_quantize import LFQ 11 | 12 | 13 | class Magvit2Tokenizer(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | resolution=128, 18 | num_down=4, 19 | ### Quantize Related 20 | n_embed=262144, 21 | embed_dim=18, 22 | sample_minimization_weight=1.0, 23 | batch_maximization_weight=1.0, 24 | ckpt_path=None, 25 | ignore_keys=[], 26 | use_ema=False, 27 | token_factorization=False, 28 | ): 29 | super().__init__() 30 | ddconfig = { 31 | "double_z": False, 32 | "z_channels": 18, 33 | "resolution": resolution, 34 | "in_channels": 3, 35 | "out_ch": 3, 36 | "ch": 128, 37 | "ch_mult": [1, 2, 2, 4], # num_down = len(ch_mult)-1 38 | "num_res_blocks": 2, 39 | } 40 | if num_down == 4: 41 | ddconfig["ch_mult"] = [1, 1, 2, 2, 4] # num_down = len(ch_mult)-1 42 | elif num_down == 3: 43 | ddconfig["ch_mult"] = [1, 2, 2, 4] # num_down = len(ch_mult)-1 44 | if ckpt_path and "256" in ckpt_path: 45 | ddconfig["resolution"] = 256 46 | self.use_ema = use_ema 47 | self.encoder = Encoder(**ddconfig) 48 | self.decoder = Decoder(**ddconfig) 49 | self.quantize = LFQ( 50 | dim=embed_dim, 51 | codebook_size=n_embed, 52 | sample_minimization_weight=sample_minimization_weight, 53 | batch_maximization_weight=batch_maximization_weight, 54 | token_factorization=token_factorization, 55 | ) 56 | 57 | if ckpt_path is not None: 58 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, stage=None) 59 | 60 | def init_from_ckpt(self, path, ignore_keys=list(), stage=None): 61 | sd = torch.load(path, map_location="cpu")["state_dict"] 62 | ema_mapping = {} 63 | new_params = OrderedDict() 64 | if stage == "transformer": ### directly use ema encoder and decoder parameter 65 | if self.use_ema: 66 | for k, v in sd.items(): 67 | if "encoder" in k: 68 | if "model_ema" in k: 69 | k = k.replace( 70 | "model_ema.", "" 71 | ) # load EMA Encoder or Decoder 72 | new_k = ema_mapping[k] 73 | new_params[new_k] = v 74 | s_name = k.replace(".", "") 75 | ema_mapping.update({s_name: k}) 76 | continue 77 | if "decoder" in k: 78 | if "model_ema" in k: 79 | k = k.replace( 80 | "model_ema.", "" 81 | ) # load EMA Encoder or Decoder 82 | new_k = ema_mapping[k] 83 | new_params[new_k] = v 84 | s_name = k.replace(".", "") 85 | ema_mapping.update({s_name: k}) 86 | continue 87 | else: # also only load the Generator 88 | for k, v in sd.items(): 89 | if "encoder" in k: 90 | new_params[k] = v 91 | elif "decoder" in k: 92 | new_params[k] = v 93 | missing_keys, unexpected_keys = self.load_state_dict( 94 | new_params, strict=False 95 | ) 96 | else: ## simple resume 97 | missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) 98 | print(f"Restored from {path}") 99 | 100 | def encode(self, x, return_embed_fea=True): 101 | h = self.encoder(x) 102 | # print(f'h {h} {h.shape}') 103 | (quant, emb_loss, info) = self.quantize( 104 | h, return_loss_breakdown=False, return_loss=False 105 | ) 106 | # print(info) 107 | ### using token factorization the info is a tuple (each for embedding) 108 | if return_embed_fea: 109 | return quant, h, info 110 | else: 111 | return quant, emb_loss, info 112 | 113 | def decode(self, quant): 114 | dec = self.decoder(quant) 115 | return dec 116 | 117 | def forward(self, input): 118 | ( 119 | quant, 120 | diff, 121 | _, 122 | ) = self.encode(input) 123 | # print(quant) 124 | # print(f'quant: {quant.shape}, diff: {diff.shape}') 125 | dec = self.decode(quant) 126 | return dec 127 | 128 | 129 | def swish(x): 130 | # swish 131 | return x * torch.sigmoid(x) 132 | 133 | 134 | class ResBlock(nn.Module): 135 | def __init__(self, in_filters, out_filters, use_conv_shortcut=False) -> None: 136 | super().__init__() 137 | 138 | self.in_filters = in_filters 139 | self.out_filters = out_filters 140 | self.use_conv_shortcut = use_conv_shortcut 141 | 142 | self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) 143 | self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) 144 | 145 | self.conv1 = nn.Conv2d( 146 | in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False 147 | ) 148 | self.conv2 = nn.Conv2d( 149 | out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False 150 | ) 151 | 152 | if in_filters != out_filters: 153 | if self.use_conv_shortcut: 154 | self.conv_shortcut = nn.Conv2d( 155 | in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False 156 | ) 157 | else: 158 | self.nin_shortcut = nn.Conv2d( 159 | in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False 160 | ) 161 | 162 | def forward(self, x, **kwargs): 163 | residual = x 164 | 165 | x = self.norm1(x) 166 | x = swish(x) 167 | x = self.conv1(x) 168 | x = self.norm2(x) 169 | x = swish(x) 170 | x = self.conv2(x) 171 | if self.in_filters != self.out_filters: 172 | if self.use_conv_shortcut: 173 | residual = self.conv_shortcut(residual) 174 | else: 175 | residual = self.nin_shortcut(residual) 176 | 177 | return x + residual 178 | 179 | 180 | class Encoder(nn.Module): 181 | def __init__( 182 | self, 183 | *, 184 | ch, 185 | out_ch, 186 | in_channels, 187 | num_res_blocks, 188 | z_channels, 189 | ch_mult=(1, 2, 2, 4), 190 | resolution, 191 | double_z=False, 192 | ): 193 | super().__init__() 194 | 195 | self.in_channels = in_channels 196 | self.z_channels = z_channels 197 | self.resolution = resolution 198 | 199 | self.num_res_blocks = num_res_blocks 200 | self.num_blocks = len(ch_mult) 201 | 202 | self.conv_in = nn.Conv2d( 203 | in_channels, ch, kernel_size=(3, 3), padding=1, bias=False 204 | ) 205 | 206 | ## construct the model 207 | self.down = nn.ModuleList() 208 | 209 | in_ch_mult = (1,) + tuple(ch_mult) 210 | for i_level in range(self.num_blocks): 211 | block = nn.ModuleList() 212 | block_in = ch * in_ch_mult[i_level] # [1, 1, 2, 2, 4] 213 | block_out = ch * ch_mult[i_level] # [1, 2, 2, 4] 214 | for _ in range(self.num_res_blocks): 215 | block.append(ResBlock(block_in, block_out)) 216 | block_in = block_out 217 | 218 | down = nn.Module() 219 | down.block = block 220 | if i_level < self.num_blocks - 1: 221 | down.downsample = nn.Conv2d( 222 | block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1 223 | ) 224 | 225 | self.down.append(down) 226 | 227 | ### mid 228 | self.mid_block = nn.ModuleList() 229 | for res_idx in range(self.num_res_blocks): 230 | self.mid_block.append(ResBlock(block_in, block_in)) 231 | 232 | ### end 233 | self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) 234 | self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1)) 235 | 236 | def forward(self, x): 237 | 238 | ## down 239 | x = self.conv_in(x) 240 | for i_level in range(self.num_blocks): 241 | for i_block in range(self.num_res_blocks): 242 | x = self.down[i_level].block[i_block](x) 243 | 244 | if i_level < self.num_blocks - 1: 245 | x = self.down[i_level].downsample(x) 246 | 247 | ## mid 248 | for res in range(self.num_res_blocks): 249 | x = self.mid_block[res](x) 250 | 251 | x = self.norm_out(x) 252 | x = swish(x) 253 | x = self.conv_out(x) 254 | 255 | return x 256 | 257 | 258 | class Decoder(nn.Module): 259 | def __init__( 260 | self, 261 | *, 262 | ch, 263 | out_ch, 264 | in_channels, 265 | num_res_blocks, 266 | z_channels, 267 | ch_mult=(1, 2, 2, 4), 268 | resolution, 269 | double_z=False, 270 | ) -> None: 271 | super().__init__() 272 | 273 | self.ch = ch 274 | self.num_blocks = len(ch_mult) 275 | self.num_res_blocks = num_res_blocks 276 | self.resolution = resolution 277 | self.in_channels = in_channels 278 | 279 | block_in = ch * ch_mult[self.num_blocks - 1] 280 | 281 | self.conv_in = nn.Conv2d( 282 | z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True 283 | ) 284 | 285 | self.mid_block = nn.ModuleList() 286 | for res_idx in range(self.num_res_blocks): 287 | self.mid_block.append(ResBlock(block_in, block_in)) 288 | 289 | self.up = nn.ModuleList() 290 | 291 | for i_level in reversed(range(self.num_blocks)): 292 | block = nn.ModuleList() 293 | block_out = ch * ch_mult[i_level] 294 | for i_block in range(self.num_res_blocks): 295 | block.append(ResBlock(block_in, block_out)) 296 | block_in = block_out 297 | 298 | up = nn.Module() 299 | up.block = block 300 | if i_level > 0: 301 | up.upsample = Upsampler(block_in) 302 | self.up.insert(0, up) 303 | 304 | self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) 305 | 306 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) 307 | 308 | def forward(self, z): 309 | 310 | z = self.conv_in(z) 311 | 312 | ## mid 313 | for res in range(self.num_res_blocks): 314 | z = self.mid_block[res](z) 315 | 316 | ## upsample 317 | for i_level in reversed(range(self.num_blocks)): 318 | for i_block in range(self.num_res_blocks): 319 | z = self.up[i_level].block[i_block](z) 320 | 321 | if i_level > 0: 322 | z = self.up[i_level].upsample(z) 323 | 324 | z = self.norm_out(z) 325 | z = swish(z) 326 | z = self.conv_out(z) 327 | 328 | return z 329 | 330 | 331 | def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: 332 | """Depth-to-Space DCR mode (depth-column-row) core implementation. 333 | 334 | Args: 335 | x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. 336 | block_size (int): block side size 337 | """ 338 | # check inputs 339 | if x.dim() < 3: 340 | raise ValueError( 341 | f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" 342 | ) 343 | c, h, w = x.shape[-3:] 344 | 345 | s = block_size**2 346 | if c % s != 0: 347 | raise ValueError( 348 | f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" 349 | ) 350 | 351 | outer_dims = x.shape[:-3] 352 | 353 | # splitting two additional dimensions from the channel dimension 354 | x = x.view(-1, block_size, block_size, c // s, h, w) 355 | 356 | # putting the two new dimensions along H and W 357 | x = x.permute(0, 3, 4, 1, 5, 2) 358 | 359 | # merging the two new dimensions with H and W 360 | x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size) 361 | 362 | return x 363 | 364 | 365 | class Upsampler(nn.Module): 366 | def __init__(self, dim, dim_out=None): 367 | super().__init__() 368 | dim_out = dim * 4 369 | self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) 370 | self.depth2space = depth_to_space 371 | 372 | def forward(self, x): 373 | """ 374 | input_image: [B C H W] 375 | """ 376 | out = self.conv1(x) 377 | out = self.depth2space(out, block_size=2) 378 | return out 379 | 380 | 381 | if __name__ == "__main__": 382 | x = torch.randn(size=(2, 3, 128, 128)) 383 | encoder = Encoder( 384 | ch=128, in_channels=3, num_res_blocks=2, z_channels=18, out_ch=3, resolution=128 385 | ) 386 | decoder = Decoder( 387 | out_ch=3, z_channels=18, num_res_blocks=2, ch=128, in_channels=3, resolution=128 388 | ) 389 | z = encoder(x) 390 | out = decoder(z) 391 | -------------------------------------------------------------------------------- /imagetokenizer/model/modules/maskgit_vqgan.py: -------------------------------------------------------------------------------- 1 | """This file contains code for MaskGIT-VQGAN. 2 | 3 | This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 4 | All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 5 | 6 | Reference: 7 | https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py 8 | """ 9 | 10 | # Copyright 2023 Google LLC and The HuggingFace Inc. team. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | 24 | r"""MaskGIT Tokenizer based on VQGAN. 25 | 26 | This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841] 27 | with several modifications. The non-local layers are removed from VQGAN for 28 | faster speed. 29 | """ 30 | 31 | import math 32 | 33 | import torch 34 | import torch.nn.functional as F 35 | from torch import nn 36 | 37 | 38 | # Conv2D with same padding 39 | class Conv2dSame(nn.Conv2d): 40 | def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: 41 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | ih, iw = x.size()[-2:] 45 | 46 | pad_h = self.calc_same_pad( 47 | i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] 48 | ) 49 | pad_w = self.calc_same_pad( 50 | i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] 51 | ) 52 | 53 | if pad_h > 0 or pad_w > 0: 54 | x = F.pad( 55 | x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 56 | ) 57 | return super().forward(x) 58 | 59 | 60 | class ResnetBlock(nn.Module): 61 | def __init__( 62 | self, 63 | in_channels: int, 64 | out_channels: int = None, 65 | dropout_prob: float = 0.0, 66 | ): 67 | super().__init__() 68 | 69 | self.in_channels = in_channels 70 | self.out_channels = out_channels 71 | self.out_channels_ = ( 72 | self.in_channels if self.out_channels is None else self.out_channels 73 | ) 74 | 75 | self.norm1 = nn.GroupNorm( 76 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 77 | ) 78 | self.conv1 = Conv2dSame( 79 | self.in_channels, self.out_channels_, kernel_size=3, bias=False 80 | ) 81 | 82 | self.norm2 = nn.GroupNorm( 83 | num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True 84 | ) 85 | self.dropout = nn.Dropout(dropout_prob) 86 | self.conv2 = Conv2dSame( 87 | self.out_channels_, self.out_channels_, kernel_size=3, bias=False 88 | ) 89 | 90 | if self.in_channels != self.out_channels_: 91 | self.nin_shortcut = Conv2dSame( 92 | self.out_channels_, self.out_channels_, kernel_size=1, bias=False 93 | ) 94 | 95 | def forward(self, hidden_states): 96 | residual = hidden_states 97 | hidden_states = self.norm1(hidden_states) 98 | hidden_states = F.silu(hidden_states) 99 | hidden_states = self.conv1(hidden_states) 100 | 101 | hidden_states = self.norm2(hidden_states) 102 | hidden_states = F.silu(hidden_states) 103 | hidden_states = self.dropout(hidden_states) 104 | hidden_states = self.conv2(hidden_states) 105 | 106 | if self.in_channels != self.out_channels_: 107 | residual = self.nin_shortcut(hidden_states) 108 | 109 | return hidden_states + residual 110 | 111 | 112 | class DownsamplingBlock(nn.Module): 113 | def __init__(self, config, block_idx: int): 114 | super().__init__() 115 | 116 | self.config = config 117 | self.block_idx = block_idx 118 | 119 | in_channel_mult = (1,) + tuple(self.config.channel_mult) 120 | block_in = self.config.hidden_channels * in_channel_mult[self.block_idx] 121 | block_out = ( 122 | self.config.hidden_channels * self.config.channel_mult[self.block_idx] 123 | ) 124 | 125 | res_blocks = nn.ModuleList() 126 | for _ in range(self.config.num_res_blocks): 127 | res_blocks.append( 128 | ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout) 129 | ) 130 | block_in = block_out 131 | self.block = res_blocks 132 | 133 | self.downsample = self.block_idx != self.config.num_resolutions - 1 134 | 135 | def forward(self, hidden_states): 136 | for res_block in self.block: 137 | hidden_states = res_block(hidden_states) 138 | 139 | if self.downsample: 140 | hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) 141 | 142 | return hidden_states 143 | 144 | 145 | class UpsamplingBlock(nn.Module): 146 | def __init__(self, config, block_idx: int): 147 | super().__init__() 148 | 149 | self.config = config 150 | self.block_idx = block_idx 151 | 152 | if self.block_idx == self.config.num_resolutions - 1: 153 | block_in = self.config.hidden_channels * self.config.channel_mult[-1] 154 | else: 155 | block_in = ( 156 | self.config.hidden_channels 157 | * self.config.channel_mult[self.block_idx + 1] 158 | ) 159 | 160 | block_out = ( 161 | self.config.hidden_channels * self.config.channel_mult[self.block_idx] 162 | ) 163 | 164 | res_blocks = [] 165 | for _ in range(self.config.num_res_blocks): 166 | res_blocks.append( 167 | ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout) 168 | ) 169 | block_in = block_out 170 | self.block = nn.ModuleList(res_blocks) 171 | 172 | self.add_upsample = self.block_idx != 0 173 | if self.add_upsample: 174 | self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3) 175 | 176 | def forward(self, hidden_states): 177 | for res_block in self.block: 178 | hidden_states = res_block(hidden_states) 179 | 180 | if self.add_upsample: 181 | hidden_states = F.interpolate( 182 | hidden_states, scale_factor=2.0, mode="nearest" 183 | ) 184 | hidden_states = self.upsample_conv(hidden_states) 185 | 186 | return hidden_states 187 | 188 | 189 | class Encoder(nn.Module): 190 | def __init__(self, config): 191 | super().__init__() 192 | self.config = config 193 | # downsampling 194 | self.conv_in = Conv2dSame( 195 | self.config.num_channels, 196 | self.config.hidden_channels, 197 | kernel_size=3, 198 | bias=False, 199 | ) 200 | 201 | downsample_blocks = [] 202 | for i_level in range(self.config.num_resolutions): 203 | downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level)) 204 | self.down = nn.ModuleList(downsample_blocks) 205 | 206 | # middle 207 | mid_channels = self.config.hidden_channels * self.config.channel_mult[-1] 208 | res_blocks = nn.ModuleList() 209 | for _ in range(self.config.num_res_blocks): 210 | res_blocks.append( 211 | ResnetBlock( 212 | mid_channels, mid_channels, dropout_prob=self.config.dropout 213 | ) 214 | ) 215 | self.mid = res_blocks 216 | 217 | # end 218 | self.norm_out = nn.GroupNorm( 219 | num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True 220 | ) 221 | self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1) 222 | 223 | def forward(self, pixel_values): 224 | # downsampling 225 | hidden_states = self.conv_in(pixel_values) 226 | for block in self.down: 227 | hidden_states = block(hidden_states) 228 | 229 | # middle 230 | for block in self.mid: 231 | hidden_states = block(hidden_states) 232 | 233 | # end 234 | hidden_states = self.norm_out(hidden_states) 235 | hidden_states = F.silu(hidden_states) 236 | hidden_states = self.conv_out(hidden_states) 237 | return hidden_states 238 | 239 | 240 | class Decoder(nn.Module): 241 | def __init__(self, config): 242 | super().__init__() 243 | 244 | self.config = config 245 | 246 | # compute in_channel_mult, block_in and curr_res at lowest res 247 | block_in = ( 248 | self.config.hidden_channels 249 | * self.config.channel_mult[self.config.num_resolutions - 1] 250 | ) 251 | curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) 252 | self.z_shape = (1, self.config.z_channels, curr_res, curr_res) 253 | 254 | # z to block_in 255 | self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3) 256 | 257 | # middle 258 | res_blocks = nn.ModuleList() 259 | for _ in range(self.config.num_res_blocks): 260 | res_blocks.append( 261 | ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout) 262 | ) 263 | self.mid = res_blocks 264 | 265 | # upsampling 266 | upsample_blocks = [] 267 | for i_level in reversed(range(self.config.num_resolutions)): 268 | upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level)) 269 | self.up = nn.ModuleList( 270 | list(reversed(upsample_blocks)) 271 | ) # reverse to get consistent order 272 | 273 | # end 274 | block_out = self.config.hidden_channels * self.config.channel_mult[0] 275 | self.norm_out = nn.GroupNorm( 276 | num_groups=32, num_channels=block_out, eps=1e-6, affine=True 277 | ) 278 | self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3) 279 | 280 | def forward(self, hidden_states): 281 | # z to block_in 282 | hidden_states = self.conv_in(hidden_states) 283 | 284 | # middle 285 | for block in self.mid: 286 | hidden_states = block(hidden_states) 287 | 288 | # upsampling 289 | for block in reversed(self.up): 290 | hidden_states = block(hidden_states) 291 | 292 | # end 293 | hidden_states = self.norm_out(hidden_states) 294 | hidden_states = F.silu(hidden_states) 295 | hidden_states = self.conv_out(hidden_states) 296 | 297 | return hidden_states 298 | 299 | 300 | class VectorQuantizer(nn.Module): 301 | """ 302 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 303 | Discretization bottleneck part of the VQ-VAE. 304 | """ 305 | 306 | def __init__(self, num_embeddings, embedding_dim, commitment_cost): 307 | r""" 308 | Args: 309 | num_embeddings: number of vectors in the quantized space. 310 | embedding_dim: dimensionality of the tensors in the quantized space. 311 | Inputs to the modules must be in this format as well. 312 | commitment_cost: scalar which controls the weighting of the loss terms 313 | (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). 314 | """ 315 | super().__init__() 316 | 317 | self.num_embeddings = num_embeddings 318 | self.embedding_dim = embedding_dim 319 | self.commitment_cost = commitment_cost 320 | 321 | self.embedding = nn.Embedding(num_embeddings, embedding_dim) 322 | self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) 323 | 324 | def forward(self, hidden_states, return_loss=False): 325 | """ 326 | Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the 327 | closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) 328 | quantization pipeline: 329 | 1. get encoder input (B,C,H,W) 330 | 2. flatten input to (B*H*W,C) 331 | """ 332 | # reshape z -> (batch, height, width, channel) and flatten 333 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 334 | 335 | distances = self.compute_distances(hidden_states) 336 | min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) 337 | min_encodings = torch.zeros( 338 | min_encoding_indices.shape[0], self.num_embeddings 339 | ).to(hidden_states) 340 | min_encodings.scatter_(1, min_encoding_indices, 1) 341 | 342 | # get quantized latent vectors 343 | z_q = torch.matmul(min_encodings, self.embedding.weight).view( 344 | hidden_states.shape 345 | ) 346 | 347 | # reshape to (batch, num_tokens) 348 | min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) 349 | 350 | # compute loss for embedding 351 | loss = None 352 | if return_loss: 353 | loss = torch.mean( 354 | (z_q.detach() - hidden_states) ** 2 355 | ) + self.commitment_cost * torch.mean((z_q - hidden_states.detach()) ** 2) 356 | # preserve gradients 357 | z_q = hidden_states + (z_q - hidden_states).detach() 358 | 359 | # reshape back to match original input shape 360 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 361 | 362 | return z_q, min_encoding_indices, loss 363 | 364 | def compute_distances(self, hidden_states): 365 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 366 | hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim)) 367 | emb_weights = self.embedding.weight.t() 368 | 369 | inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True) 370 | codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True) 371 | distances = torch.addmm( 372 | inputs_norm_sq + codebook_t_norm_sq, 373 | hidden_states_flattended, 374 | emb_weights, 375 | alpha=-2.0, 376 | ) 377 | return distances 378 | 379 | def get_codebook_entry(self, indices): 380 | # indices are expected to be of shape (batch, num_tokens) 381 | # get quantized latent vectors 382 | if len(indices.shape) == 2: 383 | batch, num_tokens = indices.shape 384 | z_q = self.embedding(indices) 385 | z_q = z_q.reshape( 386 | batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1 387 | ).permute(0, 3, 1, 2) 388 | elif len(indices.shape) == 3: 389 | batch, height, width = indices.shape 390 | indices = indices.view(batch, -1) 391 | z_q = self.embedding(indices) 392 | z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2) 393 | else: 394 | print(indices.shape) 395 | raise NotImplementedError 396 | return z_q 397 | 398 | # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 399 | def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): 400 | hidden_states = hidden_states.permute( 401 | 0, 2, 3, 1 402 | ).contiguous() # (batch, height, width, channel) 403 | distances = self.compute_distances( 404 | hidden_states 405 | ) # (batch * height * width, num_embeddings) 406 | 407 | soft_code = F.softmax( 408 | -distances / temp, dim=-1 409 | ) # (batch * height * width, num_embeddings) 410 | if stochastic: 411 | code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) 412 | else: 413 | code = distances.argmin(dim=-1) # (batch * height * width) 414 | 415 | code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) 416 | batch, num_tokens = code.shape 417 | soft_code = soft_code.reshape( 418 | batch, num_tokens, -1 419 | ) # (batch, height * width, num_embeddings) 420 | return soft_code, code 421 | 422 | def get_code(self, hidden_states): 423 | # reshape z -> (batch, height, width, channel) 424 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 425 | distances = self.compute_distances(hidden_states) 426 | indices = torch.argmin(distances, axis=1).unsqueeze(1) 427 | indices = indices.reshape(hidden_states.shape[0], -1) 428 | return indices 429 | -------------------------------------------------------------------------------- /imagetokenizer/model/modules/omni_codebook.py: -------------------------------------------------------------------------------- 1 | from enum import unique 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | 9 | from imagetokenizer.utils.omnitokenizer_utils import shift_dim 10 | 11 | 12 | class Codebook(nn.Module): 13 | def __init__( 14 | self, 15 | n_codes, 16 | embedding_dim, 17 | no_random_restart=False, 18 | restart_thres=1.0, 19 | usage_sigma=0.99, 20 | fp32_quant=False, 21 | ): 22 | super().__init__() 23 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) 24 | self.register_buffer("N", torch.zeros(n_codes)) 25 | self.register_buffer("z_avg", self.embeddings.data.clone()) 26 | self.register_buffer("codebook_usage", torch.zeros(n_codes)) 27 | 28 | self.call_cnt = 0 29 | self.usage_sigma = usage_sigma 30 | 31 | self.n_codes = n_codes 32 | self.embedding_dim = embedding_dim 33 | self._need_init = True 34 | self.no_random_restart = no_random_restart 35 | self.restart_thres = restart_thres 36 | 37 | self.fp32_quant = fp32_quant 38 | 39 | def _tile(self, x): 40 | d, ew = x.shape 41 | if d < self.n_codes: 42 | n_repeats = (self.n_codes + d - 1) // d 43 | std = 0.01 / np.sqrt(ew) 44 | x = x.repeat(n_repeats, 1) 45 | x = x + torch.randn_like(x) * std 46 | return x 47 | 48 | def _init_embeddings(self, z): 49 | # z: [b, c, t, h, w] 50 | self._need_init = False 51 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) 52 | y = self._tile(flat_inputs) 53 | 54 | d = y.shape[0] 55 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 56 | if dist.is_initialized(): 57 | dist.broadcast(_k_rand, 0) 58 | self.embeddings.data.copy_(_k_rand) 59 | self.z_avg.data.copy_(_k_rand) 60 | self.N.data.copy_(torch.ones(self.n_codes)) 61 | 62 | def calculate_batch_codebook_usage_percentage(self, batch_encoding_indices): 63 | # Flatten the batch of encoding indices into a single 1D tensor 64 | all_indices = batch_encoding_indices.flatten() 65 | 66 | # Obtain the total number of encoding indices in the batch to calculate percentages 67 | total_indices = all_indices.numel() 68 | 69 | # Initialize a tensor to store the percentage usage of each code 70 | codebook_usage_percentage = torch.zeros(self.n_codes, device=all_indices.device) 71 | 72 | # Count the number of occurrences of each index and get their frequency as percentages 73 | unique_indices, counts = torch.unique(all_indices, return_counts=True) 74 | # Calculate the percentage 75 | percentages = counts.float() / total_indices 76 | 77 | # Populate the corresponding percentages in the codebook_usage_percentage tensor 78 | codebook_usage_percentage[unique_indices.long()] = percentages 79 | 80 | return codebook_usage_percentage 81 | 82 | def forward(self, z): 83 | # z: [b, c, t, h, w] 84 | if self._need_init and self.training: 85 | self._init_embeddings(z) 86 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] 87 | 88 | distances = ( 89 | (flat_inputs**2).sum(dim=1, keepdim=True) 90 | - 2 * flat_inputs @ self.embeddings.t() 91 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) 92 | ) # [bthw, c] 93 | 94 | encoding_indices = torch.argmin(distances, dim=1) 95 | encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( 96 | flat_inputs 97 | ) # [bthw, ncode] 98 | encoding_indices = encoding_indices.view( 99 | z.shape[0], *z.shape[2:] 100 | ) # [b, t, h, w, ncode] 101 | 102 | embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, h, w, c] 103 | embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] 104 | 105 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) 106 | 107 | # EMA codebook update 108 | if self.training: 109 | n_total = encode_onehot.sum(dim=0) 110 | encode_sum = flat_inputs.t() @ encode_onehot 111 | if dist.is_initialized(): 112 | dist.all_reduce(n_total) 113 | dist.all_reduce(encode_sum) 114 | 115 | self.N.data.mul_(0.99).add_(n_total, alpha=0.01) 116 | self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) 117 | 118 | n = self.N.sum() 119 | weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n 120 | encode_normalized = self.z_avg / weights.unsqueeze(1) 121 | self.embeddings.data.copy_(encode_normalized) 122 | 123 | y = self._tile(flat_inputs) 124 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 125 | if dist.is_initialized(): 126 | dist.broadcast(_k_rand, 0) 127 | 128 | if not self.no_random_restart: 129 | usage = (self.N.view(self.n_codes, 1) >= self.restart_thres).float() 130 | self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) 131 | 132 | embeddings_st = (embeddings - z).detach() + z 133 | 134 | avg_probs = torch.mean(encode_onehot, dim=0) 135 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 136 | 137 | try: 138 | usage = self.calculate_batch_codebook_usage_percentage(encoding_indices) 139 | except: 140 | usage = torch.zeros(self.n_codes, device=encoding_indices.device) 141 | 142 | # print(usage.shape, torch.zeros(self.n_codes).shape) 143 | 144 | if self.call_cnt == 0: 145 | self.codebook_usage.data = usage 146 | else: 147 | self.codebook_usage.data = ( 148 | self.usage_sigma * self.codebook_usage.data 149 | + (1 - self.usage_sigma) * usage 150 | ) 151 | 152 | self.call_cnt += 1 153 | # avg_distribution = self.codebook_usage.data.sum() / self.n_codes 154 | avg_usage = (self.codebook_usage.data > (1 / self.n_codes)).sum() / self.n_codes 155 | 156 | return dict( 157 | embeddings=embeddings_st, 158 | encodings=encoding_indices, 159 | commitment_loss=commitment_loss, 160 | perplexity=perplexity, 161 | avg_usage=avg_usage, 162 | batch_usage=usage, 163 | ) 164 | 165 | def dictionary_lookup(self, encodings): 166 | embeddings = F.embedding(encodings, self.embeddings) 167 | return embeddings 168 | -------------------------------------------------------------------------------- /imagetokenizer/model/modules/omni_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from beartype import beartype 6 | from typing import Tuple 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | 12 | 13 | def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor: 14 | # Refer to `Unroll` to see how this performs a maxpool-Nd 15 | # B, N, C 16 | return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values 17 | 18 | 19 | def exists(val): 20 | return val is not None 21 | 22 | 23 | def default(val, d): 24 | return val if exists(val) else d 25 | 26 | 27 | def leaky_relu(p=0.1): 28 | return nn.LeakyReLU(p) 29 | 30 | 31 | def l2norm(t): 32 | return F.normalize(t, dim=-1) 33 | 34 | 35 | def precompute_freqs_cis_2d( 36 | dim: int, end: int, theta: float = 10000.0, scale=1.0, use_cls=False 37 | ): 38 | H = int(end**0.5) 39 | # assert H * H == end 40 | flat_patch_pos = torch.arange(0 if not use_cls else -1, end) # N = end 41 | x_pos = flat_patch_pos % H # N 42 | y_pos = flat_patch_pos // H # N 43 | freqs = 1.0 / ( 44 | theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim) 45 | ) # Hc/4 46 | x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 47 | y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 48 | x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) 49 | y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) 50 | freqs_cis = torch.cat( 51 | [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 52 | ) # N,Hc/4,2 53 | freqs_cis = freqs_cis.reshape(end if not use_cls else end + 1, -1) 54 | # we need to think how to implement this for multi heads. 55 | # freqs_cis = torch.cat([x_cis, y_cis], dim=-1) # N, Hc/2 56 | return freqs_cis 57 | 58 | 59 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 60 | # x: B N H Hc/2 61 | # freqs_cis: N, H*Hc/2 or N Hc/2 62 | ndim = x.ndim 63 | assert 0 <= 1 < ndim 64 | 65 | if freqs_cis.shape[-1] == x.shape[-1]: 66 | shape = [ 67 | 1 if i == 2 or i == 0 else d for i, d in enumerate(x.shape) 68 | ] # 1, N, 1, Hc/2 69 | else: 70 | shape = [d if i != 0 else 1 for i, d in enumerate(x.shape)] # 1, N, H, Hc/2 71 | # B, N, Hc/2 72 | return freqs_cis.view(*shape) 73 | 74 | 75 | def apply_rotary_emb( 76 | xq: torch.Tensor, 77 | xk: torch.Tensor, 78 | freqs_cis: torch.Tensor, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | # xq : B N H Hc 81 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 82 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 83 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 84 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc 85 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 86 | return xq_out.type_as(xq), xk_out.type_as(xk) 87 | 88 | 89 | class LayerNorm(nn.Module): 90 | def __init__(self, dim): 91 | super().__init__() 92 | self.gamma = nn.Parameter(torch.ones(dim)) 93 | self.register_buffer("beta", torch.zeros(dim)) 94 | 95 | def forward(self, x): 96 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 97 | 98 | 99 | class Pooling(nn.Module): 100 | def __init__(self, pool_type, dim): 101 | super().__init__() 102 | if pool_type == "a": 103 | self.pool = nn.AvgPool2d(kernel_size=2) 104 | 105 | elif pool_type == "m": 106 | self.pool = nn.MaxPool2d(kernel_size=2) 107 | 108 | elif pool_type == "l": 109 | self.pool = nn.Linear(4 * dim, dim) 110 | 111 | else: 112 | raise NotImplementedError 113 | 114 | self.pool_type = pool_type 115 | 116 | def forward(self, x): 117 | # B N C 118 | B, N, C = x.shape 119 | if self.pool_type in ["a", "m"]: 120 | H, W = int(math.sqrt(N)), int(math.sqrt(N)) 121 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 122 | x = self.pool(x) 123 | x = x.view(B, C, -1).transpose(1, 2).contiguous() 124 | 125 | else: 126 | x = x.view(B, N // 4, -1) 127 | x = self.pool(x) 128 | 129 | return x 130 | 131 | 132 | class Up(nn.Module): 133 | def __init__(self, up_type, dim): 134 | super().__init__() 135 | if up_type == "n": 136 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 137 | 138 | elif up_type == "r": 139 | self.up = nn.Sequential( 140 | nn.Upsample(scale_factor=2, mode="nearest"), 141 | Rearrange("b c h w -> b (h w) c"), 142 | nn.Linear(dim, dim), 143 | ) 144 | 145 | else: 146 | raise NotImplementedError 147 | 148 | self.up_type = up_type 149 | 150 | def forward(self, x): 151 | # B N C 152 | B, N, C = x.shape 153 | if self.up_type == "n": 154 | H, W = int(math.sqrt(N)), int(math.sqrt(N)) 155 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 156 | x = self.up(x) 157 | x = x.view(B, C, -1).transpose(1, 2).contiguous() 158 | 159 | else: 160 | # x = self.up(x) # B, N, 4c 161 | # x = x.view(B, N * 4, -1) 162 | H, W = int(math.sqrt(N)), int(math.sqrt(N)) 163 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # B, C, H, W 164 | x = self.up(x) # B, (2H 2W), C 165 | 166 | return x 167 | 168 | 169 | class GEGLU(nn.Module): 170 | def forward(self, x): 171 | x, gate = x.chunk(2, dim=-1) 172 | return F.gelu(gate) * x 173 | 174 | 175 | def FeedForward(dim, mult=4, dropout=0.0): 176 | """Check this paper to understand the computation: https://arxiv.org/pdf/2002.05202.pdf""" 177 | inner_dim = int(mult * (2 / 3) * dim) 178 | return nn.Sequential( 179 | nn.LayerNorm(dim), 180 | nn.Linear(dim, inner_dim * 2, bias=False), 181 | GEGLU(), 182 | nn.Dropout(dropout), 183 | nn.Linear(inner_dim, dim, bias=False), 184 | ) 185 | 186 | 187 | def window_partition(x, window_size): 188 | """ 189 | Args: 190 | x: (B, H, W, C) 191 | window_size (int): window size 192 | 193 | Returns: 194 | windows: (num_windows*B, window_size, window_size, C) 195 | """ 196 | B, H, W, C = x.shape 197 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 198 | windows = ( 199 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 200 | ) 201 | return windows 202 | 203 | 204 | def window_reverse(windows, window_size, H, W): 205 | """ 206 | Args: 207 | windows: (num_windows*B, window_size, window_size, C) 208 | window_size (int): Window size 209 | H (int): Height of image 210 | W (int): Width of image 211 | 212 | Returns: 213 | x: (B, H, W, C) 214 | """ 215 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 216 | x = windows.view( 217 | B, H // window_size, W // window_size, window_size, window_size, -1 218 | ) 219 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 220 | return x 221 | 222 | 223 | class WindowAttention(nn.Module): 224 | r"""Window based multi-head self attention (W-MSA) module with relative position bias. 225 | It supports both of shifted and non-shifted window. 226 | 227 | Args: 228 | dim (int): Number of input channels. 229 | window_size (tuple[int]): The height and width of the window. 230 | num_heads (int): Number of attention heads. 231 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 232 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 233 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 234 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 235 | """ 236 | 237 | def __init__( 238 | self, 239 | dim, 240 | window_size, 241 | num_heads, 242 | qkv_bias=False, 243 | qk_scale=None, 244 | attn_drop=0.0, 245 | proj_drop=0.0, 246 | ): 247 | 248 | super().__init__() 249 | self.dim = dim 250 | if isinstance(window_size, int): 251 | window_size = (window_size, window_size) 252 | 253 | self.norm = LayerNorm(dim) 254 | self.window_size = window_size # Wh, Ww 255 | self.num_heads = num_heads 256 | head_dim = dim // num_heads 257 | self.scale = qk_scale or head_dim**-0.5 258 | 259 | # define a parameter table of relative position bias 260 | self.relative_position_bias_table = nn.Parameter( 261 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 262 | ) # 2*Wh-1 * 2*Ww-1, nH 263 | 264 | # get pair-wise relative position index for each token inside the window 265 | coords_h = torch.arange(self.window_size[0]) 266 | coords_w = torch.arange(self.window_size[1]) 267 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 268 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 269 | relative_coords = ( 270 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 271 | ) # 2, Wh*Ww, Wh*Ww 272 | relative_coords = relative_coords.permute( 273 | 1, 2, 0 274 | ).contiguous() # Wh*Ww, Wh*Ww, 2 275 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 276 | relative_coords[:, :, 1] += self.window_size[1] - 1 277 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 278 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 279 | self.register_buffer("relative_position_index", relative_position_index) 280 | 281 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 282 | self.attn_drop = nn.Dropout(attn_drop) 283 | self.proj = nn.Linear(dim, dim) 284 | self.proj_drop = nn.Dropout(proj_drop) 285 | 286 | trunc_normal_(self.relative_position_bias_table, std=0.02) 287 | self.softmax = nn.Softmax(dim=-1) 288 | 289 | def forward(self, x): 290 | """ 291 | Args: 292 | x: input features with shape of (num_windows*B, N, C) 293 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 294 | """ 295 | B_, N, C = x.shape 296 | H, W = int(math.sqrt(N)), int(math.sqrt(N)) 297 | x = self.norm(x) 298 | 299 | x = x.view(B_, H, W, -1) 300 | # partition windows 301 | x_windows = window_partition( 302 | x, self.window_size[0] 303 | ) # nW*B, window_size, window_size, C 304 | x_windows = x_windows.view( 305 | -1, self.window_size[0] * self.window_size[1], C 306 | ) # nW*B, window_size*window_size, C 307 | 308 | BW, NW = x_windows.shape[:2] 309 | 310 | qkv = ( 311 | self.qkv(x_windows) 312 | .reshape(BW, NW, 3, self.num_heads, C // self.num_heads) 313 | .permute(2, 0, 3, 1, 4) 314 | ) 315 | q, k, v = ( 316 | qkv[0], 317 | qkv[1], 318 | qkv[2], 319 | ) # make torchscript happy (cannot use tensor as tuple) 320 | 321 | q = q * self.scale 322 | attn = q @ k.transpose(-2, -1) 323 | 324 | relative_position_bias = self.relative_position_bias_table[ 325 | self.relative_position_index.view(-1) 326 | ].view( 327 | self.window_size[0] * self.window_size[1], 328 | self.window_size[0] * self.window_size[1], 329 | -1, 330 | ) # Wh*Ww,Wh*Ww,nH 331 | relative_position_bias = relative_position_bias.permute( 332 | 2, 0, 1 333 | ).contiguous() # nH, Wh*Ww, Wh*Ww 334 | 335 | attn = attn + relative_position_bias.unsqueeze(0) 336 | attn = self.softmax(attn) 337 | 338 | attn = self.attn_drop(attn) 339 | 340 | x_windows = (attn @ v).transpose(1, 2).reshape(BW, NW, C) 341 | x_windows = self.proj(x_windows) 342 | x_windows = self.proj_drop(x_windows) 343 | 344 | x = window_reverse(x_windows, self.window_size[0], H, W) # B H' W' C 345 | x = x.view(B_, H * W, C) 346 | 347 | return x 348 | 349 | 350 | class PEG(nn.Module): 351 | def __init__(self, dim, causal=False): 352 | super().__init__() 353 | self.causal = causal 354 | self.dsconv = nn.Conv3d(dim, dim, 3, groups=dim) 355 | 356 | @beartype 357 | def forward(self, x, shape: Tuple[int, int, int, int] = None): 358 | needs_shape = x.ndim == 3 359 | assert not (needs_shape and not exists(shape)) 360 | 361 | orig_shape = x.shape 362 | if needs_shape: 363 | x = x.reshape(*shape, -1) 364 | 365 | x = rearrange(x, "b ... d -> b d ...") 366 | 367 | frame_padding = (2, 0) if self.causal else (1, 1) 368 | 369 | x = F.pad(x, (1, 1, 1, 1, *frame_padding), value=0.0) 370 | x = self.dsconv(x) 371 | 372 | x = rearrange(x, "b d ... -> b ... d") 373 | 374 | if needs_shape: 375 | x = rearrange(x, "b ... d -> b (...) d") 376 | 377 | return x.reshape(orig_shape) 378 | 379 | 380 | # attention 381 | 382 | 383 | class Attention(nn.Module): 384 | def __init__( 385 | self, 386 | dim, 387 | dim_context=None, 388 | dim_head=64, 389 | heads=8, 390 | causal=False, 391 | num_null_kv=0, 392 | norm_context=True, 393 | dropout=0.0, 394 | scale=8, 395 | spatial_pos="rel", 396 | ): 397 | super().__init__() 398 | self.heads = heads 399 | self.causal = causal 400 | self.scale = scale 401 | inner_dim = dim_head * heads 402 | dim_context = default(dim_context, dim) 403 | 404 | if spatial_pos == "rel": 405 | self.spatial_rel_pos_bias = ContinuousPositionBias( 406 | dim=dim, heads=heads 407 | ) # HACK this: whether shared pos encoding is better or on the contrary 408 | 409 | self.spatial_pos = spatial_pos 410 | self.freqs_cis = None 411 | 412 | if causal: 413 | self.rel_pos_bias = AlibiPositionalBias(heads=heads) 414 | 415 | self.p_dropout = dropout 416 | self.attn_dropout = nn.Dropout(dropout) 417 | 418 | self.norm = LayerNorm(dim) 419 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity() 420 | 421 | self.num_null_kv = num_null_kv 422 | if self.num_null_kv > 0: 423 | self.null_kv = nn.Parameter(torch.randn(heads, 2 * num_null_kv, dim_head)) 424 | else: 425 | self.null_kv = None 426 | 427 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 428 | self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias=False) 429 | self.dim = inner_dim 430 | 431 | self.q_scale = nn.Parameter(torch.ones(dim_head)) 432 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 433 | 434 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 435 | 436 | def forward( 437 | self, 438 | x, 439 | mask=None, 440 | context=None, 441 | is_spatial=True, 442 | q_stride=1, 443 | ): 444 | batch, device, dtype = x.shape[0], x.device, x.dtype 445 | 446 | if exists(context): 447 | context = self.context_norm(context) 448 | 449 | kv_input = default(context, x) 450 | 451 | x = self.norm(x) 452 | N = x.shape[1] 453 | 454 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1) 455 | q, k, v = map( 456 | lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (q, k, v) 457 | ) 458 | 459 | if self.spatial_pos == "rope" and is_spatial: 460 | if self.freqs_cis is None or self.freqs_cis.shape[0] != N: 461 | self.freqs_cis = precompute_freqs_cis_2d(self.dim // self.heads, N).to( 462 | x.device 463 | ) 464 | 465 | q, k = apply_rotary_emb(q, k, freqs_cis=self.freqs_cis) 466 | 467 | q, k, v = map( 468 | lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) 469 | ) 470 | 471 | B, H, _, D = q.shape 472 | if q_stride > 1: 473 | # Refer to Unroll to see how this performs a maxpool-Nd 474 | q = q.view(B, H, q_stride, -1, D).max(dim=2).values 475 | 476 | if self.num_null_kv > 0: 477 | nk, nv = repeat( 478 | self.null_kv, "h (n r) d -> b h n r d", b=batch, r=2 479 | ).unbind(dim=-2) 480 | 481 | k = torch.cat((nk, k), dim=-2) 482 | v = torch.cat((nv, v), dim=-2) 483 | 484 | q, k = map(l2norm, (q, k)) 485 | q = q * self.q_scale 486 | k = k * self.k_scale 487 | 488 | if hasattr(F, "scaled_dot_product_attention") and torch.__version__ >= "2.1.0": 489 | # Note: the original paper did *not* use SDPA, it's a free boost! 490 | if exists(mask): 491 | mask = F.pad(mask, (self.num_null_kv, 0), value=True) 492 | mask = rearrange(mask, "b j -> b 1 1 j") 493 | 494 | if self.spatial_pos == "rel" and is_spatial: 495 | h, w = int(math.sqrt(N)), int(math.sqrt(N)) 496 | attn_bias = self.spatial_rel_pos_bias(h, w, device=x.device) 497 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.0) 498 | 499 | # query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None 500 | out = F.scaled_dot_product_attention( 501 | q, 502 | k, 503 | v, 504 | attn_mask=mask, 505 | dropout_p=self.p_dropout, 506 | is_causal=self.causal, 507 | scale=self.scale, 508 | ) 509 | 510 | else: 511 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale 512 | i, j = sim.shape[-2:] 513 | if self.spatial_pos == "rel" and is_spatial: 514 | h, w = int(math.sqrt(N)), int(math.sqrt(N)) 515 | attn_bias = self.spatial_rel_pos_bias(h, w, device=x.device) 516 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.0) 517 | 518 | if sim.shape[2] != attn_bias.shape[1]: 519 | # handle q_pooling here 520 | q_len = sim.shape[2] 521 | kv_len = sim.shape[3] 522 | q_stride = kv_len // q_len 523 | attn_bias = attn_bias[:, ::q_stride] 524 | 525 | sim = sim + attn_bias 526 | 527 | if exists(mask): 528 | mask = F.pad(mask, (self.num_null_kv, 0), value=True) 529 | mask = rearrange(mask, "b j -> b 1 1 j") 530 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 531 | 532 | if self.causal: 533 | sim = sim + self.rel_pos_bias(sim) 534 | 535 | causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu( 536 | j - i + 1 537 | ) 538 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 539 | 540 | attn = sim.softmax(dim=-1) 541 | attn = self.attn_dropout(attn) 542 | 543 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 544 | 545 | out = rearrange(out, "b h n d -> b n (h d)") 546 | return self.to_out(out) 547 | 548 | 549 | # alibi positional bias for extrapolation 550 | class AlibiPositionalBias(nn.Module): 551 | def __init__(self, heads): 552 | super().__init__() 553 | self.heads = heads 554 | slopes = torch.Tensor(self._get_slopes(heads)) 555 | slopes = rearrange(slopes, "h -> h 1 1") 556 | self.register_buffer("slopes", slopes, persistent=False) 557 | self.register_buffer("bias", None, persistent=False) 558 | 559 | def get_bias(self, i, j, device): 560 | i_arange = torch.arange(j - i, j, device=device) 561 | j_arange = torch.arange(j, device=device) 562 | bias = -torch.abs( 563 | rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1") 564 | ) 565 | return bias 566 | 567 | @staticmethod 568 | def _get_slopes(heads): 569 | def get_slopes_power_of_2(n): 570 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 571 | ratio = start 572 | return [start * ratio**i for i in range(n)] 573 | 574 | if math.log2(heads).is_integer(): 575 | return get_slopes_power_of_2(heads) 576 | 577 | closest_power_of_2 = 2 ** math.floor(math.log2(heads)) 578 | return ( 579 | get_slopes_power_of_2(closest_power_of_2) 580 | + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ 581 | : heads - closest_power_of_2 582 | ] 583 | ) 584 | 585 | def forward(self, sim): 586 | h, i, j, device = *sim.shape[-3:], sim.device 587 | 588 | if exists(self.bias) and self.bias.shape[-1] >= j: 589 | return self.bias[..., :i, :j] 590 | 591 | bias = self.get_bias(i, j, device) 592 | bias = bias * self.slopes 593 | 594 | num_heads_unalibied = h - bias.shape[0] 595 | bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) 596 | self.register_buffer("bias", bias, persistent=False) 597 | 598 | return self.bias 599 | 600 | 601 | class ContinuousPositionBias(nn.Module): 602 | """from https://arxiv.org/abs/2111.09883""" 603 | 604 | def __init__( 605 | self, 606 | *, 607 | dim, 608 | heads, 609 | num_dims=2, # 2 for images, 3 for video 610 | layers=2, 611 | log_dist=True, 612 | cache_rel_pos=False 613 | ): 614 | super().__init__() 615 | self.num_dims = num_dims 616 | self.log_dist = log_dist 617 | 618 | self.net = nn.ModuleList([]) 619 | self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), leaky_relu())) 620 | 621 | for _ in range(layers - 1): 622 | self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu())) 623 | 624 | self.net.append(nn.Linear(dim, heads)) 625 | 626 | self.cache_rel_pos = cache_rel_pos 627 | self.register_buffer("rel_pos", None, persistent=False) 628 | 629 | def forward(self, *dimensions, device=torch.device("cpu")): 630 | 631 | if not exists(self.rel_pos) or not self.cache_rel_pos: 632 | positions = [torch.arange(d, device=device) for d in dimensions] 633 | grid = torch.stack(torch.meshgrid(*positions, indexing="ij")) 634 | grid = rearrange(grid, "c ... -> (...) c") 635 | rel_pos = rearrange(grid, "i c -> i 1 c") - rearrange(grid, "j c -> 1 j c") 636 | 637 | if self.log_dist: 638 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) 639 | 640 | self.register_buffer("rel_pos", rel_pos, persistent=False) 641 | 642 | rel_pos = self.rel_pos.float() 643 | 644 | for layer in self.net: 645 | rel_pos = layer(rel_pos) 646 | 647 | return rearrange(rel_pos, "i j h -> h i j") 648 | 649 | 650 | # transformer 651 | 652 | 653 | class Transformer(nn.Module): 654 | def __init__( 655 | self, 656 | dim, 657 | *, 658 | depth, 659 | block, 660 | dim_context=None, 661 | causal=False, 662 | dim_head=64, 663 | heads=8, 664 | ff_mult=4, 665 | peg=False, 666 | peg_causal=False, 667 | attn_num_null_kv=2, 668 | has_cross_attn=False, 669 | attn_dropout=0.0, 670 | ff_dropout=0.0, 671 | window_size=4, 672 | spatial_pos="rel" 673 | ): 674 | super().__init__() 675 | assert len(block) == depth 676 | self.layers = nn.ModuleList([]) 677 | for i in range(depth): 678 | if block[i] == "t": 679 | self.layers.append( 680 | nn.ModuleList( 681 | [ 682 | PEG(dim=dim, causal=peg_causal) if peg else None, 683 | Attention( 684 | dim=dim, 685 | dim_head=dim_head, 686 | heads=heads, 687 | causal=causal, 688 | dropout=attn_dropout, 689 | spatial_pos=spatial_pos, 690 | ), 691 | ( 692 | Attention( 693 | dim=dim, 694 | dim_head=dim_head, 695 | dim_context=dim_context, 696 | heads=heads, 697 | causal=False, 698 | num_null_kv=attn_num_null_kv, 699 | dropout=attn_dropout, 700 | ) 701 | if has_cross_attn 702 | else None 703 | ), 704 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout), 705 | ] 706 | ) 707 | ) 708 | 709 | elif block[i] == "w": 710 | self.layers.append( 711 | nn.ModuleList( 712 | [ 713 | None, 714 | WindowAttention( 715 | dim=dim, 716 | window_size=window_size, 717 | num_heads=heads, 718 | attn_drop=attn_dropout, 719 | ), 720 | None, 721 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout), 722 | ] 723 | ) 724 | ) 725 | 726 | # various pooling methods: B, N, C 727 | elif block[i] in ["a", "m", "l"]: 728 | self.layers.append( 729 | nn.ModuleList( 730 | [ 731 | None, 732 | Pooling(block[i], dim), 733 | None, 734 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout), 735 | ] 736 | ) 737 | ) 738 | 739 | elif block[i] in ["n", "r"]: 740 | self.layers.append( 741 | nn.ModuleList( 742 | [ 743 | None, 744 | Up(block[i], dim), 745 | None, 746 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout), 747 | ] 748 | ) 749 | ) 750 | 751 | else: 752 | raise NotImplementedError 753 | 754 | self.block = block 755 | self.norm_out = LayerNorm(dim) 756 | 757 | @beartype 758 | def forward( 759 | self, 760 | x, 761 | video_shape: Tuple[int, int, int, int] = None, 762 | context=None, 763 | self_attn_mask=None, 764 | cross_attn_context_mask=None, 765 | q_strides=None, 766 | is_spatial=True, 767 | ): 768 | 769 | if q_strides is None: 770 | q_strides = "1" * len(self.layers) 771 | 772 | for blk, q_stride, (peg, self_attn, cross_attn, ff) in zip( 773 | self.block, q_strides, self.layers 774 | ): 775 | if exists(peg): 776 | x = peg(x, shape=video_shape) + x 777 | 778 | if isinstance(self_attn, Attention): 779 | x = self_attn( 780 | x, 781 | mask=self_attn_mask, 782 | q_stride=int(q_stride), 783 | is_spatial=is_spatial, 784 | ) + do_pool(x, int(q_stride)) 785 | # x = checkpoint.checkpoint(self_attn, x, self_attn_mask, None, attn_bias, int(q_stride)) 786 | 787 | elif isinstance(self_attn, WindowAttention): 788 | x = self_attn(x) + x 789 | else: 790 | x = self_attn(x) 791 | 792 | if exists(cross_attn) and exists(context): 793 | x = cross_attn(x, context=context, mask=cross_attn_context_mask) + x 794 | 795 | x = ff(x) + x 796 | 797 | # deal with downsampling: 798 | if blk in ["a", "m", "l"]: 799 | video_shape = ( 800 | video_shape[0], 801 | video_shape[1], 802 | video_shape[2] // 2, 803 | video_shape[3] // 2, 804 | ) # video_shape: B, T, H, W 805 | 806 | elif blk in ["n", "r"]: 807 | video_shape = ( 808 | video_shape[0], 809 | video_shape[1], 810 | int(video_shape[2] * 2), 811 | int(video_shape[3] * 2), 812 | ) 813 | 814 | if q_stride != "1": 815 | down_ratio = int(math.sqrt(int(q_stride))) 816 | video_shape = ( 817 | video_shape[0], 818 | video_shape[1], 819 | video_shape[2] // down_ratio, 820 | video_shape[3] // down_ratio, 821 | ) 822 | 823 | return self.norm_out(x) 824 | -------------------------------------------------------------------------------- /imagetokenizer/model/modules/titok_transformer.py: -------------------------------------------------------------------------------- 1 | """Building blocks for TiTok. 2 | 3 | Copyright (2024) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Reference: 18 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | from collections import OrderedDict 24 | 25 | 26 | class ResidualAttentionBlock(nn.Module): 27 | def __init__( 28 | self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm 29 | ): 30 | super().__init__() 31 | 32 | self.ln_1 = norm_layer(d_model) 33 | self.attn = nn.MultiheadAttention(d_model, n_head) 34 | self.mlp_ratio = mlp_ratio 35 | # optionally we can disable the FFN 36 | if mlp_ratio > 0: 37 | self.ln_2 = norm_layer(d_model) 38 | mlp_width = int(d_model * mlp_ratio) 39 | self.mlp = nn.Sequential( 40 | OrderedDict( 41 | [ 42 | ("c_fc", nn.Linear(d_model, mlp_width)), 43 | ("gelu", act_layer()), 44 | ("c_proj", nn.Linear(mlp_width, d_model)), 45 | ] 46 | ) 47 | ) 48 | 49 | def attention(self, x: torch.Tensor): 50 | return self.attn(x, x, x, need_weights=False)[0] 51 | 52 | def forward( 53 | self, 54 | x: torch.Tensor, 55 | ): 56 | attn_output = self.attention(x=self.ln_1(x)) 57 | x = x + attn_output 58 | if self.mlp_ratio > 0: 59 | x = x + self.mlp(self.ln_2(x)) 60 | return x 61 | 62 | 63 | def _expand_token(token, batch_size: int): 64 | return token.unsqueeze(0).expand(batch_size, -1, -1) 65 | 66 | 67 | class TiTokEncoder(nn.Module): 68 | def __init__(self, config): 69 | super().__init__() 70 | self.config = config 71 | self.image_size = config.dataset.preprocessing.crop_size 72 | self.patch_size = config.model.vq_model.vit_enc_patch_size 73 | self.grid_size = self.image_size // self.patch_size 74 | self.model_size = config.model.vq_model.vit_enc_model_size 75 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens 76 | self.token_size = config.model.vq_model.token_size 77 | 78 | self.width = { 79 | "small": 512, 80 | "base": 768, 81 | "large": 1024, 82 | }[self.model_size] 83 | self.num_layers = { 84 | "small": 8, 85 | "base": 12, 86 | "large": 24, 87 | }[self.model_size] 88 | self.num_heads = { 89 | "small": 8, 90 | "base": 12, 91 | "large": 16, 92 | }[self.model_size] 93 | 94 | self.patch_embed = nn.Conv2d( 95 | in_channels=3, 96 | out_channels=self.width, 97 | kernel_size=self.patch_size, 98 | stride=self.patch_size, 99 | bias=True, 100 | ) 101 | 102 | scale = self.width**-0.5 103 | self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) 104 | self.positional_embedding = nn.Parameter( 105 | scale * torch.randn(self.grid_size**2 + 1, self.width) 106 | ) 107 | self.latent_token_positional_embedding = nn.Parameter( 108 | scale * torch.randn(self.num_latent_tokens, self.width) 109 | ) 110 | self.ln_pre = nn.LayerNorm(self.width) 111 | self.transformer = nn.ModuleList() 112 | for i in range(self.num_layers): 113 | self.transformer.append( 114 | ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) 115 | ) 116 | self.ln_post = nn.LayerNorm(self.width) 117 | self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) 118 | 119 | def forward(self, pixel_values, latent_tokens): 120 | batch_size = pixel_values.shape[0] 121 | x = pixel_values 122 | x = self.patch_embed(x) 123 | x = x.reshape(x.shape[0], x.shape[1], -1) 124 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 125 | # class embeddings and positional embeddings 126 | x = torch.cat( 127 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1 128 | ) 129 | x = x + self.positional_embedding.to( 130 | x.dtype 131 | ) # shape = [*, grid ** 2 + 1, width] 132 | 133 | latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) 134 | latent_tokens = latent_tokens + self.latent_token_positional_embedding.to( 135 | x.dtype 136 | ) 137 | x = torch.cat([x, latent_tokens], dim=1) 138 | 139 | x = self.ln_pre(x) 140 | x = x.permute(1, 0, 2) # NLD -> LND 141 | for i in range(self.num_layers): 142 | x = self.transformer[i](x) 143 | x = x.permute(1, 0, 2) # LND -> NLD 144 | 145 | latent_tokens = x[:, 1 + self.grid_size**2 :] 146 | latent_tokens = self.ln_post(latent_tokens) 147 | # fake 2D shape 148 | latent_tokens = latent_tokens.reshape( 149 | batch_size, self.width, self.num_latent_tokens, 1 150 | ) 151 | latent_tokens = self.conv_out(latent_tokens) 152 | latent_tokens = latent_tokens.reshape( 153 | batch_size, self.token_size, 1, self.num_latent_tokens 154 | ) 155 | return latent_tokens 156 | 157 | 158 | class TiTokDecoder(nn.Module): 159 | def __init__(self, config): 160 | super().__init__() 161 | self.config = config 162 | self.image_size = config.dataset.preprocessing.crop_size 163 | self.patch_size = config.model.vq_model.vit_dec_patch_size 164 | self.grid_size = self.image_size // self.patch_size 165 | self.model_size = config.model.vq_model.vit_dec_model_size 166 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens 167 | self.token_size = config.model.vq_model.token_size 168 | self.width = { 169 | "small": 512, 170 | "base": 768, 171 | "large": 1024, 172 | }[self.model_size] 173 | self.num_layers = { 174 | "small": 8, 175 | "base": 12, 176 | "large": 24, 177 | }[self.model_size] 178 | self.num_heads = { 179 | "small": 8, 180 | "base": 12, 181 | "large": 16, 182 | }[self.model_size] 183 | 184 | self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True) 185 | scale = self.width**-0.5 186 | self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) 187 | self.positional_embedding = nn.Parameter( 188 | scale * torch.randn(self.grid_size**2 + 1, self.width) 189 | ) 190 | # add mask token and query pos embed 191 | self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) 192 | self.latent_token_positional_embedding = nn.Parameter( 193 | scale * torch.randn(self.num_latent_tokens, self.width) 194 | ) 195 | self.ln_pre = nn.LayerNorm(self.width) 196 | self.transformer = nn.ModuleList() 197 | for i in range(self.num_layers): 198 | self.transformer.append( 199 | ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) 200 | ) 201 | self.ln_post = nn.LayerNorm(self.width) 202 | 203 | self.ffn = nn.Sequential( 204 | nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), 205 | nn.Tanh(), 206 | nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), 207 | ) 208 | self.conv_out = nn.Identity() 209 | 210 | def forward(self, z_quantized): 211 | N, C, H, W = z_quantized.shape 212 | assert ( 213 | H == 1 and W == self.num_latent_tokens 214 | ), f"{H}, {W}, {self.num_latent_tokens}" 215 | x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD 216 | x = self.decoder_embed(x) 217 | 218 | batchsize, seq_len, _ = x.shape 219 | 220 | mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to( 221 | x.dtype 222 | ) 223 | mask_tokens = torch.cat( 224 | [ 225 | _expand_token(self.class_embedding, mask_tokens.shape[0]).to( 226 | mask_tokens.dtype 227 | ), 228 | mask_tokens, 229 | ], 230 | dim=1, 231 | ) 232 | mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) 233 | x = x + self.latent_token_positional_embedding[:seq_len] 234 | x = torch.cat([mask_tokens, x], dim=1) 235 | 236 | x = self.ln_pre(x) 237 | x = x.permute(1, 0, 2) # NLD -> LND 238 | for i in range(self.num_layers): 239 | x = self.transformer[i](x) 240 | x = x.permute(1, 0, 2) # LND -> NLD 241 | x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed 242 | x = self.ln_post(x) 243 | # N L D -> N D H W 244 | x = x.permute(0, 2, 1).reshape( 245 | batchsize, self.width, self.grid_size, self.grid_size 246 | ) 247 | x = self.ffn(x.contiguous()) 248 | x = self.conv_out(x) 249 | return x 250 | -------------------------------------------------------------------------------- /imagetokenizer/model/modules/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class DiagonalGaussianDistribution(object): 6 | def __init__(self, parameters, deterministic=False): 7 | self.parameters = parameters 8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 10 | self.deterministic = deterministic 11 | self.std = torch.exp(0.5 * self.logvar) 12 | self.var = torch.exp(self.logvar) 13 | if self.deterministic: 14 | self.var = self.std = torch.zeros_like(self.mean).to( 15 | device=self.parameters.device 16 | ) 17 | 18 | def sample(self): 19 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 20 | device=self.parameters.device 21 | ) 22 | return x 23 | 24 | def kl(self, other=None): 25 | if self.deterministic: 26 | return torch.Tensor([0.0]) 27 | else: 28 | if other is None: 29 | return 0.5 * torch.sum( 30 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 31 | dim=[1, 2, 3], 32 | ) 33 | else: 34 | return 0.5 * torch.sum( 35 | torch.pow(self.mean - other.mean, 2) / other.var 36 | + self.var / other.var 37 | - 1.0 38 | - self.logvar 39 | + other.logvar, 40 | dim=[1, 2, 3], 41 | ) 42 | 43 | def nll(self, sample, dims=[1, 2, 3]): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | logtwopi = np.log(2.0 * np.pi) 47 | return 0.5 * torch.sum( 48 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 49 | dim=dims, 50 | ) 51 | 52 | def mode(self): 53 | return self.mean 54 | 55 | 56 | def normal_kl(mean1, logvar1, mean2, logvar2): 57 | """ 58 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 59 | Compute the KL divergence between two gaussians. 60 | Shapes are automatically broadcasted, so batches can be compared to 61 | scalars, among other use cases. 62 | """ 63 | tensor = None 64 | for obj in (mean1, logvar1, mean2, logvar2): 65 | if isinstance(obj, torch.Tensor): 66 | tensor = obj 67 | break 68 | assert tensor is not None, "at least one argument must be a Tensor" 69 | 70 | # Force variances to be Tensors. Broadcasting helps convert scalars to 71 | # Tensors, but it does not work for torch.exp(). 72 | logvar1, logvar2 = [ 73 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 74 | for x in (logvar1, logvar2) 75 | ] 76 | 77 | return 0.5 * ( 78 | -1.0 79 | + logvar2 80 | - logvar1 81 | + torch.exp(logvar1 - logvar2) 82 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 83 | ) 84 | -------------------------------------------------------------------------------- /imagetokenizer/model/titok.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from typing import Mapping, Text, Tuple 5 | import os 6 | 7 | import torch 8 | from einops import rearrange 9 | from torch.cuda.amp import autocast 10 | 11 | from .modules.titok_transformer import TiTokEncoder, TiTokDecoder 12 | from .modules.maskgit_vqgan import Decoder as Pixel_Decoder 13 | from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer 14 | from omegaconf import OmegaConf 15 | from easydict import EasyDict as edict 16 | 17 | 18 | class TiTok(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | config = { 22 | "experiment": { 23 | "tokenizer_checkpoint": "tokenizer_titok_l32.bin", 24 | "generator_checkpoint": "generator_titok_l32.bin", 25 | }, 26 | "model": { 27 | "vq_model": { 28 | "codebook_size": 4096, 29 | "token_size": 12, 30 | "use_l2_norm": True, 31 | "commitment_cost": 0.25, 32 | "vit_enc_model_size": "large", 33 | "vit_dec_model_size": "large", 34 | "vit_enc_patch_size": 16, 35 | "vit_dec_patch_size": 16, 36 | "num_latent_tokens": 32, 37 | }, 38 | "generator": { 39 | "dropout": 0.1, 40 | "attn_drop": 0.1, 41 | "num_steps": 8, 42 | "mask_schedule_strategy": "arccos", 43 | "class_label_dropout": 0.1, 44 | "image_seq_len": 32, 45 | "condition_num_classes": 1000, 46 | }, 47 | }, 48 | "dataset": {"preprocessing": {"crop_size": 256}}, 49 | } 50 | config = edict(config) 51 | self.config = config 52 | self.encoder = TiTokEncoder(config) 53 | self.decoder = TiTokDecoder(config) 54 | 55 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens 56 | scale = self.encoder.width**-0.5 57 | self.latent_tokens = nn.Parameter( 58 | scale * torch.randn(self.num_latent_tokens, self.encoder.width) 59 | ) 60 | 61 | self.apply(self._init_weights) 62 | 63 | self.quantize = VectorQuantizer( 64 | codebook_size=config.model.vq_model.codebook_size, 65 | token_size=config.model.vq_model.token_size, 66 | commitment_cost=config.model.vq_model.commitment_cost, 67 | use_l2_norm=config.model.vq_model.use_l2_norm, 68 | ) 69 | 70 | self.pixel_quantize = Pixel_Quantizer( 71 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25 72 | ) 73 | self.pixel_decoder = Pixel_Decoder( 74 | OmegaConf.create( 75 | { 76 | "channel_mult": [1, 1, 2, 2, 4], 77 | "num_resolutions": 5, 78 | "dropout": 0.0, 79 | "hidden_channels": 128, 80 | "num_channels": 3, 81 | "num_res_blocks": 2, 82 | "resolution": 256, 83 | "z_channels": 256, 84 | } 85 | ) 86 | ) 87 | 88 | def load_weights(self, model_path): 89 | g_p = os.path.join(model_path, 'generator_titok_l32.bin') 90 | t_p = os.path.join(model_path, 'tokenizer_titok_l32.bin') 91 | sd_g = torch.load(g_p, map_location="cpu") 92 | sd_t = torch.load(t_p, map_location="cpu") 93 | missing, unexpected = self.load_state_dict(sd_g, strict=False) 94 | missing, unexpected = self.load_state_dict(sd_t, strict=False) 95 | 96 | def _init_weights(self, module): 97 | """Initialize the weights. 98 | :param: 99 | module -> torch.nn.Module: module to initialize 100 | """ 101 | if ( 102 | isinstance(module, nn.Linear) 103 | or isinstance(module, nn.Conv1d) 104 | or isinstance(module, nn.Conv2d) 105 | ): 106 | module.weight.data = nn.init.trunc_normal_( 107 | module.weight.data, mean=0.0, std=0.02 108 | ) 109 | if module.bias is not None: 110 | module.bias.data.zero_() 111 | elif isinstance(module, nn.Embedding): 112 | module.weight.data = nn.init.trunc_normal_( 113 | module.weight.data, mean=0.0, std=0.02 114 | ) 115 | elif isinstance(module, nn.LayerNorm): 116 | module.bias.data.zero_() 117 | module.weight.data.fill_(1.0) 118 | 119 | def encode(self, x): 120 | if x.shape[-1] != self.config.dataset.preprocessing.crop_size: 121 | x = torch.nn.functional.interpolate( 122 | x, 123 | size=( 124 | self.config.dataset.preprocessing.crop_size, 125 | self.config.dataset.preprocessing.crop_size, 126 | ), 127 | mode="bilinear", 128 | align_corners=False, 129 | ) 130 | print(x.shape) 131 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) 132 | z_quantized, result_dict = self.quantize(z) 133 | return z_quantized, z, result_dict['min_encoding_indices'] 134 | 135 | def decode(self, z_quantized): 136 | decoded_latent = self.decoder(z_quantized) 137 | quantized_states = torch.einsum( 138 | "nchw,cd->ndhw", 139 | decoded_latent.softmax(1), 140 | self.pixel_quantize.embedding.weight, 141 | ) 142 | decoded = self.pixel_decoder(quantized_states) 143 | return decoded 144 | 145 | def decode_tokens(self, tokens): 146 | tokens = tokens.squeeze(1) 147 | batch, seq_len = tokens.shape # B x N 148 | z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1)).reshape( 149 | batch, 1, seq_len, -1 150 | ) 151 | if self.quantize.use_l2_norm: 152 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) 153 | z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() 154 | decoded = self.decode(z_quantized) 155 | return decoded 156 | 157 | 158 | class VectorQuantizer(torch.nn.Module): 159 | def __init__( 160 | self, 161 | codebook_size: int = 1024, 162 | token_size: int = 256, 163 | commitment_cost: float = 0.25, 164 | use_l2_norm: bool = False, 165 | ): 166 | super().__init__() 167 | self.commitment_cost = commitment_cost 168 | 169 | self.embedding = torch.nn.Embedding(codebook_size, token_size) 170 | self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) 171 | self.use_l2_norm = use_l2_norm 172 | 173 | # Ensure quantization is performed using f32 174 | @autocast(enabled=False) 175 | def forward( 176 | self, z: torch.Tensor 177 | ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: 178 | z = z.float() 179 | z = rearrange(z, "b c h w -> b h w c").contiguous() 180 | z_flattened = rearrange(z, "b h w c -> (b h w) c") 181 | 182 | if self.use_l2_norm: 183 | z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) 184 | embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) 185 | else: 186 | embedding = self.embedding.weight 187 | d = ( 188 | torch.sum(z_flattened**2, dim=1, keepdim=True) 189 | + torch.sum(embedding**2, dim=1) 190 | - 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T) 191 | ) 192 | 193 | min_encoding_indices = torch.argmin(d, dim=1) # num_ele 194 | z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) 195 | 196 | if self.use_l2_norm: 197 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) 198 | z = torch.nn.functional.normalize(z, dim=-1) 199 | 200 | # compute loss for embedding 201 | commitment_loss = self.commitment_cost * torch.mean( 202 | (z_quantized.detach() - z) ** 2 203 | ) 204 | codebook_loss = torch.mean((z_quantized - z.detach()) ** 2) 205 | 206 | loss = commitment_loss + codebook_loss 207 | 208 | # preserve gradients 209 | z_quantized = z + (z_quantized - z).detach() 210 | 211 | # reshape back to match original input shape 212 | z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() 213 | 214 | result_dict = dict( 215 | quantizer_loss=loss, 216 | commitment_loss=commitment_loss, 217 | codebook_loss=codebook_loss, 218 | min_encoding_indices=min_encoding_indices.view( 219 | z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] 220 | ), 221 | ) 222 | 223 | return z_quantized, result_dict 224 | 225 | def get_codebook_entry(self, indices): 226 | if len(indices.shape) == 1: 227 | z_quantized = self.embedding(indices) 228 | elif len(indices.shape) == 2: 229 | z_quantized = torch.einsum("bd,dn->bn", indices, self.embedding.weight) 230 | else: 231 | raise NotImplementedError 232 | return z_quantized 233 | -------------------------------------------------------------------------------- /imagetokenizer/quantize/lookup_free_quantize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lookup Free Quantization 3 | Proposed in https://arxiv.org/abs/2310.05737 4 | 5 | In the simplest setup, each dimension is quantized into {-1, 1}. 6 | An entropy penalty is used to encourage utilization. 7 | 8 | Refer to 9 | https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py 10 | https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py 11 | """ 12 | 13 | from math import log2, ceil 14 | from collections import namedtuple 15 | 16 | import torch 17 | from torch import nn, einsum 18 | import torch.nn.functional as F 19 | from torch.nn import Module 20 | 21 | from einops import rearrange, reduce, pack, unpack 22 | 23 | # constants 24 | 25 | LossBreakdown = namedtuple( 26 | "LossBreakdown", 27 | ["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"], 28 | ) 29 | 30 | # helper functions 31 | 32 | 33 | def exists(v): 34 | return v is not None 35 | 36 | 37 | def default(*args): 38 | for arg in args: 39 | if exists(arg): 40 | return arg() if callable(arg) else arg 41 | return None 42 | 43 | 44 | def pack_one(t, pattern): 45 | return pack([t], pattern) 46 | 47 | 48 | def unpack_one(t, ps, pattern): 49 | return unpack(t, ps, pattern)[0] 50 | 51 | 52 | # entropy 53 | 54 | 55 | def entropy(prob): 56 | return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) 57 | 58 | 59 | # class 60 | 61 | 62 | def mult_along_first_dims(x, y): 63 | """ 64 | returns x * y elementwise along the leading dimensions of y 65 | """ 66 | ndim_to_expand = x.ndim - y.ndim 67 | for _ in range(ndim_to_expand): 68 | y = y.unsqueeze(-1) 69 | return x * y 70 | 71 | 72 | def masked_mean(x, m): 73 | """ 74 | takes the mean of the elements of x that are not masked 75 | the mean is taken along the shared leading dims of m 76 | equivalent to: x[m].mean(tuple(range(m.ndim))) 77 | 78 | The benefit of using masked_mean rather than using 79 | tensor indexing is that masked_mean is much faster 80 | for torch-compile on batches. 81 | 82 | The drawback is larger floating point errors 83 | """ 84 | x = mult_along_first_dims(x, m) 85 | x = x / m.sum() 86 | return x.sum(tuple(range(m.ndim))) 87 | 88 | 89 | def entropy_loss( 90 | logits, 91 | mask=None, 92 | temperature=0.01, 93 | sample_minimization_weight=1.0, 94 | batch_maximization_weight=1.0, 95 | eps=1e-5, 96 | ): 97 | """ 98 | Entropy loss of unnormalized logits 99 | 100 | logits: Affinities are over the last dimension 101 | 102 | https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 103 | LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) 104 | """ 105 | probs = F.softmax(logits / temperature, -1) 106 | log_probs = F.log_softmax(logits / temperature + eps, -1) 107 | 108 | if mask is not None: 109 | avg_probs = masked_mean(probs, mask) 110 | else: 111 | avg_probs = reduce(probs, "... D -> D", "mean") 112 | 113 | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) 114 | 115 | sample_entropy = -torch.sum(probs * log_probs, -1) 116 | if mask is not None: 117 | sample_entropy = masked_mean(sample_entropy, mask).mean() 118 | else: 119 | sample_entropy = torch.mean(sample_entropy) 120 | 121 | loss = (sample_minimization_weight * sample_entropy) - ( 122 | batch_maximization_weight * avg_entropy 123 | ) 124 | 125 | return sample_entropy, avg_entropy, loss 126 | 127 | 128 | class LFQ(Module): 129 | def __init__( 130 | self, 131 | *, 132 | dim=None, 133 | codebook_size=None, 134 | num_codebooks=1, 135 | sample_minimization_weight=1.0, 136 | batch_maximization_weight=1.0, 137 | token_factorization=False, 138 | ): 139 | super().__init__() 140 | 141 | # some assert validations 142 | 143 | assert exists(dim) or exists( 144 | codebook_size 145 | ), "either dim or codebook_size must be specified for LFQ" 146 | assert ( 147 | not exists(codebook_size) or log2(codebook_size).is_integer() 148 | ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" 149 | 150 | self.codebook_size = default(codebook_size, lambda: 2**dim) 151 | self.codebook_dim = int(log2(codebook_size)) 152 | 153 | codebook_dims = self.codebook_dim * num_codebooks 154 | dim = default(dim, codebook_dims) 155 | 156 | has_projections = dim != codebook_dims 157 | self.has_projections = has_projections 158 | 159 | self.dim = dim 160 | self.codebook_dim = self.codebook_dim 161 | self.num_codebooks = num_codebooks 162 | 163 | # for entropy loss 164 | self.sample_minimization_weight = sample_minimization_weight 165 | self.batch_maximization_weight = batch_maximization_weight 166 | 167 | # for no auxiliary loss, during inference 168 | self.token_factorization = token_factorization ## only utilized in second stage 169 | if not self.token_factorization: # for first stage model 170 | self.register_buffer( 171 | "mask", 172 | 2 ** torch.arange(self.codebook_dim - 1, -1, -1), 173 | persistent=False, 174 | ) 175 | else: 176 | k = self.codebook_dim // 2 177 | self.register_buffer( 178 | "mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False 179 | ) 180 | 181 | self.register_buffer("zero", torch.tensor(0.0), persistent=False) 182 | 183 | # codes 184 | all_codes = torch.arange(codebook_size) 185 | bits = self.indices_to_bits(all_codes) 186 | codebook = bits * 2.0 - 1.0 187 | 188 | self.register_buffer("codebook", codebook, persistent=False) 189 | 190 | @property 191 | def dtype(self): 192 | return self.codebook.dtype 193 | 194 | def indices_to_bits(self, x): 195 | """ 196 | x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments. 197 | 198 | returns big endian bits 199 | """ 200 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) 201 | # x is now big endian bits, the last dimension being the bits 202 | x = (x.unsqueeze(-1) & mask) != 0 203 | return x 204 | 205 | def get_codebook_entry(self, x, bhwc): 206 | if self.token_factorization: 207 | k = self.codebook_dim // 2 208 | mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long) 209 | else: 210 | mask = 2 ** torch.arange( 211 | self.codebook_dim - 1, -1, -1, device=x.device, dtype=torch.long 212 | ) 213 | 214 | x = (x.unsqueeze(-1) & mask) != 0 215 | x = x * 2.0 - 1.0 # back to the float 216 | ## scale back to the desired shape 217 | b, h, w, c = bhwc 218 | x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) 219 | x = rearrange(x, "b h w c -> b c h w") 220 | return x 221 | 222 | def bits_to_indices(self, bits): 223 | """ 224 | bits: bool tensor of big endian bits, where the last dimension is the bit dimension 225 | 226 | returns indices, which are long integers from 0 to self.codebook_size 227 | """ 228 | assert bits.shape[-1] == self.codebook_dim 229 | indices = 2 ** torch.arange( 230 | 0, 231 | self.codebook_dim, 232 | 1, 233 | dtype=torch.long, 234 | device=bits.device, 235 | ) 236 | return (bits * indices).sum(-1) 237 | 238 | def decode(self, x): 239 | """ 240 | x: ... NH 241 | where NH is number of codebook heads 242 | A longtensor of codebook indices, containing values from 243 | 0 to self.codebook_size 244 | """ 245 | x = self.indices_to_bits(x) 246 | # to some sort of float 247 | x = x.to(self.dtype) 248 | # -1 or 1 249 | x = x * 2 - 1 250 | x = rearrange(x, "... NC Z-> ... (NC Z)") 251 | return x 252 | 253 | def forward( 254 | self, 255 | x, 256 | return_loss_breakdown=False, 257 | mask=None, 258 | return_loss=True, 259 | ): 260 | """ 261 | einstein notation 262 | b - batch 263 | n - sequence (or flattened spatial dimensions) 264 | d - feature dimension, which is also log2(codebook size) 265 | c - number of codebook dim 266 | """ 267 | 268 | x = rearrange(x, "b d ... -> b ... d") 269 | x, ps = pack_one(x, "b * d") 270 | # split out number of codebooks 271 | 272 | x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) 273 | 274 | codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) 275 | quantized = torch.where( 276 | x > 0, codebook_value, -codebook_value 277 | ) # higher than 0 filled 278 | 279 | # calculate indices 280 | if self.token_factorization: 281 | k = self.codebook_dim // 2 282 | indices_pre = reduce( 283 | (quantized[..., :k] > 0).int() * self.mask.int(), 284 | "b n c d -> b n c", 285 | "sum", 286 | ) 287 | indices_post = reduce( 288 | (quantized[..., k:] > 0).int() * self.mask.int(), 289 | "b n c d -> b n c", 290 | "sum", 291 | ) 292 | # indices_post = 2**k + indices_post #shifter to the 1024 293 | else: 294 | indices = reduce( 295 | (quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum" 296 | ) 297 | 298 | # entropy aux loss 299 | 300 | if self.training and return_loss: 301 | logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook) 302 | # the same as euclidean distance up to a constant 303 | per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( 304 | logits=logits, 305 | sample_minimization_weight=self.sample_minimization_weight, 306 | batch_maximization_weight=self.batch_maximization_weight, 307 | ) 308 | 309 | avg_probs = self.zero 310 | else: 311 | ## calculate the codebook_entropy needed for one batch evaluation 312 | # ------------------------------------------------------------------ 313 | # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) 314 | # probs = F.softmax(logits / 0.01, -1) 315 | # avg_probs = reduce(probs, "b n c d -> b d", "mean") 316 | # avg_probs = torch.sum(avg_probs, 0) #batch dimension 317 | # ------------------------------------------------------------------- 318 | # if not training, just return dummy 0 319 | per_sample_entropy = codebook_entropy = self.zero 320 | entropy_aux_loss = self.zero 321 | avg_probs = self.zero 322 | 323 | # commit loss 324 | 325 | if self.training: 326 | commit_loss = F.mse_loss(x, quantized.detach(), reduction="none") 327 | 328 | if exists(mask): 329 | commit_loss = commit_loss[mask] 330 | 331 | commit_loss = commit_loss.mean() 332 | else: 333 | commit_loss = self.zero 334 | 335 | # use straight-through gradients (optionally with custom activation fn) if training 336 | 337 | quantized = x + (quantized - x).detach() # transfer to quantized 338 | 339 | # merge back codebook dim 340 | 341 | quantized = rearrange(quantized, "b n c d -> b n (c d)") 342 | 343 | # reconstitute image or video dimensions 344 | 345 | quantized = unpack_one(quantized, ps, "b * d") 346 | quantized = rearrange(quantized, "b ... d -> b d ...") 347 | 348 | if self.token_factorization: 349 | indices_pre = unpack_one(indices_pre, ps, "b * c") 350 | indices_post = unpack_one(indices_post, ps, "b * c") 351 | indices_pre = indices_pre.flatten() 352 | indices_post = indices_post.flatten() 353 | indices = (indices_pre, indices_post) 354 | else: 355 | indices = unpack_one(indices, ps, "b * c") 356 | indices = indices.flatten() 357 | 358 | ret = (quantized, entropy_aux_loss, indices) 359 | 360 | if not return_loss_breakdown: 361 | return ret 362 | 363 | return ret, LossBreakdown( 364 | per_sample_entropy, codebook_entropy, commit_loss, avg_probs 365 | ) 366 | -------------------------------------------------------------------------------- /imagetokenizer/quantize/vector_quantize.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | import torch.distributed as distributed 7 | from torch.optim import Optimizer 8 | from torch.cuda.amp import autocast 9 | 10 | from einops import rearrange, repeat, reduce, pack, unpack 11 | 12 | from typing import Callable 13 | 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | 23 | def noop(*args, **kwargs): 24 | pass 25 | 26 | 27 | def identity(t): 28 | return t 29 | 30 | 31 | def l2norm(t): 32 | return F.normalize(t, p=2, dim=-1) 33 | 34 | 35 | def cdist(x, y): 36 | x2 = reduce(x**2, "b n d -> b n", "sum") 37 | y2 = reduce(y**2, "b n d -> b n", "sum") 38 | xy = einsum("b i d, b j d -> b i j", x, y) * -2 39 | return ( 40 | (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy) 41 | .clamp(min=0) 42 | .sqrt() 43 | ) 44 | 45 | 46 | def log(t, eps=1e-20): 47 | return torch.log(t.clamp(min=eps)) 48 | 49 | 50 | def ema_inplace(old, new, decay): 51 | is_mps = str(old.device).startswith("mps:") 52 | 53 | if not is_mps: 54 | old.lerp_(new, 1 - decay) 55 | else: 56 | old.mul_(decay).add_(new * (1 - decay)) 57 | 58 | 59 | def pack_one(t, pattern): 60 | return pack([t], pattern) 61 | 62 | 63 | def unpack_one(t, ps, pattern): 64 | return unpack(t, ps, pattern)[0] 65 | 66 | 67 | def uniform_init(*shape): 68 | t = torch.empty(shape) 69 | nn.init.kaiming_uniform_(t) 70 | return t 71 | 72 | 73 | def gumbel_noise(t): 74 | noise = torch.zeros_like(t).uniform_(0, 1) 75 | return -log(-log(noise)) 76 | 77 | 78 | def gumbel_sample( 79 | logits, 80 | temperature=1.0, 81 | stochastic=False, 82 | straight_through=False, 83 | reinmax=False, 84 | dim=-1, 85 | training=True, 86 | ): 87 | dtype, size = logits.dtype, logits.shape[dim] 88 | 89 | if training and stochastic and temperature > 0: 90 | sampling_logits = (logits / temperature) + gumbel_noise(logits) 91 | else: 92 | sampling_logits = logits 93 | 94 | ind = sampling_logits.argmax(dim=dim) 95 | one_hot = F.one_hot(ind, size).type(dtype) 96 | 97 | assert not ( 98 | reinmax and not straight_through 99 | ), "reinmax can only be turned on if using straight through gumbel softmax" 100 | 101 | if not straight_through or temperature <= 0.0 or not training: 102 | return ind, one_hot 103 | 104 | # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 105 | # algorithm 2 106 | 107 | if reinmax: 108 | π0 = logits.softmax(dim=dim) 109 | π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2 110 | π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1) 111 | π2 = 2 * π1 - 0.5 * π0 112 | one_hot = π2 - π2.detach() + one_hot 113 | else: 114 | π1 = (logits / temperature).softmax(dim=dim) 115 | one_hot = one_hot + π1 - π1.detach() 116 | 117 | return ind, one_hot 118 | 119 | 120 | def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1): 121 | denom = x.sum(dim=dim, keepdim=True) 122 | return (x + eps) / (denom + n_categories * eps) 123 | 124 | 125 | def sample_vectors(samples, num): 126 | num_samples, device = samples.shape[0], samples.device 127 | if num_samples >= num: 128 | indices = torch.randperm(num_samples, device=device)[:num] 129 | else: 130 | indices = torch.randint(0, num_samples, (num,), device=device) 131 | 132 | return samples[indices] 133 | 134 | 135 | def batched_sample_vectors(samples, num): 136 | return torch.stack( 137 | [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0 138 | ) 139 | 140 | 141 | def pad_shape(shape, size, dim=0): 142 | return [size if i == dim else s for i, s in enumerate(shape)] 143 | 144 | 145 | def sample_multinomial(total_count, probs): 146 | device = probs.device 147 | probs = probs.cpu() 148 | 149 | total_count = probs.new_full((), total_count) 150 | remainder = probs.new_ones(()) 151 | sample = torch.empty_like(probs, dtype=torch.long) 152 | 153 | for i, p in enumerate(probs): 154 | s = torch.binomial(total_count, p / remainder) 155 | sample[i] = s 156 | total_count -= s 157 | remainder -= p 158 | 159 | return sample.to(device) 160 | 161 | 162 | def all_gather_sizes(x, dim): 163 | size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device) 164 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 165 | distributed.all_gather(all_sizes, size) 166 | return torch.stack(all_sizes) 167 | 168 | 169 | def all_gather_variably_sized(x, sizes, dim=0): 170 | rank = distributed.get_rank() 171 | all_x = [] 172 | 173 | for i, size in enumerate(sizes): 174 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 175 | distributed.broadcast(t, src=i, async_op=True) 176 | all_x.append(t) 177 | 178 | distributed.barrier() 179 | return all_x 180 | 181 | 182 | def sample_vectors_distributed(local_samples, num): 183 | local_samples = rearrange(local_samples, "1 ... -> ...") 184 | 185 | rank = distributed.get_rank() 186 | all_num_samples = all_gather_sizes(local_samples, dim=0) 187 | 188 | if rank == 0: 189 | samples_per_rank = sample_multinomial( 190 | num, all_num_samples / all_num_samples.sum() 191 | ) 192 | else: 193 | samples_per_rank = torch.empty_like(all_num_samples) 194 | 195 | distributed.broadcast(samples_per_rank, src=0) 196 | samples_per_rank = samples_per_rank.tolist() 197 | 198 | local_samples = sample_vectors(local_samples, samples_per_rank[rank]) 199 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0) 200 | out = torch.cat(all_samples, dim=0) 201 | 202 | return rearrange(out, "... -> 1 ...") 203 | 204 | 205 | def batched_bincount(x, *, minlength): 206 | batch, dtype, device = x.shape[0], x.dtype, x.device 207 | target = torch.zeros(batch, minlength, dtype=dtype, device=device) 208 | values = torch.ones_like(x) 209 | target.scatter_add_(-1, x, values) 210 | return target 211 | 212 | 213 | def kmeans( 214 | samples, 215 | num_clusters, 216 | num_iters=10, 217 | use_cosine_sim=False, 218 | sample_fn=batched_sample_vectors, 219 | all_reduce_fn=noop, 220 | ): 221 | num_codebooks, dim, dtype, device = ( 222 | samples.shape[0], 223 | samples.shape[-1], 224 | samples.dtype, 225 | samples.device, 226 | ) 227 | 228 | means = sample_fn(samples, num_clusters) 229 | 230 | for _ in range(num_iters): 231 | if use_cosine_sim: 232 | dists = samples @ rearrange(means, "h n d -> h d n") 233 | else: 234 | dists = -cdist(samples, means) 235 | 236 | buckets = torch.argmax(dists, dim=-1) 237 | bins = batched_bincount(buckets, minlength=num_clusters) 238 | all_reduce_fn(bins) 239 | 240 | zero_mask = bins == 0 241 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 242 | 243 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype) 244 | 245 | new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples) 246 | new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1") 247 | all_reduce_fn(new_means) 248 | 249 | if use_cosine_sim: 250 | new_means = l2norm(new_means) 251 | 252 | means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means) 253 | 254 | return means, bins 255 | 256 | 257 | def batched_embedding(indices, embeds): 258 | batch, dim = indices.shape[1], embeds.shape[-1] 259 | indices = repeat(indices, "h b n -> h b n d", d=dim) 260 | embeds = repeat(embeds, "h c d -> h b c d", b=batch) 261 | return embeds.gather(2, indices) 262 | 263 | 264 | # regularization losses 265 | 266 | 267 | def orthogonal_loss_fn(t): 268 | # eq (2) from https://arxiv.org/abs/2112.00384 269 | h, n = t.shape[:2] 270 | normed_codes = l2norm(t) 271 | cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) 272 | return (cosine_sim**2).sum() / (h * n**2) - (1 / n) 273 | 274 | 275 | # distance types 276 | 277 | 278 | class EuclideanCodebook(nn.Module): 279 | def __init__( 280 | self, 281 | dim, 282 | codebook_size, 283 | num_codebooks=1, 284 | kmeans_init=False, 285 | kmeans_iters=10, 286 | sync_kmeans=True, 287 | decay=0.8, 288 | eps=1e-5, 289 | threshold_ema_dead_code=2, 290 | reset_cluster_size=None, 291 | use_ddp=False, 292 | learnable_codebook=False, 293 | gumbel_sample=gumbel_sample, 294 | sample_codebook_temp=1.0, 295 | ema_update=True, 296 | affine_param=False, 297 | sync_affine_param=False, 298 | affine_param_batch_decay=0.99, 299 | affine_param_codebook_decay=0.9, 300 | ): 301 | super().__init__() 302 | self.transform_input = identity 303 | 304 | self.decay = decay 305 | self.ema_update = ema_update 306 | 307 | init_fn = uniform_init if not kmeans_init else torch.zeros 308 | embed = init_fn(num_codebooks, codebook_size, dim) 309 | 310 | self.codebook_size = codebook_size 311 | self.num_codebooks = num_codebooks 312 | 313 | self.kmeans_iters = kmeans_iters 314 | self.eps = eps 315 | self.threshold_ema_dead_code = threshold_ema_dead_code 316 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) 317 | 318 | assert callable(gumbel_sample) 319 | self.gumbel_sample = gumbel_sample 320 | self.sample_codebook_temp = sample_codebook_temp 321 | 322 | assert not ( 323 | use_ddp and num_codebooks > 1 and kmeans_init 324 | ), "kmeans init is not compatible with multiple codebooks in distributed environment for now" 325 | 326 | self.sample_fn = ( 327 | sample_vectors_distributed 328 | if use_ddp and sync_kmeans 329 | else batched_sample_vectors 330 | ) 331 | self.kmeans_all_reduce_fn = ( 332 | distributed.all_reduce if use_ddp and sync_kmeans else noop 333 | ) 334 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 335 | 336 | self.register_buffer("initted", torch.Tensor([not kmeans_init])) 337 | self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) 338 | self.register_buffer("embed_avg", embed.clone()) 339 | 340 | self.learnable_codebook = learnable_codebook 341 | if learnable_codebook: 342 | self.embed = nn.Parameter(embed) 343 | else: 344 | self.register_buffer("embed", embed) 345 | 346 | # affine related params 347 | 348 | self.affine_param = affine_param 349 | self.sync_affine_param = sync_affine_param 350 | 351 | if not affine_param: 352 | return 353 | 354 | self.affine_param_batch_decay = affine_param_batch_decay 355 | self.affine_param_codebook_decay = affine_param_codebook_decay 356 | 357 | self.register_buffer("batch_mean", None) 358 | self.register_buffer("batch_variance", None) 359 | 360 | self.register_buffer("codebook_mean_needs_init", torch.Tensor([True])) 361 | self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim)) 362 | self.register_buffer("codebook_variance_needs_init", torch.Tensor([True])) 363 | self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim)) 364 | 365 | @torch.jit.ignore 366 | def init_embed_(self, data, mask=None): 367 | if self.initted: 368 | return 369 | 370 | if exists(mask): 371 | c = data.shape[0] 372 | data = rearrange(data[mask], "(c n) d -> c n d", c=c) 373 | 374 | embed, cluster_size = kmeans( 375 | data, 376 | self.codebook_size, 377 | self.kmeans_iters, 378 | sample_fn=self.sample_fn, 379 | all_reduce_fn=self.kmeans_all_reduce_fn, 380 | ) 381 | 382 | embed_sum = embed * rearrange(cluster_size, "... -> ... 1") 383 | 384 | self.embed.data.copy_(embed) 385 | self.embed_avg.data.copy_(embed_sum) 386 | self.cluster_size.data.copy_(cluster_size) 387 | self.initted.data.copy_(torch.Tensor([True])) 388 | 389 | @torch.jit.ignore 390 | def update_with_decay(self, buffer_name, new_value, decay): 391 | old_value = getattr(self, buffer_name) 392 | 393 | needs_init = getattr(self, buffer_name + "_needs_init", False) 394 | 395 | if needs_init: 396 | self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) 397 | 398 | if not exists(old_value) or needs_init: 399 | self.register_buffer(buffer_name, new_value.detach()) 400 | 401 | return 402 | 403 | value = old_value * decay + new_value.detach() * (1 - decay) 404 | self.register_buffer(buffer_name, value) 405 | 406 | @torch.jit.ignore 407 | def update_affine(self, data, embed, mask=None): 408 | assert self.affine_param 409 | 410 | var_fn = partial(torch.var, unbiased=False) 411 | 412 | # calculate codebook mean and variance 413 | 414 | embed = rearrange(embed, "h ... d -> h (...) d") 415 | 416 | if self.training: 417 | self.update_with_decay( 418 | "codebook_mean", 419 | reduce(embed, "h n d -> h 1 d", "mean"), 420 | self.affine_param_codebook_decay, 421 | ) 422 | self.update_with_decay( 423 | "codebook_variance", 424 | reduce(embed, "h n d -> h 1 d", var_fn), 425 | self.affine_param_codebook_decay, 426 | ) 427 | 428 | # prepare batch data, which depends on whether it has masking 429 | 430 | data = rearrange(data, "h ... d -> h (...) d") 431 | 432 | if exists(mask): 433 | c = data.shape[0] 434 | data = rearrange(data[mask], "(c n) d -> c n d", c=c) 435 | 436 | # calculate batch mean and variance 437 | 438 | if not self.sync_affine_param: 439 | self.update_with_decay( 440 | "batch_mean", 441 | reduce(data, "h n d -> h 1 d", "mean"), 442 | self.affine_param_batch_decay, 443 | ) 444 | self.update_with_decay( 445 | "batch_variance", 446 | reduce(data, "h n d -> h 1 d", var_fn), 447 | self.affine_param_batch_decay, 448 | ) 449 | return 450 | 451 | num_vectors, device, dtype = data.shape[-2], data.device, data.dtype 452 | 453 | # number of vectors, for denominator 454 | 455 | num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype) 456 | distributed.all_reduce(num_vectors) 457 | 458 | # calculate distributed mean 459 | 460 | batch_sum = reduce(data, "h n d -> h 1 d", "sum") 461 | distributed.all_reduce(batch_sum) 462 | batch_mean = batch_sum / num_vectors 463 | 464 | self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay) 465 | 466 | # calculate distributed variance 467 | 468 | variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum") 469 | distributed.all_reduce(variance_numer) 470 | batch_variance = variance_numer / num_vectors 471 | 472 | self.update_with_decay( 473 | "batch_variance", batch_variance, self.affine_param_batch_decay 474 | ) 475 | 476 | def replace(self, batch_samples, batch_mask): 477 | for ind, (samples, mask) in enumerate( 478 | zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0)) 479 | ): 480 | if not torch.any(mask): 481 | continue 482 | 483 | sampled = self.sample_fn( 484 | rearrange(samples, "... -> 1 ..."), mask.sum().item() 485 | ) 486 | sampled = rearrange(sampled, "1 ... -> ...") 487 | 488 | self.embed.data[ind][mask] = sampled 489 | 490 | self.cluster_size.data[ind][mask] = self.reset_cluster_size 491 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size 492 | 493 | def expire_codes_(self, batch_samples): 494 | if self.threshold_ema_dead_code == 0: 495 | return 496 | 497 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 498 | 499 | if not torch.any(expired_codes): 500 | return 501 | 502 | batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") 503 | self.replace(batch_samples, batch_mask=expired_codes) 504 | 505 | @autocast(enabled=False) 506 | def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): 507 | needs_codebook_dim = x.ndim < 4 508 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) 509 | 510 | x = x.float() 511 | 512 | if needs_codebook_dim: 513 | x = rearrange(x, "... -> 1 ...") 514 | 515 | dtype = x.dtype 516 | flatten, ps = pack_one(x, "h * d") 517 | 518 | if exists(mask): 519 | mask = repeat( 520 | mask, 521 | "b n -> c (b h n)", 522 | c=flatten.shape[0], 523 | h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), 524 | ) 525 | 526 | self.init_embed_(flatten, mask=mask) 527 | 528 | if self.affine_param: 529 | self.update_affine(flatten, self.embed, mask=mask) 530 | 531 | embed = self.embed if self.learnable_codebook else self.embed.detach() 532 | 533 | if self.affine_param: 534 | codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() 535 | batch_std = self.batch_variance.clamp(min=1e-5).sqrt() 536 | embed = (embed - self.codebook_mean) * ( 537 | batch_std / codebook_std 538 | ) + self.batch_mean 539 | 540 | dist = -cdist(flatten, embed) 541 | 542 | embed_ind, embed_onehot = self.gumbel_sample( 543 | dist, dim=-1, temperature=sample_codebook_temp, training=self.training 544 | ) 545 | 546 | embed_ind = unpack_one(embed_ind, ps, "h *") 547 | 548 | if self.training: 549 | unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") 550 | quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) 551 | else: 552 | quantize = batched_embedding(embed_ind, embed) 553 | 554 | if self.training and self.ema_update and not freeze_codebook: 555 | 556 | if self.affine_param: 557 | flatten = (flatten - self.batch_mean) * ( 558 | codebook_std / batch_std 559 | ) + self.codebook_mean 560 | 561 | if exists(mask): 562 | embed_onehot[~mask] = 0.0 563 | 564 | cluster_size = embed_onehot.sum(dim=1) 565 | 566 | self.all_reduce_fn(cluster_size) 567 | ema_inplace(self.cluster_size.data, cluster_size, self.decay) 568 | 569 | embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) 570 | embed_sum = embed_sum.contiguous() 571 | self.all_reduce_fn(embed_sum) 572 | 573 | ema_inplace(self.embed_avg.data, embed_sum, self.decay) 574 | 575 | cluster_size = laplace_smoothing( 576 | self.cluster_size, self.codebook_size, self.eps 577 | ) * self.cluster_size.sum(dim=-1, keepdim=True) 578 | 579 | embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") 580 | self.embed.data.copy_(embed_normalized) 581 | self.expire_codes_(x) 582 | 583 | if needs_codebook_dim: 584 | quantize, embed_ind = map( 585 | lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind) 586 | ) 587 | 588 | dist = unpack_one(dist, ps, "h * d") 589 | 590 | return quantize, embed_ind, dist 591 | 592 | 593 | class CosineSimCodebook(nn.Module): 594 | def __init__( 595 | self, 596 | dim, 597 | codebook_size, 598 | num_codebooks=1, 599 | kmeans_init=False, 600 | kmeans_iters=10, 601 | sync_kmeans=True, 602 | decay=0.8, 603 | eps=1e-5, 604 | threshold_ema_dead_code=2, 605 | reset_cluster_size=None, 606 | use_ddp=False, 607 | learnable_codebook=False, 608 | gumbel_sample=gumbel_sample, 609 | sample_codebook_temp=1.0, 610 | ema_update=True, 611 | ): 612 | super().__init__() 613 | self.transform_input = l2norm 614 | 615 | self.ema_update = ema_update 616 | self.decay = decay 617 | 618 | if not kmeans_init: 619 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) 620 | else: 621 | embed = torch.zeros(num_codebooks, codebook_size, dim) 622 | 623 | self.codebook_size = codebook_size 624 | self.num_codebooks = num_codebooks 625 | 626 | self.kmeans_iters = kmeans_iters 627 | self.eps = eps 628 | self.threshold_ema_dead_code = threshold_ema_dead_code 629 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) 630 | 631 | assert callable(gumbel_sample) 632 | self.gumbel_sample = gumbel_sample 633 | self.sample_codebook_temp = sample_codebook_temp 634 | 635 | self.sample_fn = ( 636 | sample_vectors_distributed 637 | if use_ddp and sync_kmeans 638 | else batched_sample_vectors 639 | ) 640 | self.kmeans_all_reduce_fn = ( 641 | distributed.all_reduce if use_ddp and sync_kmeans else noop 642 | ) 643 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 644 | 645 | self.register_buffer("initted", torch.Tensor([not kmeans_init])) 646 | self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) 647 | self.register_buffer("embed_avg", embed.clone()) 648 | 649 | self.learnable_codebook = learnable_codebook 650 | if learnable_codebook: 651 | self.embed = nn.Parameter(embed) 652 | else: 653 | self.register_buffer("embed", embed) 654 | 655 | @torch.jit.ignore 656 | def init_embed_(self, data, mask=None): 657 | if self.initted: 658 | return 659 | 660 | if exists(mask): 661 | c = data.shape[0] 662 | data = rearrange(data[mask], "(c n) d -> c n d", c=c) 663 | 664 | embed, cluster_size = kmeans( 665 | data, 666 | self.codebook_size, 667 | self.kmeans_iters, 668 | use_cosine_sim=True, 669 | sample_fn=self.sample_fn, 670 | all_reduce_fn=self.kmeans_all_reduce_fn, 671 | ) 672 | 673 | embed_sum = embed * rearrange(cluster_size, "... -> ... 1") 674 | 675 | self.embed.data.copy_(embed) 676 | self.embed_avg.data.copy_(embed_sum) 677 | self.cluster_size.data.copy_(cluster_size) 678 | self.initted.data.copy_(torch.Tensor([True])) 679 | 680 | def replace(self, batch_samples, batch_mask): 681 | batch_samples = l2norm(batch_samples) 682 | 683 | for ind, (samples, mask) in enumerate( 684 | zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0)) 685 | ): 686 | if not torch.any(mask): 687 | continue 688 | 689 | sampled = self.sample_fn( 690 | rearrange(samples, "... -> 1 ..."), mask.sum().item() 691 | ) 692 | sampled = rearrange(sampled, "1 ... -> ...") 693 | 694 | self.embed.data[ind][mask] = sampled 695 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size 696 | self.cluster_size.data[ind][mask] = self.reset_cluster_size 697 | 698 | def expire_codes_(self, batch_samples): 699 | if self.threshold_ema_dead_code == 0: 700 | return 701 | 702 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 703 | 704 | if not torch.any(expired_codes): 705 | return 706 | 707 | batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") 708 | self.replace(batch_samples, batch_mask=expired_codes) 709 | 710 | @autocast(enabled=False) 711 | def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): 712 | needs_codebook_dim = x.ndim < 4 713 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) 714 | 715 | x = x.float() 716 | 717 | if needs_codebook_dim: 718 | x = rearrange(x, "... -> 1 ...") 719 | 720 | dtype = x.dtype 721 | 722 | flatten, ps = pack_one(x, "h * d") 723 | 724 | if exists(mask): 725 | mask = repeat( 726 | mask, 727 | "b n -> c (b h n)", 728 | c=flatten.shape[0], 729 | h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), 730 | ) 731 | 732 | self.init_embed_(flatten, mask=mask) 733 | 734 | embed = self.embed if self.learnable_codebook else self.embed.detach() 735 | 736 | dist = einsum("h n d, h c d -> h n c", flatten, embed) 737 | 738 | embed_ind, embed_onehot = self.gumbel_sample( 739 | dist, dim=-1, temperature=sample_codebook_temp, training=self.training 740 | ) 741 | embed_ind = unpack_one(embed_ind, ps, "h *") 742 | 743 | if self.training: 744 | unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") 745 | quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) 746 | else: 747 | quantize = batched_embedding(embed_ind, embed) 748 | 749 | if self.training and self.ema_update and not freeze_codebook: 750 | if exists(mask): 751 | embed_onehot[~mask] = 0.0 752 | 753 | bins = embed_onehot.sum(dim=1) 754 | self.all_reduce_fn(bins) 755 | 756 | ema_inplace(self.cluster_size.data, bins, self.decay) 757 | 758 | embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) 759 | embed_sum = embed_sum.contiguous() 760 | self.all_reduce_fn(embed_sum) 761 | 762 | ema_inplace(self.embed_avg.data, embed_sum, self.decay) 763 | 764 | cluster_size = laplace_smoothing( 765 | self.cluster_size, self.codebook_size, self.eps 766 | ) * self.cluster_size.sum(dim=-1, keepdim=True) 767 | 768 | embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") 769 | embed_normalized = l2norm(embed_normalized) 770 | 771 | self.embed.data.copy_(l2norm(embed_normalized)) 772 | self.expire_codes_(x) 773 | 774 | if needs_codebook_dim: 775 | quantize, embed_ind = map( 776 | lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind) 777 | ) 778 | 779 | dist = unpack_one(dist, ps, "h * d") 780 | return quantize, embed_ind, dist 781 | 782 | 783 | # main class 784 | 785 | 786 | class VectorQuantize(nn.Module): 787 | def __init__( 788 | self, 789 | dim, 790 | codebook_size, 791 | codebook_dim=None, 792 | heads=1, 793 | separate_codebook_per_head=False, 794 | decay=0.8, 795 | eps=1e-5, 796 | freeze_codebook=False, 797 | kmeans_init=False, 798 | kmeans_iters=10, 799 | sync_kmeans=True, 800 | use_cosine_sim=False, 801 | threshold_ema_dead_code=0, 802 | channel_last=True, 803 | accept_image_fmap=False, 804 | commitment_weight=1.0, 805 | commitment_use_cross_entropy_loss=False, 806 | orthogonal_reg_weight=0.0, 807 | orthogonal_reg_active_codes_only=False, 808 | orthogonal_reg_max_codes=None, 809 | stochastic_sample_codes=False, 810 | sample_codebook_temp=1.0, 811 | straight_through=False, 812 | reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all 813 | sync_codebook=None, 814 | sync_affine_param=False, 815 | ema_update=True, 816 | learnable_codebook=False, 817 | in_place_codebook_optimizer: Callable[ 818 | ..., Optimizer 819 | ] = None, # Optimizer used to update the codebook embedding if using learnable_codebook 820 | affine_param=False, 821 | affine_param_batch_decay=0.99, 822 | affine_param_codebook_decay=0.9, 823 | sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf 824 | ): 825 | super().__init__() 826 | self.dim = dim 827 | self.heads = heads 828 | self.separate_codebook_per_head = separate_codebook_per_head 829 | 830 | codebook_dim = default(codebook_dim, dim) 831 | codebook_input_dim = codebook_dim * heads 832 | 833 | requires_projection = codebook_input_dim != dim 834 | self.project_in = ( 835 | nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 836 | ) 837 | self.project_out = ( 838 | nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 839 | ) 840 | 841 | self.has_projections = requires_projection 842 | 843 | self.eps = eps 844 | self.commitment_weight = commitment_weight 845 | self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss 846 | 847 | self.learnable_codebook = learnable_codebook 848 | 849 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 850 | self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss 851 | self.orthogonal_reg_weight = orthogonal_reg_weight 852 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 853 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 854 | 855 | assert not ( 856 | ema_update and learnable_codebook 857 | ), "learnable codebook not compatible with EMA update" 858 | 859 | assert 0 <= sync_update_v <= 1.0 860 | assert not ( 861 | sync_update_v > 0.0 and not learnable_codebook 862 | ), "learnable codebook must be turned on" 863 | 864 | self.sync_update_v = sync_update_v 865 | 866 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook 867 | 868 | gumbel_sample_fn = partial( 869 | gumbel_sample, 870 | stochastic=stochastic_sample_codes, 871 | reinmax=reinmax, 872 | straight_through=straight_through, 873 | ) 874 | 875 | if not exists(sync_codebook): 876 | sync_codebook = ( 877 | distributed.is_initialized() and distributed.get_world_size() > 1 878 | ) 879 | 880 | codebook_kwargs = dict( 881 | dim=codebook_dim, 882 | num_codebooks=heads if separate_codebook_per_head else 1, 883 | codebook_size=codebook_size, 884 | kmeans_init=kmeans_init, 885 | kmeans_iters=kmeans_iters, 886 | sync_kmeans=sync_kmeans, 887 | decay=decay, 888 | eps=eps, 889 | threshold_ema_dead_code=threshold_ema_dead_code, 890 | use_ddp=sync_codebook, 891 | learnable_codebook=has_codebook_orthogonal_loss or learnable_codebook, 892 | sample_codebook_temp=sample_codebook_temp, 893 | gumbel_sample=gumbel_sample_fn, 894 | ema_update=ema_update, 895 | ) 896 | 897 | if affine_param: 898 | assert ( 899 | not use_cosine_sim 900 | ), "affine param is only compatible with euclidean codebook" 901 | codebook_kwargs = dict( 902 | **codebook_kwargs, 903 | affine_param=True, 904 | sync_affine_param=sync_affine_param, 905 | affine_param_batch_decay=affine_param_batch_decay, 906 | affine_param_codebook_decay=affine_param_codebook_decay, 907 | ) 908 | 909 | self._codebook = codebook_class(**codebook_kwargs) 910 | 911 | self.in_place_codebook_optimizer = ( 912 | in_place_codebook_optimizer(self._codebook.parameters()) 913 | if exists(in_place_codebook_optimizer) 914 | else None 915 | ) 916 | 917 | self.codebook_size = codebook_size 918 | self.register_buffer("codebook_usage", torch.zeros(codebook_size)) 919 | self.call_cnt = 0 920 | 921 | self.accept_image_fmap = accept_image_fmap 922 | self.channel_last = channel_last 923 | 924 | @property 925 | def codebook(self): 926 | codebook = self._codebook.embed 927 | 928 | if self.separate_codebook_per_head: 929 | return codebook 930 | 931 | return rearrange(codebook, "1 ... -> ...") 932 | 933 | @codebook.setter 934 | def codebook(self, codes): 935 | if not self.separate_codebook_per_head: 936 | codes = rearrange(codes, "... -> 1 ...") 937 | 938 | self._codebook.embed.copy_(codes) 939 | 940 | def get_codes_from_indices(self, indices): 941 | codebook = self.codebook 942 | is_multiheaded = codebook.ndim > 2 943 | 944 | if not is_multiheaded: 945 | codes = codebook[indices] 946 | return rearrange(codes, "... h d -> ... (h d)") 947 | 948 | indices, ps = pack_one(indices, "b * h") 949 | indices = rearrange(indices, "b n h -> b h n") 950 | 951 | indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1]) 952 | codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0]) 953 | 954 | codes = codebook.gather(2, indices) 955 | codes = rearrange(codes, "b h n d -> b n (h d)") 956 | codes = unpack_one(codes, ps, "b * d") 957 | return codes 958 | 959 | def get_output_from_indices(self, indices): 960 | codes = self.get_codes_from_indices(indices) 961 | return self.project_out(codes) 962 | 963 | def get_perplexity(self, encoding_indices, x): 964 | encode_onehot = F.one_hot(encoding_indices, self.codebook_size).type_as( 965 | x 966 | ) # [bthw, ncode] 967 | avg_probs = torch.mean(encode_onehot, dim=0) 968 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 969 | return perplexity 970 | 971 | def get_usage(self, encoding_indices): 972 | # Flatten the batch of encoding indices into a single 1D tensor 973 | all_indices = encoding_indices.flatten() 974 | 975 | # Obtain the total number of encoding indices in the batch to calculate percentages 976 | total_indices = all_indices.numel() 977 | 978 | # Initialize a tensor to store the percentage usage of each code 979 | codebook_usage_percentage = torch.zeros( 980 | self.codebook_size, device=all_indices.device 981 | ) 982 | 983 | # Count the number of occurrences of each index and get their frequency as percentages 984 | unique_indices, counts = torch.unique(all_indices, return_counts=True) 985 | 986 | # Calculate the percentage 987 | percentages = counts.float() / total_indices 988 | 989 | # Populate the corresponding percentages in the codebook_usage_percentage tensor 990 | codebook_usage_percentage[unique_indices.long()] = percentages 991 | 992 | return codebook_usage_percentage 993 | 994 | def forward( 995 | self, 996 | x, 997 | indices=None, 998 | mask=None, 999 | sample_codebook_temp=None, 1000 | freeze_codebook=False, 1001 | ): 1002 | orig_input = x 1003 | 1004 | only_one = x.ndim == 2 1005 | 1006 | if only_one: 1007 | assert not exists(mask) 1008 | x = rearrange(x, "b d -> b 1 d") 1009 | 1010 | shape, device, heads, is_multiheaded, codebook_size, return_loss = ( 1011 | x.shape, 1012 | x.device, 1013 | self.heads, 1014 | self.heads > 1, 1015 | self.codebook_size, 1016 | exists(indices), 1017 | ) 1018 | 1019 | need_transpose = not self.channel_last and not self.accept_image_fmap 1020 | should_inplace_optimize = exists(self.in_place_codebook_optimizer) 1021 | 1022 | # rearrange inputs 1023 | 1024 | if self.accept_image_fmap: 1025 | nframes, height, width = x.shape[-3:] 1026 | x = rearrange(x, "b c t h w -> b (t h w) c") 1027 | 1028 | if need_transpose: 1029 | x = rearrange(x, "b d n -> b n d") 1030 | 1031 | # project input 1032 | 1033 | x = self.project_in(x) 1034 | 1035 | # handle multi-headed separate codebooks 1036 | 1037 | if is_multiheaded: 1038 | ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d" 1039 | x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads) 1040 | 1041 | # l2norm for cosine sim, otherwise identity 1042 | 1043 | x = self._codebook.transform_input(x) 1044 | 1045 | # codebook forward kwargs 1046 | 1047 | codebook_forward_kwargs = dict( 1048 | sample_codebook_temp=sample_codebook_temp, 1049 | mask=mask, 1050 | freeze_codebook=freeze_codebook, 1051 | ) 1052 | 1053 | # quantize 1054 | 1055 | quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) 1056 | 1057 | # one step in-place update 1058 | 1059 | if should_inplace_optimize and self.training and not freeze_codebook: 1060 | 1061 | if exists(mask): 1062 | loss = F.mse_loss(quantize, x.detach(), reduction="none") 1063 | 1064 | loss_mask = mask 1065 | if is_multiheaded: 1066 | loss_mask = repeat( 1067 | mask, 1068 | "b n -> c (b h) n", 1069 | c=loss.shape[0], 1070 | h=loss.shape[1] // mask.shape[0], 1071 | ) 1072 | 1073 | loss = loss[loss_mask].mean() 1074 | 1075 | else: 1076 | loss = F.mse_loss(quantize, x.detach()) 1077 | 1078 | loss.backward() 1079 | self.in_place_codebook_optimizer.step() 1080 | self.in_place_codebook_optimizer.zero_grad() 1081 | 1082 | # quantize again 1083 | 1084 | quantize, embed_ind, distances = self._codebook( 1085 | x, **codebook_forward_kwargs 1086 | ) 1087 | 1088 | if self.training: 1089 | # determine code to use for commitment loss 1090 | maybe_detach = ( 1091 | torch.detach 1092 | if not self.learnable_codebook or freeze_codebook 1093 | else identity 1094 | ) 1095 | 1096 | commit_quantize = maybe_detach(quantize) 1097 | 1098 | # straight through 1099 | 1100 | quantize = x + (quantize - x).detach() 1101 | 1102 | if self.sync_update_v > 0.0: 1103 | # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf 1104 | quantize = quantize + self.sync_update_v * ( 1105 | quantize - quantize.detach() 1106 | ) 1107 | 1108 | # function for calculating cross entropy loss to distance matrix 1109 | # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss 1110 | 1111 | def calculate_ce_loss(codes): 1112 | if not is_multiheaded: 1113 | dist_einops_eq = "1 b n l -> b l n" 1114 | elif self.separate_codebook_per_head: 1115 | dist_einops_eq = "c b n l -> b l n c" 1116 | else: 1117 | dist_einops_eq = "1 (b h) n l -> b l n h" 1118 | 1119 | ce_loss = F.cross_entropy( 1120 | rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1 1121 | ) 1122 | 1123 | return ce_loss 1124 | 1125 | # if returning cross entropy loss on codes that were passed in 1126 | 1127 | if return_loss: 1128 | print(indices) 1129 | return quantize, calculate_ce_loss(indices) 1130 | 1131 | # transform embedding indices 1132 | 1133 | if is_multiheaded: 1134 | if self.separate_codebook_per_head: 1135 | embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads) 1136 | else: 1137 | embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) 1138 | 1139 | if self.accept_image_fmap: 1140 | embed_ind = rearrange( 1141 | embed_ind, "b (t h w) ... -> b t h w ...", t=nframes, h=height, w=width 1142 | ) 1143 | 1144 | if only_one: 1145 | embed_ind = rearrange(embed_ind, "b 1 ... -> b ...") 1146 | 1147 | # aggregate loss 1148 | 1149 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 1150 | 1151 | if self.training: 1152 | if self.commitment_weight > 0: 1153 | if self.commitment_use_cross_entropy_loss: 1154 | if exists(mask): 1155 | ce_loss_mask = mask 1156 | if is_multiheaded: 1157 | ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads) 1158 | 1159 | embed_ind.masked_fill_(~ce_loss_mask, -1) 1160 | 1161 | print(embed_ind.shape, embed_ind) 1162 | commit_loss = calculate_ce_loss(embed_ind) 1163 | else: 1164 | if exists(mask): 1165 | # with variable lengthed sequences 1166 | commit_loss = F.mse_loss(commit_quantize, x, reduction="none") 1167 | 1168 | loss_mask = mask 1169 | if is_multiheaded: 1170 | loss_mask = repeat( 1171 | loss_mask, 1172 | "b n -> c (b h) n", 1173 | c=commit_loss.shape[0], 1174 | h=commit_loss.shape[1] // mask.shape[0], 1175 | ) 1176 | 1177 | commit_loss = commit_loss[loss_mask].mean() 1178 | else: 1179 | commit_loss = F.mse_loss(commit_quantize, x) 1180 | 1181 | loss = loss + commit_loss * self.commitment_weight 1182 | 1183 | if self.has_codebook_orthogonal_loss: 1184 | codebook = self._codebook.embed 1185 | 1186 | # only calculate orthogonal loss for the activated codes for this batch 1187 | 1188 | if self.orthogonal_reg_active_codes_only: 1189 | assert not ( 1190 | is_multiheaded and self.separate_codebook_per_head 1191 | ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" 1192 | unique_code_ids = torch.unique(embed_ind) 1193 | codebook = codebook[:, unique_code_ids] 1194 | 1195 | num_codes = codebook.shape[-2] 1196 | 1197 | if ( 1198 | exists(self.orthogonal_reg_max_codes) 1199 | and num_codes > self.orthogonal_reg_max_codes 1200 | ): 1201 | rand_ids = torch.randperm(num_codes, device=device)[ 1202 | : self.orthogonal_reg_max_codes 1203 | ] 1204 | codebook = codebook[:, rand_ids] 1205 | 1206 | orthogonal_reg_loss = orthogonal_loss_fn(codebook) 1207 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight 1208 | 1209 | # handle multi-headed quantized embeddings 1210 | 1211 | if is_multiheaded: 1212 | if self.separate_codebook_per_head: 1213 | quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads) 1214 | else: 1215 | quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads) 1216 | 1217 | # project out 1218 | 1219 | quantize = self.project_out(quantize) 1220 | 1221 | # rearrange quantized embeddings 1222 | 1223 | if need_transpose: 1224 | quantize = rearrange(quantize, "b n d -> b d n") 1225 | 1226 | if self.accept_image_fmap: 1227 | quantize = rearrange( 1228 | quantize, "b (t h w) c -> b c t h w", t=nframes, h=height, w=width 1229 | ) 1230 | 1231 | if only_one: 1232 | quantize = rearrange(quantize, "b 1 d -> b d") 1233 | 1234 | # if masking, only return quantized for where mask has True 1235 | 1236 | if exists(mask): 1237 | quantize = torch.where( 1238 | rearrange(mask, "... -> ... 1"), quantize, orig_input 1239 | ) 1240 | 1241 | # return quantize, embed_ind, loss 1242 | perplexity = self.get_perplexity(embed_ind, x) 1243 | usage = self.get_usage(embed_ind) 1244 | 1245 | if self.call_cnt == 0: 1246 | self.codebook_usage.data = usage 1247 | else: 1248 | self.codebook_usage.data = ( 1249 | 0.99 * self.codebook_usage.data + (1 - 0.99) * usage 1250 | ) 1251 | 1252 | self.call_cnt += 1 1253 | # avg_distribution = self.codebook_usage.data.sum() / self.codebook_size 1254 | avg_usage = ( 1255 | self.codebook_usage.data > (1 / self.codebook_size) 1256 | ).sum() / self.codebook_size 1257 | 1258 | return dict( 1259 | embeddings=quantize, 1260 | encodings=embed_ind, 1261 | commitment_loss=loss, 1262 | perplexity=perplexity, 1263 | avg_usage=avg_usage, 1264 | batch_usage=usage, 1265 | ) 1266 | -------------------------------------------------------------------------------- /imagetokenizer/utils/omnitokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | # Shifts src_tf dim to dest dim 7 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 8 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 9 | n_dims = len(x.shape) 10 | if src_dim < 0: 11 | src_dim = n_dims + src_dim 12 | if dest_dim < 0: 13 | dest_dim = n_dims + dest_dim 14 | 15 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 16 | 17 | dims = list(range(n_dims)) 18 | del dims[src_dim] 19 | 20 | permutation = [] 21 | ctr = 0 22 | for i in range(n_dims): 23 | if i == dest_dim: 24 | permutation.append(src_dim) 25 | else: 26 | permutation.append(dims[ctr]) 27 | ctr += 1 28 | x = x.permute(permutation) 29 | if make_contiguous: 30 | x = x.contiguous() 31 | return x 32 | 33 | 34 | def Normalize(in_channels, norm_type="group"): 35 | assert norm_type in ["group", "batch"] 36 | if norm_type == "group": 37 | return torch.nn.GroupNorm( 38 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 39 | ) 40 | elif norm_type == "batch": 41 | return torch.nn.SyncBatchNorm(in_channels) 42 | 43 | 44 | def logits_laplace(x, x_recons, logit_laplace_eps=0.1): 45 | # [-0.5, 0.5] -> [0, 1] 46 | x += 0.5 47 | x_recons += 0.5 48 | # [0, 1] -> [eps, 1-eps] 49 | x_laplace = (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps 50 | x_recons_laplace = (1 - 2 * logit_laplace_eps) * x_recons + logit_laplace_eps 51 | return F.l1_loss(x_laplace, x_recons_laplace) 52 | 53 | 54 | def divisible_by(numer, denom): 55 | return (numer % denom) == 0 56 | 57 | 58 | def pair(val): 59 | ret = (val, val) if not isinstance(val, tuple) else val 60 | assert len(ret) == 2 61 | return ret 62 | 63 | 64 | def silu(x): 65 | return x * torch.sigmoid(x) 66 | 67 | 68 | class SiLU(nn.Module): 69 | def __init__(self): 70 | super(SiLU, self).__init__() 71 | 72 | def forward(self, x): 73 | return silu(x) 74 | -------------------------------------------------------------------------------- /imagetokenizer/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Lucas Jin. All rights reserved. 2 | from datetime import datetime 3 | 4 | major_num = 2 5 | 6 | __version__ = "0.0.2" 7 | short_version = __version__ 8 | 9 | 10 | def parse_version_info(version_str): 11 | version_info = [] 12 | for x in version_str.split("."): 13 | if x.isdigit(): 14 | version_info.append(int(x)) 15 | elif x.find("rc") != -1: 16 | patch_version = x.split("rc") 17 | version_info.append(int(patch_version[0])) 18 | version_info.append(f"rc{patch_version[1]}") 19 | return tuple(version_info) 20 | 21 | 22 | version_info = parse_version_info(__version__) 23 | -------------------------------------------------------------------------------- /ps.sh: -------------------------------------------------------------------------------- 1 | # autopep8 -r ./minigemini/ -i 2 | 3 | git add . 4 | git commit -am 'add' 5 | git push origin main 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (c) 2020 JinTian. 4 | # 5 | # This file is part of alfred 6 | # (see http://jinfagang.github.io). 7 | # 8 | # Licensed to the Apache Software Foundation (ASF) under one 9 | # or more contributor license agreements. See the NOTICE file 10 | # distributed with this work for additional information 11 | # regarding copyright ownership. The ASF licenses this file 12 | # to you under the Apache License, Version 2.0 (the 13 | # "License"); you may not use this file except in compliance 14 | # with the License. You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, 19 | # software distributed under the License is distributed on an 20 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 21 | # KIND, either express or implied. See the License for the 22 | # specific language governing permissions and limitations 23 | # under the License. 24 | # 25 | """ 26 | install alfred into local bin dir. 27 | """ 28 | from setuptools import setup, find_packages 29 | from setuptools import setup, Extension 30 | import io 31 | from os import path 32 | 33 | this_directory = path.abspath(path.dirname(__file__)) 34 | with io.open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 35 | long_description = f.read() 36 | 37 | 38 | version_file = "imagetokenizer/version.py" 39 | 40 | 41 | def get_version(): 42 | with open(version_file, "r") as f: 43 | exec(compile(f.read(), version_file, "exec")) 44 | return locals()["__version__"] 45 | 46 | 47 | setup( 48 | name="imagetokenizer", 49 | version=get_version(), 50 | keywords=["deep learning", "script helper", "tools"], 51 | description="Image Tokenizer encode visuals.", 52 | long_description=long_description, 53 | long_description_content_type="text/markdown", 54 | license="GPL-3.0", 55 | classifiers=[ 56 | # Operation system 57 | "Operating System :: OS Independent", 58 | # How mature is this project? Common values are 59 | # 3 - Alpha 60 | # 4 - Beta 61 | # 5 - Production/Stable 62 | "Development Status :: 4 - Beta", 63 | # Indicate who your project is intended for 64 | "Intended Audience :: Developers", 65 | # Topics 66 | "Topic :: Education", 67 | "Topic :: Scientific/Engineering", 68 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 69 | "Topic :: Scientific/Engineering :: Image Recognition", 70 | # Pick your license as you wish 71 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 72 | # Specify the Python versions you support here. In particular, ensure 73 | # that you indicate whether you support Python 2, Python 3 or both. 74 | "Programming Language :: Python :: 3", 75 | "Programming Language :: Python :: 3.6", 76 | "Programming Language :: Python :: 3.7", 77 | "Programming Language :: Python :: 3.8", 78 | "Programming Language :: Python :: 3.9", 79 | ], 80 | packages=["imagetokenizer"], 81 | # entry_points={"console_scripts": ["alfred = alfred.alfred:main"]}, 82 | include_package_data=True, 83 | author="Lucas Jin", 84 | author_email="jinfagang19@163.com", 85 | url="https://github.com/lucasjinreal/ImageTokenizer", 86 | platforms="any", 87 | install_requires=["beartype"], 88 | ) 89 | -------------------------------------------------------------------------------- /test_image_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sending an image, encode it in a [1, 16, h, w] token 3 | then decode it back to original image 4 | """ 5 | 6 | """ 7 | We provide Tokenizer Inference code here. 8 | """ 9 | import os 10 | import sys 11 | import torch 12 | import importlib 13 | import numpy as np 14 | from PIL import Image 15 | import argparse 16 | import torchvision.transforms as T 17 | from imagetokenizer.model import Magvit2Tokenizer, OmniTokenizer, TiTok 18 | 19 | 20 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def load_vqgan_new(num_down, ckpt_path=None, is_gumbel=False): 24 | if "magvit2" in ckpt_path.lower(): 25 | model = Magvit2Tokenizer(num_down=num_down, use_ema=True) 26 | if ckpt_path is not None: 27 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 28 | missing, unexpected = model.load_state_dict(sd, strict=False) 29 | elif "omni" in ckpt_path.lower(): 30 | model = OmniTokenizer() 31 | if ckpt_path is not None: 32 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 33 | missing, unexpected = model.load_state_dict(sd, strict=False) 34 | elif "titok" in ckpt_path.lower(): 35 | model = TiTok() 36 | if ckpt_path is not None: 37 | model.load_weights(ckpt_path) 38 | return model.eval() 39 | 40 | 41 | def get_obj_from_str(string, reload=False): 42 | print(string) 43 | module, cls = string.rsplit(".", 1) 44 | if reload: 45 | module_imp = importlib.import_module(module) 46 | importlib.reload(module_imp) 47 | return getattr(importlib.import_module(module, package=None), cls) 48 | 49 | 50 | def instantiate_from_config(config): 51 | if not "class_path" in config: 52 | raise KeyError("Expected key `class_path` to instantiate.") 53 | return get_obj_from_str(config["class_path"])(**config.get("init_args", dict())) 54 | 55 | 56 | def custom_to_pil(x): 57 | x = x.detach().cpu() 58 | x = torch.clamp(x, -1.0, 1.0) 59 | x = (x + 1.0) / 2.0 60 | x = x.permute(1, 2, 0).numpy() 61 | x = (255 * x).astype(np.uint8) 62 | x = Image.fromarray(x) 63 | if not x.mode == "RGB": 64 | x = x.convert("RGB") 65 | return x 66 | 67 | 68 | def get_image_tensor_for_encoder(image): 69 | image = image / 127.5 - 1.0 70 | image = T.ToTensor()(image).unsqueeze(0) 71 | # reshape the image to closest multiple 8 size 72 | height, width = image.shape[2], image.shape[3] 73 | new_height = ((height + 7) // 8) * 8 74 | new_width = ((width + 7) // 8) * 8 # 调整图像大小 75 | image = torch.nn.functional.interpolate( 76 | image, size=(new_height, new_width), mode="bilinear", align_corners=False 77 | ) 78 | return image 79 | 80 | 81 | def main(args): 82 | model = load_vqgan_new(args.num_down, args.ckpt_path).to(DEVICE) 83 | 84 | visualize_dir = "results/" 85 | visualize_version = "v0" 86 | visualize_original = os.path.join( 87 | visualize_dir, visualize_version, "original_{}".format(args.num_down) 88 | ) 89 | visualize_rec = os.path.join( 90 | visualize_dir, visualize_version, "rec_{}".format(args.num_down) 91 | ) 92 | if not os.path.exists(visualize_original): 93 | os.makedirs(visualize_original, exist_ok=True) 94 | 95 | if not os.path.exists(visualize_rec): 96 | os.makedirs(visualize_rec, exist_ok=True) 97 | 98 | img_f = args.image_file 99 | idx = os.path.basename(img_f)[:-4] + "_constructed" 100 | image_raw = Image.open(img_f) 101 | image = np.array(image_raw) 102 | print(f"original image size: {image.shape}") 103 | images_tensor = get_image_tensor_for_encoder(image) 104 | images_tensor = images_tensor.float().to(DEVICE) 105 | print(f"images: {images_tensor.shape}") 106 | 107 | quant, embedding, codebook_indices = model.encode(images_tensor) 108 | print(f"quant: {quant.shape}") 109 | print(f"embedding: {embedding.shape}") 110 | print(f"codebook_indices: {codebook_indices.shape}") 111 | reconstructed_images = model.decode(quant) 112 | 113 | image = images_tensor[0] 114 | reconstructed_image = reconstructed_images[0] 115 | 116 | image = custom_to_pil(image) 117 | reconstructed_image = custom_to_pil(reconstructed_image) 118 | reconstructed_image.resize((image_raw.width, image_raw.height)) 119 | 120 | image.save(os.path.join(visualize_original, "{}.png".format(idx))) 121 | reconstructed_image.save(os.path.join(visualize_rec, "{}.png".format(idx))) 122 | 123 | 124 | def get_args(): 125 | parser = argparse.ArgumentParser(description="inference parameters") 126 | parser.add_argument("--ckpt_path", required=True, type=str) 127 | parser.add_argument("--num_down", default=3, type=int) 128 | parser.add_argument("--batch_size", default=1, type=int) 129 | parser.add_argument("--image_file", default="images/a.jpg", type=str) 130 | parser.add_argument("--subset", default=None) 131 | parser.add_argument("--tokenizer", default="magvit2") 132 | 133 | return parser.parse_args() 134 | 135 | 136 | if __name__ == "__main__": 137 | args = get_args() 138 | main(args) 139 | -------------------------------------------------------------------------------- /upload_pypi.sh: -------------------------------------------------------------------------------- 1 | ## 2 | ## Copyright (c) 2020 JinTian. 3 | ## 4 | ## This file is part of alfred 5 | ## (see http://jinfagang.github.io). 6 | ## 7 | ## Licensed to the Apache Software Foundation (ASF) under one 8 | ## or more contributor license agreements. See the NOTICE file 9 | ## distributed with this work for additional information 10 | ## regarding copyright ownership. The ASF licenses this file 11 | ## to you under the Apache License, Version 2.0 (the 12 | ## "License"); you may not use this file except in compliance 13 | ## with the License. You may obtain a copy of the License at 14 | ## 15 | ## http://www.apache.org/licenses/LICENSE-2.0 16 | ## 17 | ## Unless required by applicable law or agreed to in writing, 18 | ## software distributed under the License is distributed on an 19 | ## "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 20 | ## KIND, either express or implied. See the License for the 21 | ## specific language governing permissions and limitations 22 | ## under the License. 23 | ## 24 | # check setup is correct or not 25 | python3 setup.py check 26 | 27 | # bumpver update --patch 28 | 29 | sudo rm -r build/ 30 | sudo rm -r dist/ 31 | 32 | # pypi interface are not valid any longer 33 | # python3 setup.py sdist 34 | # python3 setup.py sdist upload -r pypi 35 | 36 | # using twine instead 37 | python3 setup.py sdist 38 | twine upload dist/* 39 | 40 | --------------------------------------------------------------------------------