├── .gitignore
├── LICENSE
├── README.md
├── download_tokenizer_weights.sh
├── images
├── .gitignore
├── a.jpg
├── b.jpg
└── c.jpg
├── imagetokenizer
├── model
│ ├── __init__.py
│ ├── magvit2.py
│ ├── modules
│ │ ├── maskgit_vqgan.py
│ │ ├── omni_codebook.py
│ │ ├── omni_transformer.py
│ │ ├── titok_transformer.py
│ │ └── vae.py
│ ├── omnitokenizer.py
│ └── titok.py
├── quantize
│ ├── lookup_free_quantize.py
│ └── vector_quantize.py
├── utils
│ └── omnitokenizer_utils.py
└── version.py
├── ps.sh
├── setup.py
├── test_image_tokenizer.py
└── upload_pypi.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | build/
3 | alfred_py.egg-info/
4 | alfred.egg-info/
5 | dist/
6 | build/
7 | .vscode/
8 | vendor/
9 |
10 | *.pyc
11 | a.py
12 | __pycache__/vendor/
13 | upload_tpi.sh
14 | __pycache__/
15 | *.egg-info/
16 | checkpoints/
17 | results/
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | alfred Copyright (C) 2021 Lucas Jin
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ImageTokenizer: Unified Image and Video Tokenization
2 |
3 | Welcome to the **ImageTokenizer** repository! 🎉 This Python package is designed to simplify the process of image and video tokenization, a crucial step for various applications such as image/video generation and understanding. We provide a variety of popular tokenizers with a simple and unified interface, making your coding experience seamless and efficient. 🛠️
4 |
5 | > ⚠️💡 Note that this project is still in its early stages of development. We welcome any contributions from the community to help us improve and expand the package. Please make sure **star** and **fork** the repository if you find it useful. We are tacking on some awesome applications with `imagetokenizer` such as image/video generation and understanding. Stay tuned!
6 |
7 |
8 | ## Features
9 |
10 | - **Unified Interface**: A consistent API for all supported tokenizers.
11 | - **Extensive Support**: Covers a range of popular image and video tokenizers.
12 | - **Easy Integration**: Quick setup and integration with your projects.
13 | - **Different ImageTokenizers**: Support Magvit2, OmniTokenizer, Titok etc.
14 |
15 |
16 | ## Updates
17 |
18 | - 🔥**2024.06.22**: Titok were supported now! **This most minimal tokens num tokenizer as for now**;
19 | - 🔥**2024.06.22**: OmniTokenizer supported now!
20 |
21 |
22 | ## Supported Tokenizers
23 |
24 | Here's a list of the current supported image tokenizers:
25 |
26 | - **OmniTokenizer**: Versatile tokenizer capable of handling both images and videos.
27 | - **OpenMagvit2**: An open-source version of Magvit2, renowned for its excellent results.
28 |
29 | ## Getting Started
30 |
31 | To get started with ImageTokenizer, follow these simple steps:
32 |
33 | ### Installation
34 |
35 | You can install ImageTokenizer using pip:
36 |
37 | ```bash
38 | pip install imagetokenizer
39 | ```
40 |
41 | ### Usage
42 |
43 | Here's a quick example of how to use OmniTokenizer:
44 |
45 | ```python
46 | from imagetokenizer import Magvit2Tokenizer
47 |
48 | # Initialize the tokenizer
49 | image_tokenizer = Magvit2Tokenizer()
50 |
51 | # Tokenize an image
52 | quants, embedding, codebook_indices = image_tokenizer.encode("path_to_your_image.jpg")
53 |
54 | # Print the tokens
55 | print(image_tokens)
56 |
57 | image = image_tokenizer.decode(quants)
58 | ```
59 |
60 | ### Documentation
61 |
62 | For more detailed information and examples, please refer to our [official documentation](#).
63 |
64 | ## Contributing
65 |
66 | We welcome contributions! If you have an idea for a new tokenizer or want to improve existing ones, feel free to submit a pull request or create an issue. 🔧
67 |
68 | ## License
69 |
70 | ImageTokenizer is open-source and available under the [MIT License](LICENSE).
71 |
72 | ## Community
73 |
74 | - Join our [Slack Channel](#) to discuss and collaborate.
75 | - Follow us on [Twitter](#) for updates and news.
76 |
77 | ## Acknowledgements
78 |
79 | We would like to thank all the contributors and the community for their support and feedback. 🙏
80 |
--------------------------------------------------------------------------------
/download_tokenizer_weights.sh:
--------------------------------------------------------------------------------
1 | export HF_ENDPOINT=https://hf-mirror.com
2 |
3 | mkdir checkpoints
4 | cd checkpoints
5 |
6 | # download tokenizer weights
7 | huggingface-cli download TencentARC/Open-MAGVIT2 --local-dir magvit2
8 | huggingface-cli download fun-research/TiTok --local-dir titok
9 |
10 | wget $HF_ENDPOINT/Daniel0724/OmniTokenizer/resolve/main/imagenet_sthv2.ckpt -o omni_imagenet_sthv2.ckpt
--------------------------------------------------------------------------------
/images/.gitignore:
--------------------------------------------------------------------------------
1 | *_constructed*.png
2 |
--------------------------------------------------------------------------------
/images/a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/a.jpg
--------------------------------------------------------------------------------
/images/b.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/b.jpg
--------------------------------------------------------------------------------
/images/c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/ImageTokenizer/c9b0193e2d1e21988ed4bbc6fe96b98298b050ad/images/c.jpg
--------------------------------------------------------------------------------
/imagetokenizer/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .magvit2 import Magvit2Tokenizer
2 | from .omnitokenizer import OmniTokenizer
3 | from .titok import TiTok
4 |
--------------------------------------------------------------------------------
/imagetokenizer/model/magvit2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | """
5 | for inference only
6 | """
7 | from collections import OrderedDict
8 | from torch import nn
9 | import torch
10 | from ..quantize.lookup_free_quantize import LFQ
11 |
12 |
13 | class Magvit2Tokenizer(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | resolution=128,
18 | num_down=4,
19 | ### Quantize Related
20 | n_embed=262144,
21 | embed_dim=18,
22 | sample_minimization_weight=1.0,
23 | batch_maximization_weight=1.0,
24 | ckpt_path=None,
25 | ignore_keys=[],
26 | use_ema=False,
27 | token_factorization=False,
28 | ):
29 | super().__init__()
30 | ddconfig = {
31 | "double_z": False,
32 | "z_channels": 18,
33 | "resolution": resolution,
34 | "in_channels": 3,
35 | "out_ch": 3,
36 | "ch": 128,
37 | "ch_mult": [1, 2, 2, 4], # num_down = len(ch_mult)-1
38 | "num_res_blocks": 2,
39 | }
40 | if num_down == 4:
41 | ddconfig["ch_mult"] = [1, 1, 2, 2, 4] # num_down = len(ch_mult)-1
42 | elif num_down == 3:
43 | ddconfig["ch_mult"] = [1, 2, 2, 4] # num_down = len(ch_mult)-1
44 | if ckpt_path and "256" in ckpt_path:
45 | ddconfig["resolution"] = 256
46 | self.use_ema = use_ema
47 | self.encoder = Encoder(**ddconfig)
48 | self.decoder = Decoder(**ddconfig)
49 | self.quantize = LFQ(
50 | dim=embed_dim,
51 | codebook_size=n_embed,
52 | sample_minimization_weight=sample_minimization_weight,
53 | batch_maximization_weight=batch_maximization_weight,
54 | token_factorization=token_factorization,
55 | )
56 |
57 | if ckpt_path is not None:
58 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, stage=None)
59 |
60 | def init_from_ckpt(self, path, ignore_keys=list(), stage=None):
61 | sd = torch.load(path, map_location="cpu")["state_dict"]
62 | ema_mapping = {}
63 | new_params = OrderedDict()
64 | if stage == "transformer": ### directly use ema encoder and decoder parameter
65 | if self.use_ema:
66 | for k, v in sd.items():
67 | if "encoder" in k:
68 | if "model_ema" in k:
69 | k = k.replace(
70 | "model_ema.", ""
71 | ) # load EMA Encoder or Decoder
72 | new_k = ema_mapping[k]
73 | new_params[new_k] = v
74 | s_name = k.replace(".", "")
75 | ema_mapping.update({s_name: k})
76 | continue
77 | if "decoder" in k:
78 | if "model_ema" in k:
79 | k = k.replace(
80 | "model_ema.", ""
81 | ) # load EMA Encoder or Decoder
82 | new_k = ema_mapping[k]
83 | new_params[new_k] = v
84 | s_name = k.replace(".", "")
85 | ema_mapping.update({s_name: k})
86 | continue
87 | else: # also only load the Generator
88 | for k, v in sd.items():
89 | if "encoder" in k:
90 | new_params[k] = v
91 | elif "decoder" in k:
92 | new_params[k] = v
93 | missing_keys, unexpected_keys = self.load_state_dict(
94 | new_params, strict=False
95 | )
96 | else: ## simple resume
97 | missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
98 | print(f"Restored from {path}")
99 |
100 | def encode(self, x, return_embed_fea=True):
101 | h = self.encoder(x)
102 | # print(f'h {h} {h.shape}')
103 | (quant, emb_loss, info) = self.quantize(
104 | h, return_loss_breakdown=False, return_loss=False
105 | )
106 | # print(info)
107 | ### using token factorization the info is a tuple (each for embedding)
108 | if return_embed_fea:
109 | return quant, h, info
110 | else:
111 | return quant, emb_loss, info
112 |
113 | def decode(self, quant):
114 | dec = self.decoder(quant)
115 | return dec
116 |
117 | def forward(self, input):
118 | (
119 | quant,
120 | diff,
121 | _,
122 | ) = self.encode(input)
123 | # print(quant)
124 | # print(f'quant: {quant.shape}, diff: {diff.shape}')
125 | dec = self.decode(quant)
126 | return dec
127 |
128 |
129 | def swish(x):
130 | # swish
131 | return x * torch.sigmoid(x)
132 |
133 |
134 | class ResBlock(nn.Module):
135 | def __init__(self, in_filters, out_filters, use_conv_shortcut=False) -> None:
136 | super().__init__()
137 |
138 | self.in_filters = in_filters
139 | self.out_filters = out_filters
140 | self.use_conv_shortcut = use_conv_shortcut
141 |
142 | self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
143 | self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
144 |
145 | self.conv1 = nn.Conv2d(
146 | in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False
147 | )
148 | self.conv2 = nn.Conv2d(
149 | out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False
150 | )
151 |
152 | if in_filters != out_filters:
153 | if self.use_conv_shortcut:
154 | self.conv_shortcut = nn.Conv2d(
155 | in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False
156 | )
157 | else:
158 | self.nin_shortcut = nn.Conv2d(
159 | in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False
160 | )
161 |
162 | def forward(self, x, **kwargs):
163 | residual = x
164 |
165 | x = self.norm1(x)
166 | x = swish(x)
167 | x = self.conv1(x)
168 | x = self.norm2(x)
169 | x = swish(x)
170 | x = self.conv2(x)
171 | if self.in_filters != self.out_filters:
172 | if self.use_conv_shortcut:
173 | residual = self.conv_shortcut(residual)
174 | else:
175 | residual = self.nin_shortcut(residual)
176 |
177 | return x + residual
178 |
179 |
180 | class Encoder(nn.Module):
181 | def __init__(
182 | self,
183 | *,
184 | ch,
185 | out_ch,
186 | in_channels,
187 | num_res_blocks,
188 | z_channels,
189 | ch_mult=(1, 2, 2, 4),
190 | resolution,
191 | double_z=False,
192 | ):
193 | super().__init__()
194 |
195 | self.in_channels = in_channels
196 | self.z_channels = z_channels
197 | self.resolution = resolution
198 |
199 | self.num_res_blocks = num_res_blocks
200 | self.num_blocks = len(ch_mult)
201 |
202 | self.conv_in = nn.Conv2d(
203 | in_channels, ch, kernel_size=(3, 3), padding=1, bias=False
204 | )
205 |
206 | ## construct the model
207 | self.down = nn.ModuleList()
208 |
209 | in_ch_mult = (1,) + tuple(ch_mult)
210 | for i_level in range(self.num_blocks):
211 | block = nn.ModuleList()
212 | block_in = ch * in_ch_mult[i_level] # [1, 1, 2, 2, 4]
213 | block_out = ch * ch_mult[i_level] # [1, 2, 2, 4]
214 | for _ in range(self.num_res_blocks):
215 | block.append(ResBlock(block_in, block_out))
216 | block_in = block_out
217 |
218 | down = nn.Module()
219 | down.block = block
220 | if i_level < self.num_blocks - 1:
221 | down.downsample = nn.Conv2d(
222 | block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1
223 | )
224 |
225 | self.down.append(down)
226 |
227 | ### mid
228 | self.mid_block = nn.ModuleList()
229 | for res_idx in range(self.num_res_blocks):
230 | self.mid_block.append(ResBlock(block_in, block_in))
231 |
232 | ### end
233 | self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
234 | self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1))
235 |
236 | def forward(self, x):
237 |
238 | ## down
239 | x = self.conv_in(x)
240 | for i_level in range(self.num_blocks):
241 | for i_block in range(self.num_res_blocks):
242 | x = self.down[i_level].block[i_block](x)
243 |
244 | if i_level < self.num_blocks - 1:
245 | x = self.down[i_level].downsample(x)
246 |
247 | ## mid
248 | for res in range(self.num_res_blocks):
249 | x = self.mid_block[res](x)
250 |
251 | x = self.norm_out(x)
252 | x = swish(x)
253 | x = self.conv_out(x)
254 |
255 | return x
256 |
257 |
258 | class Decoder(nn.Module):
259 | def __init__(
260 | self,
261 | *,
262 | ch,
263 | out_ch,
264 | in_channels,
265 | num_res_blocks,
266 | z_channels,
267 | ch_mult=(1, 2, 2, 4),
268 | resolution,
269 | double_z=False,
270 | ) -> None:
271 | super().__init__()
272 |
273 | self.ch = ch
274 | self.num_blocks = len(ch_mult)
275 | self.num_res_blocks = num_res_blocks
276 | self.resolution = resolution
277 | self.in_channels = in_channels
278 |
279 | block_in = ch * ch_mult[self.num_blocks - 1]
280 |
281 | self.conv_in = nn.Conv2d(
282 | z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
283 | )
284 |
285 | self.mid_block = nn.ModuleList()
286 | for res_idx in range(self.num_res_blocks):
287 | self.mid_block.append(ResBlock(block_in, block_in))
288 |
289 | self.up = nn.ModuleList()
290 |
291 | for i_level in reversed(range(self.num_blocks)):
292 | block = nn.ModuleList()
293 | block_out = ch * ch_mult[i_level]
294 | for i_block in range(self.num_res_blocks):
295 | block.append(ResBlock(block_in, block_out))
296 | block_in = block_out
297 |
298 | up = nn.Module()
299 | up.block = block
300 | if i_level > 0:
301 | up.upsample = Upsampler(block_in)
302 | self.up.insert(0, up)
303 |
304 | self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
305 |
306 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
307 |
308 | def forward(self, z):
309 |
310 | z = self.conv_in(z)
311 |
312 | ## mid
313 | for res in range(self.num_res_blocks):
314 | z = self.mid_block[res](z)
315 |
316 | ## upsample
317 | for i_level in reversed(range(self.num_blocks)):
318 | for i_block in range(self.num_res_blocks):
319 | z = self.up[i_level].block[i_block](z)
320 |
321 | if i_level > 0:
322 | z = self.up[i_level].upsample(z)
323 |
324 | z = self.norm_out(z)
325 | z = swish(z)
326 | z = self.conv_out(z)
327 |
328 | return z
329 |
330 |
331 | def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
332 | """Depth-to-Space DCR mode (depth-column-row) core implementation.
333 |
334 | Args:
335 | x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
336 | block_size (int): block side size
337 | """
338 | # check inputs
339 | if x.dim() < 3:
340 | raise ValueError(
341 | f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
342 | )
343 | c, h, w = x.shape[-3:]
344 |
345 | s = block_size**2
346 | if c % s != 0:
347 | raise ValueError(
348 | f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
349 | )
350 |
351 | outer_dims = x.shape[:-3]
352 |
353 | # splitting two additional dimensions from the channel dimension
354 | x = x.view(-1, block_size, block_size, c // s, h, w)
355 |
356 | # putting the two new dimensions along H and W
357 | x = x.permute(0, 3, 4, 1, 5, 2)
358 |
359 | # merging the two new dimensions with H and W
360 | x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size)
361 |
362 | return x
363 |
364 |
365 | class Upsampler(nn.Module):
366 | def __init__(self, dim, dim_out=None):
367 | super().__init__()
368 | dim_out = dim * 4
369 | self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
370 | self.depth2space = depth_to_space
371 |
372 | def forward(self, x):
373 | """
374 | input_image: [B C H W]
375 | """
376 | out = self.conv1(x)
377 | out = self.depth2space(out, block_size=2)
378 | return out
379 |
380 |
381 | if __name__ == "__main__":
382 | x = torch.randn(size=(2, 3, 128, 128))
383 | encoder = Encoder(
384 | ch=128, in_channels=3, num_res_blocks=2, z_channels=18, out_ch=3, resolution=128
385 | )
386 | decoder = Decoder(
387 | out_ch=3, z_channels=18, num_res_blocks=2, ch=128, in_channels=3, resolution=128
388 | )
389 | z = encoder(x)
390 | out = decoder(z)
391 |
--------------------------------------------------------------------------------
/imagetokenizer/model/modules/maskgit_vqgan.py:
--------------------------------------------------------------------------------
1 | """This file contains code for MaskGIT-VQGAN.
2 |
3 | This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4 | All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5 |
6 | Reference:
7 | https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py
8 | """
9 |
10 | # Copyright 2023 Google LLC and The HuggingFace Inc. team.
11 | #
12 | # Licensed under the Apache License, Version 2.0 (the "License");
13 | # you may not use this file except in compliance with the License.
14 | # You may obtain a copy of the License at
15 | #
16 | # http://www.apache.org/licenses/LICENSE-2.0
17 | #
18 | # Unless required by applicable law or agreed to in writing, software
19 | # distributed under the License is distributed on an "AS IS" BASIS,
20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | # See the License for the specific language governing permissions and
22 | # limitations under the License.
23 |
24 | r"""MaskGIT Tokenizer based on VQGAN.
25 |
26 | This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841]
27 | with several modifications. The non-local layers are removed from VQGAN for
28 | faster speed.
29 | """
30 |
31 | import math
32 |
33 | import torch
34 | import torch.nn.functional as F
35 | from torch import nn
36 |
37 |
38 | # Conv2D with same padding
39 | class Conv2dSame(nn.Conv2d):
40 | def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
41 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
42 |
43 | def forward(self, x: torch.Tensor) -> torch.Tensor:
44 | ih, iw = x.size()[-2:]
45 |
46 | pad_h = self.calc_same_pad(
47 | i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
48 | )
49 | pad_w = self.calc_same_pad(
50 | i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
51 | )
52 |
53 | if pad_h > 0 or pad_w > 0:
54 | x = F.pad(
55 | x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
56 | )
57 | return super().forward(x)
58 |
59 |
60 | class ResnetBlock(nn.Module):
61 | def __init__(
62 | self,
63 | in_channels: int,
64 | out_channels: int = None,
65 | dropout_prob: float = 0.0,
66 | ):
67 | super().__init__()
68 |
69 | self.in_channels = in_channels
70 | self.out_channels = out_channels
71 | self.out_channels_ = (
72 | self.in_channels if self.out_channels is None else self.out_channels
73 | )
74 |
75 | self.norm1 = nn.GroupNorm(
76 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
77 | )
78 | self.conv1 = Conv2dSame(
79 | self.in_channels, self.out_channels_, kernel_size=3, bias=False
80 | )
81 |
82 | self.norm2 = nn.GroupNorm(
83 | num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True
84 | )
85 | self.dropout = nn.Dropout(dropout_prob)
86 | self.conv2 = Conv2dSame(
87 | self.out_channels_, self.out_channels_, kernel_size=3, bias=False
88 | )
89 |
90 | if self.in_channels != self.out_channels_:
91 | self.nin_shortcut = Conv2dSame(
92 | self.out_channels_, self.out_channels_, kernel_size=1, bias=False
93 | )
94 |
95 | def forward(self, hidden_states):
96 | residual = hidden_states
97 | hidden_states = self.norm1(hidden_states)
98 | hidden_states = F.silu(hidden_states)
99 | hidden_states = self.conv1(hidden_states)
100 |
101 | hidden_states = self.norm2(hidden_states)
102 | hidden_states = F.silu(hidden_states)
103 | hidden_states = self.dropout(hidden_states)
104 | hidden_states = self.conv2(hidden_states)
105 |
106 | if self.in_channels != self.out_channels_:
107 | residual = self.nin_shortcut(hidden_states)
108 |
109 | return hidden_states + residual
110 |
111 |
112 | class DownsamplingBlock(nn.Module):
113 | def __init__(self, config, block_idx: int):
114 | super().__init__()
115 |
116 | self.config = config
117 | self.block_idx = block_idx
118 |
119 | in_channel_mult = (1,) + tuple(self.config.channel_mult)
120 | block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
121 | block_out = (
122 | self.config.hidden_channels * self.config.channel_mult[self.block_idx]
123 | )
124 |
125 | res_blocks = nn.ModuleList()
126 | for _ in range(self.config.num_res_blocks):
127 | res_blocks.append(
128 | ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)
129 | )
130 | block_in = block_out
131 | self.block = res_blocks
132 |
133 | self.downsample = self.block_idx != self.config.num_resolutions - 1
134 |
135 | def forward(self, hidden_states):
136 | for res_block in self.block:
137 | hidden_states = res_block(hidden_states)
138 |
139 | if self.downsample:
140 | hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
141 |
142 | return hidden_states
143 |
144 |
145 | class UpsamplingBlock(nn.Module):
146 | def __init__(self, config, block_idx: int):
147 | super().__init__()
148 |
149 | self.config = config
150 | self.block_idx = block_idx
151 |
152 | if self.block_idx == self.config.num_resolutions - 1:
153 | block_in = self.config.hidden_channels * self.config.channel_mult[-1]
154 | else:
155 | block_in = (
156 | self.config.hidden_channels
157 | * self.config.channel_mult[self.block_idx + 1]
158 | )
159 |
160 | block_out = (
161 | self.config.hidden_channels * self.config.channel_mult[self.block_idx]
162 | )
163 |
164 | res_blocks = []
165 | for _ in range(self.config.num_res_blocks):
166 | res_blocks.append(
167 | ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)
168 | )
169 | block_in = block_out
170 | self.block = nn.ModuleList(res_blocks)
171 |
172 | self.add_upsample = self.block_idx != 0
173 | if self.add_upsample:
174 | self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)
175 |
176 | def forward(self, hidden_states):
177 | for res_block in self.block:
178 | hidden_states = res_block(hidden_states)
179 |
180 | if self.add_upsample:
181 | hidden_states = F.interpolate(
182 | hidden_states, scale_factor=2.0, mode="nearest"
183 | )
184 | hidden_states = self.upsample_conv(hidden_states)
185 |
186 | return hidden_states
187 |
188 |
189 | class Encoder(nn.Module):
190 | def __init__(self, config):
191 | super().__init__()
192 | self.config = config
193 | # downsampling
194 | self.conv_in = Conv2dSame(
195 | self.config.num_channels,
196 | self.config.hidden_channels,
197 | kernel_size=3,
198 | bias=False,
199 | )
200 |
201 | downsample_blocks = []
202 | for i_level in range(self.config.num_resolutions):
203 | downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level))
204 | self.down = nn.ModuleList(downsample_blocks)
205 |
206 | # middle
207 | mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
208 | res_blocks = nn.ModuleList()
209 | for _ in range(self.config.num_res_blocks):
210 | res_blocks.append(
211 | ResnetBlock(
212 | mid_channels, mid_channels, dropout_prob=self.config.dropout
213 | )
214 | )
215 | self.mid = res_blocks
216 |
217 | # end
218 | self.norm_out = nn.GroupNorm(
219 | num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True
220 | )
221 | self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1)
222 |
223 | def forward(self, pixel_values):
224 | # downsampling
225 | hidden_states = self.conv_in(pixel_values)
226 | for block in self.down:
227 | hidden_states = block(hidden_states)
228 |
229 | # middle
230 | for block in self.mid:
231 | hidden_states = block(hidden_states)
232 |
233 | # end
234 | hidden_states = self.norm_out(hidden_states)
235 | hidden_states = F.silu(hidden_states)
236 | hidden_states = self.conv_out(hidden_states)
237 | return hidden_states
238 |
239 |
240 | class Decoder(nn.Module):
241 | def __init__(self, config):
242 | super().__init__()
243 |
244 | self.config = config
245 |
246 | # compute in_channel_mult, block_in and curr_res at lowest res
247 | block_in = (
248 | self.config.hidden_channels
249 | * self.config.channel_mult[self.config.num_resolutions - 1]
250 | )
251 | curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
252 | self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
253 |
254 | # z to block_in
255 | self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3)
256 |
257 | # middle
258 | res_blocks = nn.ModuleList()
259 | for _ in range(self.config.num_res_blocks):
260 | res_blocks.append(
261 | ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout)
262 | )
263 | self.mid = res_blocks
264 |
265 | # upsampling
266 | upsample_blocks = []
267 | for i_level in reversed(range(self.config.num_resolutions)):
268 | upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))
269 | self.up = nn.ModuleList(
270 | list(reversed(upsample_blocks))
271 | ) # reverse to get consistent order
272 |
273 | # end
274 | block_out = self.config.hidden_channels * self.config.channel_mult[0]
275 | self.norm_out = nn.GroupNorm(
276 | num_groups=32, num_channels=block_out, eps=1e-6, affine=True
277 | )
278 | self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3)
279 |
280 | def forward(self, hidden_states):
281 | # z to block_in
282 | hidden_states = self.conv_in(hidden_states)
283 |
284 | # middle
285 | for block in self.mid:
286 | hidden_states = block(hidden_states)
287 |
288 | # upsampling
289 | for block in reversed(self.up):
290 | hidden_states = block(hidden_states)
291 |
292 | # end
293 | hidden_states = self.norm_out(hidden_states)
294 | hidden_states = F.silu(hidden_states)
295 | hidden_states = self.conv_out(hidden_states)
296 |
297 | return hidden_states
298 |
299 |
300 | class VectorQuantizer(nn.Module):
301 | """
302 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
303 | Discretization bottleneck part of the VQ-VAE.
304 | """
305 |
306 | def __init__(self, num_embeddings, embedding_dim, commitment_cost):
307 | r"""
308 | Args:
309 | num_embeddings: number of vectors in the quantized space.
310 | embedding_dim: dimensionality of the tensors in the quantized space.
311 | Inputs to the modules must be in this format as well.
312 | commitment_cost: scalar which controls the weighting of the loss terms
313 | (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
314 | """
315 | super().__init__()
316 |
317 | self.num_embeddings = num_embeddings
318 | self.embedding_dim = embedding_dim
319 | self.commitment_cost = commitment_cost
320 |
321 | self.embedding = nn.Embedding(num_embeddings, embedding_dim)
322 | self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
323 |
324 | def forward(self, hidden_states, return_loss=False):
325 | """
326 | Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
327 | closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
328 | quantization pipeline:
329 | 1. get encoder input (B,C,H,W)
330 | 2. flatten input to (B*H*W,C)
331 | """
332 | # reshape z -> (batch, height, width, channel) and flatten
333 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
334 |
335 | distances = self.compute_distances(hidden_states)
336 | min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
337 | min_encodings = torch.zeros(
338 | min_encoding_indices.shape[0], self.num_embeddings
339 | ).to(hidden_states)
340 | min_encodings.scatter_(1, min_encoding_indices, 1)
341 |
342 | # get quantized latent vectors
343 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(
344 | hidden_states.shape
345 | )
346 |
347 | # reshape to (batch, num_tokens)
348 | min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
349 |
350 | # compute loss for embedding
351 | loss = None
352 | if return_loss:
353 | loss = torch.mean(
354 | (z_q.detach() - hidden_states) ** 2
355 | ) + self.commitment_cost * torch.mean((z_q - hidden_states.detach()) ** 2)
356 | # preserve gradients
357 | z_q = hidden_states + (z_q - hidden_states).detach()
358 |
359 | # reshape back to match original input shape
360 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
361 |
362 | return z_q, min_encoding_indices, loss
363 |
364 | def compute_distances(self, hidden_states):
365 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
366 | hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
367 | emb_weights = self.embedding.weight.t()
368 |
369 | inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
370 | codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
371 | distances = torch.addmm(
372 | inputs_norm_sq + codebook_t_norm_sq,
373 | hidden_states_flattended,
374 | emb_weights,
375 | alpha=-2.0,
376 | )
377 | return distances
378 |
379 | def get_codebook_entry(self, indices):
380 | # indices are expected to be of shape (batch, num_tokens)
381 | # get quantized latent vectors
382 | if len(indices.shape) == 2:
383 | batch, num_tokens = indices.shape
384 | z_q = self.embedding(indices)
385 | z_q = z_q.reshape(
386 | batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1
387 | ).permute(0, 3, 1, 2)
388 | elif len(indices.shape) == 3:
389 | batch, height, width = indices.shape
390 | indices = indices.view(batch, -1)
391 | z_q = self.embedding(indices)
392 | z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2)
393 | else:
394 | print(indices.shape)
395 | raise NotImplementedError
396 | return z_q
397 |
398 | # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
399 | def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
400 | hidden_states = hidden_states.permute(
401 | 0, 2, 3, 1
402 | ).contiguous() # (batch, height, width, channel)
403 | distances = self.compute_distances(
404 | hidden_states
405 | ) # (batch * height * width, num_embeddings)
406 |
407 | soft_code = F.softmax(
408 | -distances / temp, dim=-1
409 | ) # (batch * height * width, num_embeddings)
410 | if stochastic:
411 | code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
412 | else:
413 | code = distances.argmin(dim=-1) # (batch * height * width)
414 |
415 | code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
416 | batch, num_tokens = code.shape
417 | soft_code = soft_code.reshape(
418 | batch, num_tokens, -1
419 | ) # (batch, height * width, num_embeddings)
420 | return soft_code, code
421 |
422 | def get_code(self, hidden_states):
423 | # reshape z -> (batch, height, width, channel)
424 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
425 | distances = self.compute_distances(hidden_states)
426 | indices = torch.argmin(distances, axis=1).unsqueeze(1)
427 | indices = indices.reshape(hidden_states.shape[0], -1)
428 | return indices
429 |
--------------------------------------------------------------------------------
/imagetokenizer/model/modules/omni_codebook.py:
--------------------------------------------------------------------------------
1 | from enum import unique
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.distributed as dist
8 |
9 | from imagetokenizer.utils.omnitokenizer_utils import shift_dim
10 |
11 |
12 | class Codebook(nn.Module):
13 | def __init__(
14 | self,
15 | n_codes,
16 | embedding_dim,
17 | no_random_restart=False,
18 | restart_thres=1.0,
19 | usage_sigma=0.99,
20 | fp32_quant=False,
21 | ):
22 | super().__init__()
23 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
24 | self.register_buffer("N", torch.zeros(n_codes))
25 | self.register_buffer("z_avg", self.embeddings.data.clone())
26 | self.register_buffer("codebook_usage", torch.zeros(n_codes))
27 |
28 | self.call_cnt = 0
29 | self.usage_sigma = usage_sigma
30 |
31 | self.n_codes = n_codes
32 | self.embedding_dim = embedding_dim
33 | self._need_init = True
34 | self.no_random_restart = no_random_restart
35 | self.restart_thres = restart_thres
36 |
37 | self.fp32_quant = fp32_quant
38 |
39 | def _tile(self, x):
40 | d, ew = x.shape
41 | if d < self.n_codes:
42 | n_repeats = (self.n_codes + d - 1) // d
43 | std = 0.01 / np.sqrt(ew)
44 | x = x.repeat(n_repeats, 1)
45 | x = x + torch.randn_like(x) * std
46 | return x
47 |
48 | def _init_embeddings(self, z):
49 | # z: [b, c, t, h, w]
50 | self._need_init = False
51 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
52 | y = self._tile(flat_inputs)
53 |
54 | d = y.shape[0]
55 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
56 | if dist.is_initialized():
57 | dist.broadcast(_k_rand, 0)
58 | self.embeddings.data.copy_(_k_rand)
59 | self.z_avg.data.copy_(_k_rand)
60 | self.N.data.copy_(torch.ones(self.n_codes))
61 |
62 | def calculate_batch_codebook_usage_percentage(self, batch_encoding_indices):
63 | # Flatten the batch of encoding indices into a single 1D tensor
64 | all_indices = batch_encoding_indices.flatten()
65 |
66 | # Obtain the total number of encoding indices in the batch to calculate percentages
67 | total_indices = all_indices.numel()
68 |
69 | # Initialize a tensor to store the percentage usage of each code
70 | codebook_usage_percentage = torch.zeros(self.n_codes, device=all_indices.device)
71 |
72 | # Count the number of occurrences of each index and get their frequency as percentages
73 | unique_indices, counts = torch.unique(all_indices, return_counts=True)
74 | # Calculate the percentage
75 | percentages = counts.float() / total_indices
76 |
77 | # Populate the corresponding percentages in the codebook_usage_percentage tensor
78 | codebook_usage_percentage[unique_indices.long()] = percentages
79 |
80 | return codebook_usage_percentage
81 |
82 | def forward(self, z):
83 | # z: [b, c, t, h, w]
84 | if self._need_init and self.training:
85 | self._init_embeddings(z)
86 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c]
87 |
88 | distances = (
89 | (flat_inputs**2).sum(dim=1, keepdim=True)
90 | - 2 * flat_inputs @ self.embeddings.t()
91 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
92 | ) # [bthw, c]
93 |
94 | encoding_indices = torch.argmin(distances, dim=1)
95 | encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(
96 | flat_inputs
97 | ) # [bthw, ncode]
98 | encoding_indices = encoding_indices.view(
99 | z.shape[0], *z.shape[2:]
100 | ) # [b, t, h, w, ncode]
101 |
102 | embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, h, w, c]
103 | embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w]
104 |
105 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
106 |
107 | # EMA codebook update
108 | if self.training:
109 | n_total = encode_onehot.sum(dim=0)
110 | encode_sum = flat_inputs.t() @ encode_onehot
111 | if dist.is_initialized():
112 | dist.all_reduce(n_total)
113 | dist.all_reduce(encode_sum)
114 |
115 | self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
116 | self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
117 |
118 | n = self.N.sum()
119 | weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
120 | encode_normalized = self.z_avg / weights.unsqueeze(1)
121 | self.embeddings.data.copy_(encode_normalized)
122 |
123 | y = self._tile(flat_inputs)
124 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
125 | if dist.is_initialized():
126 | dist.broadcast(_k_rand, 0)
127 |
128 | if not self.no_random_restart:
129 | usage = (self.N.view(self.n_codes, 1) >= self.restart_thres).float()
130 | self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
131 |
132 | embeddings_st = (embeddings - z).detach() + z
133 |
134 | avg_probs = torch.mean(encode_onehot, dim=0)
135 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
136 |
137 | try:
138 | usage = self.calculate_batch_codebook_usage_percentage(encoding_indices)
139 | except:
140 | usage = torch.zeros(self.n_codes, device=encoding_indices.device)
141 |
142 | # print(usage.shape, torch.zeros(self.n_codes).shape)
143 |
144 | if self.call_cnt == 0:
145 | self.codebook_usage.data = usage
146 | else:
147 | self.codebook_usage.data = (
148 | self.usage_sigma * self.codebook_usage.data
149 | + (1 - self.usage_sigma) * usage
150 | )
151 |
152 | self.call_cnt += 1
153 | # avg_distribution = self.codebook_usage.data.sum() / self.n_codes
154 | avg_usage = (self.codebook_usage.data > (1 / self.n_codes)).sum() / self.n_codes
155 |
156 | return dict(
157 | embeddings=embeddings_st,
158 | encodings=encoding_indices,
159 | commitment_loss=commitment_loss,
160 | perplexity=perplexity,
161 | avg_usage=avg_usage,
162 | batch_usage=usage,
163 | )
164 |
165 | def dictionary_lookup(self, encodings):
166 | embeddings = F.embedding(encodings, self.embeddings)
167 | return embeddings
168 |
--------------------------------------------------------------------------------
/imagetokenizer/model/modules/omni_transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn, einsum
5 | from beartype import beartype
6 | from typing import Tuple
7 |
8 | from einops import rearrange, repeat
9 | from einops.layers.torch import Rearrange
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 |
12 |
13 | def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
14 | # Refer to `Unroll` to see how this performs a maxpool-Nd
15 | # B, N, C
16 | return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
17 |
18 |
19 | def exists(val):
20 | return val is not None
21 |
22 |
23 | def default(val, d):
24 | return val if exists(val) else d
25 |
26 |
27 | def leaky_relu(p=0.1):
28 | return nn.LeakyReLU(p)
29 |
30 |
31 | def l2norm(t):
32 | return F.normalize(t, dim=-1)
33 |
34 |
35 | def precompute_freqs_cis_2d(
36 | dim: int, end: int, theta: float = 10000.0, scale=1.0, use_cls=False
37 | ):
38 | H = int(end**0.5)
39 | # assert H * H == end
40 | flat_patch_pos = torch.arange(0 if not use_cls else -1, end) # N = end
41 | x_pos = flat_patch_pos % H # N
42 | y_pos = flat_patch_pos // H # N
43 | freqs = 1.0 / (
44 | theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)
45 | ) # Hc/4
46 | x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
47 | y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
48 | x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
49 | y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
50 | freqs_cis = torch.cat(
51 | [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
52 | ) # N,Hc/4,2
53 | freqs_cis = freqs_cis.reshape(end if not use_cls else end + 1, -1)
54 | # we need to think how to implement this for multi heads.
55 | # freqs_cis = torch.cat([x_cis, y_cis], dim=-1) # N, Hc/2
56 | return freqs_cis
57 |
58 |
59 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
60 | # x: B N H Hc/2
61 | # freqs_cis: N, H*Hc/2 or N Hc/2
62 | ndim = x.ndim
63 | assert 0 <= 1 < ndim
64 |
65 | if freqs_cis.shape[-1] == x.shape[-1]:
66 | shape = [
67 | 1 if i == 2 or i == 0 else d for i, d in enumerate(x.shape)
68 | ] # 1, N, 1, Hc/2
69 | else:
70 | shape = [d if i != 0 else 1 for i, d in enumerate(x.shape)] # 1, N, H, Hc/2
71 | # B, N, Hc/2
72 | return freqs_cis.view(*shape)
73 |
74 |
75 | def apply_rotary_emb(
76 | xq: torch.Tensor,
77 | xk: torch.Tensor,
78 | freqs_cis: torch.Tensor,
79 | ) -> Tuple[torch.Tensor, torch.Tensor]:
80 | # xq : B N H Hc
81 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
82 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
83 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
84 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
85 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
86 | return xq_out.type_as(xq), xk_out.type_as(xk)
87 |
88 |
89 | class LayerNorm(nn.Module):
90 | def __init__(self, dim):
91 | super().__init__()
92 | self.gamma = nn.Parameter(torch.ones(dim))
93 | self.register_buffer("beta", torch.zeros(dim))
94 |
95 | def forward(self, x):
96 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
97 |
98 |
99 | class Pooling(nn.Module):
100 | def __init__(self, pool_type, dim):
101 | super().__init__()
102 | if pool_type == "a":
103 | self.pool = nn.AvgPool2d(kernel_size=2)
104 |
105 | elif pool_type == "m":
106 | self.pool = nn.MaxPool2d(kernel_size=2)
107 |
108 | elif pool_type == "l":
109 | self.pool = nn.Linear(4 * dim, dim)
110 |
111 | else:
112 | raise NotImplementedError
113 |
114 | self.pool_type = pool_type
115 |
116 | def forward(self, x):
117 | # B N C
118 | B, N, C = x.shape
119 | if self.pool_type in ["a", "m"]:
120 | H, W = int(math.sqrt(N)), int(math.sqrt(N))
121 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
122 | x = self.pool(x)
123 | x = x.view(B, C, -1).transpose(1, 2).contiguous()
124 |
125 | else:
126 | x = x.view(B, N // 4, -1)
127 | x = self.pool(x)
128 |
129 | return x
130 |
131 |
132 | class Up(nn.Module):
133 | def __init__(self, up_type, dim):
134 | super().__init__()
135 | if up_type == "n":
136 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
137 |
138 | elif up_type == "r":
139 | self.up = nn.Sequential(
140 | nn.Upsample(scale_factor=2, mode="nearest"),
141 | Rearrange("b c h w -> b (h w) c"),
142 | nn.Linear(dim, dim),
143 | )
144 |
145 | else:
146 | raise NotImplementedError
147 |
148 | self.up_type = up_type
149 |
150 | def forward(self, x):
151 | # B N C
152 | B, N, C = x.shape
153 | if self.up_type == "n":
154 | H, W = int(math.sqrt(N)), int(math.sqrt(N))
155 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
156 | x = self.up(x)
157 | x = x.view(B, C, -1).transpose(1, 2).contiguous()
158 |
159 | else:
160 | # x = self.up(x) # B, N, 4c
161 | # x = x.view(B, N * 4, -1)
162 | H, W = int(math.sqrt(N)), int(math.sqrt(N))
163 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # B, C, H, W
164 | x = self.up(x) # B, (2H 2W), C
165 |
166 | return x
167 |
168 |
169 | class GEGLU(nn.Module):
170 | def forward(self, x):
171 | x, gate = x.chunk(2, dim=-1)
172 | return F.gelu(gate) * x
173 |
174 |
175 | def FeedForward(dim, mult=4, dropout=0.0):
176 | """Check this paper to understand the computation: https://arxiv.org/pdf/2002.05202.pdf"""
177 | inner_dim = int(mult * (2 / 3) * dim)
178 | return nn.Sequential(
179 | nn.LayerNorm(dim),
180 | nn.Linear(dim, inner_dim * 2, bias=False),
181 | GEGLU(),
182 | nn.Dropout(dropout),
183 | nn.Linear(inner_dim, dim, bias=False),
184 | )
185 |
186 |
187 | def window_partition(x, window_size):
188 | """
189 | Args:
190 | x: (B, H, W, C)
191 | window_size (int): window size
192 |
193 | Returns:
194 | windows: (num_windows*B, window_size, window_size, C)
195 | """
196 | B, H, W, C = x.shape
197 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
198 | windows = (
199 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
200 | )
201 | return windows
202 |
203 |
204 | def window_reverse(windows, window_size, H, W):
205 | """
206 | Args:
207 | windows: (num_windows*B, window_size, window_size, C)
208 | window_size (int): Window size
209 | H (int): Height of image
210 | W (int): Width of image
211 |
212 | Returns:
213 | x: (B, H, W, C)
214 | """
215 | B = int(windows.shape[0] / (H * W / window_size / window_size))
216 | x = windows.view(
217 | B, H // window_size, W // window_size, window_size, window_size, -1
218 | )
219 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
220 | return x
221 |
222 |
223 | class WindowAttention(nn.Module):
224 | r"""Window based multi-head self attention (W-MSA) module with relative position bias.
225 | It supports both of shifted and non-shifted window.
226 |
227 | Args:
228 | dim (int): Number of input channels.
229 | window_size (tuple[int]): The height and width of the window.
230 | num_heads (int): Number of attention heads.
231 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
232 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
233 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
234 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
235 | """
236 |
237 | def __init__(
238 | self,
239 | dim,
240 | window_size,
241 | num_heads,
242 | qkv_bias=False,
243 | qk_scale=None,
244 | attn_drop=0.0,
245 | proj_drop=0.0,
246 | ):
247 |
248 | super().__init__()
249 | self.dim = dim
250 | if isinstance(window_size, int):
251 | window_size = (window_size, window_size)
252 |
253 | self.norm = LayerNorm(dim)
254 | self.window_size = window_size # Wh, Ww
255 | self.num_heads = num_heads
256 | head_dim = dim // num_heads
257 | self.scale = qk_scale or head_dim**-0.5
258 |
259 | # define a parameter table of relative position bias
260 | self.relative_position_bias_table = nn.Parameter(
261 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
262 | ) # 2*Wh-1 * 2*Ww-1, nH
263 |
264 | # get pair-wise relative position index for each token inside the window
265 | coords_h = torch.arange(self.window_size[0])
266 | coords_w = torch.arange(self.window_size[1])
267 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
268 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
269 | relative_coords = (
270 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
271 | ) # 2, Wh*Ww, Wh*Ww
272 | relative_coords = relative_coords.permute(
273 | 1, 2, 0
274 | ).contiguous() # Wh*Ww, Wh*Ww, 2
275 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
276 | relative_coords[:, :, 1] += self.window_size[1] - 1
277 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
278 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
279 | self.register_buffer("relative_position_index", relative_position_index)
280 |
281 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
282 | self.attn_drop = nn.Dropout(attn_drop)
283 | self.proj = nn.Linear(dim, dim)
284 | self.proj_drop = nn.Dropout(proj_drop)
285 |
286 | trunc_normal_(self.relative_position_bias_table, std=0.02)
287 | self.softmax = nn.Softmax(dim=-1)
288 |
289 | def forward(self, x):
290 | """
291 | Args:
292 | x: input features with shape of (num_windows*B, N, C)
293 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
294 | """
295 | B_, N, C = x.shape
296 | H, W = int(math.sqrt(N)), int(math.sqrt(N))
297 | x = self.norm(x)
298 |
299 | x = x.view(B_, H, W, -1)
300 | # partition windows
301 | x_windows = window_partition(
302 | x, self.window_size[0]
303 | ) # nW*B, window_size, window_size, C
304 | x_windows = x_windows.view(
305 | -1, self.window_size[0] * self.window_size[1], C
306 | ) # nW*B, window_size*window_size, C
307 |
308 | BW, NW = x_windows.shape[:2]
309 |
310 | qkv = (
311 | self.qkv(x_windows)
312 | .reshape(BW, NW, 3, self.num_heads, C // self.num_heads)
313 | .permute(2, 0, 3, 1, 4)
314 | )
315 | q, k, v = (
316 | qkv[0],
317 | qkv[1],
318 | qkv[2],
319 | ) # make torchscript happy (cannot use tensor as tuple)
320 |
321 | q = q * self.scale
322 | attn = q @ k.transpose(-2, -1)
323 |
324 | relative_position_bias = self.relative_position_bias_table[
325 | self.relative_position_index.view(-1)
326 | ].view(
327 | self.window_size[0] * self.window_size[1],
328 | self.window_size[0] * self.window_size[1],
329 | -1,
330 | ) # Wh*Ww,Wh*Ww,nH
331 | relative_position_bias = relative_position_bias.permute(
332 | 2, 0, 1
333 | ).contiguous() # nH, Wh*Ww, Wh*Ww
334 |
335 | attn = attn + relative_position_bias.unsqueeze(0)
336 | attn = self.softmax(attn)
337 |
338 | attn = self.attn_drop(attn)
339 |
340 | x_windows = (attn @ v).transpose(1, 2).reshape(BW, NW, C)
341 | x_windows = self.proj(x_windows)
342 | x_windows = self.proj_drop(x_windows)
343 |
344 | x = window_reverse(x_windows, self.window_size[0], H, W) # B H' W' C
345 | x = x.view(B_, H * W, C)
346 |
347 | return x
348 |
349 |
350 | class PEG(nn.Module):
351 | def __init__(self, dim, causal=False):
352 | super().__init__()
353 | self.causal = causal
354 | self.dsconv = nn.Conv3d(dim, dim, 3, groups=dim)
355 |
356 | @beartype
357 | def forward(self, x, shape: Tuple[int, int, int, int] = None):
358 | needs_shape = x.ndim == 3
359 | assert not (needs_shape and not exists(shape))
360 |
361 | orig_shape = x.shape
362 | if needs_shape:
363 | x = x.reshape(*shape, -1)
364 |
365 | x = rearrange(x, "b ... d -> b d ...")
366 |
367 | frame_padding = (2, 0) if self.causal else (1, 1)
368 |
369 | x = F.pad(x, (1, 1, 1, 1, *frame_padding), value=0.0)
370 | x = self.dsconv(x)
371 |
372 | x = rearrange(x, "b d ... -> b ... d")
373 |
374 | if needs_shape:
375 | x = rearrange(x, "b ... d -> b (...) d")
376 |
377 | return x.reshape(orig_shape)
378 |
379 |
380 | # attention
381 |
382 |
383 | class Attention(nn.Module):
384 | def __init__(
385 | self,
386 | dim,
387 | dim_context=None,
388 | dim_head=64,
389 | heads=8,
390 | causal=False,
391 | num_null_kv=0,
392 | norm_context=True,
393 | dropout=0.0,
394 | scale=8,
395 | spatial_pos="rel",
396 | ):
397 | super().__init__()
398 | self.heads = heads
399 | self.causal = causal
400 | self.scale = scale
401 | inner_dim = dim_head * heads
402 | dim_context = default(dim_context, dim)
403 |
404 | if spatial_pos == "rel":
405 | self.spatial_rel_pos_bias = ContinuousPositionBias(
406 | dim=dim, heads=heads
407 | ) # HACK this: whether shared pos encoding is better or on the contrary
408 |
409 | self.spatial_pos = spatial_pos
410 | self.freqs_cis = None
411 |
412 | if causal:
413 | self.rel_pos_bias = AlibiPositionalBias(heads=heads)
414 |
415 | self.p_dropout = dropout
416 | self.attn_dropout = nn.Dropout(dropout)
417 |
418 | self.norm = LayerNorm(dim)
419 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()
420 |
421 | self.num_null_kv = num_null_kv
422 | if self.num_null_kv > 0:
423 | self.null_kv = nn.Parameter(torch.randn(heads, 2 * num_null_kv, dim_head))
424 | else:
425 | self.null_kv = None
426 |
427 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
428 | self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias=False)
429 | self.dim = inner_dim
430 |
431 | self.q_scale = nn.Parameter(torch.ones(dim_head))
432 | self.k_scale = nn.Parameter(torch.ones(dim_head))
433 |
434 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
435 |
436 | def forward(
437 | self,
438 | x,
439 | mask=None,
440 | context=None,
441 | is_spatial=True,
442 | q_stride=1,
443 | ):
444 | batch, device, dtype = x.shape[0], x.device, x.dtype
445 |
446 | if exists(context):
447 | context = self.context_norm(context)
448 |
449 | kv_input = default(context, x)
450 |
451 | x = self.norm(x)
452 | N = x.shape[1]
453 |
454 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)
455 | q, k, v = map(
456 | lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (q, k, v)
457 | )
458 |
459 | if self.spatial_pos == "rope" and is_spatial:
460 | if self.freqs_cis is None or self.freqs_cis.shape[0] != N:
461 | self.freqs_cis = precompute_freqs_cis_2d(self.dim // self.heads, N).to(
462 | x.device
463 | )
464 |
465 | q, k = apply_rotary_emb(q, k, freqs_cis=self.freqs_cis)
466 |
467 | q, k, v = map(
468 | lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v)
469 | )
470 |
471 | B, H, _, D = q.shape
472 | if q_stride > 1:
473 | # Refer to Unroll to see how this performs a maxpool-Nd
474 | q = q.view(B, H, q_stride, -1, D).max(dim=2).values
475 |
476 | if self.num_null_kv > 0:
477 | nk, nv = repeat(
478 | self.null_kv, "h (n r) d -> b h n r d", b=batch, r=2
479 | ).unbind(dim=-2)
480 |
481 | k = torch.cat((nk, k), dim=-2)
482 | v = torch.cat((nv, v), dim=-2)
483 |
484 | q, k = map(l2norm, (q, k))
485 | q = q * self.q_scale
486 | k = k * self.k_scale
487 |
488 | if hasattr(F, "scaled_dot_product_attention") and torch.__version__ >= "2.1.0":
489 | # Note: the original paper did *not* use SDPA, it's a free boost!
490 | if exists(mask):
491 | mask = F.pad(mask, (self.num_null_kv, 0), value=True)
492 | mask = rearrange(mask, "b j -> b 1 1 j")
493 |
494 | if self.spatial_pos == "rel" and is_spatial:
495 | h, w = int(math.sqrt(N)), int(math.sqrt(N))
496 | attn_bias = self.spatial_rel_pos_bias(h, w, device=x.device)
497 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.0)
498 |
499 | # query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
500 | out = F.scaled_dot_product_attention(
501 | q,
502 | k,
503 | v,
504 | attn_mask=mask,
505 | dropout_p=self.p_dropout,
506 | is_causal=self.causal,
507 | scale=self.scale,
508 | )
509 |
510 | else:
511 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
512 | i, j = sim.shape[-2:]
513 | if self.spatial_pos == "rel" and is_spatial:
514 | h, w = int(math.sqrt(N)), int(math.sqrt(N))
515 | attn_bias = self.spatial_rel_pos_bias(h, w, device=x.device)
516 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.0)
517 |
518 | if sim.shape[2] != attn_bias.shape[1]:
519 | # handle q_pooling here
520 | q_len = sim.shape[2]
521 | kv_len = sim.shape[3]
522 | q_stride = kv_len // q_len
523 | attn_bias = attn_bias[:, ::q_stride]
524 |
525 | sim = sim + attn_bias
526 |
527 | if exists(mask):
528 | mask = F.pad(mask, (self.num_null_kv, 0), value=True)
529 | mask = rearrange(mask, "b j -> b 1 1 j")
530 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
531 |
532 | if self.causal:
533 | sim = sim + self.rel_pos_bias(sim)
534 |
535 | causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu(
536 | j - i + 1
537 | )
538 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
539 |
540 | attn = sim.softmax(dim=-1)
541 | attn = self.attn_dropout(attn)
542 |
543 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
544 |
545 | out = rearrange(out, "b h n d -> b n (h d)")
546 | return self.to_out(out)
547 |
548 |
549 | # alibi positional bias for extrapolation
550 | class AlibiPositionalBias(nn.Module):
551 | def __init__(self, heads):
552 | super().__init__()
553 | self.heads = heads
554 | slopes = torch.Tensor(self._get_slopes(heads))
555 | slopes = rearrange(slopes, "h -> h 1 1")
556 | self.register_buffer("slopes", slopes, persistent=False)
557 | self.register_buffer("bias", None, persistent=False)
558 |
559 | def get_bias(self, i, j, device):
560 | i_arange = torch.arange(j - i, j, device=device)
561 | j_arange = torch.arange(j, device=device)
562 | bias = -torch.abs(
563 | rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1")
564 | )
565 | return bias
566 |
567 | @staticmethod
568 | def _get_slopes(heads):
569 | def get_slopes_power_of_2(n):
570 | start = 2 ** (-(2 ** -(math.log2(n) - 3)))
571 | ratio = start
572 | return [start * ratio**i for i in range(n)]
573 |
574 | if math.log2(heads).is_integer():
575 | return get_slopes_power_of_2(heads)
576 |
577 | closest_power_of_2 = 2 ** math.floor(math.log2(heads))
578 | return (
579 | get_slopes_power_of_2(closest_power_of_2)
580 | + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
581 | : heads - closest_power_of_2
582 | ]
583 | )
584 |
585 | def forward(self, sim):
586 | h, i, j, device = *sim.shape[-3:], sim.device
587 |
588 | if exists(self.bias) and self.bias.shape[-1] >= j:
589 | return self.bias[..., :i, :j]
590 |
591 | bias = self.get_bias(i, j, device)
592 | bias = bias * self.slopes
593 |
594 | num_heads_unalibied = h - bias.shape[0]
595 | bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
596 | self.register_buffer("bias", bias, persistent=False)
597 |
598 | return self.bias
599 |
600 |
601 | class ContinuousPositionBias(nn.Module):
602 | """from https://arxiv.org/abs/2111.09883"""
603 |
604 | def __init__(
605 | self,
606 | *,
607 | dim,
608 | heads,
609 | num_dims=2, # 2 for images, 3 for video
610 | layers=2,
611 | log_dist=True,
612 | cache_rel_pos=False
613 | ):
614 | super().__init__()
615 | self.num_dims = num_dims
616 | self.log_dist = log_dist
617 |
618 | self.net = nn.ModuleList([])
619 | self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), leaky_relu()))
620 |
621 | for _ in range(layers - 1):
622 | self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
623 |
624 | self.net.append(nn.Linear(dim, heads))
625 |
626 | self.cache_rel_pos = cache_rel_pos
627 | self.register_buffer("rel_pos", None, persistent=False)
628 |
629 | def forward(self, *dimensions, device=torch.device("cpu")):
630 |
631 | if not exists(self.rel_pos) or not self.cache_rel_pos:
632 | positions = [torch.arange(d, device=device) for d in dimensions]
633 | grid = torch.stack(torch.meshgrid(*positions, indexing="ij"))
634 | grid = rearrange(grid, "c ... -> (...) c")
635 | rel_pos = rearrange(grid, "i c -> i 1 c") - rearrange(grid, "j c -> 1 j c")
636 |
637 | if self.log_dist:
638 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
639 |
640 | self.register_buffer("rel_pos", rel_pos, persistent=False)
641 |
642 | rel_pos = self.rel_pos.float()
643 |
644 | for layer in self.net:
645 | rel_pos = layer(rel_pos)
646 |
647 | return rearrange(rel_pos, "i j h -> h i j")
648 |
649 |
650 | # transformer
651 |
652 |
653 | class Transformer(nn.Module):
654 | def __init__(
655 | self,
656 | dim,
657 | *,
658 | depth,
659 | block,
660 | dim_context=None,
661 | causal=False,
662 | dim_head=64,
663 | heads=8,
664 | ff_mult=4,
665 | peg=False,
666 | peg_causal=False,
667 | attn_num_null_kv=2,
668 | has_cross_attn=False,
669 | attn_dropout=0.0,
670 | ff_dropout=0.0,
671 | window_size=4,
672 | spatial_pos="rel"
673 | ):
674 | super().__init__()
675 | assert len(block) == depth
676 | self.layers = nn.ModuleList([])
677 | for i in range(depth):
678 | if block[i] == "t":
679 | self.layers.append(
680 | nn.ModuleList(
681 | [
682 | PEG(dim=dim, causal=peg_causal) if peg else None,
683 | Attention(
684 | dim=dim,
685 | dim_head=dim_head,
686 | heads=heads,
687 | causal=causal,
688 | dropout=attn_dropout,
689 | spatial_pos=spatial_pos,
690 | ),
691 | (
692 | Attention(
693 | dim=dim,
694 | dim_head=dim_head,
695 | dim_context=dim_context,
696 | heads=heads,
697 | causal=False,
698 | num_null_kv=attn_num_null_kv,
699 | dropout=attn_dropout,
700 | )
701 | if has_cross_attn
702 | else None
703 | ),
704 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout),
705 | ]
706 | )
707 | )
708 |
709 | elif block[i] == "w":
710 | self.layers.append(
711 | nn.ModuleList(
712 | [
713 | None,
714 | WindowAttention(
715 | dim=dim,
716 | window_size=window_size,
717 | num_heads=heads,
718 | attn_drop=attn_dropout,
719 | ),
720 | None,
721 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout),
722 | ]
723 | )
724 | )
725 |
726 | # various pooling methods: B, N, C
727 | elif block[i] in ["a", "m", "l"]:
728 | self.layers.append(
729 | nn.ModuleList(
730 | [
731 | None,
732 | Pooling(block[i], dim),
733 | None,
734 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout),
735 | ]
736 | )
737 | )
738 |
739 | elif block[i] in ["n", "r"]:
740 | self.layers.append(
741 | nn.ModuleList(
742 | [
743 | None,
744 | Up(block[i], dim),
745 | None,
746 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout),
747 | ]
748 | )
749 | )
750 |
751 | else:
752 | raise NotImplementedError
753 |
754 | self.block = block
755 | self.norm_out = LayerNorm(dim)
756 |
757 | @beartype
758 | def forward(
759 | self,
760 | x,
761 | video_shape: Tuple[int, int, int, int] = None,
762 | context=None,
763 | self_attn_mask=None,
764 | cross_attn_context_mask=None,
765 | q_strides=None,
766 | is_spatial=True,
767 | ):
768 |
769 | if q_strides is None:
770 | q_strides = "1" * len(self.layers)
771 |
772 | for blk, q_stride, (peg, self_attn, cross_attn, ff) in zip(
773 | self.block, q_strides, self.layers
774 | ):
775 | if exists(peg):
776 | x = peg(x, shape=video_shape) + x
777 |
778 | if isinstance(self_attn, Attention):
779 | x = self_attn(
780 | x,
781 | mask=self_attn_mask,
782 | q_stride=int(q_stride),
783 | is_spatial=is_spatial,
784 | ) + do_pool(x, int(q_stride))
785 | # x = checkpoint.checkpoint(self_attn, x, self_attn_mask, None, attn_bias, int(q_stride))
786 |
787 | elif isinstance(self_attn, WindowAttention):
788 | x = self_attn(x) + x
789 | else:
790 | x = self_attn(x)
791 |
792 | if exists(cross_attn) and exists(context):
793 | x = cross_attn(x, context=context, mask=cross_attn_context_mask) + x
794 |
795 | x = ff(x) + x
796 |
797 | # deal with downsampling:
798 | if blk in ["a", "m", "l"]:
799 | video_shape = (
800 | video_shape[0],
801 | video_shape[1],
802 | video_shape[2] // 2,
803 | video_shape[3] // 2,
804 | ) # video_shape: B, T, H, W
805 |
806 | elif blk in ["n", "r"]:
807 | video_shape = (
808 | video_shape[0],
809 | video_shape[1],
810 | int(video_shape[2] * 2),
811 | int(video_shape[3] * 2),
812 | )
813 |
814 | if q_stride != "1":
815 | down_ratio = int(math.sqrt(int(q_stride)))
816 | video_shape = (
817 | video_shape[0],
818 | video_shape[1],
819 | video_shape[2] // down_ratio,
820 | video_shape[3] // down_ratio,
821 | )
822 |
823 | return self.norm_out(x)
824 |
--------------------------------------------------------------------------------
/imagetokenizer/model/modules/titok_transformer.py:
--------------------------------------------------------------------------------
1 | """Building blocks for TiTok.
2 |
3 | Copyright (2024) Bytedance Ltd. and/or its affiliates
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 |
17 | Reference:
18 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
19 | """
20 |
21 | import torch
22 | import torch.nn as nn
23 | from collections import OrderedDict
24 |
25 |
26 | class ResidualAttentionBlock(nn.Module):
27 | def __init__(
28 | self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm
29 | ):
30 | super().__init__()
31 |
32 | self.ln_1 = norm_layer(d_model)
33 | self.attn = nn.MultiheadAttention(d_model, n_head)
34 | self.mlp_ratio = mlp_ratio
35 | # optionally we can disable the FFN
36 | if mlp_ratio > 0:
37 | self.ln_2 = norm_layer(d_model)
38 | mlp_width = int(d_model * mlp_ratio)
39 | self.mlp = nn.Sequential(
40 | OrderedDict(
41 | [
42 | ("c_fc", nn.Linear(d_model, mlp_width)),
43 | ("gelu", act_layer()),
44 | ("c_proj", nn.Linear(mlp_width, d_model)),
45 | ]
46 | )
47 | )
48 |
49 | def attention(self, x: torch.Tensor):
50 | return self.attn(x, x, x, need_weights=False)[0]
51 |
52 | def forward(
53 | self,
54 | x: torch.Tensor,
55 | ):
56 | attn_output = self.attention(x=self.ln_1(x))
57 | x = x + attn_output
58 | if self.mlp_ratio > 0:
59 | x = x + self.mlp(self.ln_2(x))
60 | return x
61 |
62 |
63 | def _expand_token(token, batch_size: int):
64 | return token.unsqueeze(0).expand(batch_size, -1, -1)
65 |
66 |
67 | class TiTokEncoder(nn.Module):
68 | def __init__(self, config):
69 | super().__init__()
70 | self.config = config
71 | self.image_size = config.dataset.preprocessing.crop_size
72 | self.patch_size = config.model.vq_model.vit_enc_patch_size
73 | self.grid_size = self.image_size // self.patch_size
74 | self.model_size = config.model.vq_model.vit_enc_model_size
75 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens
76 | self.token_size = config.model.vq_model.token_size
77 |
78 | self.width = {
79 | "small": 512,
80 | "base": 768,
81 | "large": 1024,
82 | }[self.model_size]
83 | self.num_layers = {
84 | "small": 8,
85 | "base": 12,
86 | "large": 24,
87 | }[self.model_size]
88 | self.num_heads = {
89 | "small": 8,
90 | "base": 12,
91 | "large": 16,
92 | }[self.model_size]
93 |
94 | self.patch_embed = nn.Conv2d(
95 | in_channels=3,
96 | out_channels=self.width,
97 | kernel_size=self.patch_size,
98 | stride=self.patch_size,
99 | bias=True,
100 | )
101 |
102 | scale = self.width**-0.5
103 | self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
104 | self.positional_embedding = nn.Parameter(
105 | scale * torch.randn(self.grid_size**2 + 1, self.width)
106 | )
107 | self.latent_token_positional_embedding = nn.Parameter(
108 | scale * torch.randn(self.num_latent_tokens, self.width)
109 | )
110 | self.ln_pre = nn.LayerNorm(self.width)
111 | self.transformer = nn.ModuleList()
112 | for i in range(self.num_layers):
113 | self.transformer.append(
114 | ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
115 | )
116 | self.ln_post = nn.LayerNorm(self.width)
117 | self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
118 |
119 | def forward(self, pixel_values, latent_tokens):
120 | batch_size = pixel_values.shape[0]
121 | x = pixel_values
122 | x = self.patch_embed(x)
123 | x = x.reshape(x.shape[0], x.shape[1], -1)
124 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
125 | # class embeddings and positional embeddings
126 | x = torch.cat(
127 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
128 | )
129 | x = x + self.positional_embedding.to(
130 | x.dtype
131 | ) # shape = [*, grid ** 2 + 1, width]
132 |
133 | latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
134 | latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(
135 | x.dtype
136 | )
137 | x = torch.cat([x, latent_tokens], dim=1)
138 |
139 | x = self.ln_pre(x)
140 | x = x.permute(1, 0, 2) # NLD -> LND
141 | for i in range(self.num_layers):
142 | x = self.transformer[i](x)
143 | x = x.permute(1, 0, 2) # LND -> NLD
144 |
145 | latent_tokens = x[:, 1 + self.grid_size**2 :]
146 | latent_tokens = self.ln_post(latent_tokens)
147 | # fake 2D shape
148 | latent_tokens = latent_tokens.reshape(
149 | batch_size, self.width, self.num_latent_tokens, 1
150 | )
151 | latent_tokens = self.conv_out(latent_tokens)
152 | latent_tokens = latent_tokens.reshape(
153 | batch_size, self.token_size, 1, self.num_latent_tokens
154 | )
155 | return latent_tokens
156 |
157 |
158 | class TiTokDecoder(nn.Module):
159 | def __init__(self, config):
160 | super().__init__()
161 | self.config = config
162 | self.image_size = config.dataset.preprocessing.crop_size
163 | self.patch_size = config.model.vq_model.vit_dec_patch_size
164 | self.grid_size = self.image_size // self.patch_size
165 | self.model_size = config.model.vq_model.vit_dec_model_size
166 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens
167 | self.token_size = config.model.vq_model.token_size
168 | self.width = {
169 | "small": 512,
170 | "base": 768,
171 | "large": 1024,
172 | }[self.model_size]
173 | self.num_layers = {
174 | "small": 8,
175 | "base": 12,
176 | "large": 24,
177 | }[self.model_size]
178 | self.num_heads = {
179 | "small": 8,
180 | "base": 12,
181 | "large": 16,
182 | }[self.model_size]
183 |
184 | self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
185 | scale = self.width**-0.5
186 | self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
187 | self.positional_embedding = nn.Parameter(
188 | scale * torch.randn(self.grid_size**2 + 1, self.width)
189 | )
190 | # add mask token and query pos embed
191 | self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
192 | self.latent_token_positional_embedding = nn.Parameter(
193 | scale * torch.randn(self.num_latent_tokens, self.width)
194 | )
195 | self.ln_pre = nn.LayerNorm(self.width)
196 | self.transformer = nn.ModuleList()
197 | for i in range(self.num_layers):
198 | self.transformer.append(
199 | ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
200 | )
201 | self.ln_post = nn.LayerNorm(self.width)
202 |
203 | self.ffn = nn.Sequential(
204 | nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
205 | nn.Tanh(),
206 | nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
207 | )
208 | self.conv_out = nn.Identity()
209 |
210 | def forward(self, z_quantized):
211 | N, C, H, W = z_quantized.shape
212 | assert (
213 | H == 1 and W == self.num_latent_tokens
214 | ), f"{H}, {W}, {self.num_latent_tokens}"
215 | x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD
216 | x = self.decoder_embed(x)
217 |
218 | batchsize, seq_len, _ = x.shape
219 |
220 | mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(
221 | x.dtype
222 | )
223 | mask_tokens = torch.cat(
224 | [
225 | _expand_token(self.class_embedding, mask_tokens.shape[0]).to(
226 | mask_tokens.dtype
227 | ),
228 | mask_tokens,
229 | ],
230 | dim=1,
231 | )
232 | mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
233 | x = x + self.latent_token_positional_embedding[:seq_len]
234 | x = torch.cat([mask_tokens, x], dim=1)
235 |
236 | x = self.ln_pre(x)
237 | x = x.permute(1, 0, 2) # NLD -> LND
238 | for i in range(self.num_layers):
239 | x = self.transformer[i](x)
240 | x = x.permute(1, 0, 2) # LND -> NLD
241 | x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed
242 | x = self.ln_post(x)
243 | # N L D -> N D H W
244 | x = x.permute(0, 2, 1).reshape(
245 | batchsize, self.width, self.grid_size, self.grid_size
246 | )
247 | x = self.ffn(x.contiguous())
248 | x = self.conv_out(x)
249 | return x
250 |
--------------------------------------------------------------------------------
/imagetokenizer/model/modules/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class DiagonalGaussianDistribution(object):
6 | def __init__(self, parameters, deterministic=False):
7 | self.parameters = parameters
8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
10 | self.deterministic = deterministic
11 | self.std = torch.exp(0.5 * self.logvar)
12 | self.var = torch.exp(self.logvar)
13 | if self.deterministic:
14 | self.var = self.std = torch.zeros_like(self.mean).to(
15 | device=self.parameters.device
16 | )
17 |
18 | def sample(self):
19 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
20 | device=self.parameters.device
21 | )
22 | return x
23 |
24 | def kl(self, other=None):
25 | if self.deterministic:
26 | return torch.Tensor([0.0])
27 | else:
28 | if other is None:
29 | return 0.5 * torch.sum(
30 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
31 | dim=[1, 2, 3],
32 | )
33 | else:
34 | return 0.5 * torch.sum(
35 | torch.pow(self.mean - other.mean, 2) / other.var
36 | + self.var / other.var
37 | - 1.0
38 | - self.logvar
39 | + other.logvar,
40 | dim=[1, 2, 3],
41 | )
42 |
43 | def nll(self, sample, dims=[1, 2, 3]):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | logtwopi = np.log(2.0 * np.pi)
47 | return 0.5 * torch.sum(
48 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
49 | dim=dims,
50 | )
51 |
52 | def mode(self):
53 | return self.mean
54 |
55 |
56 | def normal_kl(mean1, logvar1, mean2, logvar2):
57 | """
58 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
59 | Compute the KL divergence between two gaussians.
60 | Shapes are automatically broadcasted, so batches can be compared to
61 | scalars, among other use cases.
62 | """
63 | tensor = None
64 | for obj in (mean1, logvar1, mean2, logvar2):
65 | if isinstance(obj, torch.Tensor):
66 | tensor = obj
67 | break
68 | assert tensor is not None, "at least one argument must be a Tensor"
69 |
70 | # Force variances to be Tensors. Broadcasting helps convert scalars to
71 | # Tensors, but it does not work for torch.exp().
72 | logvar1, logvar2 = [
73 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
74 | for x in (logvar1, logvar2)
75 | ]
76 |
77 | return 0.5 * (
78 | -1.0
79 | + logvar2
80 | - logvar1
81 | + torch.exp(logvar1 - logvar2)
82 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
83 | )
84 |
--------------------------------------------------------------------------------
/imagetokenizer/model/titok.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from typing import Mapping, Text, Tuple
5 | import os
6 |
7 | import torch
8 | from einops import rearrange
9 | from torch.cuda.amp import autocast
10 |
11 | from .modules.titok_transformer import TiTokEncoder, TiTokDecoder
12 | from .modules.maskgit_vqgan import Decoder as Pixel_Decoder
13 | from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
14 | from omegaconf import OmegaConf
15 | from easydict import EasyDict as edict
16 |
17 |
18 | class TiTok(nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | config = {
22 | "experiment": {
23 | "tokenizer_checkpoint": "tokenizer_titok_l32.bin",
24 | "generator_checkpoint": "generator_titok_l32.bin",
25 | },
26 | "model": {
27 | "vq_model": {
28 | "codebook_size": 4096,
29 | "token_size": 12,
30 | "use_l2_norm": True,
31 | "commitment_cost": 0.25,
32 | "vit_enc_model_size": "large",
33 | "vit_dec_model_size": "large",
34 | "vit_enc_patch_size": 16,
35 | "vit_dec_patch_size": 16,
36 | "num_latent_tokens": 32,
37 | },
38 | "generator": {
39 | "dropout": 0.1,
40 | "attn_drop": 0.1,
41 | "num_steps": 8,
42 | "mask_schedule_strategy": "arccos",
43 | "class_label_dropout": 0.1,
44 | "image_seq_len": 32,
45 | "condition_num_classes": 1000,
46 | },
47 | },
48 | "dataset": {"preprocessing": {"crop_size": 256}},
49 | }
50 | config = edict(config)
51 | self.config = config
52 | self.encoder = TiTokEncoder(config)
53 | self.decoder = TiTokDecoder(config)
54 |
55 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens
56 | scale = self.encoder.width**-0.5
57 | self.latent_tokens = nn.Parameter(
58 | scale * torch.randn(self.num_latent_tokens, self.encoder.width)
59 | )
60 |
61 | self.apply(self._init_weights)
62 |
63 | self.quantize = VectorQuantizer(
64 | codebook_size=config.model.vq_model.codebook_size,
65 | token_size=config.model.vq_model.token_size,
66 | commitment_cost=config.model.vq_model.commitment_cost,
67 | use_l2_norm=config.model.vq_model.use_l2_norm,
68 | )
69 |
70 | self.pixel_quantize = Pixel_Quantizer(
71 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25
72 | )
73 | self.pixel_decoder = Pixel_Decoder(
74 | OmegaConf.create(
75 | {
76 | "channel_mult": [1, 1, 2, 2, 4],
77 | "num_resolutions": 5,
78 | "dropout": 0.0,
79 | "hidden_channels": 128,
80 | "num_channels": 3,
81 | "num_res_blocks": 2,
82 | "resolution": 256,
83 | "z_channels": 256,
84 | }
85 | )
86 | )
87 |
88 | def load_weights(self, model_path):
89 | g_p = os.path.join(model_path, 'generator_titok_l32.bin')
90 | t_p = os.path.join(model_path, 'tokenizer_titok_l32.bin')
91 | sd_g = torch.load(g_p, map_location="cpu")
92 | sd_t = torch.load(t_p, map_location="cpu")
93 | missing, unexpected = self.load_state_dict(sd_g, strict=False)
94 | missing, unexpected = self.load_state_dict(sd_t, strict=False)
95 |
96 | def _init_weights(self, module):
97 | """Initialize the weights.
98 | :param:
99 | module -> torch.nn.Module: module to initialize
100 | """
101 | if (
102 | isinstance(module, nn.Linear)
103 | or isinstance(module, nn.Conv1d)
104 | or isinstance(module, nn.Conv2d)
105 | ):
106 | module.weight.data = nn.init.trunc_normal_(
107 | module.weight.data, mean=0.0, std=0.02
108 | )
109 | if module.bias is not None:
110 | module.bias.data.zero_()
111 | elif isinstance(module, nn.Embedding):
112 | module.weight.data = nn.init.trunc_normal_(
113 | module.weight.data, mean=0.0, std=0.02
114 | )
115 | elif isinstance(module, nn.LayerNorm):
116 | module.bias.data.zero_()
117 | module.weight.data.fill_(1.0)
118 |
119 | def encode(self, x):
120 | if x.shape[-1] != self.config.dataset.preprocessing.crop_size:
121 | x = torch.nn.functional.interpolate(
122 | x,
123 | size=(
124 | self.config.dataset.preprocessing.crop_size,
125 | self.config.dataset.preprocessing.crop_size,
126 | ),
127 | mode="bilinear",
128 | align_corners=False,
129 | )
130 | print(x.shape)
131 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)
132 | z_quantized, result_dict = self.quantize(z)
133 | return z_quantized, z, result_dict['min_encoding_indices']
134 |
135 | def decode(self, z_quantized):
136 | decoded_latent = self.decoder(z_quantized)
137 | quantized_states = torch.einsum(
138 | "nchw,cd->ndhw",
139 | decoded_latent.softmax(1),
140 | self.pixel_quantize.embedding.weight,
141 | )
142 | decoded = self.pixel_decoder(quantized_states)
143 | return decoded
144 |
145 | def decode_tokens(self, tokens):
146 | tokens = tokens.squeeze(1)
147 | batch, seq_len = tokens.shape # B x N
148 | z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1)).reshape(
149 | batch, 1, seq_len, -1
150 | )
151 | if self.quantize.use_l2_norm:
152 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
153 | z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
154 | decoded = self.decode(z_quantized)
155 | return decoded
156 |
157 |
158 | class VectorQuantizer(torch.nn.Module):
159 | def __init__(
160 | self,
161 | codebook_size: int = 1024,
162 | token_size: int = 256,
163 | commitment_cost: float = 0.25,
164 | use_l2_norm: bool = False,
165 | ):
166 | super().__init__()
167 | self.commitment_cost = commitment_cost
168 |
169 | self.embedding = torch.nn.Embedding(codebook_size, token_size)
170 | self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
171 | self.use_l2_norm = use_l2_norm
172 |
173 | # Ensure quantization is performed using f32
174 | @autocast(enabled=False)
175 | def forward(
176 | self, z: torch.Tensor
177 | ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
178 | z = z.float()
179 | z = rearrange(z, "b c h w -> b h w c").contiguous()
180 | z_flattened = rearrange(z, "b h w c -> (b h w) c")
181 |
182 | if self.use_l2_norm:
183 | z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
184 | embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
185 | else:
186 | embedding = self.embedding.weight
187 | d = (
188 | torch.sum(z_flattened**2, dim=1, keepdim=True)
189 | + torch.sum(embedding**2, dim=1)
190 | - 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T)
191 | )
192 |
193 | min_encoding_indices = torch.argmin(d, dim=1) # num_ele
194 | z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
195 |
196 | if self.use_l2_norm:
197 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
198 | z = torch.nn.functional.normalize(z, dim=-1)
199 |
200 | # compute loss for embedding
201 | commitment_loss = self.commitment_cost * torch.mean(
202 | (z_quantized.detach() - z) ** 2
203 | )
204 | codebook_loss = torch.mean((z_quantized - z.detach()) ** 2)
205 |
206 | loss = commitment_loss + codebook_loss
207 |
208 | # preserve gradients
209 | z_quantized = z + (z_quantized - z).detach()
210 |
211 | # reshape back to match original input shape
212 | z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
213 |
214 | result_dict = dict(
215 | quantizer_loss=loss,
216 | commitment_loss=commitment_loss,
217 | codebook_loss=codebook_loss,
218 | min_encoding_indices=min_encoding_indices.view(
219 | z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]
220 | ),
221 | )
222 |
223 | return z_quantized, result_dict
224 |
225 | def get_codebook_entry(self, indices):
226 | if len(indices.shape) == 1:
227 | z_quantized = self.embedding(indices)
228 | elif len(indices.shape) == 2:
229 | z_quantized = torch.einsum("bd,dn->bn", indices, self.embedding.weight)
230 | else:
231 | raise NotImplementedError
232 | return z_quantized
233 |
--------------------------------------------------------------------------------
/imagetokenizer/quantize/lookup_free_quantize.py:
--------------------------------------------------------------------------------
1 | """
2 | Lookup Free Quantization
3 | Proposed in https://arxiv.org/abs/2310.05737
4 |
5 | In the simplest setup, each dimension is quantized into {-1, 1}.
6 | An entropy penalty is used to encourage utilization.
7 |
8 | Refer to
9 | https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py
10 | https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py
11 | """
12 |
13 | from math import log2, ceil
14 | from collections import namedtuple
15 |
16 | import torch
17 | from torch import nn, einsum
18 | import torch.nn.functional as F
19 | from torch.nn import Module
20 |
21 | from einops import rearrange, reduce, pack, unpack
22 |
23 | # constants
24 |
25 | LossBreakdown = namedtuple(
26 | "LossBreakdown",
27 | ["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"],
28 | )
29 |
30 | # helper functions
31 |
32 |
33 | def exists(v):
34 | return v is not None
35 |
36 |
37 | def default(*args):
38 | for arg in args:
39 | if exists(arg):
40 | return arg() if callable(arg) else arg
41 | return None
42 |
43 |
44 | def pack_one(t, pattern):
45 | return pack([t], pattern)
46 |
47 |
48 | def unpack_one(t, ps, pattern):
49 | return unpack(t, ps, pattern)[0]
50 |
51 |
52 | # entropy
53 |
54 |
55 | def entropy(prob):
56 | return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
57 |
58 |
59 | # class
60 |
61 |
62 | def mult_along_first_dims(x, y):
63 | """
64 | returns x * y elementwise along the leading dimensions of y
65 | """
66 | ndim_to_expand = x.ndim - y.ndim
67 | for _ in range(ndim_to_expand):
68 | y = y.unsqueeze(-1)
69 | return x * y
70 |
71 |
72 | def masked_mean(x, m):
73 | """
74 | takes the mean of the elements of x that are not masked
75 | the mean is taken along the shared leading dims of m
76 | equivalent to: x[m].mean(tuple(range(m.ndim)))
77 |
78 | The benefit of using masked_mean rather than using
79 | tensor indexing is that masked_mean is much faster
80 | for torch-compile on batches.
81 |
82 | The drawback is larger floating point errors
83 | """
84 | x = mult_along_first_dims(x, m)
85 | x = x / m.sum()
86 | return x.sum(tuple(range(m.ndim)))
87 |
88 |
89 | def entropy_loss(
90 | logits,
91 | mask=None,
92 | temperature=0.01,
93 | sample_minimization_weight=1.0,
94 | batch_maximization_weight=1.0,
95 | eps=1e-5,
96 | ):
97 | """
98 | Entropy loss of unnormalized logits
99 |
100 | logits: Affinities are over the last dimension
101 |
102 | https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
103 | LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
104 | """
105 | probs = F.softmax(logits / temperature, -1)
106 | log_probs = F.log_softmax(logits / temperature + eps, -1)
107 |
108 | if mask is not None:
109 | avg_probs = masked_mean(probs, mask)
110 | else:
111 | avg_probs = reduce(probs, "... D -> D", "mean")
112 |
113 | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
114 |
115 | sample_entropy = -torch.sum(probs * log_probs, -1)
116 | if mask is not None:
117 | sample_entropy = masked_mean(sample_entropy, mask).mean()
118 | else:
119 | sample_entropy = torch.mean(sample_entropy)
120 |
121 | loss = (sample_minimization_weight * sample_entropy) - (
122 | batch_maximization_weight * avg_entropy
123 | )
124 |
125 | return sample_entropy, avg_entropy, loss
126 |
127 |
128 | class LFQ(Module):
129 | def __init__(
130 | self,
131 | *,
132 | dim=None,
133 | codebook_size=None,
134 | num_codebooks=1,
135 | sample_minimization_weight=1.0,
136 | batch_maximization_weight=1.0,
137 | token_factorization=False,
138 | ):
139 | super().__init__()
140 |
141 | # some assert validations
142 |
143 | assert exists(dim) or exists(
144 | codebook_size
145 | ), "either dim or codebook_size must be specified for LFQ"
146 | assert (
147 | not exists(codebook_size) or log2(codebook_size).is_integer()
148 | ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
149 |
150 | self.codebook_size = default(codebook_size, lambda: 2**dim)
151 | self.codebook_dim = int(log2(codebook_size))
152 |
153 | codebook_dims = self.codebook_dim * num_codebooks
154 | dim = default(dim, codebook_dims)
155 |
156 | has_projections = dim != codebook_dims
157 | self.has_projections = has_projections
158 |
159 | self.dim = dim
160 | self.codebook_dim = self.codebook_dim
161 | self.num_codebooks = num_codebooks
162 |
163 | # for entropy loss
164 | self.sample_minimization_weight = sample_minimization_weight
165 | self.batch_maximization_weight = batch_maximization_weight
166 |
167 | # for no auxiliary loss, during inference
168 | self.token_factorization = token_factorization ## only utilized in second stage
169 | if not self.token_factorization: # for first stage model
170 | self.register_buffer(
171 | "mask",
172 | 2 ** torch.arange(self.codebook_dim - 1, -1, -1),
173 | persistent=False,
174 | )
175 | else:
176 | k = self.codebook_dim // 2
177 | self.register_buffer(
178 | "mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False
179 | )
180 |
181 | self.register_buffer("zero", torch.tensor(0.0), persistent=False)
182 |
183 | # codes
184 | all_codes = torch.arange(codebook_size)
185 | bits = self.indices_to_bits(all_codes)
186 | codebook = bits * 2.0 - 1.0
187 |
188 | self.register_buffer("codebook", codebook, persistent=False)
189 |
190 | @property
191 | def dtype(self):
192 | return self.codebook.dtype
193 |
194 | def indices_to_bits(self, x):
195 | """
196 | x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments.
197 |
198 | returns big endian bits
199 | """
200 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
201 | # x is now big endian bits, the last dimension being the bits
202 | x = (x.unsqueeze(-1) & mask) != 0
203 | return x
204 |
205 | def get_codebook_entry(self, x, bhwc):
206 | if self.token_factorization:
207 | k = self.codebook_dim // 2
208 | mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long)
209 | else:
210 | mask = 2 ** torch.arange(
211 | self.codebook_dim - 1, -1, -1, device=x.device, dtype=torch.long
212 | )
213 |
214 | x = (x.unsqueeze(-1) & mask) != 0
215 | x = x * 2.0 - 1.0 # back to the float
216 | ## scale back to the desired shape
217 | b, h, w, c = bhwc
218 | x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
219 | x = rearrange(x, "b h w c -> b c h w")
220 | return x
221 |
222 | def bits_to_indices(self, bits):
223 | """
224 | bits: bool tensor of big endian bits, where the last dimension is the bit dimension
225 |
226 | returns indices, which are long integers from 0 to self.codebook_size
227 | """
228 | assert bits.shape[-1] == self.codebook_dim
229 | indices = 2 ** torch.arange(
230 | 0,
231 | self.codebook_dim,
232 | 1,
233 | dtype=torch.long,
234 | device=bits.device,
235 | )
236 | return (bits * indices).sum(-1)
237 |
238 | def decode(self, x):
239 | """
240 | x: ... NH
241 | where NH is number of codebook heads
242 | A longtensor of codebook indices, containing values from
243 | 0 to self.codebook_size
244 | """
245 | x = self.indices_to_bits(x)
246 | # to some sort of float
247 | x = x.to(self.dtype)
248 | # -1 or 1
249 | x = x * 2 - 1
250 | x = rearrange(x, "... NC Z-> ... (NC Z)")
251 | return x
252 |
253 | def forward(
254 | self,
255 | x,
256 | return_loss_breakdown=False,
257 | mask=None,
258 | return_loss=True,
259 | ):
260 | """
261 | einstein notation
262 | b - batch
263 | n - sequence (or flattened spatial dimensions)
264 | d - feature dimension, which is also log2(codebook size)
265 | c - number of codebook dim
266 | """
267 |
268 | x = rearrange(x, "b d ... -> b ... d")
269 | x, ps = pack_one(x, "b * d")
270 | # split out number of codebooks
271 |
272 | x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
273 |
274 | codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
275 | quantized = torch.where(
276 | x > 0, codebook_value, -codebook_value
277 | ) # higher than 0 filled
278 |
279 | # calculate indices
280 | if self.token_factorization:
281 | k = self.codebook_dim // 2
282 | indices_pre = reduce(
283 | (quantized[..., :k] > 0).int() * self.mask.int(),
284 | "b n c d -> b n c",
285 | "sum",
286 | )
287 | indices_post = reduce(
288 | (quantized[..., k:] > 0).int() * self.mask.int(),
289 | "b n c d -> b n c",
290 | "sum",
291 | )
292 | # indices_post = 2**k + indices_post #shifter to the 1024
293 | else:
294 | indices = reduce(
295 | (quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum"
296 | )
297 |
298 | # entropy aux loss
299 |
300 | if self.training and return_loss:
301 | logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook)
302 | # the same as euclidean distance up to a constant
303 | per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
304 | logits=logits,
305 | sample_minimization_weight=self.sample_minimization_weight,
306 | batch_maximization_weight=self.batch_maximization_weight,
307 | )
308 |
309 | avg_probs = self.zero
310 | else:
311 | ## calculate the codebook_entropy needed for one batch evaluation
312 | # ------------------------------------------------------------------
313 | # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
314 | # probs = F.softmax(logits / 0.01, -1)
315 | # avg_probs = reduce(probs, "b n c d -> b d", "mean")
316 | # avg_probs = torch.sum(avg_probs, 0) #batch dimension
317 | # -------------------------------------------------------------------
318 | # if not training, just return dummy 0
319 | per_sample_entropy = codebook_entropy = self.zero
320 | entropy_aux_loss = self.zero
321 | avg_probs = self.zero
322 |
323 | # commit loss
324 |
325 | if self.training:
326 | commit_loss = F.mse_loss(x, quantized.detach(), reduction="none")
327 |
328 | if exists(mask):
329 | commit_loss = commit_loss[mask]
330 |
331 | commit_loss = commit_loss.mean()
332 | else:
333 | commit_loss = self.zero
334 |
335 | # use straight-through gradients (optionally with custom activation fn) if training
336 |
337 | quantized = x + (quantized - x).detach() # transfer to quantized
338 |
339 | # merge back codebook dim
340 |
341 | quantized = rearrange(quantized, "b n c d -> b n (c d)")
342 |
343 | # reconstitute image or video dimensions
344 |
345 | quantized = unpack_one(quantized, ps, "b * d")
346 | quantized = rearrange(quantized, "b ... d -> b d ...")
347 |
348 | if self.token_factorization:
349 | indices_pre = unpack_one(indices_pre, ps, "b * c")
350 | indices_post = unpack_one(indices_post, ps, "b * c")
351 | indices_pre = indices_pre.flatten()
352 | indices_post = indices_post.flatten()
353 | indices = (indices_pre, indices_post)
354 | else:
355 | indices = unpack_one(indices, ps, "b * c")
356 | indices = indices.flatten()
357 |
358 | ret = (quantized, entropy_aux_loss, indices)
359 |
360 | if not return_loss_breakdown:
361 | return ret
362 |
363 | return ret, LossBreakdown(
364 | per_sample_entropy, codebook_entropy, commit_loss, avg_probs
365 | )
366 |
--------------------------------------------------------------------------------
/imagetokenizer/quantize/vector_quantize.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 | import torch.distributed as distributed
7 | from torch.optim import Optimizer
8 | from torch.cuda.amp import autocast
9 |
10 | from einops import rearrange, repeat, reduce, pack, unpack
11 |
12 | from typing import Callable
13 |
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 |
19 | def default(val, d):
20 | return val if exists(val) else d
21 |
22 |
23 | def noop(*args, **kwargs):
24 | pass
25 |
26 |
27 | def identity(t):
28 | return t
29 |
30 |
31 | def l2norm(t):
32 | return F.normalize(t, p=2, dim=-1)
33 |
34 |
35 | def cdist(x, y):
36 | x2 = reduce(x**2, "b n d -> b n", "sum")
37 | y2 = reduce(y**2, "b n d -> b n", "sum")
38 | xy = einsum("b i d, b j d -> b i j", x, y) * -2
39 | return (
40 | (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy)
41 | .clamp(min=0)
42 | .sqrt()
43 | )
44 |
45 |
46 | def log(t, eps=1e-20):
47 | return torch.log(t.clamp(min=eps))
48 |
49 |
50 | def ema_inplace(old, new, decay):
51 | is_mps = str(old.device).startswith("mps:")
52 |
53 | if not is_mps:
54 | old.lerp_(new, 1 - decay)
55 | else:
56 | old.mul_(decay).add_(new * (1 - decay))
57 |
58 |
59 | def pack_one(t, pattern):
60 | return pack([t], pattern)
61 |
62 |
63 | def unpack_one(t, ps, pattern):
64 | return unpack(t, ps, pattern)[0]
65 |
66 |
67 | def uniform_init(*shape):
68 | t = torch.empty(shape)
69 | nn.init.kaiming_uniform_(t)
70 | return t
71 |
72 |
73 | def gumbel_noise(t):
74 | noise = torch.zeros_like(t).uniform_(0, 1)
75 | return -log(-log(noise))
76 |
77 |
78 | def gumbel_sample(
79 | logits,
80 | temperature=1.0,
81 | stochastic=False,
82 | straight_through=False,
83 | reinmax=False,
84 | dim=-1,
85 | training=True,
86 | ):
87 | dtype, size = logits.dtype, logits.shape[dim]
88 |
89 | if training and stochastic and temperature > 0:
90 | sampling_logits = (logits / temperature) + gumbel_noise(logits)
91 | else:
92 | sampling_logits = logits
93 |
94 | ind = sampling_logits.argmax(dim=dim)
95 | one_hot = F.one_hot(ind, size).type(dtype)
96 |
97 | assert not (
98 | reinmax and not straight_through
99 | ), "reinmax can only be turned on if using straight through gumbel softmax"
100 |
101 | if not straight_through or temperature <= 0.0 or not training:
102 | return ind, one_hot
103 |
104 | # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
105 | # algorithm 2
106 |
107 | if reinmax:
108 | π0 = logits.softmax(dim=dim)
109 | π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2
110 | π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
111 | π2 = 2 * π1 - 0.5 * π0
112 | one_hot = π2 - π2.detach() + one_hot
113 | else:
114 | π1 = (logits / temperature).softmax(dim=dim)
115 | one_hot = one_hot + π1 - π1.detach()
116 |
117 | return ind, one_hot
118 |
119 |
120 | def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1):
121 | denom = x.sum(dim=dim, keepdim=True)
122 | return (x + eps) / (denom + n_categories * eps)
123 |
124 |
125 | def sample_vectors(samples, num):
126 | num_samples, device = samples.shape[0], samples.device
127 | if num_samples >= num:
128 | indices = torch.randperm(num_samples, device=device)[:num]
129 | else:
130 | indices = torch.randint(0, num_samples, (num,), device=device)
131 |
132 | return samples[indices]
133 |
134 |
135 | def batched_sample_vectors(samples, num):
136 | return torch.stack(
137 | [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
138 | )
139 |
140 |
141 | def pad_shape(shape, size, dim=0):
142 | return [size if i == dim else s for i, s in enumerate(shape)]
143 |
144 |
145 | def sample_multinomial(total_count, probs):
146 | device = probs.device
147 | probs = probs.cpu()
148 |
149 | total_count = probs.new_full((), total_count)
150 | remainder = probs.new_ones(())
151 | sample = torch.empty_like(probs, dtype=torch.long)
152 |
153 | for i, p in enumerate(probs):
154 | s = torch.binomial(total_count, p / remainder)
155 | sample[i] = s
156 | total_count -= s
157 | remainder -= p
158 |
159 | return sample.to(device)
160 |
161 |
162 | def all_gather_sizes(x, dim):
163 | size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device)
164 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
165 | distributed.all_gather(all_sizes, size)
166 | return torch.stack(all_sizes)
167 |
168 |
169 | def all_gather_variably_sized(x, sizes, dim=0):
170 | rank = distributed.get_rank()
171 | all_x = []
172 |
173 | for i, size in enumerate(sizes):
174 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
175 | distributed.broadcast(t, src=i, async_op=True)
176 | all_x.append(t)
177 |
178 | distributed.barrier()
179 | return all_x
180 |
181 |
182 | def sample_vectors_distributed(local_samples, num):
183 | local_samples = rearrange(local_samples, "1 ... -> ...")
184 |
185 | rank = distributed.get_rank()
186 | all_num_samples = all_gather_sizes(local_samples, dim=0)
187 |
188 | if rank == 0:
189 | samples_per_rank = sample_multinomial(
190 | num, all_num_samples / all_num_samples.sum()
191 | )
192 | else:
193 | samples_per_rank = torch.empty_like(all_num_samples)
194 |
195 | distributed.broadcast(samples_per_rank, src=0)
196 | samples_per_rank = samples_per_rank.tolist()
197 |
198 | local_samples = sample_vectors(local_samples, samples_per_rank[rank])
199 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0)
200 | out = torch.cat(all_samples, dim=0)
201 |
202 | return rearrange(out, "... -> 1 ...")
203 |
204 |
205 | def batched_bincount(x, *, minlength):
206 | batch, dtype, device = x.shape[0], x.dtype, x.device
207 | target = torch.zeros(batch, minlength, dtype=dtype, device=device)
208 | values = torch.ones_like(x)
209 | target.scatter_add_(-1, x, values)
210 | return target
211 |
212 |
213 | def kmeans(
214 | samples,
215 | num_clusters,
216 | num_iters=10,
217 | use_cosine_sim=False,
218 | sample_fn=batched_sample_vectors,
219 | all_reduce_fn=noop,
220 | ):
221 | num_codebooks, dim, dtype, device = (
222 | samples.shape[0],
223 | samples.shape[-1],
224 | samples.dtype,
225 | samples.device,
226 | )
227 |
228 | means = sample_fn(samples, num_clusters)
229 |
230 | for _ in range(num_iters):
231 | if use_cosine_sim:
232 | dists = samples @ rearrange(means, "h n d -> h d n")
233 | else:
234 | dists = -cdist(samples, means)
235 |
236 | buckets = torch.argmax(dists, dim=-1)
237 | bins = batched_bincount(buckets, minlength=num_clusters)
238 | all_reduce_fn(bins)
239 |
240 | zero_mask = bins == 0
241 | bins_min_clamped = bins.masked_fill(zero_mask, 1)
242 |
243 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
244 |
245 | new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
246 | new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
247 | all_reduce_fn(new_means)
248 |
249 | if use_cosine_sim:
250 | new_means = l2norm(new_means)
251 |
252 | means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
253 |
254 | return means, bins
255 |
256 |
257 | def batched_embedding(indices, embeds):
258 | batch, dim = indices.shape[1], embeds.shape[-1]
259 | indices = repeat(indices, "h b n -> h b n d", d=dim)
260 | embeds = repeat(embeds, "h c d -> h b c d", b=batch)
261 | return embeds.gather(2, indices)
262 |
263 |
264 | # regularization losses
265 |
266 |
267 | def orthogonal_loss_fn(t):
268 | # eq (2) from https://arxiv.org/abs/2112.00384
269 | h, n = t.shape[:2]
270 | normed_codes = l2norm(t)
271 | cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
272 | return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
273 |
274 |
275 | # distance types
276 |
277 |
278 | class EuclideanCodebook(nn.Module):
279 | def __init__(
280 | self,
281 | dim,
282 | codebook_size,
283 | num_codebooks=1,
284 | kmeans_init=False,
285 | kmeans_iters=10,
286 | sync_kmeans=True,
287 | decay=0.8,
288 | eps=1e-5,
289 | threshold_ema_dead_code=2,
290 | reset_cluster_size=None,
291 | use_ddp=False,
292 | learnable_codebook=False,
293 | gumbel_sample=gumbel_sample,
294 | sample_codebook_temp=1.0,
295 | ema_update=True,
296 | affine_param=False,
297 | sync_affine_param=False,
298 | affine_param_batch_decay=0.99,
299 | affine_param_codebook_decay=0.9,
300 | ):
301 | super().__init__()
302 | self.transform_input = identity
303 |
304 | self.decay = decay
305 | self.ema_update = ema_update
306 |
307 | init_fn = uniform_init if not kmeans_init else torch.zeros
308 | embed = init_fn(num_codebooks, codebook_size, dim)
309 |
310 | self.codebook_size = codebook_size
311 | self.num_codebooks = num_codebooks
312 |
313 | self.kmeans_iters = kmeans_iters
314 | self.eps = eps
315 | self.threshold_ema_dead_code = threshold_ema_dead_code
316 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
317 |
318 | assert callable(gumbel_sample)
319 | self.gumbel_sample = gumbel_sample
320 | self.sample_codebook_temp = sample_codebook_temp
321 |
322 | assert not (
323 | use_ddp and num_codebooks > 1 and kmeans_init
324 | ), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
325 |
326 | self.sample_fn = (
327 | sample_vectors_distributed
328 | if use_ddp and sync_kmeans
329 | else batched_sample_vectors
330 | )
331 | self.kmeans_all_reduce_fn = (
332 | distributed.all_reduce if use_ddp and sync_kmeans else noop
333 | )
334 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
335 |
336 | self.register_buffer("initted", torch.Tensor([not kmeans_init]))
337 | self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
338 | self.register_buffer("embed_avg", embed.clone())
339 |
340 | self.learnable_codebook = learnable_codebook
341 | if learnable_codebook:
342 | self.embed = nn.Parameter(embed)
343 | else:
344 | self.register_buffer("embed", embed)
345 |
346 | # affine related params
347 |
348 | self.affine_param = affine_param
349 | self.sync_affine_param = sync_affine_param
350 |
351 | if not affine_param:
352 | return
353 |
354 | self.affine_param_batch_decay = affine_param_batch_decay
355 | self.affine_param_codebook_decay = affine_param_codebook_decay
356 |
357 | self.register_buffer("batch_mean", None)
358 | self.register_buffer("batch_variance", None)
359 |
360 | self.register_buffer("codebook_mean_needs_init", torch.Tensor([True]))
361 | self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim))
362 | self.register_buffer("codebook_variance_needs_init", torch.Tensor([True]))
363 | self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim))
364 |
365 | @torch.jit.ignore
366 | def init_embed_(self, data, mask=None):
367 | if self.initted:
368 | return
369 |
370 | if exists(mask):
371 | c = data.shape[0]
372 | data = rearrange(data[mask], "(c n) d -> c n d", c=c)
373 |
374 | embed, cluster_size = kmeans(
375 | data,
376 | self.codebook_size,
377 | self.kmeans_iters,
378 | sample_fn=self.sample_fn,
379 | all_reduce_fn=self.kmeans_all_reduce_fn,
380 | )
381 |
382 | embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
383 |
384 | self.embed.data.copy_(embed)
385 | self.embed_avg.data.copy_(embed_sum)
386 | self.cluster_size.data.copy_(cluster_size)
387 | self.initted.data.copy_(torch.Tensor([True]))
388 |
389 | @torch.jit.ignore
390 | def update_with_decay(self, buffer_name, new_value, decay):
391 | old_value = getattr(self, buffer_name)
392 |
393 | needs_init = getattr(self, buffer_name + "_needs_init", False)
394 |
395 | if needs_init:
396 | self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
397 |
398 | if not exists(old_value) or needs_init:
399 | self.register_buffer(buffer_name, new_value.detach())
400 |
401 | return
402 |
403 | value = old_value * decay + new_value.detach() * (1 - decay)
404 | self.register_buffer(buffer_name, value)
405 |
406 | @torch.jit.ignore
407 | def update_affine(self, data, embed, mask=None):
408 | assert self.affine_param
409 |
410 | var_fn = partial(torch.var, unbiased=False)
411 |
412 | # calculate codebook mean and variance
413 |
414 | embed = rearrange(embed, "h ... d -> h (...) d")
415 |
416 | if self.training:
417 | self.update_with_decay(
418 | "codebook_mean",
419 | reduce(embed, "h n d -> h 1 d", "mean"),
420 | self.affine_param_codebook_decay,
421 | )
422 | self.update_with_decay(
423 | "codebook_variance",
424 | reduce(embed, "h n d -> h 1 d", var_fn),
425 | self.affine_param_codebook_decay,
426 | )
427 |
428 | # prepare batch data, which depends on whether it has masking
429 |
430 | data = rearrange(data, "h ... d -> h (...) d")
431 |
432 | if exists(mask):
433 | c = data.shape[0]
434 | data = rearrange(data[mask], "(c n) d -> c n d", c=c)
435 |
436 | # calculate batch mean and variance
437 |
438 | if not self.sync_affine_param:
439 | self.update_with_decay(
440 | "batch_mean",
441 | reduce(data, "h n d -> h 1 d", "mean"),
442 | self.affine_param_batch_decay,
443 | )
444 | self.update_with_decay(
445 | "batch_variance",
446 | reduce(data, "h n d -> h 1 d", var_fn),
447 | self.affine_param_batch_decay,
448 | )
449 | return
450 |
451 | num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
452 |
453 | # number of vectors, for denominator
454 |
455 | num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype)
456 | distributed.all_reduce(num_vectors)
457 |
458 | # calculate distributed mean
459 |
460 | batch_sum = reduce(data, "h n d -> h 1 d", "sum")
461 | distributed.all_reduce(batch_sum)
462 | batch_mean = batch_sum / num_vectors
463 |
464 | self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)
465 |
466 | # calculate distributed variance
467 |
468 | variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
469 | distributed.all_reduce(variance_numer)
470 | batch_variance = variance_numer / num_vectors
471 |
472 | self.update_with_decay(
473 | "batch_variance", batch_variance, self.affine_param_batch_decay
474 | )
475 |
476 | def replace(self, batch_samples, batch_mask):
477 | for ind, (samples, mask) in enumerate(
478 | zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
479 | ):
480 | if not torch.any(mask):
481 | continue
482 |
483 | sampled = self.sample_fn(
484 | rearrange(samples, "... -> 1 ..."), mask.sum().item()
485 | )
486 | sampled = rearrange(sampled, "1 ... -> ...")
487 |
488 | self.embed.data[ind][mask] = sampled
489 |
490 | self.cluster_size.data[ind][mask] = self.reset_cluster_size
491 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
492 |
493 | def expire_codes_(self, batch_samples):
494 | if self.threshold_ema_dead_code == 0:
495 | return
496 |
497 | expired_codes = self.cluster_size < self.threshold_ema_dead_code
498 |
499 | if not torch.any(expired_codes):
500 | return
501 |
502 | batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
503 | self.replace(batch_samples, batch_mask=expired_codes)
504 |
505 | @autocast(enabled=False)
506 | def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
507 | needs_codebook_dim = x.ndim < 4
508 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
509 |
510 | x = x.float()
511 |
512 | if needs_codebook_dim:
513 | x = rearrange(x, "... -> 1 ...")
514 |
515 | dtype = x.dtype
516 | flatten, ps = pack_one(x, "h * d")
517 |
518 | if exists(mask):
519 | mask = repeat(
520 | mask,
521 | "b n -> c (b h n)",
522 | c=flatten.shape[0],
523 | h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
524 | )
525 |
526 | self.init_embed_(flatten, mask=mask)
527 |
528 | if self.affine_param:
529 | self.update_affine(flatten, self.embed, mask=mask)
530 |
531 | embed = self.embed if self.learnable_codebook else self.embed.detach()
532 |
533 | if self.affine_param:
534 | codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
535 | batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
536 | embed = (embed - self.codebook_mean) * (
537 | batch_std / codebook_std
538 | ) + self.batch_mean
539 |
540 | dist = -cdist(flatten, embed)
541 |
542 | embed_ind, embed_onehot = self.gumbel_sample(
543 | dist, dim=-1, temperature=sample_codebook_temp, training=self.training
544 | )
545 |
546 | embed_ind = unpack_one(embed_ind, ps, "h *")
547 |
548 | if self.training:
549 | unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
550 | quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
551 | else:
552 | quantize = batched_embedding(embed_ind, embed)
553 |
554 | if self.training and self.ema_update and not freeze_codebook:
555 |
556 | if self.affine_param:
557 | flatten = (flatten - self.batch_mean) * (
558 | codebook_std / batch_std
559 | ) + self.codebook_mean
560 |
561 | if exists(mask):
562 | embed_onehot[~mask] = 0.0
563 |
564 | cluster_size = embed_onehot.sum(dim=1)
565 |
566 | self.all_reduce_fn(cluster_size)
567 | ema_inplace(self.cluster_size.data, cluster_size, self.decay)
568 |
569 | embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
570 | embed_sum = embed_sum.contiguous()
571 | self.all_reduce_fn(embed_sum)
572 |
573 | ema_inplace(self.embed_avg.data, embed_sum, self.decay)
574 |
575 | cluster_size = laplace_smoothing(
576 | self.cluster_size, self.codebook_size, self.eps
577 | ) * self.cluster_size.sum(dim=-1, keepdim=True)
578 |
579 | embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
580 | self.embed.data.copy_(embed_normalized)
581 | self.expire_codes_(x)
582 |
583 | if needs_codebook_dim:
584 | quantize, embed_ind = map(
585 | lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
586 | )
587 |
588 | dist = unpack_one(dist, ps, "h * d")
589 |
590 | return quantize, embed_ind, dist
591 |
592 |
593 | class CosineSimCodebook(nn.Module):
594 | def __init__(
595 | self,
596 | dim,
597 | codebook_size,
598 | num_codebooks=1,
599 | kmeans_init=False,
600 | kmeans_iters=10,
601 | sync_kmeans=True,
602 | decay=0.8,
603 | eps=1e-5,
604 | threshold_ema_dead_code=2,
605 | reset_cluster_size=None,
606 | use_ddp=False,
607 | learnable_codebook=False,
608 | gumbel_sample=gumbel_sample,
609 | sample_codebook_temp=1.0,
610 | ema_update=True,
611 | ):
612 | super().__init__()
613 | self.transform_input = l2norm
614 |
615 | self.ema_update = ema_update
616 | self.decay = decay
617 |
618 | if not kmeans_init:
619 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
620 | else:
621 | embed = torch.zeros(num_codebooks, codebook_size, dim)
622 |
623 | self.codebook_size = codebook_size
624 | self.num_codebooks = num_codebooks
625 |
626 | self.kmeans_iters = kmeans_iters
627 | self.eps = eps
628 | self.threshold_ema_dead_code = threshold_ema_dead_code
629 | self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
630 |
631 | assert callable(gumbel_sample)
632 | self.gumbel_sample = gumbel_sample
633 | self.sample_codebook_temp = sample_codebook_temp
634 |
635 | self.sample_fn = (
636 | sample_vectors_distributed
637 | if use_ddp and sync_kmeans
638 | else batched_sample_vectors
639 | )
640 | self.kmeans_all_reduce_fn = (
641 | distributed.all_reduce if use_ddp and sync_kmeans else noop
642 | )
643 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
644 |
645 | self.register_buffer("initted", torch.Tensor([not kmeans_init]))
646 | self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
647 | self.register_buffer("embed_avg", embed.clone())
648 |
649 | self.learnable_codebook = learnable_codebook
650 | if learnable_codebook:
651 | self.embed = nn.Parameter(embed)
652 | else:
653 | self.register_buffer("embed", embed)
654 |
655 | @torch.jit.ignore
656 | def init_embed_(self, data, mask=None):
657 | if self.initted:
658 | return
659 |
660 | if exists(mask):
661 | c = data.shape[0]
662 | data = rearrange(data[mask], "(c n) d -> c n d", c=c)
663 |
664 | embed, cluster_size = kmeans(
665 | data,
666 | self.codebook_size,
667 | self.kmeans_iters,
668 | use_cosine_sim=True,
669 | sample_fn=self.sample_fn,
670 | all_reduce_fn=self.kmeans_all_reduce_fn,
671 | )
672 |
673 | embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
674 |
675 | self.embed.data.copy_(embed)
676 | self.embed_avg.data.copy_(embed_sum)
677 | self.cluster_size.data.copy_(cluster_size)
678 | self.initted.data.copy_(torch.Tensor([True]))
679 |
680 | def replace(self, batch_samples, batch_mask):
681 | batch_samples = l2norm(batch_samples)
682 |
683 | for ind, (samples, mask) in enumerate(
684 | zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
685 | ):
686 | if not torch.any(mask):
687 | continue
688 |
689 | sampled = self.sample_fn(
690 | rearrange(samples, "... -> 1 ..."), mask.sum().item()
691 | )
692 | sampled = rearrange(sampled, "1 ... -> ...")
693 |
694 | self.embed.data[ind][mask] = sampled
695 | self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
696 | self.cluster_size.data[ind][mask] = self.reset_cluster_size
697 |
698 | def expire_codes_(self, batch_samples):
699 | if self.threshold_ema_dead_code == 0:
700 | return
701 |
702 | expired_codes = self.cluster_size < self.threshold_ema_dead_code
703 |
704 | if not torch.any(expired_codes):
705 | return
706 |
707 | batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
708 | self.replace(batch_samples, batch_mask=expired_codes)
709 |
710 | @autocast(enabled=False)
711 | def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
712 | needs_codebook_dim = x.ndim < 4
713 | sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
714 |
715 | x = x.float()
716 |
717 | if needs_codebook_dim:
718 | x = rearrange(x, "... -> 1 ...")
719 |
720 | dtype = x.dtype
721 |
722 | flatten, ps = pack_one(x, "h * d")
723 |
724 | if exists(mask):
725 | mask = repeat(
726 | mask,
727 | "b n -> c (b h n)",
728 | c=flatten.shape[0],
729 | h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
730 | )
731 |
732 | self.init_embed_(flatten, mask=mask)
733 |
734 | embed = self.embed if self.learnable_codebook else self.embed.detach()
735 |
736 | dist = einsum("h n d, h c d -> h n c", flatten, embed)
737 |
738 | embed_ind, embed_onehot = self.gumbel_sample(
739 | dist, dim=-1, temperature=sample_codebook_temp, training=self.training
740 | )
741 | embed_ind = unpack_one(embed_ind, ps, "h *")
742 |
743 | if self.training:
744 | unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
745 | quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
746 | else:
747 | quantize = batched_embedding(embed_ind, embed)
748 |
749 | if self.training and self.ema_update and not freeze_codebook:
750 | if exists(mask):
751 | embed_onehot[~mask] = 0.0
752 |
753 | bins = embed_onehot.sum(dim=1)
754 | self.all_reduce_fn(bins)
755 |
756 | ema_inplace(self.cluster_size.data, bins, self.decay)
757 |
758 | embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
759 | embed_sum = embed_sum.contiguous()
760 | self.all_reduce_fn(embed_sum)
761 |
762 | ema_inplace(self.embed_avg.data, embed_sum, self.decay)
763 |
764 | cluster_size = laplace_smoothing(
765 | self.cluster_size, self.codebook_size, self.eps
766 | ) * self.cluster_size.sum(dim=-1, keepdim=True)
767 |
768 | embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
769 | embed_normalized = l2norm(embed_normalized)
770 |
771 | self.embed.data.copy_(l2norm(embed_normalized))
772 | self.expire_codes_(x)
773 |
774 | if needs_codebook_dim:
775 | quantize, embed_ind = map(
776 | lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
777 | )
778 |
779 | dist = unpack_one(dist, ps, "h * d")
780 | return quantize, embed_ind, dist
781 |
782 |
783 | # main class
784 |
785 |
786 | class VectorQuantize(nn.Module):
787 | def __init__(
788 | self,
789 | dim,
790 | codebook_size,
791 | codebook_dim=None,
792 | heads=1,
793 | separate_codebook_per_head=False,
794 | decay=0.8,
795 | eps=1e-5,
796 | freeze_codebook=False,
797 | kmeans_init=False,
798 | kmeans_iters=10,
799 | sync_kmeans=True,
800 | use_cosine_sim=False,
801 | threshold_ema_dead_code=0,
802 | channel_last=True,
803 | accept_image_fmap=False,
804 | commitment_weight=1.0,
805 | commitment_use_cross_entropy_loss=False,
806 | orthogonal_reg_weight=0.0,
807 | orthogonal_reg_active_codes_only=False,
808 | orthogonal_reg_max_codes=None,
809 | stochastic_sample_codes=False,
810 | sample_codebook_temp=1.0,
811 | straight_through=False,
812 | reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all
813 | sync_codebook=None,
814 | sync_affine_param=False,
815 | ema_update=True,
816 | learnable_codebook=False,
817 | in_place_codebook_optimizer: Callable[
818 | ..., Optimizer
819 | ] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
820 | affine_param=False,
821 | affine_param_batch_decay=0.99,
822 | affine_param_codebook_decay=0.9,
823 | sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
824 | ):
825 | super().__init__()
826 | self.dim = dim
827 | self.heads = heads
828 | self.separate_codebook_per_head = separate_codebook_per_head
829 |
830 | codebook_dim = default(codebook_dim, dim)
831 | codebook_input_dim = codebook_dim * heads
832 |
833 | requires_projection = codebook_input_dim != dim
834 | self.project_in = (
835 | nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
836 | )
837 | self.project_out = (
838 | nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
839 | )
840 |
841 | self.has_projections = requires_projection
842 |
843 | self.eps = eps
844 | self.commitment_weight = commitment_weight
845 | self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
846 |
847 | self.learnable_codebook = learnable_codebook
848 |
849 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
850 | self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
851 | self.orthogonal_reg_weight = orthogonal_reg_weight
852 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
853 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
854 |
855 | assert not (
856 | ema_update and learnable_codebook
857 | ), "learnable codebook not compatible with EMA update"
858 |
859 | assert 0 <= sync_update_v <= 1.0
860 | assert not (
861 | sync_update_v > 0.0 and not learnable_codebook
862 | ), "learnable codebook must be turned on"
863 |
864 | self.sync_update_v = sync_update_v
865 |
866 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
867 |
868 | gumbel_sample_fn = partial(
869 | gumbel_sample,
870 | stochastic=stochastic_sample_codes,
871 | reinmax=reinmax,
872 | straight_through=straight_through,
873 | )
874 |
875 | if not exists(sync_codebook):
876 | sync_codebook = (
877 | distributed.is_initialized() and distributed.get_world_size() > 1
878 | )
879 |
880 | codebook_kwargs = dict(
881 | dim=codebook_dim,
882 | num_codebooks=heads if separate_codebook_per_head else 1,
883 | codebook_size=codebook_size,
884 | kmeans_init=kmeans_init,
885 | kmeans_iters=kmeans_iters,
886 | sync_kmeans=sync_kmeans,
887 | decay=decay,
888 | eps=eps,
889 | threshold_ema_dead_code=threshold_ema_dead_code,
890 | use_ddp=sync_codebook,
891 | learnable_codebook=has_codebook_orthogonal_loss or learnable_codebook,
892 | sample_codebook_temp=sample_codebook_temp,
893 | gumbel_sample=gumbel_sample_fn,
894 | ema_update=ema_update,
895 | )
896 |
897 | if affine_param:
898 | assert (
899 | not use_cosine_sim
900 | ), "affine param is only compatible with euclidean codebook"
901 | codebook_kwargs = dict(
902 | **codebook_kwargs,
903 | affine_param=True,
904 | sync_affine_param=sync_affine_param,
905 | affine_param_batch_decay=affine_param_batch_decay,
906 | affine_param_codebook_decay=affine_param_codebook_decay,
907 | )
908 |
909 | self._codebook = codebook_class(**codebook_kwargs)
910 |
911 | self.in_place_codebook_optimizer = (
912 | in_place_codebook_optimizer(self._codebook.parameters())
913 | if exists(in_place_codebook_optimizer)
914 | else None
915 | )
916 |
917 | self.codebook_size = codebook_size
918 | self.register_buffer("codebook_usage", torch.zeros(codebook_size))
919 | self.call_cnt = 0
920 |
921 | self.accept_image_fmap = accept_image_fmap
922 | self.channel_last = channel_last
923 |
924 | @property
925 | def codebook(self):
926 | codebook = self._codebook.embed
927 |
928 | if self.separate_codebook_per_head:
929 | return codebook
930 |
931 | return rearrange(codebook, "1 ... -> ...")
932 |
933 | @codebook.setter
934 | def codebook(self, codes):
935 | if not self.separate_codebook_per_head:
936 | codes = rearrange(codes, "... -> 1 ...")
937 |
938 | self._codebook.embed.copy_(codes)
939 |
940 | def get_codes_from_indices(self, indices):
941 | codebook = self.codebook
942 | is_multiheaded = codebook.ndim > 2
943 |
944 | if not is_multiheaded:
945 | codes = codebook[indices]
946 | return rearrange(codes, "... h d -> ... (h d)")
947 |
948 | indices, ps = pack_one(indices, "b * h")
949 | indices = rearrange(indices, "b n h -> b h n")
950 |
951 | indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1])
952 | codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0])
953 |
954 | codes = codebook.gather(2, indices)
955 | codes = rearrange(codes, "b h n d -> b n (h d)")
956 | codes = unpack_one(codes, ps, "b * d")
957 | return codes
958 |
959 | def get_output_from_indices(self, indices):
960 | codes = self.get_codes_from_indices(indices)
961 | return self.project_out(codes)
962 |
963 | def get_perplexity(self, encoding_indices, x):
964 | encode_onehot = F.one_hot(encoding_indices, self.codebook_size).type_as(
965 | x
966 | ) # [bthw, ncode]
967 | avg_probs = torch.mean(encode_onehot, dim=0)
968 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
969 | return perplexity
970 |
971 | def get_usage(self, encoding_indices):
972 | # Flatten the batch of encoding indices into a single 1D tensor
973 | all_indices = encoding_indices.flatten()
974 |
975 | # Obtain the total number of encoding indices in the batch to calculate percentages
976 | total_indices = all_indices.numel()
977 |
978 | # Initialize a tensor to store the percentage usage of each code
979 | codebook_usage_percentage = torch.zeros(
980 | self.codebook_size, device=all_indices.device
981 | )
982 |
983 | # Count the number of occurrences of each index and get their frequency as percentages
984 | unique_indices, counts = torch.unique(all_indices, return_counts=True)
985 |
986 | # Calculate the percentage
987 | percentages = counts.float() / total_indices
988 |
989 | # Populate the corresponding percentages in the codebook_usage_percentage tensor
990 | codebook_usage_percentage[unique_indices.long()] = percentages
991 |
992 | return codebook_usage_percentage
993 |
994 | def forward(
995 | self,
996 | x,
997 | indices=None,
998 | mask=None,
999 | sample_codebook_temp=None,
1000 | freeze_codebook=False,
1001 | ):
1002 | orig_input = x
1003 |
1004 | only_one = x.ndim == 2
1005 |
1006 | if only_one:
1007 | assert not exists(mask)
1008 | x = rearrange(x, "b d -> b 1 d")
1009 |
1010 | shape, device, heads, is_multiheaded, codebook_size, return_loss = (
1011 | x.shape,
1012 | x.device,
1013 | self.heads,
1014 | self.heads > 1,
1015 | self.codebook_size,
1016 | exists(indices),
1017 | )
1018 |
1019 | need_transpose = not self.channel_last and not self.accept_image_fmap
1020 | should_inplace_optimize = exists(self.in_place_codebook_optimizer)
1021 |
1022 | # rearrange inputs
1023 |
1024 | if self.accept_image_fmap:
1025 | nframes, height, width = x.shape[-3:]
1026 | x = rearrange(x, "b c t h w -> b (t h w) c")
1027 |
1028 | if need_transpose:
1029 | x = rearrange(x, "b d n -> b n d")
1030 |
1031 | # project input
1032 |
1033 | x = self.project_in(x)
1034 |
1035 | # handle multi-headed separate codebooks
1036 |
1037 | if is_multiheaded:
1038 | ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
1039 | x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)
1040 |
1041 | # l2norm for cosine sim, otherwise identity
1042 |
1043 | x = self._codebook.transform_input(x)
1044 |
1045 | # codebook forward kwargs
1046 |
1047 | codebook_forward_kwargs = dict(
1048 | sample_codebook_temp=sample_codebook_temp,
1049 | mask=mask,
1050 | freeze_codebook=freeze_codebook,
1051 | )
1052 |
1053 | # quantize
1054 |
1055 | quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
1056 |
1057 | # one step in-place update
1058 |
1059 | if should_inplace_optimize and self.training and not freeze_codebook:
1060 |
1061 | if exists(mask):
1062 | loss = F.mse_loss(quantize, x.detach(), reduction="none")
1063 |
1064 | loss_mask = mask
1065 | if is_multiheaded:
1066 | loss_mask = repeat(
1067 | mask,
1068 | "b n -> c (b h) n",
1069 | c=loss.shape[0],
1070 | h=loss.shape[1] // mask.shape[0],
1071 | )
1072 |
1073 | loss = loss[loss_mask].mean()
1074 |
1075 | else:
1076 | loss = F.mse_loss(quantize, x.detach())
1077 |
1078 | loss.backward()
1079 | self.in_place_codebook_optimizer.step()
1080 | self.in_place_codebook_optimizer.zero_grad()
1081 |
1082 | # quantize again
1083 |
1084 | quantize, embed_ind, distances = self._codebook(
1085 | x, **codebook_forward_kwargs
1086 | )
1087 |
1088 | if self.training:
1089 | # determine code to use for commitment loss
1090 | maybe_detach = (
1091 | torch.detach
1092 | if not self.learnable_codebook or freeze_codebook
1093 | else identity
1094 | )
1095 |
1096 | commit_quantize = maybe_detach(quantize)
1097 |
1098 | # straight through
1099 |
1100 | quantize = x + (quantize - x).detach()
1101 |
1102 | if self.sync_update_v > 0.0:
1103 | # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
1104 | quantize = quantize + self.sync_update_v * (
1105 | quantize - quantize.detach()
1106 | )
1107 |
1108 | # function for calculating cross entropy loss to distance matrix
1109 | # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
1110 |
1111 | def calculate_ce_loss(codes):
1112 | if not is_multiheaded:
1113 | dist_einops_eq = "1 b n l -> b l n"
1114 | elif self.separate_codebook_per_head:
1115 | dist_einops_eq = "c b n l -> b l n c"
1116 | else:
1117 | dist_einops_eq = "1 (b h) n l -> b l n h"
1118 |
1119 | ce_loss = F.cross_entropy(
1120 | rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1
1121 | )
1122 |
1123 | return ce_loss
1124 |
1125 | # if returning cross entropy loss on codes that were passed in
1126 |
1127 | if return_loss:
1128 | print(indices)
1129 | return quantize, calculate_ce_loss(indices)
1130 |
1131 | # transform embedding indices
1132 |
1133 | if is_multiheaded:
1134 | if self.separate_codebook_per_head:
1135 | embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
1136 | else:
1137 | embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
1138 |
1139 | if self.accept_image_fmap:
1140 | embed_ind = rearrange(
1141 | embed_ind, "b (t h w) ... -> b t h w ...", t=nframes, h=height, w=width
1142 | )
1143 |
1144 | if only_one:
1145 | embed_ind = rearrange(embed_ind, "b 1 ... -> b ...")
1146 |
1147 | # aggregate loss
1148 |
1149 | loss = torch.tensor([0.0], device=device, requires_grad=self.training)
1150 |
1151 | if self.training:
1152 | if self.commitment_weight > 0:
1153 | if self.commitment_use_cross_entropy_loss:
1154 | if exists(mask):
1155 | ce_loss_mask = mask
1156 | if is_multiheaded:
1157 | ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
1158 |
1159 | embed_ind.masked_fill_(~ce_loss_mask, -1)
1160 |
1161 | print(embed_ind.shape, embed_ind)
1162 | commit_loss = calculate_ce_loss(embed_ind)
1163 | else:
1164 | if exists(mask):
1165 | # with variable lengthed sequences
1166 | commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
1167 |
1168 | loss_mask = mask
1169 | if is_multiheaded:
1170 | loss_mask = repeat(
1171 | loss_mask,
1172 | "b n -> c (b h) n",
1173 | c=commit_loss.shape[0],
1174 | h=commit_loss.shape[1] // mask.shape[0],
1175 | )
1176 |
1177 | commit_loss = commit_loss[loss_mask].mean()
1178 | else:
1179 | commit_loss = F.mse_loss(commit_quantize, x)
1180 |
1181 | loss = loss + commit_loss * self.commitment_weight
1182 |
1183 | if self.has_codebook_orthogonal_loss:
1184 | codebook = self._codebook.embed
1185 |
1186 | # only calculate orthogonal loss for the activated codes for this batch
1187 |
1188 | if self.orthogonal_reg_active_codes_only:
1189 | assert not (
1190 | is_multiheaded and self.separate_codebook_per_head
1191 | ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
1192 | unique_code_ids = torch.unique(embed_ind)
1193 | codebook = codebook[:, unique_code_ids]
1194 |
1195 | num_codes = codebook.shape[-2]
1196 |
1197 | if (
1198 | exists(self.orthogonal_reg_max_codes)
1199 | and num_codes > self.orthogonal_reg_max_codes
1200 | ):
1201 | rand_ids = torch.randperm(num_codes, device=device)[
1202 | : self.orthogonal_reg_max_codes
1203 | ]
1204 | codebook = codebook[:, rand_ids]
1205 |
1206 | orthogonal_reg_loss = orthogonal_loss_fn(codebook)
1207 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
1208 |
1209 | # handle multi-headed quantized embeddings
1210 |
1211 | if is_multiheaded:
1212 | if self.separate_codebook_per_head:
1213 | quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
1214 | else:
1215 | quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)
1216 |
1217 | # project out
1218 |
1219 | quantize = self.project_out(quantize)
1220 |
1221 | # rearrange quantized embeddings
1222 |
1223 | if need_transpose:
1224 | quantize = rearrange(quantize, "b n d -> b d n")
1225 |
1226 | if self.accept_image_fmap:
1227 | quantize = rearrange(
1228 | quantize, "b (t h w) c -> b c t h w", t=nframes, h=height, w=width
1229 | )
1230 |
1231 | if only_one:
1232 | quantize = rearrange(quantize, "b 1 d -> b d")
1233 |
1234 | # if masking, only return quantized for where mask has True
1235 |
1236 | if exists(mask):
1237 | quantize = torch.where(
1238 | rearrange(mask, "... -> ... 1"), quantize, orig_input
1239 | )
1240 |
1241 | # return quantize, embed_ind, loss
1242 | perplexity = self.get_perplexity(embed_ind, x)
1243 | usage = self.get_usage(embed_ind)
1244 |
1245 | if self.call_cnt == 0:
1246 | self.codebook_usage.data = usage
1247 | else:
1248 | self.codebook_usage.data = (
1249 | 0.99 * self.codebook_usage.data + (1 - 0.99) * usage
1250 | )
1251 |
1252 | self.call_cnt += 1
1253 | # avg_distribution = self.codebook_usage.data.sum() / self.codebook_size
1254 | avg_usage = (
1255 | self.codebook_usage.data > (1 / self.codebook_size)
1256 | ).sum() / self.codebook_size
1257 |
1258 | return dict(
1259 | embeddings=quantize,
1260 | encodings=embed_ind,
1261 | commitment_loss=loss,
1262 | perplexity=perplexity,
1263 | avg_usage=avg_usage,
1264 | batch_usage=usage,
1265 | )
1266 |
--------------------------------------------------------------------------------
/imagetokenizer/utils/omnitokenizer_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | # Shifts src_tf dim to dest dim
7 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
8 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
9 | n_dims = len(x.shape)
10 | if src_dim < 0:
11 | src_dim = n_dims + src_dim
12 | if dest_dim < 0:
13 | dest_dim = n_dims + dest_dim
14 |
15 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
16 |
17 | dims = list(range(n_dims))
18 | del dims[src_dim]
19 |
20 | permutation = []
21 | ctr = 0
22 | for i in range(n_dims):
23 | if i == dest_dim:
24 | permutation.append(src_dim)
25 | else:
26 | permutation.append(dims[ctr])
27 | ctr += 1
28 | x = x.permute(permutation)
29 | if make_contiguous:
30 | x = x.contiguous()
31 | return x
32 |
33 |
34 | def Normalize(in_channels, norm_type="group"):
35 | assert norm_type in ["group", "batch"]
36 | if norm_type == "group":
37 | return torch.nn.GroupNorm(
38 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
39 | )
40 | elif norm_type == "batch":
41 | return torch.nn.SyncBatchNorm(in_channels)
42 |
43 |
44 | def logits_laplace(x, x_recons, logit_laplace_eps=0.1):
45 | # [-0.5, 0.5] -> [0, 1]
46 | x += 0.5
47 | x_recons += 0.5
48 | # [0, 1] -> [eps, 1-eps]
49 | x_laplace = (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
50 | x_recons_laplace = (1 - 2 * logit_laplace_eps) * x_recons + logit_laplace_eps
51 | return F.l1_loss(x_laplace, x_recons_laplace)
52 |
53 |
54 | def divisible_by(numer, denom):
55 | return (numer % denom) == 0
56 |
57 |
58 | def pair(val):
59 | ret = (val, val) if not isinstance(val, tuple) else val
60 | assert len(ret) == 2
61 | return ret
62 |
63 |
64 | def silu(x):
65 | return x * torch.sigmoid(x)
66 |
67 |
68 | class SiLU(nn.Module):
69 | def __init__(self):
70 | super(SiLU, self).__init__()
71 |
72 | def forward(self, x):
73 | return silu(x)
74 |
--------------------------------------------------------------------------------
/imagetokenizer/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Lucas Jin. All rights reserved.
2 | from datetime import datetime
3 |
4 | major_num = 2
5 |
6 | __version__ = "0.0.2"
7 | short_version = __version__
8 |
9 |
10 | def parse_version_info(version_str):
11 | version_info = []
12 | for x in version_str.split("."):
13 | if x.isdigit():
14 | version_info.append(int(x))
15 | elif x.find("rc") != -1:
16 | patch_version = x.split("rc")
17 | version_info.append(int(patch_version[0]))
18 | version_info.append(f"rc{patch_version[1]}")
19 | return tuple(version_info)
20 |
21 |
22 | version_info = parse_version_info(__version__)
23 |
--------------------------------------------------------------------------------
/ps.sh:
--------------------------------------------------------------------------------
1 | # autopep8 -r ./minigemini/ -i
2 |
3 | git add .
4 | git commit -am 'add'
5 | git push origin main
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright (c) 2020 JinTian.
4 | #
5 | # This file is part of alfred
6 | # (see http://jinfagang.github.io).
7 | #
8 | # Licensed to the Apache Software Foundation (ASF) under one
9 | # or more contributor license agreements. See the NOTICE file
10 | # distributed with this work for additional information
11 | # regarding copyright ownership. The ASF licenses this file
12 | # to you under the Apache License, Version 2.0 (the
13 | # "License"); you may not use this file except in compliance
14 | # with the License. You may obtain a copy of the License at
15 | #
16 | # http://www.apache.org/licenses/LICENSE-2.0
17 | #
18 | # Unless required by applicable law or agreed to in writing,
19 | # software distributed under the License is distributed on an
20 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
21 | # KIND, either express or implied. See the License for the
22 | # specific language governing permissions and limitations
23 | # under the License.
24 | #
25 | """
26 | install alfred into local bin dir.
27 | """
28 | from setuptools import setup, find_packages
29 | from setuptools import setup, Extension
30 | import io
31 | from os import path
32 |
33 | this_directory = path.abspath(path.dirname(__file__))
34 | with io.open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
35 | long_description = f.read()
36 |
37 |
38 | version_file = "imagetokenizer/version.py"
39 |
40 |
41 | def get_version():
42 | with open(version_file, "r") as f:
43 | exec(compile(f.read(), version_file, "exec"))
44 | return locals()["__version__"]
45 |
46 |
47 | setup(
48 | name="imagetokenizer",
49 | version=get_version(),
50 | keywords=["deep learning", "script helper", "tools"],
51 | description="Image Tokenizer encode visuals.",
52 | long_description=long_description,
53 | long_description_content_type="text/markdown",
54 | license="GPL-3.0",
55 | classifiers=[
56 | # Operation system
57 | "Operating System :: OS Independent",
58 | # How mature is this project? Common values are
59 | # 3 - Alpha
60 | # 4 - Beta
61 | # 5 - Production/Stable
62 | "Development Status :: 4 - Beta",
63 | # Indicate who your project is intended for
64 | "Intended Audience :: Developers",
65 | # Topics
66 | "Topic :: Education",
67 | "Topic :: Scientific/Engineering",
68 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
69 | "Topic :: Scientific/Engineering :: Image Recognition",
70 | # Pick your license as you wish
71 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
72 | # Specify the Python versions you support here. In particular, ensure
73 | # that you indicate whether you support Python 2, Python 3 or both.
74 | "Programming Language :: Python :: 3",
75 | "Programming Language :: Python :: 3.6",
76 | "Programming Language :: Python :: 3.7",
77 | "Programming Language :: Python :: 3.8",
78 | "Programming Language :: Python :: 3.9",
79 | ],
80 | packages=["imagetokenizer"],
81 | # entry_points={"console_scripts": ["alfred = alfred.alfred:main"]},
82 | include_package_data=True,
83 | author="Lucas Jin",
84 | author_email="jinfagang19@163.com",
85 | url="https://github.com/lucasjinreal/ImageTokenizer",
86 | platforms="any",
87 | install_requires=["beartype"],
88 | )
89 |
--------------------------------------------------------------------------------
/test_image_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Sending an image, encode it in a [1, 16, h, w] token
3 | then decode it back to original image
4 | """
5 |
6 | """
7 | We provide Tokenizer Inference code here.
8 | """
9 | import os
10 | import sys
11 | import torch
12 | import importlib
13 | import numpy as np
14 | from PIL import Image
15 | import argparse
16 | import torchvision.transforms as T
17 | from imagetokenizer.model import Magvit2Tokenizer, OmniTokenizer, TiTok
18 |
19 |
20 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | def load_vqgan_new(num_down, ckpt_path=None, is_gumbel=False):
24 | if "magvit2" in ckpt_path.lower():
25 | model = Magvit2Tokenizer(num_down=num_down, use_ema=True)
26 | if ckpt_path is not None:
27 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
28 | missing, unexpected = model.load_state_dict(sd, strict=False)
29 | elif "omni" in ckpt_path.lower():
30 | model = OmniTokenizer()
31 | if ckpt_path is not None:
32 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
33 | missing, unexpected = model.load_state_dict(sd, strict=False)
34 | elif "titok" in ckpt_path.lower():
35 | model = TiTok()
36 | if ckpt_path is not None:
37 | model.load_weights(ckpt_path)
38 | return model.eval()
39 |
40 |
41 | def get_obj_from_str(string, reload=False):
42 | print(string)
43 | module, cls = string.rsplit(".", 1)
44 | if reload:
45 | module_imp = importlib.import_module(module)
46 | importlib.reload(module_imp)
47 | return getattr(importlib.import_module(module, package=None), cls)
48 |
49 |
50 | def instantiate_from_config(config):
51 | if not "class_path" in config:
52 | raise KeyError("Expected key `class_path` to instantiate.")
53 | return get_obj_from_str(config["class_path"])(**config.get("init_args", dict()))
54 |
55 |
56 | def custom_to_pil(x):
57 | x = x.detach().cpu()
58 | x = torch.clamp(x, -1.0, 1.0)
59 | x = (x + 1.0) / 2.0
60 | x = x.permute(1, 2, 0).numpy()
61 | x = (255 * x).astype(np.uint8)
62 | x = Image.fromarray(x)
63 | if not x.mode == "RGB":
64 | x = x.convert("RGB")
65 | return x
66 |
67 |
68 | def get_image_tensor_for_encoder(image):
69 | image = image / 127.5 - 1.0
70 | image = T.ToTensor()(image).unsqueeze(0)
71 | # reshape the image to closest multiple 8 size
72 | height, width = image.shape[2], image.shape[3]
73 | new_height = ((height + 7) // 8) * 8
74 | new_width = ((width + 7) // 8) * 8 # 调整图像大小
75 | image = torch.nn.functional.interpolate(
76 | image, size=(new_height, new_width), mode="bilinear", align_corners=False
77 | )
78 | return image
79 |
80 |
81 | def main(args):
82 | model = load_vqgan_new(args.num_down, args.ckpt_path).to(DEVICE)
83 |
84 | visualize_dir = "results/"
85 | visualize_version = "v0"
86 | visualize_original = os.path.join(
87 | visualize_dir, visualize_version, "original_{}".format(args.num_down)
88 | )
89 | visualize_rec = os.path.join(
90 | visualize_dir, visualize_version, "rec_{}".format(args.num_down)
91 | )
92 | if not os.path.exists(visualize_original):
93 | os.makedirs(visualize_original, exist_ok=True)
94 |
95 | if not os.path.exists(visualize_rec):
96 | os.makedirs(visualize_rec, exist_ok=True)
97 |
98 | img_f = args.image_file
99 | idx = os.path.basename(img_f)[:-4] + "_constructed"
100 | image_raw = Image.open(img_f)
101 | image = np.array(image_raw)
102 | print(f"original image size: {image.shape}")
103 | images_tensor = get_image_tensor_for_encoder(image)
104 | images_tensor = images_tensor.float().to(DEVICE)
105 | print(f"images: {images_tensor.shape}")
106 |
107 | quant, embedding, codebook_indices = model.encode(images_tensor)
108 | print(f"quant: {quant.shape}")
109 | print(f"embedding: {embedding.shape}")
110 | print(f"codebook_indices: {codebook_indices.shape}")
111 | reconstructed_images = model.decode(quant)
112 |
113 | image = images_tensor[0]
114 | reconstructed_image = reconstructed_images[0]
115 |
116 | image = custom_to_pil(image)
117 | reconstructed_image = custom_to_pil(reconstructed_image)
118 | reconstructed_image.resize((image_raw.width, image_raw.height))
119 |
120 | image.save(os.path.join(visualize_original, "{}.png".format(idx)))
121 | reconstructed_image.save(os.path.join(visualize_rec, "{}.png".format(idx)))
122 |
123 |
124 | def get_args():
125 | parser = argparse.ArgumentParser(description="inference parameters")
126 | parser.add_argument("--ckpt_path", required=True, type=str)
127 | parser.add_argument("--num_down", default=3, type=int)
128 | parser.add_argument("--batch_size", default=1, type=int)
129 | parser.add_argument("--image_file", default="images/a.jpg", type=str)
130 | parser.add_argument("--subset", default=None)
131 | parser.add_argument("--tokenizer", default="magvit2")
132 |
133 | return parser.parse_args()
134 |
135 |
136 | if __name__ == "__main__":
137 | args = get_args()
138 | main(args)
139 |
--------------------------------------------------------------------------------
/upload_pypi.sh:
--------------------------------------------------------------------------------
1 | ##
2 | ## Copyright (c) 2020 JinTian.
3 | ##
4 | ## This file is part of alfred
5 | ## (see http://jinfagang.github.io).
6 | ##
7 | ## Licensed to the Apache Software Foundation (ASF) under one
8 | ## or more contributor license agreements. See the NOTICE file
9 | ## distributed with this work for additional information
10 | ## regarding copyright ownership. The ASF licenses this file
11 | ## to you under the Apache License, Version 2.0 (the
12 | ## "License"); you may not use this file except in compliance
13 | ## with the License. You may obtain a copy of the License at
14 | ##
15 | ## http://www.apache.org/licenses/LICENSE-2.0
16 | ##
17 | ## Unless required by applicable law or agreed to in writing,
18 | ## software distributed under the License is distributed on an
19 | ## "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
20 | ## KIND, either express or implied. See the License for the
21 | ## specific language governing permissions and limitations
22 | ## under the License.
23 | ##
24 | # check setup is correct or not
25 | python3 setup.py check
26 |
27 | # bumpver update --patch
28 |
29 | sudo rm -r build/
30 | sudo rm -r dist/
31 |
32 | # pypi interface are not valid any longer
33 | # python3 setup.py sdist
34 | # python3 setup.py sdist upload -r pypi
35 |
36 | # using twine instead
37 | python3 setup.py sdist
38 | twine upload dist/*
39 |
40 |
--------------------------------------------------------------------------------