├── .dir-locals.el
├── .flake8
├── .gitignore
├── LICENSE
├── README.md
├── _imgs
├── cub.png
├── iwae.png
├── mnist-svhn.png
├── obj.png
├── schematic.png
└── simple.png
├── bin
└── make-mnist-svhn-idx.py
├── data
└── cub
│ ├── text_testclasses.txt
│ └── text_trainvalclasses.txt
├── requirements.txt
└── src
├── datasets.py
├── main.py
├── models
├── __init__.py
├── mmvae.py
├── mmvae_cub_images_sentences.py
├── mmvae_cub_images_sentences_ft.py
├── mmvae_mnist_svhn.py
├── vae.py
├── vae_cub_image.py
├── vae_cub_image_ft.py
├── vae_cub_sent.py
├── vae_cub_sent_ft.py
├── vae_mnist.py
└── vae_svhn.py
├── objectives.py
├── report
├── analyse_cub.py
├── analyse_ms.py
├── calculate_likelihoods.py
└── helper.py
├── utils.py
└── vis.py
/.dir-locals.el:
--------------------------------------------------------------------------------
1 | ((lua-mode . ((lua-indent-level . 2)))
2 | (python-mode . ((tab-width . 2))))
3 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | # https://github.com/pytorch/pytorch/blob/d0db624e02951c4dd6eb6b21d051f7ccf8133707/setup.cfg
2 | [flake8]
3 | max-line-length = 120
4 | ignore = E302,E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841,W503
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .*
2 | **/*~
3 | **/_*
4 | **/auto
5 | **/*.aux
6 | **/*.bbl
7 | **/*.blg
8 | **/*.log
9 | **/*.pdf
10 | **/*.out
11 | **/*.old
12 | **/*.run.xml
13 | **/images/
14 | *.pyc
15 | **/__pycache__
16 | !__init__.py
17 | !_imgs
18 |
19 | data/
20 | experiments/**/
21 | /.bash_history
22 |
23 | bin/*.sh
24 | bin/*.png
25 | bin/face_extract_vgg/
26 |
27 | doc/
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Multimodal Mixture-of-Experts VAE
4 | This repository contains the code for the framework in **Variational Mixture-of-Experts Autoencodersfor Multi-Modal Deep Generative Models** (see [paper](https://arxiv.org/pdf/1911.03393.pdf)).
5 |
6 | ## Requirements
7 | List of packages we used and the version we tested the model on (see also `requirements.txt`)
8 |
9 | ```
10 | python == 3.6.8
11 | gensim == 3.8.1
12 | matplotlib == 3.1.1
13 | nltk == 3.4.5
14 | numpy == 1.16.4
15 | pandas == 0.25.3
16 | scipy == 1.3.2
17 | seaborn == 0.9.0
18 | scikit-image == 0.15.0
19 | torch == 1.3.1
20 | torchnet == 0.0.4
21 | torchvision == 0.4.2
22 | umap-learn == 0.1.1
23 | ```
24 |
25 | ## Downloads
26 | ### MNIST-SVHN Dataset
27 |
28 | 
29 |
30 | We construct a dataset of pairs of MNIST and SVHN such that each pair depicts the same digit class. Each instance of a digit class in either dataset is randomly paired with 20 instances of the same digit class from the other dataset.
31 |
32 | **Usage**: To prepare this dataset, run `bin/make-mnist-svhn-idx.py` -- this should automatically handle the download and pairing.
33 |
34 | ### CUB Image-Caption
35 |
36 | 
37 |
38 | We use Caltech-UCSD Birds (CUB) dataset, with the bird images and their captions serving as two modalities.
39 |
40 | **Usage**: We offer a cleaned-up version of the CUB dataset. Download the dataset [here](http://www.robots.ox.ac.uk/~yshi/mmdgm/datasets/cub.zip). First, create a `data` folder under the project directory; then unzip thedownloaded content into `data`. After finishing these steps, the structure of the `data/cub` folder should look like:
41 |
42 | ```
43 | data/cub
44 | │───text_testclasses.txt
45 | │───text_trainvalclasses.txt
46 | │───train
47 | │ │───002.Laysan_Albatross
48 | │ │ └───...jpg
49 | │ │───003.Sooty_Albatross
50 | │ │ └───...jpg
51 | │ │───...
52 | │ └───200.Common_Yellowthroat
53 | │ └───...jpg
54 | └───test
55 | │───001.Black_footed_Albatross
56 | │ └───...jpg
57 | │───004.Groove_billed_Ani
58 | │ └───...jpg
59 | │───...
60 | └───197.Marsh_Wren
61 | └───...jpg
62 | ```
63 |
64 |
65 | ### Pretrained network
66 | Pretrained models are also available if you want to play around with it. Download from the following links:
67 | - [MNIST-SVHN](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/mnist-svhn.zip)
68 | - [CUB Image-Caption (feature)](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubISft.zip)
69 | - [CUB Image-Caption (raw images)](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubIS.zip)
70 |
71 | ## Usage
72 |
73 | ### Training
74 |
75 | Make sure the [requirements](#requirements) are satisfied in your environment, and relevant [datasets](#downloads) are downloaded. `cd` into `src`, and, for MNIST-SVHN experiments, run
76 |
77 | ```bash
78 | python main.py --model mnist_svhn
79 |
80 | ```
81 |
82 | For CUB Image-Caption with image feature search (See Figure 7 in our [paper](https://arxiv.org/pdf/1911.03393.pdf)), run
83 | ```bash
84 | python main.py --model cubISft
85 |
86 | ```
87 |
88 | For CUB Image-Caption with raw image generation, run
89 | ```bash
90 | python main.py --model cubIS
91 |
92 | ```
93 |
94 | You can also play with the hyperparameters using arguments. Some of the more interesting ones are listed as follows:
95 | - **`--obj`**: Objective functions, offers 3 choices including importance-sampled ELBO (`elbo`), IWAE (`iwae`) and DReG (`dreg`, used in paper). Including the `--looser` flag when using IWAE or DReG removes unbalanced weighting of modalities, which we find to perform better empirically;
96 | - **`--K`**: Number of particles, controls the number of particles `K` in IWAE/DReG estimator, as specified in following equation:
97 |
98 | 
99 |
100 | - **`--learn-prior`**: Prior variance learning, controls whether to enable prior variance learning. Results in our paper are produced with this enabled. Excluding this argument in the command will disable this option;
101 | - **`--llik_scaling`**: Likelihood scaling, specifies the likelihood scaling of one of the two modalities, so that the likelihoods of two modalities contribute similarly to the lower bound. The default values are:
102 | - _MNIST-SVHN_: MNIST scaling factor 32*32*3/28*28*1 = 3.92
103 | - _CUB Image-Cpation_: Image scaling factor 32/64*64*3 = 0.0026
104 | - **`--latent-dimension`**: Latent dimension
105 |
106 | You can also load from pre-trained models by specifying the path to the model folder, for example `python --model mnist_svhn --pre-trained path/to/model/folder/`. See following for the flag we used for these pretrained models:
107 | - **MNIST-SVHN**: `--model mnist_svhn --obj dreg --K 30 --learn-prior --looser --epochs 30 --batch-size 128 --latent-dim 20`
108 | - **CUB Image-Caption (feature)**: `--model cubISft --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64 --llik_scaling 0.002`
109 | - **CUB Image-Caption (raw images)**: `--model cubIS --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64`
110 |
111 | ### Analysing
112 | We offer tools to reproduce the quantitative results in our paper in `src/report`. To run any of the provided scripts, `cd` into `src`, and
113 |
114 | - for likelihood estimation of data using a trained model, run `python calculate_likelihoods.py --save-dir path/to/trained/model/folder/ --iwae-samples 1000`;
115 | - for coherence analysis and latent digit classification accuracy on MNIST-SVHN dataset, run `python analyse_ms.py --save-dir path/to/trained/model/folder/`;
116 | - for coherence analysis on CUB image-caption dataset, run `python analyse_cub.py --save-dir path/to/trained/model/folder/`.
117 | - _**Note**_: The learnt CCA projection matrix and FastText embeddings can vary quite a bit due to the limited dataset size, therefore re-computing them as part of the analyses can result in different numeric values including for the baseline. **The relative performance of our model against the baseline remains the same, just that the numbers can different.**
118 | To produce similar results to what's reported in our paper, download the zip file [here](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/CCA_emb.zip) and do the following:
119 | 1. Move `cub.all`, `cub.emb`, `cub.pc` to under `data/cub/oc:3_sl:32_s:300_w:3/`;
120 | 2. Move the rest of the files, i.e. `emb_mean.pt`, `emb_proj.pt`, `images_mean.pt`, `im_proj.pt` to `path/to/trained/model/folder/`;
121 | 3. Set the `RESET` variable in `src/report/analyse_cub.py` to `False`.
122 |
123 |
124 | ## Contact
125 | If you have any questions, feel free to create an issue or email Yuge Shi at yshi@robots.ox.ac.uk.
126 |
--------------------------------------------------------------------------------
/_imgs/cub.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/cub.png
--------------------------------------------------------------------------------
/_imgs/iwae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/iwae.png
--------------------------------------------------------------------------------
/_imgs/mnist-svhn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/mnist-svhn.png
--------------------------------------------------------------------------------
/_imgs/obj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/obj.png
--------------------------------------------------------------------------------
/_imgs/schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/schematic.png
--------------------------------------------------------------------------------
/_imgs/simple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/simple.png
--------------------------------------------------------------------------------
/bin/make-mnist-svhn-idx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 | def rand_match_on_idx(l1, idx1, l2, idx2, max_d=10000, dm=10):
5 | """
6 | l*: sorted labels
7 | idx*: indices of sorted labels in original list
8 | """
9 | _idx1, _idx2 = [], []
10 | for l in l1.unique(): # assuming both have same idxs
11 | l_idx1, l_idx2 = idx1[l1 == l], idx2[l2 == l]
12 | n = min(l_idx1.size(0), l_idx2.size(0), max_d)
13 | l_idx1, l_idx2 = l_idx1[:n], l_idx2[:n]
14 | for _ in range(dm):
15 | _idx1.append(l_idx1[torch.randperm(n)])
16 | _idx2.append(l_idx2[torch.randperm(n)])
17 | return torch.cat(_idx1), torch.cat(_idx2)
18 |
19 | if __name__ == '__main__':
20 | max_d = 10000 # maximum number of datapoints per class
21 | dm = 30 # data multiplier: random permutations to match
22 |
23 | # get the individual datasets
24 | tx = transforms.ToTensor()
25 | train_mnist = datasets.MNIST('../data', train=True, download=True, transform=tx)
26 | test_mnist = datasets.MNIST('../data', train=False, download=True, transform=tx)
27 | train_svhn = datasets.SVHN('../data', split='train', download=True, transform=tx)
28 | test_svhn = datasets.SVHN('../data', split='test', download=True, transform=tx)
29 | # svhn labels need extra work
30 | train_svhn.labels = torch.LongTensor(train_svhn.labels.squeeze().astype(int)) % 10
31 | test_svhn.labels = torch.LongTensor(test_svhn.labels.squeeze().astype(int)) % 10
32 |
33 | mnist_l, mnist_li = train_mnist.targets.sort()
34 | svhn_l, svhn_li = train_svhn.labels.sort()
35 | idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm)
36 | print('len train idx:', len(idx1), len(idx2))
37 | torch.save(idx1, '../data/train-ms-mnist-idx.pt')
38 | torch.save(idx2, '../data/train-ms-svhn-idx.pt')
39 |
40 | mnist_l, mnist_li = test_mnist.targets.sort()
41 | svhn_l, svhn_li = test_svhn.labels.sort()
42 | idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm)
43 | print('len test idx:', len(idx1), len(idx2))
44 | torch.save(idx1, '../data/test-ms-mnist-idx.pt')
45 | torch.save(idx2, '../data/test-ms-svhn-idx.pt')
46 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python == 3.6.8
2 | gensim == 3.8.1
3 | matplotlib == 3.1.1
4 | nltk == 3.4.5
5 | numpy == 1.16.4
6 | pandas == 0.25.3
7 | scipy == 1.3.2
8 | seaborn == 0.9.0
9 | scikit-image == 0.15.0
10 | torch == 1.3.1
11 | torchnet == 0.0.4
12 | torchvision == 0.4.2
13 | umap-learn == 0.1.1
14 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | import os
4 | import pickle
5 | from collections import Counter, OrderedDict
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from nltk.tokenize import sent_tokenize, word_tokenize
12 | from torch.utils.data import Dataset
13 | from torchvision import transforms, models, datasets
14 |
15 |
16 | class OrderedCounter(Counter, OrderedDict):
17 | """Counter that remembers the order elements are first encountered."""
18 |
19 | def __repr__(self):
20 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
21 |
22 | def __reduce__(self):
23 | return self.__class__, (OrderedDict(self),)
24 |
25 |
26 | class CUBSentences(Dataset):
27 |
28 | def __init__(self, root_data_dir, split, transform=None, **kwargs):
29 | """split: 'trainval' or 'test' """
30 |
31 | super().__init__()
32 | self.data_dir = os.path.join(root_data_dir, 'cub')
33 | self.split = split
34 | self.max_sequence_length = kwargs.get('max_sequence_length', 32)
35 | self.min_occ = kwargs.get('min_occ', 3)
36 | self.transform = transform
37 | os.makedirs(os.path.join(root_data_dir, "lang_emb"), exist_ok=True)
38 |
39 | self.gen_dir = os.path.join(self.data_dir, "oc:{}_msl:{}".
40 | format(self.min_occ, self.max_sequence_length))
41 |
42 | if split == 'train':
43 | self.raw_data_path = os.path.join(self.data_dir, 'text_trainvalclasses.txt')
44 | elif split == 'test':
45 | self.raw_data_path = os.path.join(self.data_dir, 'text_testclasses.txt')
46 | else:
47 | raise Exception("Only train or test split is available")
48 |
49 | os.makedirs(self.gen_dir, exist_ok=True)
50 | self.data_file = 'cub.{}.s{}'.format(split, self.max_sequence_length)
51 | self.vocab_file = 'cub.vocab'
52 |
53 | if not os.path.exists(os.path.join(self.gen_dir, self.data_file)):
54 | print("Data file not found for {} split at {}. Creating new... (this may take a while)".
55 | format(split.upper(), os.path.join(self.gen_dir, self.data_file)))
56 | self._create_data()
57 |
58 | else:
59 | self._load_data()
60 |
61 | def __len__(self):
62 | return len(self.data)
63 |
64 | def __getitem__(self, idx):
65 | sent = self.data[str(idx)]['idx']
66 | if self.transform is not None:
67 | sent = self.transform(sent)
68 | return sent, self.data[str(idx)]['length']
69 |
70 | @property
71 | def vocab_size(self):
72 | return len(self.w2i)
73 |
74 | @property
75 | def pad_idx(self):
76 | return self.w2i['']
77 |
78 | @property
79 | def eos_idx(self):
80 | return self.w2i['']
81 |
82 | @property
83 | def unk_idx(self):
84 | return self.w2i['']
85 |
86 | def get_w2i(self):
87 | return self.w2i
88 |
89 | def get_i2w(self):
90 | return self.i2w
91 |
92 | def _load_data(self, vocab=True):
93 | with open(os.path.join(self.gen_dir, self.data_file), 'rb') as file:
94 | self.data = json.load(file)
95 |
96 | if vocab:
97 | self._load_vocab()
98 |
99 | def _load_vocab(self):
100 | if not os.path.exists(os.path.join(self.gen_dir, self.vocab_file)):
101 | self._create_vocab()
102 | with open(os.path.join(self.gen_dir, self.vocab_file), 'r') as vocab_file:
103 | vocab = json.load(vocab_file)
104 | self.w2i, self.i2w = vocab['w2i'], vocab['i2w']
105 |
106 | def _create_data(self):
107 | if self.split == 'train' and not os.path.exists(os.path.join(self.gen_dir, self.vocab_file)):
108 | self._create_vocab()
109 | else:
110 | self._load_vocab()
111 |
112 | with open(self.raw_data_path, 'r') as file:
113 | text = file.read()
114 | sentences = sent_tokenize(text)
115 |
116 | data = defaultdict(dict)
117 | pad_count = 0
118 |
119 | for i, line in enumerate(sentences):
120 | words = word_tokenize(line)
121 |
122 | tok = words[:self.max_sequence_length - 1]
123 | tok = tok + ['']
124 | length = len(tok)
125 | if self.max_sequence_length > length:
126 | tok.extend([''] * (self.max_sequence_length - length))
127 | pad_count += 1
128 | idx = [self.w2i.get(w, self.w2i['']) for w in tok]
129 |
130 | id = len(data)
131 | data[id]['tok'] = tok
132 | data[id]['idx'] = idx
133 | data[id]['length'] = length
134 |
135 | print("{} out of {} sentences are truncated with max sentence length {}.".
136 | format(len(sentences) - pad_count, len(sentences), self.max_sequence_length))
137 | with io.open(os.path.join(self.gen_dir, self.data_file), 'wb') as data_file:
138 | data = json.dumps(data, ensure_ascii=False)
139 | data_file.write(data.encode('utf8', 'replace'))
140 |
141 | self._load_data(vocab=False)
142 |
143 | def _create_vocab(self):
144 |
145 | assert self.split == 'train', "Vocablurary can only be created for training file."
146 |
147 | with open(self.raw_data_path, 'r') as file:
148 | text = file.read()
149 | sentences = sent_tokenize(text)
150 |
151 | occ_register = OrderedCounter()
152 | w2i = dict()
153 | i2w = dict()
154 |
155 | special_tokens = ['', '', '']
156 | for st in special_tokens:
157 | i2w[len(w2i)] = st
158 | w2i[st] = len(w2i)
159 |
160 | texts = []
161 | unq_words = []
162 |
163 | for i, line in enumerate(sentences):
164 | words = word_tokenize(line)
165 | occ_register.update(words)
166 | texts.append(words)
167 |
168 | for w, occ in occ_register.items():
169 | if occ > self.min_occ and w not in special_tokens:
170 | i2w[len(w2i)] = w
171 | w2i[w] = len(w2i)
172 | else:
173 | unq_words.append(w)
174 |
175 | assert len(w2i) == len(i2w)
176 |
177 | print("Vocablurary of {} keys created, {} words are excluded (occurrence <= {})."
178 | .format(len(w2i), len(unq_words), self.min_occ))
179 |
180 | vocab = dict(w2i=w2i, i2w=i2w)
181 | with io.open(os.path.join(self.gen_dir, self.vocab_file), 'wb') as vocab_file:
182 | data = json.dumps(vocab, ensure_ascii=False)
183 | vocab_file.write(data.encode('utf8', 'replace'))
184 |
185 | with open(os.path.join(self.gen_dir, 'cub.unique'), 'wb') as unq_file:
186 | pickle.dump(np.array(unq_words), unq_file)
187 |
188 | with open(os.path.join(self.gen_dir, 'cub.all'), 'wb') as a_file:
189 | pickle.dump(occ_register, a_file)
190 |
191 | self._load_vocab()
192 |
193 |
194 | class CUBImageFt(Dataset):
195 | def __init__(self, root_data_dir, split, device):
196 | """split: 'trainval' or 'test' """
197 |
198 | super().__init__()
199 | self.data_dir = os.path.join(root_data_dir, 'cub')
200 | self.data_file = os.path.join(self.data_dir, split)
201 | self.gen_dir = os.path.join(self.data_dir, 'resnet101_2048')
202 | self.gen_ft_file = os.path.join(self.gen_dir, '{}.ft'.format(split))
203 | self.gen_data_file = os.path.join(self.gen_dir, '{}.data'.format(split))
204 | self.split = split
205 |
206 | tx = transforms.Compose([
207 | transforms.Resize(224),
208 | transforms.ToTensor()
209 | ])
210 | self.dataset = datasets.ImageFolder(self.data_file, transform=tx)
211 |
212 | os.makedirs(self.gen_dir, exist_ok=True)
213 | if not os.path.exists(self.gen_ft_file):
214 | print("Data file not found for CUB image features at `{}`. "
215 | "Extracting resnet101 features from CUB image dataset... "
216 | "(this may take a while)".format(self.gen_ft_file))
217 | self._create_ft_mat(device)
218 |
219 | else:
220 | self._load_ft_mat()
221 |
222 | def __len__(self):
223 | return len(self.ft_mat)
224 |
225 | def __getitem__(self, idx):
226 | return self.ft_mat[idx]
227 |
228 | def _load_ft_mat(self):
229 | self.ft_mat = torch.load(self.gen_ft_file)
230 |
231 | def _load_data(self):
232 | self.data_mat = torch.load(self.gen_data_file)
233 |
234 | def _create_ft_mat(self, device):
235 | resnet = models.resnet101(pretrained=True)
236 | modules = list(resnet.children())[:-1]
237 | self.model = nn.Sequential(*modules)
238 | self.model.eval()
239 |
240 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
241 |
242 | loader = torch.utils.data.DataLoader(self.dataset, batch_size=256,
243 | shuffle=False, **kwargs)
244 | with torch.no_grad():
245 | ft_mat = torch.cat([self.model(data[0]).squeeze() for data in loader])
246 |
247 | torch.save(ft_mat, self.gen_ft_file)
248 | del ft_mat
249 |
250 | data_mat = torch.cat([data[0].squeeze() for data in loader])
251 | torch.save(data_mat, self.gen_data_file)
252 |
253 | self._load_ft_mat()
254 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import sys
4 | import json
5 | from collections import defaultdict
6 | from pathlib import Path
7 | from tempfile import mkdtemp
8 |
9 | import numpy as np
10 | import torch
11 | from torch import optim
12 |
13 | import models
14 | import objectives
15 | from utils import Logger, Timer, save_model, save_vars, unpack_data
16 |
17 | parser = argparse.ArgumentParser(description='Multi-Modal VAEs')
18 | parser.add_argument('--experiment', type=str, default='', metavar='E',
19 | help='experiment name')
20 | parser.add_argument('--model', type=str, default='mnist_svhn', metavar='M',
21 | choices=[s[4:] for s in dir(models) if 'VAE_' in s],
22 | help='model name (default: mnist_svhn)')
23 | parser.add_argument('--obj', type=str, default='elbo', metavar='O',
24 | choices=['elbo', 'iwae', 'dreg'],
25 | help='objective to use (default: elbo)')
26 | parser.add_argument('--K', type=int, default=20, metavar='K',
27 | help='number of particles to use for iwae/dreg (default: 10)')
28 | parser.add_argument('--looser', action='store_true', default=False,
29 | help='use the looser version of IWAE/DREG')
30 | parser.add_argument('--llik_scaling', type=float, default=0.,
31 | help='likelihood scaling for cub images/svhn modality when running in'
32 | 'multimodal setting, set as 0 to use default value')
33 | parser.add_argument('--batch-size', type=int, default=256, metavar='N',
34 | help='batch size for data (default: 256)')
35 | parser.add_argument('--epochs', type=int, default=10, metavar='E',
36 | help='number of epochs to train (default: 10)')
37 | parser.add_argument('--latent-dim', type=int, default=20, metavar='L',
38 | help='latent dimensionality (default: 20)')
39 | parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H',
40 | help='number of hidden layers in enc and dec (default: 1)')
41 | parser.add_argument('--pre-trained', type=str, default="",
42 | help='path to pre-trained model (train from scratch if empty)')
43 | parser.add_argument('--learn-prior', action='store_true', default=False,
44 | help='learn model prior parameters')
45 | parser.add_argument('--logp', action='store_true', default=False,
46 | help='estimate tight marginal likelihood on completion')
47 | parser.add_argument('--print-freq', type=int, default=0, metavar='f',
48 | help='frequency with which to print stats (default: 0)')
49 | parser.add_argument('--no-analytics', action='store_true', default=False,
50 | help='disable plotting analytics')
51 | parser.add_argument('--no-cuda', action='store_true', default=False,
52 | help='disable CUDA use')
53 | parser.add_argument('--seed', type=int, default=1, metavar='S',
54 | help='random seed (default: 1)')
55 |
56 | # args
57 | args = parser.parse_args()
58 |
59 | # random seed
60 | # https://pytorch.org/docs/stable/notes/randomness.html
61 | torch.backends.cudnn.benchmark = True
62 | torch.manual_seed(args.seed)
63 | np.random.seed(args.seed)
64 |
65 | # load args from disk if pretrained model path is given
66 | pretrained_path = ""
67 | if args.pre_trained:
68 | pretrained_path = args.pre_trained
69 | args = torch.load(args.pre_trained + '/args.rar')
70 |
71 | args.cuda = not args.no_cuda and torch.cuda.is_available()
72 | device = torch.device("cuda" if args.cuda else "cpu")
73 |
74 | # load model
75 | modelC = getattr(models, 'VAE_{}'.format(args.model))
76 | model = modelC(args).to(device)
77 |
78 | if pretrained_path:
79 | print('Loading model {} from {}'.format(model.modelName, pretrained_path))
80 | model.load_state_dict(torch.load(pretrained_path + '/model.rar'))
81 | model._pz_params = model._pz_params
82 |
83 | if not args.experiment:
84 | args.experiment = model.modelName
85 |
86 | # set up run path
87 | runId = datetime.datetime.now().isoformat()
88 | experiment_dir = Path('../experiments/' + args.experiment)
89 | experiment_dir.mkdir(parents=True, exist_ok=True)
90 | runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
91 | sys.stdout = Logger('{}/run.log'.format(runPath))
92 | print('Expt:', runPath)
93 | print('RunID:', runId)
94 |
95 | # save args to run
96 | with open('{}/args.json'.format(runPath), 'w') as fp:
97 | json.dump(args.__dict__, fp)
98 | # -- also save object because we want to recover these for other things
99 | torch.save(args, '{}/args.rar'.format(runPath))
100 |
101 | # preparation for training
102 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
103 | lr=1e-3, amsgrad=True)
104 | train_loader, test_loader = model.getDataLoaders(args.batch_size, device=device)
105 | objective = getattr(objectives,
106 | ('m_' if hasattr(model, 'vaes') else '')
107 | + args.obj
108 | + ('_looser' if (args.looser and args.obj != 'elbo') else ''))
109 | t_objective = getattr(objectives, ('m_' if hasattr(model, 'vaes') else '') + 'iwae')
110 |
111 |
112 | def train(epoch, agg):
113 | model.train()
114 | b_loss = 0
115 | for i, dataT in enumerate(train_loader):
116 | data = unpack_data(dataT, device=device)
117 | optimizer.zero_grad()
118 | loss = -objective(model, data, K=args.K)
119 | loss.backward()
120 | optimizer.step()
121 | b_loss += loss.item()
122 | if args.print_freq > 0 and i % args.print_freq == 0:
123 | print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
124 | agg['train_loss'].append(b_loss / len(train_loader.dataset))
125 | print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))
126 |
127 |
128 | def test(epoch, agg):
129 | model.eval()
130 | b_loss = 0
131 | with torch.no_grad():
132 | for i, dataT in enumerate(test_loader):
133 | data = unpack_data(dataT, device=device)
134 | loss = -t_objective(model, data, K=args.K)
135 | b_loss += loss.item()
136 | if i == 0:
137 | model.reconstruct(data, runPath, epoch)
138 | if not args.no_analytics:
139 | model.analyse(data, runPath, epoch)
140 | agg['test_loss'].append(b_loss / len(test_loader.dataset))
141 | print('====> Test loss: {:.4f}'.format(agg['test_loss'][-1]))
142 |
143 |
144 | def estimate_log_marginal(K):
145 | """Compute an IWAE estimate of the log-marginal likelihood of test data."""
146 | model.eval()
147 | marginal_loglik = 0
148 | with torch.no_grad():
149 | for dataT in test_loader:
150 | data = unpack_data(dataT, device=device)
151 | marginal_loglik += -t_objective(model, data, K).item()
152 |
153 | marginal_loglik /= len(test_loader.dataset)
154 | print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format(K, marginal_loglik))
155 |
156 |
157 | if __name__ == '__main__':
158 | with Timer('MM-VAE') as t:
159 | agg = defaultdict(list)
160 | for epoch in range(1, args.epochs + 1):
161 | train(epoch, agg)
162 | test(epoch, agg)
163 | save_model(model, runPath + '/model.rar')
164 | save_vars(agg, runPath + '/losses.rar')
165 | model.generate(runPath, epoch)
166 | if args.logp: # compute as tight a marginal likelihood as possible
167 | estimate_log_marginal(5000)
168 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mmvae_cub_images_sentences import CUB_Image_Sentence as VAE_cubIS
2 | from .mmvae_cub_images_sentences_ft import CUB_Image_Sentence_ft as VAE_cubISft
3 | from .mmvae_mnist_svhn import MNIST_SVHN as VAE_mnist_svhn
4 | from .vae_cub_image import CUB_Image as VAE_cubI
5 | from .vae_cub_image_ft import CUB_Image_ft as VAE_cubIft
6 | from .vae_cub_sent import CUB_Sentence as VAE_cubS
7 | from .vae_mnist import MNIST as VAE_mnist
8 | from .vae_svhn import SVHN as VAE_svhn
9 |
10 | __all__ = [VAE_mnist_svhn, VAE_mnist, VAE_svhn, VAE_cubIS, VAE_cubS,
11 | VAE_cubI, VAE_cubISft, VAE_cubIft]
12 |
--------------------------------------------------------------------------------
/src/models/mmvae.py:
--------------------------------------------------------------------------------
1 | # Base MMVAE class definition
2 |
3 | from itertools import combinations
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from utils import get_mean, kl_divergence
9 | from vis import embed_umap, tensors_to_df
10 |
11 |
12 | class MMVAE(nn.Module):
13 | def __init__(self, prior_dist, params, *vaes):
14 | super(MMVAE, self).__init__()
15 | self.pz = prior_dist
16 | self.vaes = nn.ModuleList([vae(params) for vae in vaes])
17 | self.modelName = None # filled-in per sub-class
18 | self.params = params
19 | self._pz_params = None # defined in subclass
20 |
21 | @property
22 | def pz_params(self):
23 | return self._pz_params
24 |
25 | @staticmethod
26 | def getDataLoaders(batch_size, shuffle=True, device="cuda"):
27 | # handle merging individual datasets appropriately in sub-class
28 | raise NotImplementedError
29 |
30 | def forward(self, x, K=1):
31 | qz_xs, zss = [], []
32 | # initialise cross-modal matrix
33 | px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
34 | for m, vae in enumerate(self.vaes):
35 | qz_x, px_z, zs = vae(x[m], K=K)
36 | qz_xs.append(qz_x)
37 | zss.append(zs)
38 | px_zs[m][m] = px_z # fill-in diagonal
39 | for e, zs in enumerate(zss):
40 | for d, vae in enumerate(self.vaes):
41 | if e != d: # fill-in off-diagonal
42 | px_zs[e][d] = vae.px_z(*vae.dec(zs))
43 | return qz_xs, px_zs, zss
44 |
45 | def generate(self, N):
46 | self.eval()
47 | with torch.no_grad():
48 | data = []
49 | pz = self.pz(*self.pz_params)
50 | latents = pz.rsample(torch.Size([N]))
51 | for d, vae in enumerate(self.vaes):
52 | px_z = vae.px_z(*vae.dec(latents))
53 | data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
54 | return data # list of generations---one for each modality
55 |
56 | def reconstruct(self, data):
57 | self.eval()
58 | with torch.no_grad():
59 | _, px_zs, _ = self.forward(data)
60 | # cross-modal matrix of reconstructions
61 | recons = [[get_mean(px_z) for px_z in r] for r in px_zs]
62 | return recons
63 |
64 | def analyse(self, data, K):
65 | self.eval()
66 | with torch.no_grad():
67 | qz_xs, _, zss = self.forward(data, K=K)
68 | pz = self.pz(*self.pz_params)
69 | zss = [pz.sample(torch.Size([K, data[0].size(0)])).view(-1, pz.batch_shape[-1]),
70 | *[zs.view(-1, zs.size(-1)) for zs in zss]]
71 | zsl = [torch.zeros(zs.size(0)).fill_(i) for i, zs in enumerate(zss)]
72 | kls_df = tensors_to_df(
73 | [*[kl_divergence(qz_x, pz).cpu().numpy() for qz_x in qz_xs],
74 | *[0.5 * (kl_divergence(p, q) + kl_divergence(q, p)).cpu().numpy()
75 | for p, q in combinations(qz_xs, 2)]],
76 | head='KL',
77 | keys=[*[r'KL$(q(z|x_{})\,||\,p(z))$'.format(i) for i in range(len(qz_xs))],
78 | *[r'J$(q(z|x_{})\,||\,q(z|x_{}))$'.format(i, j)
79 | for i, j in combinations(range(len(qz_xs)), 2)]],
80 | ax_names=['Dimensions', r'KL$(q\,||\,p)$']
81 | )
82 | return embed_umap(torch.cat(zss, 0).cpu().numpy()), \
83 | torch.cat(zsl, 0).cpu().numpy(), \
84 | kls_df
85 |
--------------------------------------------------------------------------------
/src/models/mmvae_cub_images_sentences.py:
--------------------------------------------------------------------------------
1 | # cub multi-modal model specification
2 | import matplotlib.pyplot as plt
3 | import torch.distributions as dist
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.utils.data
7 | from numpy import sqrt, prod
8 | from torch.utils.data import DataLoader
9 | from torchnet.dataset import TensorDataset, ResampleDataset
10 | from torchvision.utils import save_image, make_grid
11 |
12 | from utils import Constants
13 | from vis import plot_embeddings, plot_kls_df
14 | from .mmvae import MMVAE
15 | from .vae_cub_image import CUB_Image
16 | from .vae_cub_sent import CUB_Sentence
17 |
18 | # Constants
19 | maxSentLen = 32
20 | minOccur = 3
21 |
22 |
23 | # This is required because there are 10 captions per image.
24 | # Allows easier reuse of the same image for the corresponding set of captions.
25 | def resampler(dataset, idx):
26 | return idx // 10
27 |
28 |
29 | class CUB_Image_Sentence(MMVAE):
30 |
31 | def __init__(self, params):
32 | super(CUB_Image_Sentence, self).__init__(dist.Laplace, params, CUB_Image, CUB_Sentence)
33 | grad = {'requires_grad': params.learn_prior}
34 | self._pz_params = nn.ParameterList([
35 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
36 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
37 | ])
38 | self.vaes[0].llik_scaling = self.vaes[1].maxSentLen / prod(self.vaes[0].dataSize) \
39 | if params.llik_scaling == 0 else params.llik_scaling
40 |
41 | for vae in self.vaes:
42 | vae._pz_params = self._pz_params
43 | self.modelName = 'cubIS'
44 |
45 | self.i2w = self.vaes[1].load_vocab()
46 |
47 | @property
48 | def pz_params(self):
49 | return self._pz_params[0], \
50 | F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(1) + Constants.eta
51 |
52 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'):
53 | # load base datasets
54 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device)
55 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device)
56 |
57 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
58 | train_loader = DataLoader(TensorDataset([
59 | ResampleDataset(t1.dataset, resampler, size=len(t1.dataset) * 10),
60 | t2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
61 | test_loader = DataLoader(TensorDataset([
62 | ResampleDataset(s1.dataset, resampler, size=len(s1.dataset) * 10),
63 | s2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
64 | return train_loader, test_loader
65 |
66 | def generate(self, runPath, epoch):
67 | N = 8
68 | samples = super(CUB_Image_Sentence, self).generate(N)
69 | images, captions = [sample.data.cpu() for sample in samples]
70 | captions = self._sent_preprocess(captions)
71 | fig = plt.figure(figsize=(8, 6))
72 | for i, (image, caption) in enumerate(zip(images, captions)):
73 | fig = self._imshow(image, caption, i, fig, N)
74 |
75 | plt.savefig('{}/gen_samples_{:03d}.png'.format(runPath, epoch))
76 | plt.close()
77 |
78 | def reconstruct(self, raw_data, runPath, epoch):
79 | N = 8
80 | recons_mat = super(CUB_Image_Sentence, self).reconstruct([d[:N] for d in raw_data])
81 | fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)]
82 | for r, recons_list in enumerate(recons_mat):
83 | for o, recon in enumerate(recons_list):
84 | data = fns[r](raw_data[r][:N])
85 | recon = fns[o](recon.squeeze())
86 | if r != o:
87 | fig = plt.figure(figsize=(8, 6))
88 | for i, (_data, _recon) in enumerate(zip(data, recon)):
89 | image, caption = (_data, _recon) if r == 0 else (_recon, _data)
90 | fig = self._imshow(image, caption, i, fig, N)
91 | plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
92 | plt.close()
93 | else:
94 | if r == 0:
95 | comp = torch.cat([data, recon])
96 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
97 | else:
98 | with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
99 | for r_sent, d_sent in zip(recon, data):
100 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in d_sent)))
101 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(self.i2w[str(i)] for i in r_sent)))
102 |
103 | def analyse(self, data, runPath, epoch):
104 | zemb, zsl, kls_df = super(CUB_Image_Sentence, self).analyse(data, K=10)
105 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]]
106 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
107 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
108 |
109 | def _sent_preprocess(self, sentences):
110 | """make sure raw data is always passed as dim=2 to avoid argmax.
111 | last dimension must always be word embedding."""
112 | if len(sentences.shape) > 2:
113 | sentences = sentences.argmax(-1).squeeze()
114 | return [self.vaes[1].fn_trun(s) for s in self.vaes[1].fn_2i(sentences)]
115 |
116 | def _imshow(self, image, caption, i, fig, N):
117 | """Imshow for Tensor."""
118 | ax = fig.add_subplot(N // 2, 4, i * 2 + 1)
119 | ax.axis('off')
120 | image = image.numpy().transpose((1, 2, 0)) #
121 | plt.imshow(image)
122 | ax = fig.add_subplot(N // 2, 4, i * 2 + 2)
123 | pos = ax.get_position()
124 | ax.axis('off')
125 | plt.text(
126 | x=0.5 * (pos.x0 + pos.x1),
127 | y=0.5 * (pos.y0 + pos.y1),
128 | ha='left',
129 | s='{}'.format(
130 | ' '.join(self.i2w[str(i)] + '\n' if (n + 1) % 5 == 0
131 | else self.i2w[str(i)] for n, i in enumerate(caption))),
132 | fontsize=6,
133 | verticalalignment='center',
134 | horizontalalignment='center'
135 | )
136 | return fig
137 |
--------------------------------------------------------------------------------
/src/models/mmvae_cub_images_sentences_ft.py:
--------------------------------------------------------------------------------
1 | # cub multi-modal model specification
2 | import matplotlib.pyplot as plt
3 | import torch.distributions as dist
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.utils.data
7 | from numpy import sqrt, prod
8 | from torch.utils.data import DataLoader
9 | from torchnet.dataset import TensorDataset, ResampleDataset
10 | from torchvision.utils import save_image, make_grid
11 |
12 | from utils import Constants
13 | from vis import plot_embeddings, plot_kls_df
14 | from .mmvae import MMVAE
15 | from .vae_cub_image_ft import CUB_Image_ft
16 | from .vae_cub_sent_ft import CUB_Sentence_ft
17 |
18 | # Constants
19 | maxSentLen = 32
20 | minOccur = 3
21 |
22 |
23 | # This is required because there are 10 captions per image.
24 | # Allows easier reuse of the same image for the corresponding set of captions.
25 | def resampler(dataset, idx):
26 | return idx // 10
27 |
28 |
29 | class CUB_Image_Sentence_ft(MMVAE):
30 |
31 | def __init__(self, params):
32 | super(CUB_Image_Sentence_ft, self).__init__(dist.Normal, params, CUB_Image_ft, CUB_Sentence_ft)
33 | grad = {'requires_grad': params.learn_prior}
34 | self._pz_params = nn.ParameterList([
35 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
36 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
37 | ])
38 | self.vaes[0].llik_scaling = self.vaes[1].maxSentLen / prod(self.vaes[0].dataSize) \
39 | if params.llik_scaling == 0 else params.llik_scaling
40 |
41 | for vae in self.vaes:
42 | vae._pz_params = self._pz_params
43 | self.modelName = 'cubISft'
44 |
45 | self.i2w = self.vaes[1].load_vocab()
46 |
47 | @property
48 | def pz_params(self):
49 | return self._pz_params[0], \
50 | F.softplus(self._pz_params[1]) + Constants.eta
51 |
52 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'):
53 | # load base datasets
54 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device)
55 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device)
56 |
57 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
58 | train_loader = DataLoader(TensorDataset([
59 | ResampleDataset(t1.dataset, resampler, size=len(t1.dataset) * 10),
60 | t2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
61 | test_loader = DataLoader(TensorDataset([
62 | ResampleDataset(s1.dataset, resampler, size=len(s1.dataset) * 10),
63 | s2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
64 | return train_loader, test_loader
65 |
66 | def generate(self, runPath, epoch):
67 | N = 8
68 | samples = super(CUB_Image_Sentence_ft, self).generate(N)
69 | samples[0] = self.vaes[0].unproject(samples[0], search_split='train')
70 | images, captions = [sample.data.cpu() for sample in samples]
71 | captions = self._sent_preprocess(captions)
72 | fig = plt.figure(figsize=(8, 6))
73 | for i, (image, caption) in enumerate(zip(images, captions)):
74 | fig = self._imshow(image, caption, i, fig, N)
75 |
76 | plt.savefig('{}/gen_samples_{:03d}.png'.format(runPath, epoch))
77 | plt.close()
78 |
79 | def reconstruct(self, raw_data, runPath, epoch):
80 | N = 8
81 | recons_mat = super(CUB_Image_Sentence_ft, self).reconstruct([d[:N] for d in raw_data])
82 | fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)]
83 | for r, recons_list in enumerate(recons_mat):
84 | for o, recon in enumerate(recons_list):
85 | data = fns[r](raw_data[r][:N])
86 | recon = fns[o](recon.squeeze())
87 | if r != o:
88 | fig = plt.figure(figsize=(8, 6))
89 | for i, (_data, _recon) in enumerate(zip(data, recon)):
90 | image, caption = (_data, _recon) if r == 0 else (_recon, _data)
91 | search_split = 'test' if r == 0 else 'train'
92 | image = self.vaes[0].unproject(image.unsqueeze(0), search_split=search_split)
93 | fig = self._imshow(image, caption, i, fig, N)
94 | plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
95 | plt.close()
96 | else:
97 | if r == 0:
98 | data_ = self.vaes[0].unproject(data, search_split='test')
99 | recon_ = self.vaes[0].unproject(recon, search_split='train')
100 | comp = torch.cat([data_, recon_])
101 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
102 | else:
103 | with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
104 | for r_sent, d_sent in zip(recon, data):
105 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in d_sent)))
106 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(self.i2w[str(i)] for i in r_sent)))
107 |
108 | def analyse(self, data, runPath, epoch):
109 | zemb, zsl, kls_df = super(CUB_Image_Sentence_ft, self).analyse(data, K=10)
110 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]]
111 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
112 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
113 |
114 | def _sent_preprocess(self, sentences):
115 | """make sure raw data is always passed as dim=2 to avoid argmax.
116 | last dimension must always be word embedding."""
117 | if len(sentences.shape) > 2:
118 | sentences = sentences.argmax(-1).squeeze()
119 | return [self.vaes[1].fn_trun(s) for s in self.vaes[1].fn_2i(sentences)]
120 |
121 | def _imshow(self, image, caption, i, fig, N):
122 | """Imshow for Tensor."""
123 | ax = fig.add_subplot(N // 2, 4, i * 2 + 1)
124 | ax.axis('off')
125 | image = image.numpy().transpose((1, 2, 0)) #
126 | plt.imshow(image)
127 | ax = fig.add_subplot(N // 2, 4, i * 2 + 2)
128 | pos = ax.get_position()
129 | ax.axis('off')
130 | plt.text(
131 | x=0.5 * (pos.x0 + pos.x1),
132 | y=0.5 * (pos.y0 + pos.y1),
133 | ha='left',
134 | s='{}'.format(
135 | ' '.join(self.i2w[str(i)] + '\n' if (n + 1) % 5 == 0
136 | else self.i2w[str(i)] for n, i in enumerate(caption))),
137 | fontsize=6,
138 | verticalalignment='center',
139 | horizontalalignment='center'
140 | )
141 | return fig
142 |
--------------------------------------------------------------------------------
/src/models/mmvae_mnist_svhn.py:
--------------------------------------------------------------------------------
1 | # MNIST-SVHN multi-modal model specification
2 | import os
3 |
4 | import torch
5 | import torch.distributions as dist
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from numpy import sqrt, prod
9 | from torch.utils.data import DataLoader
10 | from torchnet.dataset import TensorDataset, ResampleDataset
11 | from torchvision.utils import save_image, make_grid
12 |
13 | from vis import plot_embeddings, plot_kls_df
14 | from .mmvae import MMVAE
15 | from .vae_mnist import MNIST
16 | from .vae_svhn import SVHN
17 |
18 |
19 | class MNIST_SVHN(MMVAE):
20 | def __init__(self, params):
21 | super(MNIST_SVHN, self).__init__(dist.Laplace, params, MNIST, SVHN)
22 | grad = {'requires_grad': params.learn_prior}
23 | self._pz_params = nn.ParameterList([
24 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
25 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
26 | ])
27 | self.vaes[0].llik_scaling = prod(self.vaes[1].dataSize) / prod(self.vaes[0].dataSize) \
28 | if params.llik_scaling == 0 else params.llik_scaling
29 | self.modelName = 'mnist-svhn'
30 |
31 | @property
32 | def pz_params(self):
33 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)
34 |
35 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'):
36 | if not (os.path.exists('../data/train-ms-mnist-idx.pt')
37 | and os.path.exists('../data/train-ms-svhn-idx.pt')
38 | and os.path.exists('../data/test-ms-mnist-idx.pt')
39 | and os.path.exists('../data/test-ms-svhn-idx.pt')):
40 | raise RuntimeError('Generate transformed indices with the script in bin')
41 | # get transformed indices
42 | t_mnist = torch.load('../data/train-ms-mnist-idx.pt')
43 | t_svhn = torch.load('../data/train-ms-svhn-idx.pt')
44 | s_mnist = torch.load('../data/test-ms-mnist-idx.pt')
45 | s_svhn = torch.load('../data/test-ms-svhn-idx.pt')
46 |
47 | # load base datasets
48 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device)
49 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device)
50 |
51 | train_mnist_svhn = TensorDataset([
52 | ResampleDataset(t1.dataset, lambda d, i: t_mnist[i], size=len(t_mnist)),
53 | ResampleDataset(t2.dataset, lambda d, i: t_svhn[i], size=len(t_svhn))
54 | ])
55 | test_mnist_svhn = TensorDataset([
56 | ResampleDataset(s1.dataset, lambda d, i: s_mnist[i], size=len(s_mnist)),
57 | ResampleDataset(s2.dataset, lambda d, i: s_svhn[i], size=len(s_svhn))
58 | ])
59 |
60 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
61 | train = DataLoader(train_mnist_svhn, batch_size=batch_size, shuffle=shuffle, **kwargs)
62 | test = DataLoader(test_mnist_svhn, batch_size=batch_size, shuffle=shuffle, **kwargs)
63 | return train, test
64 |
65 | def generate(self, runPath, epoch):
66 | N = 64
67 | samples_list = super(MNIST_SVHN, self).generate(N)
68 | for i, samples in enumerate(samples_list):
69 | samples = samples.data.cpu()
70 | # wrangle things so they come out tiled
71 | samples = samples.view(N, *samples.size()[1:])
72 | save_image(samples,
73 | '{}/gen_samples_{}_{:03d}.png'.format(runPath, i, epoch),
74 | nrow=int(sqrt(N)))
75 |
76 | def reconstruct(self, data, runPath, epoch):
77 | recons_mat = super(MNIST_SVHN, self).reconstruct([d[:8] for d in data])
78 | for r, recons_list in enumerate(recons_mat):
79 | for o, recon in enumerate(recons_list):
80 | _data = data[r][:8].cpu()
81 | recon = recon.squeeze(0).cpu()
82 | # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
83 | _data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize)
84 | recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
85 | comp = torch.cat([_data, recon])
86 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
87 |
88 | def analyse(self, data, runPath, epoch):
89 | zemb, zsl, kls_df = super(MNIST_SVHN, self).analyse(data, K=10)
90 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]]
91 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
92 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
93 |
94 |
95 | def resize_img(img, refsize):
96 | return F.pad(img, (2, 2, 2, 2)).expand(img.size(0), *refsize)
97 |
--------------------------------------------------------------------------------
/src/models/vae.py:
--------------------------------------------------------------------------------
1 | # Base VAE class definition
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utils import get_mean, kl_divergence
7 | from vis import embed_umap, tensors_to_df
8 |
9 |
10 | class VAE(nn.Module):
11 | def __init__(self, prior_dist, likelihood_dist, post_dist, enc, dec, params):
12 | super(VAE, self).__init__()
13 | self.pz = prior_dist
14 | self.px_z = likelihood_dist
15 | self.qz_x = post_dist
16 | self.enc = enc
17 | self.dec = dec
18 | self.modelName = None
19 | self.params = params
20 | self._pz_params = None # defined in subclass
21 | self._qz_x_params = None # populated in `forward`
22 | self.llik_scaling = 1.0
23 |
24 | @property
25 | def pz_params(self):
26 | return self._pz_params
27 |
28 | @property
29 | def qz_x_params(self):
30 | if self._qz_x_params is None:
31 | raise NameError("qz_x params not initalised yet!")
32 | return self._qz_x_params
33 |
34 | @staticmethod
35 | def getDataLoaders(batch_size, shuffle=True, device="cuda"):
36 | # handle merging individual datasets appropriately in sub-class
37 | raise NotImplementedError
38 |
39 | def forward(self, x, K=1):
40 | self._qz_x_params = self.enc(x)
41 | qz_x = self.qz_x(*self._qz_x_params)
42 | zs = qz_x.rsample(torch.Size([K]))
43 | px_z = self.px_z(*self.dec(zs))
44 | return qz_x, px_z, zs
45 |
46 | def generate(self, N, K):
47 | self.eval()
48 | with torch.no_grad():
49 | pz = self.pz(*self.pz_params)
50 | latents = pz.rsample(torch.Size([N]))
51 | px_z = self.px_z(*self.dec(latents))
52 | data = px_z.sample(torch.Size([K]))
53 | return data.view(-1, *data.size()[3:])
54 |
55 | def reconstruct(self, data):
56 | self.eval()
57 | with torch.no_grad():
58 | qz_x = self.qz_x(*self.enc(data))
59 | latents = qz_x.rsample() # no dim expansion
60 | px_z = self.px_z(*self.dec(latents))
61 | recon = get_mean(px_z)
62 | return recon
63 |
64 | def analyse(self, data, K):
65 | self.eval()
66 | with torch.no_grad():
67 | qz_x, _, zs = self.forward(data, K=K)
68 | pz = self.pz(*self.pz_params)
69 | zss = [pz.sample(torch.Size([K, data.size(0)])).view(-1, pz.batch_shape[-1]),
70 | zs.view(-1, zs.size(-1))]
71 | zsl = [torch.zeros(zs.size(0)).fill_(i) for i, zs in enumerate(zss)]
72 | kls_df = tensors_to_df(
73 | [kl_divergence(qz_x, pz).cpu().numpy()],
74 | head='KL',
75 | keys=[r'KL$(q(z|x)\,||\,p(z))$'],
76 | ax_names=['Dimensions', r'KL$(q\,||\,p)$']
77 | )
78 | return embed_umap(torch.cat(zss, 0).cpu().numpy()), \
79 | torch.cat(zsl, 0).cpu().numpy(), \
80 | kls_df
81 |
--------------------------------------------------------------------------------
/src/models/vae_cub_image.py:
--------------------------------------------------------------------------------
1 | # CUB Image model specification
2 |
3 | import torch
4 | import torch.distributions as dist
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | from numpy import sqrt
9 | from torchvision import datasets, transforms
10 | from torchvision.utils import make_grid, save_image
11 |
12 | from utils import Constants
13 | from vis import plot_embeddings, plot_kls_df
14 | from .vae import VAE
15 |
16 | # Constants
17 | imgChans = 3
18 | fBase = 64
19 |
20 |
21 | # Classes
22 | class Enc(nn.Module):
23 | """ Generate latent parameters for CUB image data. """
24 |
25 | def __init__(self, latentDim):
26 | super(Enc, self).__init__()
27 | modules = [
28 | # input size: 3 x 128 x 128
29 | nn.Conv2d(imgChans, fBase, 4, 2, 1, bias=True),
30 | nn.ReLU(True),
31 | # input size: 1 x 64 x 64
32 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True),
33 | nn.ReLU(True),
34 | # size: (fBase * 2) x 32 x 32
35 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True),
36 | nn.ReLU(True),
37 | # size: (fBase * 4) x 16 x 16
38 | nn.Conv2d(fBase * 4, fBase * 8, 4, 2, 1, bias=True),
39 | nn.ReLU(True)]
40 | # size: (fBase * 8) x 4 x 4
41 |
42 | self.enc = nn.Sequential(*modules)
43 | self.c1 = nn.Conv2d(fBase * 8, latentDim, 4, 1, 0, bias=True)
44 | self.c2 = nn.Conv2d(fBase * 8, latentDim, 4, 1, 0, bias=True)
45 | # c1, c2 size: latentDim x 1 x 1
46 |
47 | def forward(self, x):
48 | e = self.enc(x)
49 | return self.c1(e).squeeze(), F.softplus(self.c2(e)).squeeze() + Constants.eta
50 |
51 |
52 | class Dec(nn.Module):
53 | """ Generate an image given a sample from the latent space. """
54 |
55 | def __init__(self, latentDim):
56 | super(Dec, self).__init__()
57 | modules = [nn.ConvTranspose2d(latentDim, fBase * 8, 4, 1, 0, bias=True),
58 | nn.ReLU(True), ]
59 |
60 | modules.extend([
61 | nn.ConvTranspose2d(fBase * 8, fBase * 4, 4, 2, 1, bias=True),
62 | nn.ReLU(True),
63 | # size: (fBase * 4) x 16 x 16
64 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True),
65 | nn.ReLU(True),
66 | # size: (fBase * 2) x 32 x 32
67 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True),
68 | nn.ReLU(True),
69 | # size: (fBase) x 64 x 64
70 | nn.ConvTranspose2d(fBase, imgChans, 4, 2, 1, bias=True),
71 | nn.Sigmoid()
72 | # Output size: 3 x 128 x 128
73 | ])
74 | self.dec = nn.Sequential(*modules)
75 |
76 | def forward(self, z):
77 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers
78 | out = self.dec(z.view(-1, *z.size()[-3:]))
79 | out = out.view(*z.size()[:-3], *out.size()[1:])
80 | return out, torch.tensor(0.01).to(z.device)
81 |
82 |
83 | class CUB_Image(VAE):
84 | """ Derive a specific sub-class of a VAE for a CNN sentence model. """
85 |
86 | def __init__(self, params):
87 | super(CUB_Image, self).__init__(
88 | dist.Laplace, # prior
89 | dist.Laplace, # likelihood
90 | dist.Laplace, # posterior
91 | Enc(params.latent_dim),
92 | Dec(params.latent_dim),
93 | params
94 | )
95 | grad = {'requires_grad': params.learn_prior}
96 | self._pz_params = nn.ParameterList([
97 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
98 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
99 | ])
100 | self.modelName = 'cubI'
101 | self.dataSize = torch.Size([3, 64, 64])
102 | self.llik_scaling = 1.
103 |
104 | @property
105 | def pz_params(self):
106 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta
107 |
108 | # remember that when combining with captions, this should be x10
109 | def getDataLoaders(self, batch_size, shuffle=True, device="cuda"):
110 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
111 | tx = transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor()])
112 | train_loader = torch.utils.data.DataLoader(
113 | datasets.ImageFolder('../data/cub/train', transform=tx),
114 | batch_size=batch_size, shuffle=shuffle, **kwargs)
115 | test_loader = torch.utils.data.DataLoader(
116 | datasets.ImageFolder('../data/cub/test', transform=tx),
117 | batch_size=batch_size, shuffle=shuffle, **kwargs)
118 | return train_loader, test_loader
119 |
120 | def generate(self, runPath, epoch):
121 | N, K = 64, 9
122 | samples = super(CUB_Image, self).generate(N, K).data.cpu()
123 | # wrangle things so they come out tiled
124 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1)
125 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples.data.cpu()]
126 | save_image(torch.stack(s),
127 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch),
128 | nrow=int(sqrt(N)))
129 |
130 | def reconstruct(self, data, runPath, epoch):
131 | recon = super(CUB_Image, self).reconstruct(data[:8])
132 | comp = torch.cat([data[:8], recon])
133 | save_image(comp.data.cpu(), '{}/recon_{:03d}.png'.format(runPath, epoch))
134 |
135 | def analyse(self, data, runPath, epoch):
136 | zemb, zsl, kls_df = super(CUB_Image, self).analyse(data, K=10)
137 | labels = ['Prior', self.modelName.lower()]
138 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
139 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
140 |
--------------------------------------------------------------------------------
/src/models/vae_cub_image_ft.py:
--------------------------------------------------------------------------------
1 | # CUB Image feature model specification
2 |
3 | import torch
4 | import torch.distributions as dist
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | from numpy import sqrt
9 | from torchvision.utils import make_grid, save_image
10 |
11 | from datasets import CUBImageFt
12 | from utils import Constants, NN_lookup
13 | from vis import plot_embeddings, plot_kls_df
14 | from .vae import VAE
15 |
16 | # Constants
17 | imgChans = 3
18 | fBase = 64
19 |
20 |
21 | class Enc(nn.Module):
22 | """ Generate latent parameters for CUB image feature. """
23 |
24 | def __init__(self, latent_dim, n_c):
25 | super(Enc, self).__init__()
26 | dim_hidden = 256
27 | self.enc = nn.Sequential()
28 | for i in range(int(torch.tensor(n_c / dim_hidden).log2())):
29 | self.enc.add_module("layer" + str(i), nn.Sequential(
30 | nn.Linear(n_c // (2 ** i), n_c // (2 ** (i + 1))),
31 | nn.ELU(inplace=True),
32 | ))
33 | # relies on above terminating at dim_hidden
34 | self.fc21 = nn.Linear(dim_hidden, latent_dim)
35 | self.fc22 = nn.Linear(dim_hidden, latent_dim)
36 |
37 | def forward(self, x):
38 | e = self.enc(x)
39 | return self.fc21(e), F.softplus(self.fc22(e)) + Constants.eta
40 |
41 |
42 | class Dec(nn.Module):
43 | """ Generate a CUB image feature given a sample from the latent space. """
44 |
45 | def __init__(self, latent_dim, n_c):
46 | super(Dec, self).__init__()
47 | self.n_c = n_c
48 | dim_hidden = 256
49 | self.dec = nn.Sequential()
50 | for i in range(int(torch.tensor(n_c / dim_hidden).log2())):
51 | indim = latent_dim if i == 0 else dim_hidden * i
52 | outdim = dim_hidden if i == 0 else dim_hidden * (2 * i)
53 | self.dec.add_module("out_t" if i == 0 else "layer" + str(i) + "_t", nn.Sequential(
54 | nn.Linear(indim, outdim),
55 | nn.ELU(inplace=True),
56 | ))
57 | # relies on above terminating at n_c // 2
58 | self.fc31 = nn.Linear(n_c // 2, n_c)
59 |
60 | def forward(self, z):
61 | p = self.dec(z.view(-1, z.size(-1)))
62 | mean = self.fc31(p).view(*z.size()[:-1], -1)
63 | return mean, torch.tensor([0.01]).to(mean.device)
64 |
65 |
66 | class CUB_Image_ft(VAE):
67 | """ Derive a specific sub-class of a VAE for a CNN sentence model. """
68 |
69 | def __init__(self, params):
70 | super(CUB_Image_ft, self).__init__(
71 | dist.Normal, # prior
72 | dist.Laplace, # likelihood
73 | dist.Normal, # posterior
74 | Enc(params.latent_dim, 2048),
75 | Dec(params.latent_dim, 2048),
76 | params
77 | )
78 | grad = {'requires_grad': params.learn_prior}
79 | self._pz_params = nn.ParameterList([
80 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
81 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
82 | ])
83 | self.modelName = 'cubIft'
84 | self.dataSize = torch.Size([2048])
85 |
86 | self.llik_scaling = 1.
87 |
88 | @property
89 | def pz_params(self):
90 | return self._pz_params[0], \
91 | F.softplus(self._pz_params[1]) + Constants.eta
92 |
93 | # remember that when combining with captions, this should be x10
94 | def getDataLoaders(self, batch_size, shuffle=True, device="cuda"):
95 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
96 |
97 | train_dataset = CUBImageFt('../data', 'train', device)
98 | test_dataset = CUBImageFt('../data', 'test', device)
99 | train_loader = torch.utils.data.DataLoader(train_dataset,
100 | batch_size=batch_size, shuffle=shuffle, **kwargs)
101 | test_loader = torch.utils.data.DataLoader(test_dataset,
102 | batch_size=batch_size, shuffle=shuffle, **kwargs)
103 |
104 | train_dataset._load_data()
105 | test_dataset._load_data()
106 | self.unproject = lambda emb_h, search_split='train', \
107 | te=train_dataset.ft_mat, td=train_dataset.data_mat, \
108 | se=test_dataset.ft_mat, sd=test_dataset.data_mat: \
109 | NN_lookup(emb_h, te, td) if search_split == 'train' else NN_lookup(emb_h, se, sd)
110 |
111 | return train_loader, test_loader
112 |
113 | def generate(self, runPath, epoch):
114 | N, K = 64, 9
115 | samples = super(CUB_Image_ft, self).generate(N, K).data.cpu()
116 | samples = self.unproject(samples, search_split='train')
117 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1)
118 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples.data.cpu()]
119 | save_image(torch.stack(s),
120 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch),
121 | nrow=int(sqrt(N)))
122 |
123 | def reconstruct(self, data, runPath, epoch):
124 | recon = super(CUB_Image_ft, self).reconstruct(data[:8])
125 | data_ = self.unproject(data[:8], search_split='test')
126 | recon_ = self.unproject(recon, search_split='train')
127 | comp = torch.cat([data_, recon_])
128 | save_image(comp.data.cpu(), '{}/recon_{:03d}.png'.format(runPath, epoch))
129 |
130 | def analyse(self, data, runPath, epoch):
131 | zemb, zsl, kls_df = super(CUB_Image_ft, self).analyse(data, K=10)
132 | labels = ['Prior', self.modelName.lower()]
133 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
134 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
135 |
--------------------------------------------------------------------------------
/src/models/vae_cub_sent.py:
--------------------------------------------------------------------------------
1 | # Sentence model specification - real CUB image version
2 | import os
3 | import json
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributions as dist
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 | from torch.utils.data import DataLoader
12 |
13 | from datasets import CUBSentences
14 | from utils import Constants, FakeCategorical
15 | from .vae import VAE
16 |
17 | # Constants
18 | maxSentLen = 32 # max length of any description for birds dataset
19 | minOccur = 3
20 | embeddingDim = 128
21 | lenWindow = 3
22 | fBase = 32
23 | vocabSize = 1590
24 | vocab_path = '../data/cub/oc:{}_sl:{}_s:{}_w:{}/cub.vocab'.format(minOccur, maxSentLen, 300, lenWindow)
25 |
26 |
27 | # Classes
28 | class Enc(nn.Module):
29 | """ Generate latent parameters for sentence data. """
30 |
31 | def __init__(self, latentDim):
32 | super(Enc, self).__init__()
33 | self.embedding = nn.Embedding(vocabSize, embeddingDim, padding_idx=0)
34 | self.enc = nn.Sequential(
35 | # input size: 1 x 32 x 128
36 | nn.Conv2d(1, fBase, 4, 2, 1, bias=False),
37 | nn.BatchNorm2d(fBase),
38 | nn.ReLU(True),
39 | # size: (fBase) x 16 x 64
40 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=False),
41 | nn.BatchNorm2d(fBase * 2),
42 | nn.ReLU(True),
43 | # size: (fBase * 2) x 8 x 32
44 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=False),
45 | nn.BatchNorm2d(fBase * 4),
46 | nn.ReLU(True),
47 | # # size: (fBase * 4) x 4 x 16
48 | nn.Conv2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False),
49 | nn.BatchNorm2d(fBase * 4),
50 | nn.ReLU(True),
51 | # size: (fBase * 8) x 4 x 8
52 | nn.Conv2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False),
53 | nn.BatchNorm2d(fBase * 4),
54 | nn.ReLU(True),
55 | # size: (fBase * 8) x 4 x 4
56 | )
57 | self.c1 = nn.Conv2d(fBase * 4, latentDim, 4, 1, 0, bias=False)
58 | self.c2 = nn.Conv2d(fBase * 4, latentDim, 4, 1, 0, bias=False)
59 | # c1, c2 size: latentDim x 1 x 1
60 |
61 | def forward(self, x):
62 | e = self.enc(self.embedding(x.long()).unsqueeze(1))
63 | mu, logvar = self.c1(e).squeeze(), self.c2(e).squeeze()
64 | return mu, F.softplus(logvar) + Constants.eta
65 |
66 |
67 | class Dec(nn.Module):
68 | """ Generate a sentence given a sample from the latent space. """
69 |
70 | def __init__(self, latentDim):
71 | super(Dec, self).__init__()
72 | self.dec = nn.Sequential(
73 | nn.ConvTranspose2d(latentDim, fBase * 4, 4, 1, 0, bias=False),
74 | nn.BatchNorm2d(fBase * 4),
75 | nn.ReLU(True),
76 | # size: (fBase * 8) x 4 x 4
77 | nn.ConvTranspose2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False),
78 | nn.BatchNorm2d(fBase * 4),
79 | nn.ReLU(True),
80 | # size: (fBase * 8) x 4 x 8
81 | nn.ConvTranspose2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False),
82 | nn.BatchNorm2d(fBase * 4),
83 | nn.ReLU(True),
84 | # size: (fBase * 4) x 8 x 32
85 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=False),
86 | nn.BatchNorm2d(fBase * 2),
87 | nn.ReLU(True),
88 | # size: (fBase * 2) x 16 x 64
89 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=False),
90 | nn.BatchNorm2d(fBase),
91 | nn.ReLU(True),
92 | # size: (fBase) x 32 x 128
93 | nn.ConvTranspose2d(fBase, 1, 4, 2, 1, bias=False),
94 | nn.ReLU(True)
95 | # Output size: 1 x 64 x 256
96 | )
97 | # inverts the 'embedding' module upto one-hotness
98 | self.toVocabSize = nn.Linear(embeddingDim, vocabSize)
99 |
100 | def forward(self, z):
101 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers
102 | out = self.dec(z.view(-1, *z.size()[-3:])).view(-1, embeddingDim)
103 |
104 | return self.toVocabSize(out).view(*z.size()[:-3], maxSentLen, vocabSize),
105 |
106 |
107 | class CUB_Sentence(VAE):
108 | """ Derive a specific sub-class of a VAE for a sentence model. """
109 |
110 | def __init__(self, params):
111 | super(CUB_Sentence, self).__init__(
112 | prior_dist=dist.Normal,
113 | likelihood_dist=FakeCategorical,
114 | post_dist=dist.Normal,
115 | enc=Enc(params.latent_dim),
116 | dec=Dec(params.latent_dim),
117 | params=params)
118 | grad = {'requires_grad': params.learn_prior}
119 | self._pz_params = nn.ParameterList([
120 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
121 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
122 | ])
123 | self.modelName = 'cubS'
124 | self.llik_scaling = 1.
125 |
126 | self.tie_modules()
127 |
128 | self.fn_2i = lambda t: t.cpu().numpy().astype(int)
129 | self.fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s
130 | self.vocab_file = vocab_path
131 |
132 | self.maxSentLen = maxSentLen
133 | self.vocabSize = vocabSize
134 |
135 | def tie_modules(self):
136 | # This looks dumb, but is actually dumber than you might realise.
137 | # A linear(a, b) module has a [b x a] weight matrix, but an embedding(a, b)
138 | # module has a [a x b] weight matrix. So when we want the transpose at
139 | # decoding time, we just use the weight matrix as is.
140 | self.dec.toVocabSize.weight = self.enc.embedding.weight
141 |
142 | @property
143 | def pz_params(self):
144 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta
145 |
146 | @staticmethod
147 | def getDataLoaders(batch_size, shuffle=True, device="cuda"):
148 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
149 | tx = lambda data: torch.Tensor(data)
150 | t_data = CUBSentences('../data', split='train', transform=tx, max_sequence_length=maxSentLen)
151 | s_data = CUBSentences('../data', split='test', transform=tx, max_sequence_length=maxSentLen)
152 |
153 | train_loader = DataLoader(t_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
154 | test_loader = DataLoader(s_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
155 |
156 | return train_loader, test_loader
157 |
158 | def reconstruct(self, data, runPath, epoch):
159 | recon = super(CUB_Sentence, self).reconstruct(data[:8]).argmax(dim=-1).squeeze()
160 | recon, data = self.fn_2i(recon), self.fn_2i(data[:8])
161 | recon, data = [self.fn_trun(r) for r in recon], [self.fn_trun(d) for d in data]
162 | i2w = self.load_vocab()
163 | print("\n Reconstruction examples (excluding ):")
164 | for r_sent, d_sent in zip(recon[:3], data[:3]):
165 | print('[DATA] ==> {}'.format(' '.join(i2w[str(i)] for i in d_sent)))
166 | print('[RECON] ==> {}\n'.format(' '.join(i2w[str(i)] for i in r_sent)))
167 |
168 | with open('{}/recon_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
169 | for r_sent, d_sent in zip(recon, data):
170 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(i2w[str(i)] for i in d_sent)))
171 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(i2w[str(i)] for i in r_sent)))
172 |
173 | def generate(self, runPath, epoch):
174 | N, K = 5, 4
175 | i2w = self.load_vocab()
176 | samples = super(CUB_Sentence, self).generate(N, K).argmax(dim=-1).squeeze()
177 | samples = samples.view(K, N, samples.size(-1)).transpose(0, 1) # N x K x 64
178 | samples = [[self.fn_trun(s) for s in ss] for ss in self.fn_2i(samples)]
179 | # samples = [self.fn_trun(s) for s in samples]
180 | print("\n Generated examples (excluding ):")
181 | for s_sent in samples[0][:3]:
182 | print('[GEN] ==> {}'.format(' '.join(i2w[str(i)] for i in s_sent if i != 0)))
183 |
184 | with open('{}/gen_samples_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
185 | for s_sents in samples:
186 | for s_sent in s_sents:
187 | txt_file.write('{}\n'.format(' '.join(i2w[str(i)] for i in s_sent)))
188 | txt_file.write('\n')
189 |
190 | def analyse(self, data, runPath, epoch):
191 | pass
192 |
193 | def load_vocab(self):
194 | # call dataloader function to create vocab file
195 | if not os.path.exists(self.vocab_file):
196 | _, _ = self.getDataLoaders(256)
197 | with open(self.vocab_file, 'r') as vocab_file:
198 | vocab = json.load(vocab_file)
199 | return vocab['i2w']
200 |
--------------------------------------------------------------------------------
/src/models/vae_cub_sent_ft.py:
--------------------------------------------------------------------------------
1 | # Sentence model specification - CUB image feature version
2 | import json
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributions as dist
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 | from torch.utils.data import DataLoader
12 |
13 | from datasets import CUBSentences
14 | from utils import Constants, FakeCategorical
15 | from .vae import VAE
16 |
17 | maxSentLen = 32 # max length of any description for birds dataset
18 | minOccur = 3
19 | embeddingDim = 128
20 | lenWindow = 3
21 | fBase = 32
22 | vocabSize = 1590
23 | vocab_path = '../data/cub/oc:{}_sl:{}_s:{}_w:{}/cub.vocab'.format(minOccur, maxSentLen, 300, lenWindow)
24 |
25 |
26 | # Classes
27 | class Enc(nn.Module):
28 | """ Generate latent parameters for sentence data. """
29 |
30 | def __init__(self, latentDim):
31 | super(Enc, self).__init__()
32 | self.embedding = nn.Embedding(vocabSize, embeddingDim, padding_idx=0)
33 | self.enc = nn.Sequential(
34 | # input size: 1 x 32 x 128
35 | nn.Conv2d(1, fBase, 4, 2, 1, bias=True),
36 | nn.BatchNorm2d(fBase),
37 | nn.ReLU(True),
38 | # size: (fBase) x 16 x 64
39 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True),
40 | nn.BatchNorm2d(fBase * 2),
41 | nn.ReLU(True),
42 | # size: (fBase * 2) x 8 x 32
43 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True),
44 | nn.BatchNorm2d(fBase * 4),
45 | nn.ReLU(True),
46 | # # size: (fBase * 4) x 4 x 16
47 | nn.Conv2d(fBase * 4, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
48 | nn.BatchNorm2d(fBase * 8),
49 | nn.ReLU(True),
50 | # size: (fBase * 8) x 4 x 8
51 | nn.Conv2d(fBase * 8, fBase * 16, (1, 4), (1, 2), (0, 1), bias=True),
52 | nn.BatchNorm2d(fBase * 16),
53 | nn.ReLU(True),
54 | # size: (fBase * 8) x 4 x 4
55 | )
56 | self.c1 = nn.Conv2d(fBase * 16, latentDim, 4, 1, 0, bias=True)
57 | self.c2 = nn.Conv2d(fBase * 16, latentDim, 4, 1, 0, bias=True)
58 | # c1, c2 size: latentDim x 1 x 1
59 |
60 | def forward(self, x):
61 | e = self.enc(self.embedding(x.long()).unsqueeze(1))
62 | mu, logvar = self.c1(e).squeeze(), self.c2(e).squeeze()
63 | return mu, F.softplus(logvar) + Constants.eta
64 |
65 |
66 | class Dec(nn.Module):
67 | """ Generate a sentence given a sample from the latent space. """
68 |
69 | def __init__(self, latentDim):
70 | super(Dec, self).__init__()
71 | self.dec = nn.Sequential(
72 | nn.ConvTranspose2d(latentDim, fBase * 16, 4, 1, 0, bias=True),
73 | nn.BatchNorm2d(fBase * 16),
74 | nn.ReLU(True),
75 | # size: (fBase * 8) x 4 x 4
76 | nn.ConvTranspose2d(fBase * 16, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
77 | nn.BatchNorm2d(fBase * 8),
78 | nn.ReLU(True),
79 | # size: (fBase * 8) x 4 x 8
80 | nn.ConvTranspose2d(fBase * 8, fBase * 4, (1, 4), (1, 2), (0, 1), bias=True),
81 | nn.BatchNorm2d(fBase * 4),
82 | nn.ReLU(True),
83 | # size: (fBase * 4) x 8 x 32
84 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True),
85 | nn.BatchNorm2d(fBase * 2),
86 | nn.ReLU(True),
87 | # size: (fBase * 2) x 16 x 64
88 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True),
89 | nn.BatchNorm2d(fBase),
90 | nn.ReLU(True),
91 | # size: (fBase) x 32 x 128
92 | nn.ConvTranspose2d(fBase, 1, 4, 2, 1, bias=True),
93 | nn.ReLU(True)
94 | # Output size: 1 x 64 x 256
95 | )
96 | # inverts the 'embedding' module upto one-hotness
97 | self.toVocabSize = nn.Linear(embeddingDim, vocabSize)
98 |
99 | def forward(self, z):
100 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers
101 | out = self.dec(z.view(-1, *z.size()[-3:])).view(-1, embeddingDim)
102 |
103 | return self.toVocabSize(out).view(*z.size()[:-3], maxSentLen, vocabSize),
104 |
105 |
106 | class CUB_Sentence_ft(VAE):
107 | """ Derive a specific sub-class of a VAE for a sentence model. """
108 |
109 | def __init__(self, params):
110 | super(CUB_Sentence_ft, self).__init__(
111 | prior_dist=dist.Normal,
112 | likelihood_dist=FakeCategorical,
113 | post_dist=dist.Normal,
114 | enc=Enc(params.latent_dim),
115 | dec=Dec(params.latent_dim),
116 | params=params)
117 | grad = {'requires_grad': params.learn_prior}
118 | self._pz_params = nn.ParameterList([
119 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
120 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
121 | ])
122 | self.modelName = 'cubSft'
123 | self.llik_scaling = 1.
124 |
125 | self.tie_modules()
126 |
127 | self.fn_2i = lambda t: t.cpu().numpy().astype(int)
128 | self.fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s
129 | self.vocab_file = vocab_path
130 |
131 | self.maxSentLen = maxSentLen
132 | self.vocabSize = vocabSize
133 |
134 | def tie_modules(self):
135 | # This looks dumb, but is actually dumber than you might realise.
136 | # A linear(a, b) module has a [b x a] weight matrix, but an embedding(a, b)
137 | # module has a [a x b] weight matrix. So when we want the transpose at
138 | # decoding time, we just use the weight matrix as is.
139 | self.dec.toVocabSize.weight = self.enc.embedding.weight
140 |
141 | @property
142 | def pz_params(self):
143 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta
144 |
145 | @staticmethod
146 | def getDataLoaders(batch_size, shuffle=True, device="cuda"):
147 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
148 | tx = lambda data: torch.Tensor(data)
149 | t_data = CUBSentences('../data', split='train', transform=tx, max_sequence_length=maxSentLen)
150 | s_data = CUBSentences('../data', split='test', transform=tx, max_sequence_length=maxSentLen)
151 |
152 | train_loader = DataLoader(t_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
153 | test_loader = DataLoader(s_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
154 |
155 | return train_loader, test_loader
156 |
157 | def reconstruct(self, data, runPath, epoch):
158 | recon = super(CUB_Sentence_ft, self).reconstruct(data[:8]).argmax(dim=-1).squeeze()
159 | recon, data = self.fn_2i(recon), self.fn_2i(data[:8])
160 | recon, data = [self.fn_trun(r) for r in recon], [self.fn_trun(d) for d in data]
161 | i2w = self.load_vocab()
162 | print("\n Reconstruction examples (excluding ):")
163 | for r_sent, d_sent in zip(recon[:3], data[:3]):
164 | print('[DATA] ==> {}'.format(' '.join(i2w[str(i)] for i in d_sent)))
165 | print('[RECON] ==> {}\n'.format(' '.join(i2w[str(i)] for i in r_sent)))
166 |
167 | with open('{}/recon_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
168 | for r_sent, d_sent in zip(recon, data):
169 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(i2w[str(i)] for i in d_sent)))
170 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(i2w[str(i)] for i in r_sent)))
171 |
172 | def generate(self, runPath, epoch):
173 | N, K = 5, 4
174 | i2w = self.load_vocab()
175 | samples = super(CUB_Sentence_ft, self).generate(N, K).argmax(dim=-1).squeeze()
176 | samples = samples.view(K, N, samples.size(-1)).transpose(0, 1) # N x K x 64
177 | samples = [[self.fn_trun(s) for s in ss] for ss in self.fn_2i(samples)]
178 | # samples = [self.fn_trun(s) for s in samples]
179 | print("\n Generated examples (excluding ):")
180 | for s_sent in samples[0][:3]:
181 | print('[GEN] ==> {}'.format(' '.join(i2w[str(i)] for i in s_sent if i != 0)))
182 |
183 | with open('{}/gen_samples_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
184 | for s_sents in samples:
185 | for s_sent in s_sents:
186 | txt_file.write('{}\n'.format(' '.join(i2w[str(i)] for i in s_sent)))
187 | txt_file.write('\n')
188 |
189 | def analyse(self, data, runPath, epoch):
190 | pass
191 |
192 | def load_vocab(self):
193 | # call dataloader function to create vocab file
194 | if not os.path.exists(self.vocab_file):
195 | _, _ = self.getDataLoaders(256)
196 | with open(self.vocab_file, 'r') as vocab_file:
197 | vocab = json.load(vocab_file)
198 | return vocab['i2w']
199 |
--------------------------------------------------------------------------------
/src/models/vae_mnist.py:
--------------------------------------------------------------------------------
1 | # MNIST model specification
2 |
3 | import torch
4 | import torch.distributions as dist
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from numpy import prod, sqrt
8 | from torch.utils.data import DataLoader
9 | from torchvision import datasets, transforms
10 | from torchvision.utils import save_image, make_grid
11 |
12 | from utils import Constants
13 | from vis import plot_embeddings, plot_kls_df
14 | from .vae import VAE
15 |
16 | # Constants
17 | dataSize = torch.Size([1, 28, 28])
18 | data_dim = int(prod(dataSize))
19 | hidden_dim = 400
20 |
21 |
22 | def extra_hidden_layer():
23 | return nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True))
24 |
25 |
26 | # Classes
27 | class Enc(nn.Module):
28 | """ Generate latent parameters for MNIST image data. """
29 |
30 | def __init__(self, latent_dim, num_hidden_layers=1):
31 | super(Enc, self).__init__()
32 | modules = []
33 | modules.append(nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.ReLU(True)))
34 | modules.extend([extra_hidden_layer() for _ in range(num_hidden_layers - 1)])
35 | self.enc = nn.Sequential(*modules)
36 | self.fc21 = nn.Linear(hidden_dim, latent_dim)
37 | self.fc22 = nn.Linear(hidden_dim, latent_dim)
38 |
39 | def forward(self, x):
40 | e = self.enc(x.view(*x.size()[:-3], -1)) # flatten data
41 | lv = self.fc22(e)
42 | return self.fc21(e), F.softmax(lv, dim=-1) * lv.size(-1) + Constants.eta
43 |
44 |
45 | class Dec(nn.Module):
46 | """ Generate an MNIST image given a sample from the latent space. """
47 |
48 | def __init__(self, latent_dim, num_hidden_layers=1):
49 | super(Dec, self).__init__()
50 | modules = []
51 | modules.append(nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.ReLU(True)))
52 | modules.extend([extra_hidden_layer() for _ in range(num_hidden_layers - 1)])
53 | self.dec = nn.Sequential(*modules)
54 | self.fc3 = nn.Linear(hidden_dim, data_dim)
55 |
56 | def forward(self, z):
57 | p = self.fc3(self.dec(z))
58 | d = torch.sigmoid(p.view(*z.size()[:-1], *dataSize)) # reshape data
59 | d = d.clamp(Constants.eta, 1 - Constants.eta)
60 |
61 | return d, torch.tensor(0.75).to(z.device) # mean, length scale
62 |
63 |
64 | class MNIST(VAE):
65 | """ Derive a specific sub-class of a VAE for MNIST. """
66 |
67 | def __init__(self, params):
68 | super(MNIST, self).__init__(
69 | dist.Laplace, # prior
70 | dist.Laplace, # likelihood
71 | dist.Laplace, # posterior
72 | Enc(params.latent_dim, params.num_hidden_layers),
73 | Dec(params.latent_dim, params.num_hidden_layers),
74 | params
75 | )
76 | grad = {'requires_grad': params.learn_prior}
77 | self._pz_params = nn.ParameterList([
78 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
79 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
80 | ])
81 | self.modelName = 'mnist'
82 | self.dataSize = dataSize
83 | self.llik_scaling = 1.
84 |
85 | @property
86 | def pz_params(self):
87 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)
88 |
89 | @staticmethod
90 | def getDataLoaders(batch_size, shuffle=True, device="cuda"):
91 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
92 | tx = transforms.ToTensor()
93 | train = DataLoader(datasets.MNIST('../data', train=True, download=True, transform=tx),
94 | batch_size=batch_size, shuffle=shuffle, **kwargs)
95 | test = DataLoader(datasets.MNIST('../data', train=False, download=True, transform=tx),
96 | batch_size=batch_size, shuffle=shuffle, **kwargs)
97 | return train, test
98 |
99 | def generate(self, runPath, epoch):
100 | N, K = 64, 9
101 | samples = super(MNIST, self).generate(N, K).cpu()
102 | # wrangle things so they come out tiled
103 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1) # N x K x 1 x 28 x 28
104 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples]
105 | save_image(torch.stack(s),
106 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch),
107 | nrow=int(sqrt(N)))
108 |
109 | def reconstruct(self, data, runPath, epoch):
110 | recon = super(MNIST, self).reconstruct(data[:8])
111 | comp = torch.cat([data[:8], recon]).data.cpu()
112 | save_image(comp, '{}/recon_{:03d}.png'.format(runPath, epoch))
113 |
114 | def analyse(self, data, runPath, epoch):
115 | zemb, zsl, kls_df = super(MNIST, self).analyse(data, K=10)
116 | labels = ['Prior', self.modelName.lower()]
117 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
118 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
119 |
--------------------------------------------------------------------------------
/src/models/vae_svhn.py:
--------------------------------------------------------------------------------
1 | # SVHN model specification
2 |
3 | import torch
4 | import torch.distributions as dist
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from numpy import sqrt
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms, datasets
10 | from torchvision.utils import save_image, make_grid
11 |
12 | from utils import Constants
13 | from vis import plot_embeddings, plot_kls_df
14 | from .vae import VAE
15 |
16 | # Constants
17 | dataSize = torch.Size([3, 32, 32])
18 | imgChans = dataSize[0]
19 | fBase = 32 # base size of filter channels
20 |
21 |
22 | # Classes
23 | class Enc(nn.Module):
24 | """ Generate latent parameters for SVHN image data. """
25 |
26 | def __init__(self, latent_dim):
27 | super(Enc, self).__init__()
28 | self.enc = nn.Sequential(
29 | # input size: 3 x 32 x 32
30 | nn.Conv2d(imgChans, fBase, 4, 2, 1, bias=True),
31 | nn.ReLU(True),
32 | # size: (fBase) x 16 x 16
33 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True),
34 | nn.ReLU(True),
35 | # size: (fBase * 2) x 8 x 8
36 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True),
37 | nn.ReLU(True),
38 | # size: (fBase * 4) x 4 x 4
39 | )
40 | self.c1 = nn.Conv2d(fBase * 4, latent_dim, 4, 1, 0, bias=True)
41 | self.c2 = nn.Conv2d(fBase * 4, latent_dim, 4, 1, 0, bias=True)
42 | # c1, c2 size: latent_dim x 1 x 1
43 |
44 | def forward(self, x):
45 | e = self.enc(x)
46 | lv = self.c2(e).squeeze()
47 | return self.c1(e).squeeze(), F.softmax(lv, dim=-1) * lv.size(-1) + Constants.eta
48 |
49 |
50 | class Dec(nn.Module):
51 | """ Generate a SVHN image given a sample from the latent space. """
52 |
53 | def __init__(self, latent_dim):
54 | super(Dec, self).__init__()
55 | self.dec = nn.Sequential(
56 | nn.ConvTranspose2d(latent_dim, fBase * 4, 4, 1, 0, bias=True),
57 | nn.ReLU(True),
58 | # size: (fBase * 4) x 4 x 4
59 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True),
60 | nn.ReLU(True),
61 | # size: (fBase * 2) x 8 x 8
62 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True),
63 | nn.ReLU(True),
64 | # size: (fBase) x 16 x 16
65 | nn.ConvTranspose2d(fBase, imgChans, 4, 2, 1, bias=True),
66 | nn.Sigmoid()
67 | # Output size: 3 x 32 x 32
68 | )
69 |
70 | def forward(self, z):
71 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers
72 | out = self.dec(z.view(-1, *z.size()[-3:]))
73 | out = out.view(*z.size()[:-3], *out.size()[1:])
74 | # consider also predicting the length scale
75 | return out, torch.tensor(0.75).to(z.device) # mean, length scale
76 |
77 |
78 | class SVHN(VAE):
79 | """ Derive a specific sub-class of a VAE for SVHN """
80 |
81 | def __init__(self, params):
82 | super(SVHN, self).__init__(
83 | dist.Laplace, # prior
84 | dist.Laplace, # likelihood
85 | dist.Laplace, # posterior
86 | Enc(params.latent_dim),
87 | Dec(params.latent_dim),
88 | params
89 | )
90 | grad = {'requires_grad': params.learn_prior}
91 | self._pz_params = nn.ParameterList([
92 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu
93 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar
94 | ])
95 | self.modelName = 'svhn'
96 | self.dataSize = dataSize
97 | self.llik_scaling = 1.
98 |
99 | @property
100 | def pz_params(self):
101 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)
102 |
103 | @staticmethod
104 | def getDataLoaders(batch_size, shuffle=True, device='cuda'):
105 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}
106 | tx = transforms.ToTensor()
107 | train = DataLoader(datasets.SVHN('../data', split='train', download=True, transform=tx),
108 | batch_size=batch_size, shuffle=shuffle, **kwargs)
109 | test = DataLoader(datasets.SVHN('../data', split='test', download=True, transform=tx),
110 | batch_size=batch_size, shuffle=shuffle, **kwargs)
111 | return train, test
112 |
113 | def generate(self, runPath, epoch):
114 | N, K = 64, 9
115 | samples = super(SVHN, self).generate(N, K).cpu()
116 | # wrangle things so they come out tiled
117 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1)
118 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples]
119 | save_image(torch.stack(s),
120 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch),
121 | nrow=int(sqrt(N)))
122 |
123 | def reconstruct(self, data, runPath, epoch):
124 | recon = super(SVHN, self).reconstruct(data[:8])
125 | comp = torch.cat([data[:8], recon]).data.cpu()
126 | save_image(comp, '{}/recon_{:03d}.png'.format(runPath, epoch))
127 |
128 | def analyse(self, data, runPath, epoch):
129 | zemb, zsl, kls_df = super(SVHN, self).analyse(data, K=10)
130 | labels = ['Prior', self.modelName.lower()]
131 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
132 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
133 |
--------------------------------------------------------------------------------
/src/objectives.py:
--------------------------------------------------------------------------------
1 | # objectives of choice
2 | import torch
3 | from numpy import prod
4 |
5 | from utils import log_mean_exp, is_multidata, kl_divergence
6 |
7 |
8 | # helper to vectorise computation
9 | def compute_microbatch_split(x, K):
10 | """ Checks if batch needs to be broken down further to fit in memory. """
11 | B = x[0].size(0) if is_multidata(x) else x.size(0)
12 | S = sum([1.0 / (K * prod(_x.size()[1:])) for _x in x]) if is_multidata(x) \
13 | else 1.0 / (K * prod(x.size()[1:]))
14 | S = int(1e8 * S) # float heuristic for 12Gb cuda memory
15 | assert (S > 0), "Cannot fit individual data in memory, consider smaller K"
16 | return min(B, S)
17 |
18 |
19 | def elbo(model, x, K=1):
20 | """Computes E_{p(x)}[ELBO] """
21 | qz_x, px_z, _ = model(x)
22 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling
23 | kld = kl_divergence(qz_x, model.pz(*model.pz_params))
24 | return (lpx_z.sum(-1) - kld.sum(-1)).mean(0).sum()
25 |
26 |
27 | def _iwae(model, x, K):
28 | """IWAE estimate for log p_\theta(x) -- fully vectorised."""
29 | qz_x, px_z, zs = model(x, K)
30 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1)
31 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling
32 | lqz_x = qz_x.log_prob(zs).sum(-1)
33 | return lpz + lpx_z.sum(-1) - lqz_x
34 |
35 |
36 | def iwae(model, x, K):
37 | """Computes an importance-weighted ELBO estimate for log p_\theta(x)
38 | Iterates over the batch as necessary.
39 | """
40 | S = compute_microbatch_split(x, K)
41 | lw = torch.cat([_iwae(model, _x, K) for _x in x.split(S)], 1) # concat on batch
42 | return log_mean_exp(lw).sum()
43 |
44 |
45 | def _dreg(model, x, K):
46 | """DREG estimate for log p_\theta(x) -- fully vectorised."""
47 | _, px_z, zs = model(x, K)
48 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1)
49 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling
50 | qz_x = model.qz_x(*[p.detach() for p in model.qz_x_params]) # stop-grad for \phi
51 | lqz_x = qz_x.log_prob(zs).sum(-1)
52 | lw = lpz + lpx_z.sum(-1) - lqz_x
53 | return lw, zs
54 |
55 |
56 | def dreg(model, x, K, regs=None):
57 | """Computes a doubly-reparameterised importance-weighted ELBO estimate for log p_\theta(x)
58 | Iterates over the batch as necessary.
59 | """
60 | S = compute_microbatch_split(x, K)
61 | lw, zs = zip(*[_dreg(model, _x, K) for _x in x.split(S)])
62 | lw = torch.cat(lw, 1) # concat on batch
63 | zs = torch.cat(zs, 1) # concat on batch
64 | with torch.no_grad():
65 | grad_wt = (lw - torch.logsumexp(lw, 0, keepdim=True)).exp()
66 | if zs.requires_grad:
67 | zs.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
68 | return (grad_wt * lw).sum()
69 |
70 |
71 | # multi-modal variants
72 | def m_elbo_naive(model, x, K=1):
73 | """Computes E_{p(x)}[ELBO] for multi-modal vae --- NOT EXPOSED"""
74 | qz_xs, px_zs, zss = model(x)
75 | lpx_zs, klds = [], []
76 | for r, qz_x in enumerate(qz_xs):
77 | kld = kl_divergence(qz_x, model.pz(*model.pz_params))
78 | klds.append(kld.sum(-1))
79 | for d, px_z in enumerate(px_zs[r]):
80 | lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
81 | lpx_zs.append(lpx_z.view(*px_z.batch_shape[:2], -1).sum(-1))
82 | obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))
83 | return obj.mean(0).sum()
84 |
85 |
86 | def m_elbo(model, x, K=1):
87 | """Computes importance-sampled m_elbo (in notes3) for multi-modal vae """
88 | qz_xs, px_zs, zss = model(x)
89 | lpx_zs, klds = [], []
90 | for r, qz_x in enumerate(qz_xs):
91 | kld = kl_divergence(qz_x, model.pz(*model.pz_params))
92 | klds.append(kld.sum(-1))
93 | for d in range(len(px_zs)):
94 | lpx_z = px_zs[d][d].log_prob(x[d]).view(*px_zs[d][d].batch_shape[:2], -1)
95 | lpx_z = (lpx_z * model.vaes[d].llik_scaling).sum(-1)
96 | if d == r:
97 | lwt = torch.tensor(0.0)
98 | else:
99 | zs = zss[d].detach()
100 | lwt = (qz_x.log_prob(zs) - qz_xs[d].log_prob(zs).detach()).sum(-1)
101 | lpx_zs.append(lwt.exp() * lpx_z)
102 | obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))
103 | return obj.mean(0).sum()
104 |
105 |
106 | def _m_iwae(model, x, K=1):
107 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised"""
108 | qz_xs, px_zs, zss = model(x, K)
109 | lws = []
110 | for r, qz_x in enumerate(qz_xs):
111 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
112 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs]))
113 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1)
114 | .mul(model.vaes[d].llik_scaling).sum(-1)
115 | for d, px_z in enumerate(px_zs[r])]
116 | lpx_z = torch.stack(lpx_z).sum(0)
117 | lw = lpz + lpx_z - lqz_x
118 | lws.append(lw)
119 | return torch.cat(lws) # (n_modality * n_samples) x batch_size, batch_size
120 |
121 |
122 | def m_iwae(model, x, K=1):
123 | """Computes iwae estimate for log p_\theta(x) for multi-modal vae """
124 | S = compute_microbatch_split(x, K)
125 | x_split = zip(*[_x.split(S) for _x in x])
126 | lw = [_m_iwae(model, _x, K) for _x in x_split]
127 | lw = torch.cat(lw, 1) # concat on batch
128 | return log_mean_exp(lw).sum()
129 |
130 |
131 | def _m_iwae_looser(model, x, K=1):
132 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised
133 | This version is the looser bound---with the average over modalities outside the log
134 | """
135 | qz_xs, px_zs, zss = model(x, K)
136 | lws = []
137 | for r, qz_x in enumerate(qz_xs):
138 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
139 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs]))
140 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1)
141 | .mul(model.vaes[d].llik_scaling).sum(-1)
142 | for d, px_z in enumerate(px_zs[r])]
143 | lpx_z = torch.stack(lpx_z).sum(0)
144 | lw = lpz + lpx_z - lqz_x
145 | lws.append(lw)
146 | return torch.stack(lws) # (n_modality * n_samples) x batch_size, batch_size
147 |
148 |
149 | def m_iwae_looser(model, x, K=1):
150 | """Computes iwae estimate for log p_\theta(x) for multi-modal vae
151 | This version is the looser bound---with the average over modalities outside the log
152 | """
153 | S = compute_microbatch_split(x, K)
154 | x_split = zip(*[_x.split(S) for _x in x])
155 | lw = [_m_iwae_looser(model, _x, K) for _x in x_split]
156 | lw = torch.cat(lw, 2) # concat on batch
157 | return log_mean_exp(lw, dim=1).mean(0).sum()
158 |
159 |
160 | def _m_dreg(model, x, K=1):
161 | """DERG estimate for log p_\theta(x) for multi-modal vae -- fully vectorised"""
162 | qz_xs, px_zs, zss = model(x, K)
163 | qz_xs_ = [vae.qz_x(*[p.detach() for p in vae.qz_x_params]) for vae in model.vaes]
164 | lws = []
165 | for r, vae in enumerate(model.vaes):
166 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
167 | lqz_x = log_mean_exp(torch.stack([qz_x_.log_prob(zss[r]).sum(-1) for qz_x_ in qz_xs_]))
168 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1)
169 | .mul(model.vaes[d].llik_scaling).sum(-1)
170 | for d, px_z in enumerate(px_zs[r])]
171 | lpx_z = torch.stack(lpx_z).sum(0)
172 | lw = lpz + lpx_z - lqz_x
173 | lws.append(lw)
174 | return torch.cat(lws), torch.cat(zss)
175 |
176 |
177 | def m_dreg(model, x, K=1):
178 | """Computes dreg estimate for log p_\theta(x) for multi-modal vae """
179 | S = compute_microbatch_split(x, K)
180 | x_split = zip(*[_x.split(S) for _x in x])
181 | lw, zss = zip(*[_m_dreg(model, _x, K) for _x in x_split])
182 | lw = torch.cat(lw, 1) # concat on batch
183 | zss = torch.cat(zss, 1) # concat on batch
184 | with torch.no_grad():
185 | grad_wt = (lw - torch.logsumexp(lw, 0, keepdim=True)).exp()
186 | if zss.requires_grad:
187 | zss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
188 | return (grad_wt * lw).sum()
189 |
190 |
191 | def _m_dreg_looser(model, x, K=1):
192 | """DERG estimate for log p_\theta(x) for multi-modal vae -- fully vectorised
193 | This version is the looser bound---with the average over modalities outside the log
194 | """
195 | qz_xs, px_zs, zss = model(x, K)
196 | qz_xs_ = [vae.qz_x(*[p.detach() for p in vae.qz_x_params]) for vae in model.vaes]
197 | lws = []
198 | for r, vae in enumerate(model.vaes):
199 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
200 | lqz_x = log_mean_exp(torch.stack([qz_x_.log_prob(zss[r]).sum(-1) for qz_x_ in qz_xs_]))
201 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1)
202 | .mul(model.vaes[d].llik_scaling).sum(-1)
203 | for d, px_z in enumerate(px_zs[r])]
204 | lpx_z = torch.stack(lpx_z).sum(0)
205 | lw = lpz + lpx_z - lqz_x
206 | lws.append(lw)
207 | return torch.stack(lws), torch.stack(zss)
208 |
209 |
210 | def m_dreg_looser(model, x, K=1):
211 | """Computes dreg estimate for log p_\theta(x) for multi-modal vae
212 | This version is the looser bound---with the average over modalities outside the log
213 | """
214 | S = compute_microbatch_split(x, K)
215 | x_split = zip(*[_x.split(S) for _x in x])
216 | lw, zss = zip(*[_m_dreg_looser(model, _x, K) for _x in x_split])
217 | lw = torch.cat(lw, 2) # concat on batch
218 | zss = torch.cat(zss, 2) # concat on batch
219 | with torch.no_grad():
220 | grad_wt = (lw - torch.logsumexp(lw, 1, keepdim=True)).exp()
221 | if zss.requires_grad:
222 | zss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
223 | return (grad_wt * lw).mean(0).sum()
224 |
--------------------------------------------------------------------------------
/src/report/analyse_cub.py:
--------------------------------------------------------------------------------
1 | """Calculate cross and joint coherence of language and image generation on CUB dataset using CCA."""
2 | import argparse
3 | import os
4 | import sys
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | # relative import hack (sorry)
10 | import inspect
11 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
12 | parentdir = os.path.dirname(currentdir)
13 | sys.path.insert(0, parentdir) # for system user
14 | os.chdir(parentdir) # for pycharm user
15 |
16 | import models
17 | from utils import Logger, Timer, unpack_data
18 | from helper import cca, fetch_emb, fetch_weights, fetch_pc, apply_weights, apply_pc
19 |
20 | # variables
21 | RESET = True
22 | USE_PCA = True
23 | maxSentLen = 32
24 | minOccur = 3
25 | lenEmbedding = 300
26 | lenWindow = 3
27 | fBase = 96
28 | vocab_dir = '../data/cub/oc:{}_sl:{}_s:{}_w:{}'.format(minOccur, maxSentLen, lenEmbedding, lenWindow)
29 | batch_size = 256
30 |
31 | # args
32 | torch.backends.cudnn.benchmark = True
33 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results')
34 | parser.add_argument('--save-dir', type=str, default=".",
35 | metavar='N', help='save directory of results')
36 | parser.add_argument('--no-cuda', action='store_true', default=True,
37 | help='disables CUDA use')
38 | cmds = parser.parse_args()
39 | runPath = cmds.save_dir
40 | sys.stdout = Logger('{}/analyse.log'.format(runPath))
41 | args = torch.load(runPath + '/args.rar')
42 |
43 | # cuda stuff
44 | needs_conversion = cmds.no_cuda and args.cuda
45 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {}
46 | args.cuda = not cmds.no_cuda and torch.cuda.is_available()
47 | device = torch.device("cuda" if args.cuda else "cpu")
48 | torch.manual_seed(args.seed)
49 |
50 | forward_args = {'drop_modality': True} if args.model == 'mcubISft' else {}
51 |
52 | # load trained model
53 | modelC = getattr(models, 'VAE_{}'.format(args.model))
54 | model = modelC(args)
55 | if args.cuda:
56 | model.cuda()
57 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False)
58 | train_loader, test_loader = model.getDataLoaders(batch_size, device=device)
59 | N = len(test_loader.dataset)
60 |
61 | # generate word embeddings and sentence weighting
62 | emb_path = os.path.join(vocab_dir, 'cub.emb')
63 | weights_path = os.path.join(vocab_dir, 'cub.weights')
64 | vocab_path = os.path.join(vocab_dir, 'cub.vocab')
65 | pc_path = os.path.join(vocab_dir, 'cub.pc')
66 |
67 | emb = fetch_emb(lenWindow, minOccur, emb_path, vocab_path, RESET)
68 | weights = fetch_weights(weights_path, vocab_path, RESET, a=1e-3)
69 | emb = torch.from_numpy(emb).to(device)
70 | weights = torch.from_numpy(weights).to(device).type(emb.dtype)
71 | u = fetch_pc(emb, weights, train_loader, pc_path, RESET)
72 |
73 | # set up word to sentence functions
74 | fn_to_emb = lambda data, emb=emb, weights=weights, u=u: \
75 | apply_pc(apply_weights(emb, weights, data), u)
76 |
77 |
78 | def calculate_corr(images, embeddings):
79 | global RESET
80 | if not os.path.exists(runPath + '/images_mean.pt') or RESET:
81 | generate_cca_projection()
82 | RESET = False
83 | im_mean = torch.load(runPath + '/images_mean.pt')
84 | emb_mean = torch.load(runPath + '/emb_mean.pt')
85 | im_proj = torch.load(runPath + '/im_proj.pt')
86 | emb_proj = torch.load(runPath + '/emb_proj.pt')
87 | with torch.no_grad():
88 | corr = F.cosine_similarity((images - im_mean) @ im_proj,
89 | (embeddings - emb_mean) @ emb_proj).mean()
90 | return corr
91 |
92 |
93 | def generate_cca_projection():
94 | images, sentences = [torch.cat(l) for l in zip(*[(d[0], d[1][0]) for d in train_loader])]
95 | emb = fn_to_emb(sentences.int())
96 | corr, (im_proj, emb_proj) = cca([images, emb], k=40)
97 | print("Largest eigen value from CCA: {:.3f}".format(corr[0]))
98 | torch.save(images.mean(dim=0), runPath + '/images_mean.pt')
99 | torch.save(emb.mean(dim=0), runPath + '/emb_mean.pt')
100 | torch.save(im_proj, runPath + '/im_proj.pt')
101 | torch.save(emb_proj, runPath + '/emb_proj.pt')
102 |
103 |
104 | def cross_coherence():
105 | model.eval()
106 | with torch.no_grad():
107 | i2t = []
108 | s2i = []
109 | gt = []
110 | for i, dataT in enumerate(test_loader):
111 | # get the inputs
112 | images, sentences = unpack_data(dataT, device=device)
113 | if images.shape[0] != batch_size:
114 | break
115 | _, px_zs, _ = model([images, sentences], K=1, **forward_args)
116 | cross_sentences = px_zs[0][1].mean.argmax(dim=-1).squeeze(0)
117 | cross_images = px_zs[1][0].mean.squeeze(0)
118 | # calculate correlation with CCA:
119 | i2t.append(calculate_corr(images, fn_to_emb(cross_sentences)))
120 | s2i.append(calculate_corr(cross_images, fn_to_emb(sentences.int())))
121 | gt.append(calculate_corr(images, fn_to_emb(sentences.int())))
122 | print("Coherence score: \nground truth {:10.9f}, \nimage to sentence {:10.9f}, "
123 | "\nsentence to image {:10.9f}".format(sum(gt) / len(gt),
124 | sum(i2t) / len(gt),
125 | sum(s2i) / len(gt)))
126 |
127 |
128 | def joint_coherence():
129 | model.eval()
130 | with torch.no_grad():
131 | pzs = model.pz(*model.pz_params).sample([1000])
132 | gen_images = model.vaes[0].dec(pzs)[0].squeeze(1)
133 | gen_sentences = model.vaes[1].dec(pzs)[0].argmax(dim=-1).squeeze(1)
134 | score = calculate_corr(gen_images, fn_to_emb(gen_sentences))
135 | print("joint generation {:10.9f}".format(score))
136 |
137 |
138 | if __name__ == '__main__':
139 | with Timer('MM-VAE analysis') as t:
140 | print('-' * 89)
141 | cross_coherence()
142 | print('-' * 89)
143 | joint_coherence()
144 |
--------------------------------------------------------------------------------
/src/report/analyse_ms.py:
--------------------------------------------------------------------------------
1 | """Calculate cross and joint coherence of trained model on MNIST-SVHN dataset.
2 | Train and evaluate a linear model for latent space digit classification."""
3 |
4 | import argparse
5 | import os
6 | import sys
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 |
12 | # relative import hacks (sorry)
13 | import inspect
14 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
15 | parentdir = os.path.dirname(currentdir)
16 | sys.path.insert(0, parentdir) # for bash user
17 | os.chdir(parentdir) # for pycharm user
18 |
19 | import models
20 | from helper import Latent_Classifier, SVHN_Classifier, MNIST_Classifier
21 | from utils import Logger, Timer
22 |
23 |
24 | torch.backends.cudnn.benchmark = True
25 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results')
26 | parser.add_argument('--save-dir', type=str, default="",
27 | metavar='N', help='save directory of results')
28 | parser.add_argument('--no-cuda', action='store_true', default=False,
29 | help='disables CUDA use')
30 | cmds = parser.parse_args()
31 | runPath = cmds.save_dir
32 |
33 | sys.stdout = Logger('{}/ms_acc.log'.format(runPath))
34 | args = torch.load(runPath + '/args.rar')
35 |
36 | # cuda stuff
37 | needs_conversion = cmds.no_cuda and args.cuda
38 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {}
39 | args.cuda = not cmds.no_cuda and torch.cuda.is_available()
40 | device = torch.device("cuda" if args.cuda else "cpu")
41 | torch.manual_seed(args.seed)
42 |
43 | modelC = getattr(models, 'VAE_{}'.format(args.model))
44 | model = modelC(args)
45 | if args.cuda:
46 | model.cuda()
47 |
48 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False)
49 | B = 256 # rough batch size heuristic
50 | train_loader, test_loader = model.getDataLoaders(B, device=device)
51 | N = len(test_loader.dataset)
52 |
53 |
54 | def classify_latents(epochs, option):
55 | model.eval()
56 | vae = unpack_model(option)
57 | if '_' not in args.model:
58 | epochs *= 10 # account for the fact the mnist-svhn has more examples (roughly x10)
59 | classifier = Latent_Classifier(args.latent_dim, 10).to(device)
60 | criterion = nn.CrossEntropyLoss()
61 | optimizer = optim.Adam(classifier.parameters(), lr=0.001)
62 |
63 | for epoch in range(epochs): # loop over the dataset multiple times
64 | running_loss = 0.0
65 | total_iters = len(train_loader)
66 | print('\n====> Epoch: {:03d} '.format(epoch))
67 | for i, data in enumerate(train_loader):
68 | # get the inputs
69 | x, targets = unpack_data_mlp(data, option)
70 | x, targets = x.to(device), targets.to(device)
71 | with torch.no_grad():
72 | qz_x_params = vae.enc(x)
73 | zs = vae.qz_x(*qz_x_params).rsample()
74 | optimizer.zero_grad()
75 | outputs = classifier(zs)
76 | loss = criterion(outputs, targets)
77 | loss.backward()
78 | optimizer.step()
79 | # print statistics
80 | running_loss += loss.item()
81 | if (i + 1) % 1000 == 0:
82 | print('iteration {:04d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / 1000))
83 | running_loss = 0.0
84 | print('Finished Training, calculating test loss...')
85 |
86 | classifier.eval()
87 | total = 0
88 | correct = 0
89 | with torch.no_grad():
90 | for i, data in enumerate(test_loader):
91 | x, targets = unpack_data_mlp(data, option)
92 | x, targets = x.to(device), targets.to(device)
93 | qz_x_params = vae.enc(x)
94 | zs = vae.qz_x(*qz_x_params).rsample()
95 | outputs = classifier(zs)
96 | _, predicted = torch.max(outputs.data, 1)
97 | total += targets.size(0)
98 | correct += (predicted == targets).sum().item()
99 | print('The classifier correctly classified {} out of {} examples. Accuracy: '
100 | '{:.2f}%'.format(correct, total, correct / total * 100))
101 |
102 |
103 | def _maybe_train_or_load_digit_classifier_img(path, epochs):
104 |
105 | options = [o for o in ['mnist', 'svhn'] if not os.path.exists(path.format(o))]
106 |
107 | for option in options:
108 | print("Cannot find trained {} digit classifier in {}, training...".
109 | format(option, path.format(option)))
110 | classifier = globals()['{}_Classifier'.format(option.upper())]().to(device)
111 | criterion = nn.CrossEntropyLoss()
112 | optimizer = optim.Adam(classifier.parameters(), lr=0.001)
113 | for epoch in range(epochs): # loop over the dataset multiple times
114 | running_loss = 0.0
115 | total_iters = len(train_loader)
116 | print('\n====> Epoch: {:03d} '.format(epoch))
117 | for i, data in enumerate(train_loader):
118 | # get the inputs
119 | x, targets = unpack_data_mlp(data, option)
120 | x, targets = x.to(device), targets.to(device)
121 |
122 | optimizer.zero_grad()
123 | outputs = classifier(x)
124 | loss = criterion(outputs, targets)
125 | loss.backward()
126 | optimizer.step()
127 | # print statistics
128 | running_loss += loss.item()
129 | if (i + 1) % 1000 == 0:
130 | print('iteration {:04d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / 1000))
131 | running_loss = 0.0
132 | print('Finished Training, calculating test loss...')
133 |
134 | classifier.eval()
135 | total = 0
136 | correct = 0
137 | with torch.no_grad():
138 | for i, data in enumerate(test_loader):
139 | x, targets = unpack_data_mlp(data, option)
140 | x, targets = x.to(device), targets.to(device)
141 | outputs = classifier(x)
142 | _, predicted = torch.max(outputs.data, 1)
143 | total += targets.size(0)
144 | correct += (predicted == targets).sum().item()
145 | print('The classifier correctly classified {} out of {} examples. Accuracy: '
146 | '{:.2f}%'.format(correct, total, correct / total * 100))
147 |
148 | torch.save(classifier.state_dict(), path.format(option))
149 |
150 | mnist_net, svhn_net = MNIST_Classifier().to(device), SVHN_Classifier().to(device)
151 | mnist_net.load_state_dict(torch.load(path.format('mnist')))
152 | svhn_net.load_state_dict(torch.load(path.format('svhn')))
153 | return mnist_net, svhn_net
154 |
155 | def cross_coherence(epochs):
156 | model.eval()
157 |
158 | mnist_net, svhn_net = _maybe_train_or_load_digit_classifier_img("../data/{}_model.pt", epochs=epochs)
159 | mnist_net.eval()
160 | svhn_net.eval()
161 |
162 | total = 0
163 | corr_m = 0
164 | corr_s = 0
165 | with torch.no_grad():
166 | for i, data in enumerate(test_loader):
167 | mnist, svhn, targets = unpack_data_mlp(data, option='both')
168 | mnist, svhn, targets = mnist.to(device), svhn.to(device), targets.to(device)
169 | _, px_zs, _ = model([mnist, svhn], 1)
170 | mnist_mnist = mnist_net(px_zs[1][0].mean.squeeze(0))
171 | svhn_svhn = svhn_net(px_zs[0][1].mean.squeeze(0))
172 |
173 | _, pred_m = torch.max(mnist_mnist.data, 1)
174 | _, pred_s = torch.max(svhn_svhn.data, 1)
175 | total += targets.size(0)
176 | corr_m += (pred_m == targets).sum().item()
177 | corr_s += (pred_s == targets).sum().item()
178 |
179 | print('Cross coherence: \n SVHN -> MNIST {:.2f}% \n MNIST -> SVHN {:.2f}%'.format(
180 | corr_m / total * 100, corr_s / total * 100))
181 |
182 |
183 | def joint_coherence():
184 | model.eval()
185 | mnist_net, svhn_net = MNIST_Classifier().to(device), SVHN_Classifier().to(device)
186 | mnist_net.load_state_dict(torch.load('../data/mnist_model.pt'))
187 | svhn_net.load_state_dict(torch.load('../data/svhn_model.pt'))
188 |
189 | mnist_net.eval()
190 | svhn_net.eval()
191 |
192 | total = 0
193 | corr = 0
194 | with torch.no_grad():
195 | pzs = model.pz(*model.pz_params).sample([10000])
196 | mnist = model.vaes[0].dec(pzs)
197 | svhn = model.vaes[1].dec(pzs)
198 |
199 | mnist_mnist = mnist_net(mnist[0].squeeze(1))
200 | svhn_svhn = svhn_net(svhn[0].squeeze(1))
201 |
202 | _, pred_m = torch.max(mnist_mnist.data, 1)
203 | _, pred_s = torch.max(svhn_svhn.data, 1)
204 | total += pred_m.size(0)
205 | corr += (pred_m == pred_s).sum().item()
206 |
207 | print('Joint coherence: {:.2f}%'.format(corr / total * 100))
208 |
209 |
210 | def unpack_data_mlp(dataB, option='both'):
211 | if len(dataB[0]) == 2:
212 | if option == 'both':
213 | return dataB[0][0], dataB[1][0], dataB[1][1]
214 | elif option == 'svhn':
215 | return dataB[1][0], dataB[1][1]
216 | elif option == 'mnist':
217 | return dataB[0][0], dataB[0][1]
218 | else:
219 | return dataB
220 |
221 |
222 | def unpack_model(option='svhn'):
223 | if 'mnist_svhn' in args.model:
224 | return model.vaes[1] if option == 'svhn' else model.vaes[0]
225 | else:
226 | return model
227 |
228 |
229 | if __name__ == '__main__':
230 | with Timer('MM-VAE analysis') as t:
231 | print('-' * 25 + 'latent classification accuracy' + '-' * 25)
232 | print("Calculating latent classification accuracy for single MNIST VAE...")
233 | classify_latents(epochs=30, option='mnist')
234 | # #
235 | print("\n Calculating latent classification accuracy for single SVHN VAE...")
236 | classify_latents(epochs=30, option='svhn')
237 | #
238 | print('\n' + '-' * 45 + 'cross coherence' + '-' * 45)
239 | cross_coherence(epochs=30)
240 | #
241 | print('\n' + '-' * 45 + 'joint coherence' + '-' * 45)
242 | joint_coherence()
243 |
--------------------------------------------------------------------------------
/src/report/calculate_likelihoods.py:
--------------------------------------------------------------------------------
1 | """Calculate data marginal likelihood p(x) evaluated on the trained generative model."""
2 | import os
3 | import sys
4 | import argparse
5 |
6 | import numpy as np
7 | import torch
8 | from torchvision.utils import save_image
9 |
10 | # relative import hacks (sorry)
11 | import inspect
12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
13 | parentdir = os.path.dirname(currentdir)
14 | sys.path.insert(0, parentdir) # for bash user
15 | os.chdir(parentdir) # for pycharm user
16 |
17 | import models
18 | from utils import Logger, Timer, unpack_data, log_mean_exp
19 |
20 | torch.backends.cudnn.benchmark = True
21 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results')
22 | parser.add_argument('--save-dir', type=str, default="",
23 | metavar='N', help='save directory of results')
24 | parser.add_argument('--iwae-samples', type=int, default=1000, metavar='I',
25 | help='number of samples to estimate marginal log likelihood (default: 1000)')
26 | parser.add_argument('--no-cuda', action='store_true', default=False,
27 | help='disables CUDA use')
28 | cmds = parser.parse_args()
29 | runPath = cmds.save_dir
30 |
31 | sys.stdout = Logger('{}/llik.log'.format(runPath))
32 | args = torch.load(runPath + '/args.rar')
33 |
34 | # cuda stuff
35 | needs_conversion = cmds.no_cuda and args.cuda
36 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {}
37 | args.cuda = not cmds.no_cuda and torch.cuda.is_available()
38 | device = torch.device("cuda" if args.cuda else "cpu")
39 | torch.manual_seed(args.seed)
40 |
41 | modelC = getattr(models, 'VAE_{}'.format(args.model))
42 | model = modelC(args)
43 | if args.cuda:
44 | model.cuda()
45 |
46 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False)
47 | B = 12000 // cmds.iwae_samples # rough batch size heuristic
48 | train_loader, test_loader = model.getDataLoaders(B, device=device)
49 | N = len(test_loader.dataset)
50 |
51 |
52 | def m_iwae(qz_xs, px_zs, zss, x):
53 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised"""
54 | lws = []
55 | for r, qz_x in enumerate(qz_xs):
56 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
57 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs]))
58 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1)
59 | .mul(model.vaes[d].llik_scaling).sum(-1)
60 | for d, px_z in enumerate(px_zs[r])]
61 | lpx_z = torch.stack(lpx_z).sum(0)
62 | lw = lpz + lpx_z - lqz_x
63 | lws.append(lw)
64 | return log_mean_exp(torch.cat(lws)).sum()
65 |
66 |
67 | def iwae(qz_x, px_z, zs, x):
68 | """IWAE estimate for log p_\theta(x) -- fully vectorised."""
69 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1)
70 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling
71 | lqz_x = qz_x.log_prob(zs).sum(-1)
72 | return log_mean_exp(lpz + lpx_z.sum(-1) - lqz_x).sum()
73 |
74 |
75 | @torch.no_grad()
76 | def joint_elbo(K):
77 | model.eval()
78 | llik = 0
79 | obj = locals()[('m_' if hasattr(model, 'vaes') else '') + 'iwae']()
80 | for dataT in test_loader:
81 | data = unpack_data(dataT, device=device)
82 | llik += obj(model, data, K).item()
83 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}'
84 | .format(model.modelName, K, llik / N))
85 |
86 |
87 | def cross_iwaes(qz_xs, px_zs, zss, x):
88 | lws = []
89 | for e, _px_zs in enumerate(px_zs): # rows are encoders
90 | lpz = model.pz(*model.pz_params).log_prob(zss[e]).sum(-1)
91 | lqz_x = qz_xs[e].log_prob(zss[e]).sum(-1)
92 | _lpx_zs = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1)
93 | for d, px_z in enumerate(_px_zs)]
94 | lws.append([log_mean_exp(_lpx_z + lpz - lqz_x).sum() for _lpx_z in _lpx_zs])
95 | return lws
96 |
97 |
98 | def individual_iwaes(qz_xs, px_zs, zss, x):
99 | lws = []
100 | for d, _px_zs in enumerate(np.array(px_zs).T): # rows are decoders now
101 | lw = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1)
102 | + model.pz(*model.pz_params).log_prob(zss[e]).sum(-1)
103 | - log_mean_exp(torch.stack([qz_x.log_prob(zss[e]).sum(-1) for qz_x in qz_xs]))
104 | for e, px_z in enumerate(_px_zs)]
105 | lw = torch.cat(lw)
106 | lws.append(log_mean_exp(lw).sum())
107 | return lws
108 |
109 |
110 | @torch.no_grad()
111 | def m_llik_eval(K):
112 | model.eval()
113 | llik_joint = 0
114 | llik_synergy = np.array([0 for _ in model.vaes])
115 | lliks_cross = np.array([[0 for _ in model.vaes] for _ in model.vaes])
116 | for dataT in test_loader:
117 | data = unpack_data(dataT, device=device)
118 | qz_xs, px_zs, zss = model(data, K)
119 | objs = individual_iwaes(qz_xs, px_zs, zss, data)
120 | objs_cross = cross_iwaes(qz_xs, px_zs, zss, data)
121 | llik_joint += m_iwae(qz_xs, px_zs, zss, data)
122 | llik_synergy = llik_synergy + np.array(objs)
123 | lliks_cross = lliks_cross + np.array(objs_cross)
124 |
125 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}'
126 | .format(model.modelName, K, llik_joint / N))
127 | print('-' * 89)
128 |
129 | for i, llik in enumerate(llik_synergy):
130 | print('Marginal Log Likelihood of {} from {} (IWAE, K = {}): {:.4f}'
131 | .format(model.vaes[i].modelName, model.modelName, K, (llik / N).item()))
132 | print('-' * 89)
133 |
134 | for e, _lliks_cross in enumerate(lliks_cross):
135 | for d, llik_cross in enumerate(_lliks_cross):
136 | print('Marginal Log Likelihood of {} from {} (IWAE, K = {}): {:.4f}'
137 | .format(model.vaes[d].modelName, model.vaes[e].modelName, K, (llik_cross / N).item()))
138 | print('-' * 89)
139 |
140 |
141 | @torch.no_grad()
142 | def llik_eval(K):
143 | model.eval()
144 | llik_joint = 0
145 | for dataT in test_loader:
146 | data = unpack_data(dataT, device=device)
147 | qz_xs, px_zs, zss = model(data, K)
148 | llik_joint += iwae(qz_xs, px_zs, zss, data)
149 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}'
150 | .format(model.modelName, K, llik_joint / N))
151 |
152 |
153 | @torch.no_grad()
154 | def generate_sparse(D, steps, J):
155 | """generate `steps` perturbations for all `D` latent dimensions on `J` datapoints. """
156 | model.eval()
157 | for i, dataT in enumerate(test_loader):
158 | data = unpack_data(dataT, require_length=(args.projection == 'Sft'), device=device)
159 | qz_xs, _, zss = model(data, args.K)
160 | for i, (qz_x, zs) in enumerate(zip(qz_xs, zss)):
161 | embs = []
162 | # for delta in torch.linspace(0.01, 0.99, steps=steps):
163 | for delta in torch.linspace(-5, 5, steps=steps):
164 | for d in range(D):
165 | mod_emb = qz_x.mean + torch.zeros_like(qz_x.mean)
166 | mod_emb[:, d] += model.vaes[i].pz(*model.vaes[i].pz_params).stddev[:, d] * delta
167 | embs.append(mod_emb)
168 | embs = torch.stack(embs).transpose(0, 1).contiguous()
169 | for r in range(2):
170 | samples = model.vaes[r].px_z(*model.vaes[r].dec(embs.view(-1, D)[:((J) * steps * D)])).mean
171 | save_image(samples.cpu(), os.path.join(runPath, 'latent-traversals-{}x{}.png'.format(i, r)), nrow=D)
172 | break
173 |
174 |
175 | if __name__ == '__main__':
176 | with Timer('MM-VAE analysis') as t:
177 | # likelihood evaluation
178 | print('-' * 89)
179 | eval = locals()[('m_' if hasattr(model, 'vaes') else '') + 'llik_eval']
180 | eval(cmds.iwae_samples)
181 | print('-' * 89)
182 |
--------------------------------------------------------------------------------
/src/report/helper.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | from collections import Counter, OrderedDict
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from gensim.models import FastText
11 | from nltk.tokenize import sent_tokenize, word_tokenize
12 | from scipy.linalg import eig
13 | from skimage.filters import threshold_yen as threshold
14 |
15 |
16 | class OrderedCounter(Counter, OrderedDict):
17 | """Counter that remembers the order elements are first encountered."""
18 |
19 | def __repr__(self):
20 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
21 |
22 | def __reduce__(self):
23 | return self.__class__, (OrderedDict(self),)
24 |
25 |
26 | def cca(views, k=None, eps=1e-12):
27 | """Compute (multi-view) CCA
28 |
29 | Args:
30 | views (list): list of views where each view `v_i` is of size `N x o_i`
31 | k (int): joint projection dimension | if None, find using Otsu
32 | eps (float): regulariser [default: 1e-12]
33 |
34 | Returns:
35 | correlations: correlations along each of the k dimensions
36 | projections: projection matrices for each view
37 | """
38 | V = len(views) # number of views
39 | N = views[0].size(0) # number of observations (same across views)
40 | os = [v.size(1) for v in views]
41 | kmax = np.min(os)
42 | ocum = np.cumsum([0] + os)
43 | os_sum = sum(os)
44 | A, B = np.zeros([os_sum, os_sum]), np.zeros([os_sum, os_sum])
45 |
46 | for i in range(V):
47 | v_i = views[i]
48 | v_i_bar = v_i - v_i.mean(0).expand_as(v_i) # centered, N x o_i
49 | C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_i_bar)
50 | # A[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij
51 | B[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij
52 | for j in range(i + 1, V):
53 | v_j = views[j] # N x o_j
54 | v_j_bar = v_j - v_j.mean(0).expand_as(v_j) # centered
55 | C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_j_bar)
56 | A[ocum[i]:ocum[i + 1], ocum[j]:ocum[j + 1]] = C_ij
57 | A[ocum[j]:ocum[j + 1], ocum[i]:ocum[i + 1]] = C_ij.t()
58 |
59 | A[np.diag_indices_from(A)] += eps
60 | B[np.diag_indices_from(B)] += eps
61 |
62 | eigenvalues, eigenvectors = eig(A, B)
63 | # TODO: sanity check to see that all eigenvalues are e+0i
64 | idx = eigenvalues.argsort()[::-1] # sort descending
65 | eigenvalues = eigenvalues[idx] # arrange in descending order
66 |
67 | if k is None:
68 | t = threshold(eigenvalues.real[:kmax])
69 | k = np.abs(np.asarray(eigenvalues.real[0::10]) - t).argmin() * 10 # closest k % 10 == 0 idx
70 | print('k unspecified, (auto-)choosing:', k)
71 |
72 | eigenvalues = eigenvalues[idx[:k]]
73 | eigenvectors = eigenvectors[:, idx[:k]]
74 |
75 | correlations = torch.from_numpy(eigenvalues.real).type_as(views[0])
76 | proj_matrices = torch.split(torch.from_numpy(eigenvectors.real).type_as(views[0]), os)
77 |
78 | return correlations, proj_matrices
79 |
80 |
81 | def fetch_emb(lenWindow, minOccur, emb_path, vocab_path, RESET):
82 | if not os.path.exists(emb_path) or RESET:
83 | with open('../data/cub/text_trainvalclasses.txt', 'r') as file:
84 | text = file.read()
85 | sentences = sent_tokenize(text)
86 |
87 | texts = []
88 | for i, line in enumerate(sentences):
89 | words = word_tokenize(line)
90 | texts.append(words)
91 |
92 | model = FastText(size=300, window=lenWindow, min_count=minOccur)
93 | model.build_vocab(sentences=texts)
94 | model.train(sentences=texts, total_examples=len(texts), epochs=10)
95 |
96 | with open(vocab_path, 'rb') as file:
97 | vocab = json.load(file)
98 |
99 | i2w = vocab['i2w']
100 | base = np.ones((300,), dtype=np.float32)
101 | emb = [base * (i - 1) for i in range(3)]
102 | for word in list(i2w.values())[3:]:
103 | emb.append(model[word])
104 |
105 | emb = np.array(emb)
106 | with open(emb_path, 'wb') as file:
107 | pickle.dump(emb, file)
108 |
109 | else:
110 | with open(emb_path, 'rb') as file:
111 | emb = pickle.load(file)
112 |
113 | return emb
114 |
115 |
116 | def fetch_weights(weights_path, vocab_path, RESET, a=1e-3):
117 | if not os.path.exists(weights_path) or RESET:
118 | with open('../data/cub/text_trainvalclasses.txt', 'r') as file:
119 | text = file.read()
120 | sentences = sent_tokenize(text)
121 | occ_register = OrderedCounter()
122 |
123 | for i, line in enumerate(sentences):
124 | words = word_tokenize(line)
125 | occ_register.update(words)
126 |
127 | with open(vocab_path, 'r') as file:
128 | vocab = json.load(file)
129 | w2i = vocab['w2i']
130 | weights = np.zeros(len(w2i))
131 | total_occ = sum(list(occ_register.values()))
132 | exc_occ = 0
133 | for w, occ in occ_register.items():
134 | if w in w2i.keys():
135 | weights[w2i[w]] = a / (a + occ / total_occ)
136 | else:
137 | exc_occ += occ
138 | weights[0] = a / (a + exc_occ / total_occ)
139 |
140 | with open(weights_path, 'wb') as file:
141 | pickle.dump(weights, file)
142 | else:
143 | with open(weights_path, 'rb') as file:
144 | weights = pickle.load(file)
145 |
146 | return weights
147 |
148 |
149 | def fetch_pc(emb, weights, train_loader, pc_path, RESET):
150 | sentences = torch.cat([d[1][0] for d in train_loader]).int()
151 | emb_dataset = apply_weights(emb, weights, sentences)
152 |
153 | if not os.path.exists(pc_path) or RESET:
154 | _, _, V = torch.svd(emb_dataset - emb_dataset.mean(dim=0), some=True)
155 | v = V[:, 0].unsqueeze(-1)
156 | u = v.mm(v.t())
157 | with open(pc_path, 'wb') as file:
158 | pickle.dump(u, file)
159 | else:
160 | with open(pc_path, 'rb') as file:
161 | u = pickle.load(file)
162 | return u
163 |
164 |
165 | def apply_weights(emb, weights, data):
166 | fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s
167 | batch_emb = []
168 | for sent_i in data:
169 | emb_stacked = torch.stack([emb[idx] for idx in fn_trun(sent_i)])
170 | weights_stacked = torch.stack([weights[idx] for idx in fn_trun(sent_i)])
171 | batch_emb.append(torch.sum(emb_stacked * weights_stacked.unsqueeze(-1), dim=0) / emb_stacked.shape[0])
172 |
173 | return torch.stack(batch_emb, dim=0)
174 |
175 |
176 | def apply_pc(weighted_emb, u):
177 | return torch.cat([e - torch.matmul(u, e.unsqueeze(-1)).squeeze() for e in weighted_emb.split(2048, 0)])
178 |
179 |
180 | class Latent_Classifier(nn.Module):
181 | """ Generate latent parameters for SVHN image data. """
182 |
183 | def __init__(self, in_n, out_n):
184 | super(Latent_Classifier, self).__init__()
185 | self.mlp = nn.Linear(in_n, out_n)
186 |
187 | def forward(self, x):
188 | return self.mlp(x)
189 |
190 |
191 | class SVHN_Classifier(nn.Module):
192 | def __init__(self):
193 | super(SVHN_Classifier, self).__init__()
194 | self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
195 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
196 | self.conv2_drop = nn.Dropout2d()
197 | self.fc1 = nn.Linear(500, 50)
198 | self.fc2 = nn.Linear(50, 10)
199 |
200 | def forward(self, x):
201 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
202 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
203 | x = x.view(-1, 500)
204 | x = F.relu(self.fc1(x))
205 | x = F.dropout(x, training=self.training)
206 | x = self.fc2(x)
207 | return F.log_softmax(x, dim=-1)
208 |
209 |
210 | class MNIST_Classifier(nn.Module):
211 | def __init__(self):
212 | super(MNIST_Classifier, self).__init__()
213 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
214 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
215 | self.conv2_drop = nn.Dropout2d()
216 | self.fc1 = nn.Linear(320, 50)
217 | self.fc2 = nn.Linear(50, 10)
218 |
219 | def forward(self, x):
220 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
221 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
222 | x = x.view(-1, 320)
223 | x = F.relu(self.fc1(x))
224 | x = F.dropout(x, training=self.training)
225 | x = self.fc2(x)
226 | return F.log_softmax(x, dim=-1)
227 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import shutil
4 | import sys
5 | import time
6 |
7 | import torch
8 | import torch.distributions as dist
9 | import torch.nn.functional as F
10 |
11 | from datasets import CUBImageFt
12 |
13 |
14 | # Classes
15 | class Constants(object):
16 | eta = 1e-6
17 | log2 = math.log(2)
18 | log2pi = math.log(2 * math.pi)
19 | logceilc = 88 # largest cuda v s.t. exp(v) < inf
20 | logfloorc = -104 # smallest cuda v s.t. exp(v) > 0
21 |
22 |
23 | # https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting
24 | class Logger(object):
25 | def __init__(self, filename, mode="a"):
26 | self.terminal = sys.stdout
27 | self.log = open(filename, mode)
28 |
29 | def write(self, message):
30 | self.terminal.write(message)
31 | self.log.write(message)
32 |
33 | def flush(self):
34 | # this flush method is needed for python 3 compatibility.
35 | # this handles the flush command by doing nothing.
36 | # you might want to specify some extra behavior here.
37 | pass
38 |
39 |
40 | class Timer:
41 | def __init__(self, name):
42 | self.name = name
43 |
44 | def __enter__(self):
45 | self.begin = time.time()
46 | return self
47 |
48 | def __exit__(self, *args):
49 | self.end = time.time()
50 | self.elapsed = self.end - self.begin
51 | self.elapsedH = time.gmtime(self.elapsed)
52 | print('====> [{}] Time: {:7.3f}s or {}'
53 | .format(self.name,
54 | self.elapsed,
55 | time.strftime("%H:%M:%S", self.elapsedH)))
56 |
57 |
58 | # Functions
59 | def save_vars(vs, filepath):
60 | """
61 | Saves variables to the given filepath in a safe manner.
62 | """
63 | if os.path.exists(filepath):
64 | shutil.copyfile(filepath, '{}.old'.format(filepath))
65 | torch.save(vs, filepath)
66 |
67 |
68 | def save_model(model, filepath):
69 | """
70 | To load a saved model, simply use
71 | `model.load_state_dict(torch.load('path-to-saved-model'))`.
72 | """
73 | save_vars(model.state_dict(), filepath)
74 | if hasattr(model, 'vaes'):
75 | for vae in model.vaes:
76 | fdir, fext = os.path.splitext(filepath)
77 | save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext)
78 |
79 |
80 | def is_multidata(dataB):
81 | return isinstance(dataB, list) or isinstance(dataB, tuple)
82 |
83 |
84 | def unpack_data(dataB, device='cuda'):
85 | # dataB :: (Tensor, Idx) | [(Tensor, Idx)]
86 | """ Unpacks the data batch object in an appropriate manner to extract data """
87 | if is_multidata(dataB):
88 | if torch.is_tensor(dataB[0]):
89 | if torch.is_tensor(dataB[1]):
90 | return dataB[0].to(device) # mnist, svhn, cubI
91 | elif is_multidata(dataB[1]):
92 | return dataB[0].to(device), dataB[1][0].to(device) # cubISft
93 | else:
94 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1])))
95 |
96 | elif is_multidata(dataB[0]):
97 | return [d.to(device) for d in list(zip(*dataB))[0]] # mnist-svhn, cubIS
98 | else:
99 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0])))
100 | elif torch.is_tensor(dataB):
101 | return dataB.to(device)
102 | else:
103 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB)))
104 |
105 |
106 | def get_mean(d, K=100):
107 | """
108 | Extract the `mean` parameter for given distribution.
109 | If attribute not available, estimate from samples.
110 | """
111 | try:
112 | mean = d.mean
113 | except NotImplementedError:
114 | samples = d.rsample(torch.Size([K]))
115 | mean = samples.mean(0)
116 | return mean
117 |
118 |
119 | def log_mean_exp(value, dim=0, keepdim=False):
120 | return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim))
121 |
122 |
123 | def kl_divergence(d1, d2, K=100):
124 | """Computes closed-form KL if available, else computes a MC estimate."""
125 | if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY:
126 | return torch.distributions.kl_divergence(d1, d2)
127 | else:
128 | samples = d1.rsample(torch.Size([K]))
129 | return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0)
130 |
131 |
132 | def pdist(sample_1, sample_2, eps=1e-5):
133 | """Compute the matrix of all squared pairwise distances. Code
134 | adapted from the torch-two-sample library (added batching).
135 | You can find the original implementation of this function here:
136 | https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py
137 |
138 | Arguments
139 | ---------
140 | sample_1 : torch.Tensor or Variable
141 | The first sample, should be of shape ``(batch_size, n_1, d)``.
142 | sample_2 : torch.Tensor or Variable
143 | The second sample, should be of shape ``(batch_size, n_2, d)``.
144 | norm : float
145 | The l_p norm to be used.
146 | batched : bool
147 | whether data is batched
148 |
149 | Returns
150 | -------
151 | torch.Tensor or Variable
152 | Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to
153 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
154 | if len(sample_1.shape) == 2:
155 | sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0)
156 | B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1)
157 | norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True)
158 | norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True)
159 | norms = (norms_1.expand(B, n_1, n_2)
160 | + norms_2.transpose(1, 2).expand(B, n_1, n_2))
161 | distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2))
162 | return torch.sqrt(eps + torch.abs(distances_squared)).squeeze() # batch x K x latent
163 |
164 |
165 | def NN_lookup(emb_h, emb, data):
166 | indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0)
167 | # indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze()
168 | return data[indices]
169 |
170 |
171 | class FakeCategorical(dist.Distribution):
172 | support = dist.constraints.real
173 | has_rsample = True
174 |
175 | def __init__(self, locs):
176 | self.logits = locs
177 | self._batch_shape = self.logits.shape
178 |
179 | @property
180 | def mean(self):
181 | return self.logits
182 |
183 | def sample(self, sample_shape=torch.Size()):
184 | with torch.no_grad():
185 | return self.rsample(sample_shape)
186 |
187 | def rsample(self, sample_shape=torch.Size()):
188 | return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous()
189 |
190 | def log_prob(self, value):
191 | # value of shape (K, B, D)
192 | lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)),
193 | target=value.expand(self.logits.size()[:-1]).long().view(-1),
194 | reduction='none',
195 | ignore_index=0)
196 |
197 | return lpx_z.view(*self.logits.shape[:-1])
198 | # it is inevitable to have the word embedding dimension summed up in
199 | # cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the
200 | # operationally equivalence here, which is summing up the sentence dimension
201 | # in objective.
202 |
--------------------------------------------------------------------------------
/src/vis.py:
--------------------------------------------------------------------------------
1 | # visualisation related functions
2 |
3 | import matplotlib.colors as colors
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pandas as pd
7 | import seaborn as sns
8 | import torch
9 | from matplotlib.lines import Line2D
10 | from umap import UMAP
11 |
12 |
13 | def custom_cmap(n):
14 | """Create customised colormap for scattered latent plot of n categories.
15 | Returns colormap object and colormap array that contains the RGB value of the colors.
16 | See official matplotlib document for colormap reference:
17 | https://matplotlib.org/examples/color/colormaps_reference.html
18 | """
19 | # first color is grey from Set1, rest other sensible categorical colourmap
20 | cmap_array = sns.color_palette("Set1", 9)[-1:] + sns.husl_palette(n - 1, h=.6, s=0.7)
21 | cmap = colors.LinearSegmentedColormap.from_list('mmdgm_cmap', cmap_array)
22 | return cmap, cmap_array
23 |
24 |
25 | def embed_umap(data):
26 | """data should be on cpu, numpy"""
27 | embedding = UMAP(metric='euclidean',
28 | n_neighbors=40,
29 | # angular_rp_forest=True,
30 | # random_state=torch.initial_seed(),
31 | transform_seed=torch.initial_seed())
32 | return embedding.fit_transform(data)
33 |
34 |
35 | def plot_embeddings(emb, emb_l, labels, filepath):
36 | cmap_obj, cmap_arr = custom_cmap(n=len(labels))
37 | plt.figure()
38 | plt.scatter(emb[:, 0], emb[:, 1], c=emb_l, cmap=cmap_obj, s=25, alpha=0.2, edgecolors='none')
39 | l_elems = [Line2D([0], [0], marker='o', color=cm, label=l, alpha=0.5, linestyle='None')
40 | for (cm, l) in zip(cmap_arr, labels)]
41 | plt.legend(frameon=False, loc=2, handles=l_elems)
42 | plt.savefig(filepath, bbox_inches='tight')
43 | plt.close()
44 |
45 |
46 | def tensor_to_df(tensor, ax_names=None):
47 | assert tensor.ndim == 2, "Can only currently convert 2D tensors to dataframes"
48 | df = pd.DataFrame(data=tensor, columns=np.arange(tensor.shape[1]))
49 | return df.melt(value_vars=df.columns,
50 | var_name=('variable' if ax_names is None else ax_names[0]),
51 | value_name=('value' if ax_names is None else ax_names[1]))
52 |
53 |
54 | def tensors_to_df(tensors, head=None, keys=None, ax_names=None):
55 | dfs = [tensor_to_df(tensor, ax_names=ax_names) for tensor in tensors]
56 | df = pd.concat(dfs, keys=(np.arange(len(tensors)) if keys is None else keys))
57 | df.reset_index(level=0, inplace=True)
58 | if head is not None:
59 | df.rename(columns={'level_0': head}, inplace=True)
60 | return df
61 |
62 |
63 | def plot_kls_df(df, filepath):
64 | _, cmap_arr = custom_cmap(df[df.columns[0]].nunique() + 1)
65 | with sns.plotting_context("notebook", font_scale=2.0):
66 | g = sns.FacetGrid(df, height=12, aspect=2)
67 | g = g.map(sns.boxplot, df.columns[1], df.columns[2], df.columns[0], palette=cmap_arr[1:],
68 | order=None, hue_order=None)
69 | g = g.set(yscale='log').despine(offset=10)
70 | plt.legend(loc='best', fontsize='22')
71 | plt.savefig(filepath, bbox_inches='tight')
72 | plt.close()
73 |
--------------------------------------------------------------------------------