├── .dir-locals.el ├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── _imgs ├── cub.png ├── iwae.png ├── mnist-svhn.png ├── obj.png ├── schematic.png └── simple.png ├── bin └── make-mnist-svhn-idx.py ├── data └── cub │ ├── text_testclasses.txt │ └── text_trainvalclasses.txt ├── requirements.txt └── src ├── datasets.py ├── main.py ├── models ├── __init__.py ├── mmvae.py ├── mmvae_cub_images_sentences.py ├── mmvae_cub_images_sentences_ft.py ├── mmvae_mnist_svhn.py ├── vae.py ├── vae_cub_image.py ├── vae_cub_image_ft.py ├── vae_cub_sent.py ├── vae_cub_sent_ft.py ├── vae_mnist.py └── vae_svhn.py ├── objectives.py ├── report ├── analyse_cub.py ├── analyse_ms.py ├── calculate_likelihoods.py └── helper.py ├── utils.py └── vis.py /.dir-locals.el: -------------------------------------------------------------------------------- 1 | ((lua-mode . ((lua-indent-level . 2))) 2 | (python-mode . ((tab-width . 2)))) 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # https://github.com/pytorch/pytorch/blob/d0db624e02951c4dd6eb6b21d051f7ccf8133707/setup.cfg 2 | [flake8] 3 | max-line-length = 120 4 | ignore = E302,E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841,W503 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | **/*~ 3 | **/_* 4 | **/auto 5 | **/*.aux 6 | **/*.bbl 7 | **/*.blg 8 | **/*.log 9 | **/*.pdf 10 | **/*.out 11 | **/*.old 12 | **/*.run.xml 13 | **/images/ 14 | *.pyc 15 | **/__pycache__ 16 | !__init__.py 17 | !_imgs 18 | 19 | data/ 20 | experiments/**/ 21 | /.bash_history 22 | 23 | bin/*.sh 24 | bin/*.png 25 | bin/face_extract_vgg/ 26 | 27 | doc/ 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Multimodal Mixture-of-Experts VAE 4 | This repository contains the code for the framework in **Variational Mixture-of-Experts Autoencodersfor Multi-Modal Deep Generative Models** (see [paper](https://arxiv.org/pdf/1911.03393.pdf)). 5 | 6 | ## Requirements 7 | List of packages we used and the version we tested the model on (see also `requirements.txt`) 8 | 9 | ``` 10 | python == 3.6.8 11 | gensim == 3.8.1 12 | matplotlib == 3.1.1 13 | nltk == 3.4.5 14 | numpy == 1.16.4 15 | pandas == 0.25.3 16 | scipy == 1.3.2 17 | seaborn == 0.9.0 18 | scikit-image == 0.15.0 19 | torch == 1.3.1 20 | torchnet == 0.0.4 21 | torchvision == 0.4.2 22 | umap-learn == 0.1.1 23 | ``` 24 | 25 | ## Downloads 26 | ### MNIST-SVHN Dataset 27 | 28 |

29 | 30 | We construct a dataset of pairs of MNIST and SVHN such that each pair depicts the same digit class. Each instance of a digit class in either dataset is randomly paired with 20 instances of the same digit class from the other dataset. 31 | 32 | **Usage**: To prepare this dataset, run `bin/make-mnist-svhn-idx.py` -- this should automatically handle the download and pairing. 33 | 34 | ### CUB Image-Caption 35 | 36 |

37 | 38 | We use Caltech-UCSD Birds (CUB) dataset, with the bird images and their captions serving as two modalities. 39 | 40 | **Usage**: We offer a cleaned-up version of the CUB dataset. Download the dataset [here](http://www.robots.ox.ac.uk/~yshi/mmdgm/datasets/cub.zip). First, create a `data` folder under the project directory; then unzip thedownloaded content into `data`. After finishing these steps, the structure of the `data/cub` folder should look like: 41 | 42 | ``` 43 | data/cub 44 | │───text_testclasses.txt 45 | │───text_trainvalclasses.txt 46 | │───train 47 | │ │───002.Laysan_Albatross 48 | │ │ └───...jpg 49 | │ │───003.Sooty_Albatross 50 | │ │ └───...jpg 51 | │ │───... 52 | │ └───200.Common_Yellowthroat 53 | │ └───...jpg 54 | └───test 55 | │───001.Black_footed_Albatross 56 | │ └───...jpg 57 | │───004.Groove_billed_Ani 58 | │ └───...jpg 59 | │───... 60 | └───197.Marsh_Wren 61 | └───...jpg 62 | ``` 63 | 64 | 65 | ### Pretrained network 66 | Pretrained models are also available if you want to play around with it. Download from the following links: 67 | - [MNIST-SVHN](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/mnist-svhn.zip) 68 | - [CUB Image-Caption (feature)](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubISft.zip) 69 | - [CUB Image-Caption (raw images)](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubIS.zip) 70 | 71 | ## Usage 72 | 73 | ### Training 74 | 75 | Make sure the [requirements](#requirements) are satisfied in your environment, and relevant [datasets](#downloads) are downloaded. `cd` into `src`, and, for MNIST-SVHN experiments, run 76 | 77 | ```bash 78 | python main.py --model mnist_svhn 79 | 80 | ``` 81 | 82 | For CUB Image-Caption with image feature search (See Figure 7 in our [paper](https://arxiv.org/pdf/1911.03393.pdf)), run 83 | ```bash 84 | python main.py --model cubISft 85 | 86 | ``` 87 | 88 | For CUB Image-Caption with raw image generation, run 89 | ```bash 90 | python main.py --model cubIS 91 | 92 | ``` 93 | 94 | You can also play with the hyperparameters using arguments. Some of the more interesting ones are listed as follows: 95 | - **`--obj`**: Objective functions, offers 3 choices including importance-sampled ELBO (`elbo`), IWAE (`iwae`) and DReG (`dreg`, used in paper). Including the `--looser` flag when using IWAE or DReG removes unbalanced weighting of modalities, which we find to perform better empirically; 96 | - **`--K`**: Number of particles, controls the number of particles `K` in IWAE/DReG estimator, as specified in following equation: 97 | 98 |

99 | 100 | - **`--learn-prior`**: Prior variance learning, controls whether to enable prior variance learning. Results in our paper are produced with this enabled. Excluding this argument in the command will disable this option; 101 | - **`--llik_scaling`**: Likelihood scaling, specifies the likelihood scaling of one of the two modalities, so that the likelihoods of two modalities contribute similarly to the lower bound. The default values are: 102 | - _MNIST-SVHN_: MNIST scaling factor 32*32*3/28*28*1 = 3.92 103 | - _CUB Image-Cpation_: Image scaling factor 32/64*64*3 = 0.0026 104 | - **`--latent-dimension`**: Latent dimension 105 | 106 | You can also load from pre-trained models by specifying the path to the model folder, for example `python --model mnist_svhn --pre-trained path/to/model/folder/`. See following for the flag we used for these pretrained models: 107 | - **MNIST-SVHN**: `--model mnist_svhn --obj dreg --K 30 --learn-prior --looser --epochs 30 --batch-size 128 --latent-dim 20` 108 | - **CUB Image-Caption (feature)**: `--model cubISft --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64 --llik_scaling 0.002` 109 | - **CUB Image-Caption (raw images)**: `--model cubIS --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64` 110 | 111 | ### Analysing 112 | We offer tools to reproduce the quantitative results in our paper in `src/report`. To run any of the provided scripts, `cd` into `src`, and 113 | 114 | - for likelihood estimation of data using a trained model, run `python calculate_likelihoods.py --save-dir path/to/trained/model/folder/ --iwae-samples 1000`; 115 | - for coherence analysis and latent digit classification accuracy on MNIST-SVHN dataset, run `python analyse_ms.py --save-dir path/to/trained/model/folder/`; 116 | - for coherence analysis on CUB image-caption dataset, run `python analyse_cub.py --save-dir path/to/trained/model/folder/`. 117 | - _**Note**_: The learnt CCA projection matrix and FastText embeddings can vary quite a bit due to the limited dataset size, therefore re-computing them as part of the analyses can result in different numeric values including for the baseline. **The relative performance of our model against the baseline remains the same, just that the numbers can different.** 118 | To produce similar results to what's reported in our paper, download the zip file [here](http://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/CCA_emb.zip) and do the following: 119 | 1. Move `cub.all`, `cub.emb`, `cub.pc` to under `data/cub/oc:3_sl:32_s:300_w:3/`; 120 | 2. Move the rest of the files, i.e. `emb_mean.pt`, `emb_proj.pt`, `images_mean.pt`, `im_proj.pt` to `path/to/trained/model/folder/`; 121 | 3. Set the `RESET` variable in `src/report/analyse_cub.py` to `False`. 122 | 123 | 124 | ## Contact 125 | If you have any questions, feel free to create an issue or email Yuge Shi at yshi@robots.ox.ac.uk. 126 | -------------------------------------------------------------------------------- /_imgs/cub.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/cub.png -------------------------------------------------------------------------------- /_imgs/iwae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/iwae.png -------------------------------------------------------------------------------- /_imgs/mnist-svhn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/mnist-svhn.png -------------------------------------------------------------------------------- /_imgs/obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/obj.png -------------------------------------------------------------------------------- /_imgs/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/schematic.png -------------------------------------------------------------------------------- /_imgs/simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffsid/mmvae/54398d75c144e9b6c06ef5c33a0ee9a162638b7d/_imgs/simple.png -------------------------------------------------------------------------------- /bin/make-mnist-svhn-idx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | def rand_match_on_idx(l1, idx1, l2, idx2, max_d=10000, dm=10): 5 | """ 6 | l*: sorted labels 7 | idx*: indices of sorted labels in original list 8 | """ 9 | _idx1, _idx2 = [], [] 10 | for l in l1.unique(): # assuming both have same idxs 11 | l_idx1, l_idx2 = idx1[l1 == l], idx2[l2 == l] 12 | n = min(l_idx1.size(0), l_idx2.size(0), max_d) 13 | l_idx1, l_idx2 = l_idx1[:n], l_idx2[:n] 14 | for _ in range(dm): 15 | _idx1.append(l_idx1[torch.randperm(n)]) 16 | _idx2.append(l_idx2[torch.randperm(n)]) 17 | return torch.cat(_idx1), torch.cat(_idx2) 18 | 19 | if __name__ == '__main__': 20 | max_d = 10000 # maximum number of datapoints per class 21 | dm = 30 # data multiplier: random permutations to match 22 | 23 | # get the individual datasets 24 | tx = transforms.ToTensor() 25 | train_mnist = datasets.MNIST('../data', train=True, download=True, transform=tx) 26 | test_mnist = datasets.MNIST('../data', train=False, download=True, transform=tx) 27 | train_svhn = datasets.SVHN('../data', split='train', download=True, transform=tx) 28 | test_svhn = datasets.SVHN('../data', split='test', download=True, transform=tx) 29 | # svhn labels need extra work 30 | train_svhn.labels = torch.LongTensor(train_svhn.labels.squeeze().astype(int)) % 10 31 | test_svhn.labels = torch.LongTensor(test_svhn.labels.squeeze().astype(int)) % 10 32 | 33 | mnist_l, mnist_li = train_mnist.targets.sort() 34 | svhn_l, svhn_li = train_svhn.labels.sort() 35 | idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm) 36 | print('len train idx:', len(idx1), len(idx2)) 37 | torch.save(idx1, '../data/train-ms-mnist-idx.pt') 38 | torch.save(idx2, '../data/train-ms-svhn-idx.pt') 39 | 40 | mnist_l, mnist_li = test_mnist.targets.sort() 41 | svhn_l, svhn_li = test_svhn.labels.sort() 42 | idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm) 43 | print('len test idx:', len(idx1), len(idx2)) 44 | torch.save(idx1, '../data/test-ms-mnist-idx.pt') 45 | torch.save(idx2, '../data/test-ms-svhn-idx.pt') 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python == 3.6.8 2 | gensim == 3.8.1 3 | matplotlib == 3.1.1 4 | nltk == 3.4.5 5 | numpy == 1.16.4 6 | pandas == 0.25.3 7 | scipy == 1.3.2 8 | seaborn == 0.9.0 9 | scikit-image == 0.15.0 10 | torch == 1.3.1 11 | torchnet == 0.0.4 12 | torchvision == 0.4.2 13 | umap-learn == 0.1.1 14 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import os 4 | import pickle 5 | from collections import Counter, OrderedDict 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from nltk.tokenize import sent_tokenize, word_tokenize 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms, models, datasets 14 | 15 | 16 | class OrderedCounter(Counter, OrderedDict): 17 | """Counter that remembers the order elements are first encountered.""" 18 | 19 | def __repr__(self): 20 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self)) 21 | 22 | def __reduce__(self): 23 | return self.__class__, (OrderedDict(self),) 24 | 25 | 26 | class CUBSentences(Dataset): 27 | 28 | def __init__(self, root_data_dir, split, transform=None, **kwargs): 29 | """split: 'trainval' or 'test' """ 30 | 31 | super().__init__() 32 | self.data_dir = os.path.join(root_data_dir, 'cub') 33 | self.split = split 34 | self.max_sequence_length = kwargs.get('max_sequence_length', 32) 35 | self.min_occ = kwargs.get('min_occ', 3) 36 | self.transform = transform 37 | os.makedirs(os.path.join(root_data_dir, "lang_emb"), exist_ok=True) 38 | 39 | self.gen_dir = os.path.join(self.data_dir, "oc:{}_msl:{}". 40 | format(self.min_occ, self.max_sequence_length)) 41 | 42 | if split == 'train': 43 | self.raw_data_path = os.path.join(self.data_dir, 'text_trainvalclasses.txt') 44 | elif split == 'test': 45 | self.raw_data_path = os.path.join(self.data_dir, 'text_testclasses.txt') 46 | else: 47 | raise Exception("Only train or test split is available") 48 | 49 | os.makedirs(self.gen_dir, exist_ok=True) 50 | self.data_file = 'cub.{}.s{}'.format(split, self.max_sequence_length) 51 | self.vocab_file = 'cub.vocab' 52 | 53 | if not os.path.exists(os.path.join(self.gen_dir, self.data_file)): 54 | print("Data file not found for {} split at {}. Creating new... (this may take a while)". 55 | format(split.upper(), os.path.join(self.gen_dir, self.data_file))) 56 | self._create_data() 57 | 58 | else: 59 | self._load_data() 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, idx): 65 | sent = self.data[str(idx)]['idx'] 66 | if self.transform is not None: 67 | sent = self.transform(sent) 68 | return sent, self.data[str(idx)]['length'] 69 | 70 | @property 71 | def vocab_size(self): 72 | return len(self.w2i) 73 | 74 | @property 75 | def pad_idx(self): 76 | return self.w2i[''] 77 | 78 | @property 79 | def eos_idx(self): 80 | return self.w2i[''] 81 | 82 | @property 83 | def unk_idx(self): 84 | return self.w2i[''] 85 | 86 | def get_w2i(self): 87 | return self.w2i 88 | 89 | def get_i2w(self): 90 | return self.i2w 91 | 92 | def _load_data(self, vocab=True): 93 | with open(os.path.join(self.gen_dir, self.data_file), 'rb') as file: 94 | self.data = json.load(file) 95 | 96 | if vocab: 97 | self._load_vocab() 98 | 99 | def _load_vocab(self): 100 | if not os.path.exists(os.path.join(self.gen_dir, self.vocab_file)): 101 | self._create_vocab() 102 | with open(os.path.join(self.gen_dir, self.vocab_file), 'r') as vocab_file: 103 | vocab = json.load(vocab_file) 104 | self.w2i, self.i2w = vocab['w2i'], vocab['i2w'] 105 | 106 | def _create_data(self): 107 | if self.split == 'train' and not os.path.exists(os.path.join(self.gen_dir, self.vocab_file)): 108 | self._create_vocab() 109 | else: 110 | self._load_vocab() 111 | 112 | with open(self.raw_data_path, 'r') as file: 113 | text = file.read() 114 | sentences = sent_tokenize(text) 115 | 116 | data = defaultdict(dict) 117 | pad_count = 0 118 | 119 | for i, line in enumerate(sentences): 120 | words = word_tokenize(line) 121 | 122 | tok = words[:self.max_sequence_length - 1] 123 | tok = tok + [''] 124 | length = len(tok) 125 | if self.max_sequence_length > length: 126 | tok.extend([''] * (self.max_sequence_length - length)) 127 | pad_count += 1 128 | idx = [self.w2i.get(w, self.w2i['']) for w in tok] 129 | 130 | id = len(data) 131 | data[id]['tok'] = tok 132 | data[id]['idx'] = idx 133 | data[id]['length'] = length 134 | 135 | print("{} out of {} sentences are truncated with max sentence length {}.". 136 | format(len(sentences) - pad_count, len(sentences), self.max_sequence_length)) 137 | with io.open(os.path.join(self.gen_dir, self.data_file), 'wb') as data_file: 138 | data = json.dumps(data, ensure_ascii=False) 139 | data_file.write(data.encode('utf8', 'replace')) 140 | 141 | self._load_data(vocab=False) 142 | 143 | def _create_vocab(self): 144 | 145 | assert self.split == 'train', "Vocablurary can only be created for training file." 146 | 147 | with open(self.raw_data_path, 'r') as file: 148 | text = file.read() 149 | sentences = sent_tokenize(text) 150 | 151 | occ_register = OrderedCounter() 152 | w2i = dict() 153 | i2w = dict() 154 | 155 | special_tokens = ['', '', ''] 156 | for st in special_tokens: 157 | i2w[len(w2i)] = st 158 | w2i[st] = len(w2i) 159 | 160 | texts = [] 161 | unq_words = [] 162 | 163 | for i, line in enumerate(sentences): 164 | words = word_tokenize(line) 165 | occ_register.update(words) 166 | texts.append(words) 167 | 168 | for w, occ in occ_register.items(): 169 | if occ > self.min_occ and w not in special_tokens: 170 | i2w[len(w2i)] = w 171 | w2i[w] = len(w2i) 172 | else: 173 | unq_words.append(w) 174 | 175 | assert len(w2i) == len(i2w) 176 | 177 | print("Vocablurary of {} keys created, {} words are excluded (occurrence <= {})." 178 | .format(len(w2i), len(unq_words), self.min_occ)) 179 | 180 | vocab = dict(w2i=w2i, i2w=i2w) 181 | with io.open(os.path.join(self.gen_dir, self.vocab_file), 'wb') as vocab_file: 182 | data = json.dumps(vocab, ensure_ascii=False) 183 | vocab_file.write(data.encode('utf8', 'replace')) 184 | 185 | with open(os.path.join(self.gen_dir, 'cub.unique'), 'wb') as unq_file: 186 | pickle.dump(np.array(unq_words), unq_file) 187 | 188 | with open(os.path.join(self.gen_dir, 'cub.all'), 'wb') as a_file: 189 | pickle.dump(occ_register, a_file) 190 | 191 | self._load_vocab() 192 | 193 | 194 | class CUBImageFt(Dataset): 195 | def __init__(self, root_data_dir, split, device): 196 | """split: 'trainval' or 'test' """ 197 | 198 | super().__init__() 199 | self.data_dir = os.path.join(root_data_dir, 'cub') 200 | self.data_file = os.path.join(self.data_dir, split) 201 | self.gen_dir = os.path.join(self.data_dir, 'resnet101_2048') 202 | self.gen_ft_file = os.path.join(self.gen_dir, '{}.ft'.format(split)) 203 | self.gen_data_file = os.path.join(self.gen_dir, '{}.data'.format(split)) 204 | self.split = split 205 | 206 | tx = transforms.Compose([ 207 | transforms.Resize(224), 208 | transforms.ToTensor() 209 | ]) 210 | self.dataset = datasets.ImageFolder(self.data_file, transform=tx) 211 | 212 | os.makedirs(self.gen_dir, exist_ok=True) 213 | if not os.path.exists(self.gen_ft_file): 214 | print("Data file not found for CUB image features at `{}`. " 215 | "Extracting resnet101 features from CUB image dataset... " 216 | "(this may take a while)".format(self.gen_ft_file)) 217 | self._create_ft_mat(device) 218 | 219 | else: 220 | self._load_ft_mat() 221 | 222 | def __len__(self): 223 | return len(self.ft_mat) 224 | 225 | def __getitem__(self, idx): 226 | return self.ft_mat[idx] 227 | 228 | def _load_ft_mat(self): 229 | self.ft_mat = torch.load(self.gen_ft_file) 230 | 231 | def _load_data(self): 232 | self.data_mat = torch.load(self.gen_data_file) 233 | 234 | def _create_ft_mat(self, device): 235 | resnet = models.resnet101(pretrained=True) 236 | modules = list(resnet.children())[:-1] 237 | self.model = nn.Sequential(*modules) 238 | self.model.eval() 239 | 240 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 241 | 242 | loader = torch.utils.data.DataLoader(self.dataset, batch_size=256, 243 | shuffle=False, **kwargs) 244 | with torch.no_grad(): 245 | ft_mat = torch.cat([self.model(data[0]).squeeze() for data in loader]) 246 | 247 | torch.save(ft_mat, self.gen_ft_file) 248 | del ft_mat 249 | 250 | data_mat = torch.cat([data[0].squeeze() for data in loader]) 251 | torch.save(data_mat, self.gen_data_file) 252 | 253 | self._load_ft_mat() 254 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import sys 4 | import json 5 | from collections import defaultdict 6 | from pathlib import Path 7 | from tempfile import mkdtemp 8 | 9 | import numpy as np 10 | import torch 11 | from torch import optim 12 | 13 | import models 14 | import objectives 15 | from utils import Logger, Timer, save_model, save_vars, unpack_data 16 | 17 | parser = argparse.ArgumentParser(description='Multi-Modal VAEs') 18 | parser.add_argument('--experiment', type=str, default='', metavar='E', 19 | help='experiment name') 20 | parser.add_argument('--model', type=str, default='mnist_svhn', metavar='M', 21 | choices=[s[4:] for s in dir(models) if 'VAE_' in s], 22 | help='model name (default: mnist_svhn)') 23 | parser.add_argument('--obj', type=str, default='elbo', metavar='O', 24 | choices=['elbo', 'iwae', 'dreg'], 25 | help='objective to use (default: elbo)') 26 | parser.add_argument('--K', type=int, default=20, metavar='K', 27 | help='number of particles to use for iwae/dreg (default: 10)') 28 | parser.add_argument('--looser', action='store_true', default=False, 29 | help='use the looser version of IWAE/DREG') 30 | parser.add_argument('--llik_scaling', type=float, default=0., 31 | help='likelihood scaling for cub images/svhn modality when running in' 32 | 'multimodal setting, set as 0 to use default value') 33 | parser.add_argument('--batch-size', type=int, default=256, metavar='N', 34 | help='batch size for data (default: 256)') 35 | parser.add_argument('--epochs', type=int, default=10, metavar='E', 36 | help='number of epochs to train (default: 10)') 37 | parser.add_argument('--latent-dim', type=int, default=20, metavar='L', 38 | help='latent dimensionality (default: 20)') 39 | parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H', 40 | help='number of hidden layers in enc and dec (default: 1)') 41 | parser.add_argument('--pre-trained', type=str, default="", 42 | help='path to pre-trained model (train from scratch if empty)') 43 | parser.add_argument('--learn-prior', action='store_true', default=False, 44 | help='learn model prior parameters') 45 | parser.add_argument('--logp', action='store_true', default=False, 46 | help='estimate tight marginal likelihood on completion') 47 | parser.add_argument('--print-freq', type=int, default=0, metavar='f', 48 | help='frequency with which to print stats (default: 0)') 49 | parser.add_argument('--no-analytics', action='store_true', default=False, 50 | help='disable plotting analytics') 51 | parser.add_argument('--no-cuda', action='store_true', default=False, 52 | help='disable CUDA use') 53 | parser.add_argument('--seed', type=int, default=1, metavar='S', 54 | help='random seed (default: 1)') 55 | 56 | # args 57 | args = parser.parse_args() 58 | 59 | # random seed 60 | # https://pytorch.org/docs/stable/notes/randomness.html 61 | torch.backends.cudnn.benchmark = True 62 | torch.manual_seed(args.seed) 63 | np.random.seed(args.seed) 64 | 65 | # load args from disk if pretrained model path is given 66 | pretrained_path = "" 67 | if args.pre_trained: 68 | pretrained_path = args.pre_trained 69 | args = torch.load(args.pre_trained + '/args.rar') 70 | 71 | args.cuda = not args.no_cuda and torch.cuda.is_available() 72 | device = torch.device("cuda" if args.cuda else "cpu") 73 | 74 | # load model 75 | modelC = getattr(models, 'VAE_{}'.format(args.model)) 76 | model = modelC(args).to(device) 77 | 78 | if pretrained_path: 79 | print('Loading model {} from {}'.format(model.modelName, pretrained_path)) 80 | model.load_state_dict(torch.load(pretrained_path + '/model.rar')) 81 | model._pz_params = model._pz_params 82 | 83 | if not args.experiment: 84 | args.experiment = model.modelName 85 | 86 | # set up run path 87 | runId = datetime.datetime.now().isoformat() 88 | experiment_dir = Path('../experiments/' + args.experiment) 89 | experiment_dir.mkdir(parents=True, exist_ok=True) 90 | runPath = mkdtemp(prefix=runId, dir=str(experiment_dir)) 91 | sys.stdout = Logger('{}/run.log'.format(runPath)) 92 | print('Expt:', runPath) 93 | print('RunID:', runId) 94 | 95 | # save args to run 96 | with open('{}/args.json'.format(runPath), 'w') as fp: 97 | json.dump(args.__dict__, fp) 98 | # -- also save object because we want to recover these for other things 99 | torch.save(args, '{}/args.rar'.format(runPath)) 100 | 101 | # preparation for training 102 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 103 | lr=1e-3, amsgrad=True) 104 | train_loader, test_loader = model.getDataLoaders(args.batch_size, device=device) 105 | objective = getattr(objectives, 106 | ('m_' if hasattr(model, 'vaes') else '') 107 | + args.obj 108 | + ('_looser' if (args.looser and args.obj != 'elbo') else '')) 109 | t_objective = getattr(objectives, ('m_' if hasattr(model, 'vaes') else '') + 'iwae') 110 | 111 | 112 | def train(epoch, agg): 113 | model.train() 114 | b_loss = 0 115 | for i, dataT in enumerate(train_loader): 116 | data = unpack_data(dataT, device=device) 117 | optimizer.zero_grad() 118 | loss = -objective(model, data, K=args.K) 119 | loss.backward() 120 | optimizer.step() 121 | b_loss += loss.item() 122 | if args.print_freq > 0 and i % args.print_freq == 0: 123 | print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size)) 124 | agg['train_loss'].append(b_loss / len(train_loader.dataset)) 125 | print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1])) 126 | 127 | 128 | def test(epoch, agg): 129 | model.eval() 130 | b_loss = 0 131 | with torch.no_grad(): 132 | for i, dataT in enumerate(test_loader): 133 | data = unpack_data(dataT, device=device) 134 | loss = -t_objective(model, data, K=args.K) 135 | b_loss += loss.item() 136 | if i == 0: 137 | model.reconstruct(data, runPath, epoch) 138 | if not args.no_analytics: 139 | model.analyse(data, runPath, epoch) 140 | agg['test_loss'].append(b_loss / len(test_loader.dataset)) 141 | print('====> Test loss: {:.4f}'.format(agg['test_loss'][-1])) 142 | 143 | 144 | def estimate_log_marginal(K): 145 | """Compute an IWAE estimate of the log-marginal likelihood of test data.""" 146 | model.eval() 147 | marginal_loglik = 0 148 | with torch.no_grad(): 149 | for dataT in test_loader: 150 | data = unpack_data(dataT, device=device) 151 | marginal_loglik += -t_objective(model, data, K).item() 152 | 153 | marginal_loglik /= len(test_loader.dataset) 154 | print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format(K, marginal_loglik)) 155 | 156 | 157 | if __name__ == '__main__': 158 | with Timer('MM-VAE') as t: 159 | agg = defaultdict(list) 160 | for epoch in range(1, args.epochs + 1): 161 | train(epoch, agg) 162 | test(epoch, agg) 163 | save_model(model, runPath + '/model.rar') 164 | save_vars(agg, runPath + '/losses.rar') 165 | model.generate(runPath, epoch) 166 | if args.logp: # compute as tight a marginal likelihood as possible 167 | estimate_log_marginal(5000) 168 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmvae_cub_images_sentences import CUB_Image_Sentence as VAE_cubIS 2 | from .mmvae_cub_images_sentences_ft import CUB_Image_Sentence_ft as VAE_cubISft 3 | from .mmvae_mnist_svhn import MNIST_SVHN as VAE_mnist_svhn 4 | from .vae_cub_image import CUB_Image as VAE_cubI 5 | from .vae_cub_image_ft import CUB_Image_ft as VAE_cubIft 6 | from .vae_cub_sent import CUB_Sentence as VAE_cubS 7 | from .vae_mnist import MNIST as VAE_mnist 8 | from .vae_svhn import SVHN as VAE_svhn 9 | 10 | __all__ = [VAE_mnist_svhn, VAE_mnist, VAE_svhn, VAE_cubIS, VAE_cubS, 11 | VAE_cubI, VAE_cubISft, VAE_cubIft] 12 | -------------------------------------------------------------------------------- /src/models/mmvae.py: -------------------------------------------------------------------------------- 1 | # Base MMVAE class definition 2 | 3 | from itertools import combinations 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from utils import get_mean, kl_divergence 9 | from vis import embed_umap, tensors_to_df 10 | 11 | 12 | class MMVAE(nn.Module): 13 | def __init__(self, prior_dist, params, *vaes): 14 | super(MMVAE, self).__init__() 15 | self.pz = prior_dist 16 | self.vaes = nn.ModuleList([vae(params) for vae in vaes]) 17 | self.modelName = None # filled-in per sub-class 18 | self.params = params 19 | self._pz_params = None # defined in subclass 20 | 21 | @property 22 | def pz_params(self): 23 | return self._pz_params 24 | 25 | @staticmethod 26 | def getDataLoaders(batch_size, shuffle=True, device="cuda"): 27 | # handle merging individual datasets appropriately in sub-class 28 | raise NotImplementedError 29 | 30 | def forward(self, x, K=1): 31 | qz_xs, zss = [], [] 32 | # initialise cross-modal matrix 33 | px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))] 34 | for m, vae in enumerate(self.vaes): 35 | qz_x, px_z, zs = vae(x[m], K=K) 36 | qz_xs.append(qz_x) 37 | zss.append(zs) 38 | px_zs[m][m] = px_z # fill-in diagonal 39 | for e, zs in enumerate(zss): 40 | for d, vae in enumerate(self.vaes): 41 | if e != d: # fill-in off-diagonal 42 | px_zs[e][d] = vae.px_z(*vae.dec(zs)) 43 | return qz_xs, px_zs, zss 44 | 45 | def generate(self, N): 46 | self.eval() 47 | with torch.no_grad(): 48 | data = [] 49 | pz = self.pz(*self.pz_params) 50 | latents = pz.rsample(torch.Size([N])) 51 | for d, vae in enumerate(self.vaes): 52 | px_z = vae.px_z(*vae.dec(latents)) 53 | data.append(px_z.mean.view(-1, *px_z.mean.size()[2:])) 54 | return data # list of generations---one for each modality 55 | 56 | def reconstruct(self, data): 57 | self.eval() 58 | with torch.no_grad(): 59 | _, px_zs, _ = self.forward(data) 60 | # cross-modal matrix of reconstructions 61 | recons = [[get_mean(px_z) for px_z in r] for r in px_zs] 62 | return recons 63 | 64 | def analyse(self, data, K): 65 | self.eval() 66 | with torch.no_grad(): 67 | qz_xs, _, zss = self.forward(data, K=K) 68 | pz = self.pz(*self.pz_params) 69 | zss = [pz.sample(torch.Size([K, data[0].size(0)])).view(-1, pz.batch_shape[-1]), 70 | *[zs.view(-1, zs.size(-1)) for zs in zss]] 71 | zsl = [torch.zeros(zs.size(0)).fill_(i) for i, zs in enumerate(zss)] 72 | kls_df = tensors_to_df( 73 | [*[kl_divergence(qz_x, pz).cpu().numpy() for qz_x in qz_xs], 74 | *[0.5 * (kl_divergence(p, q) + kl_divergence(q, p)).cpu().numpy() 75 | for p, q in combinations(qz_xs, 2)]], 76 | head='KL', 77 | keys=[*[r'KL$(q(z|x_{})\,||\,p(z))$'.format(i) for i in range(len(qz_xs))], 78 | *[r'J$(q(z|x_{})\,||\,q(z|x_{}))$'.format(i, j) 79 | for i, j in combinations(range(len(qz_xs)), 2)]], 80 | ax_names=['Dimensions', r'KL$(q\,||\,p)$'] 81 | ) 82 | return embed_umap(torch.cat(zss, 0).cpu().numpy()), \ 83 | torch.cat(zsl, 0).cpu().numpy(), \ 84 | kls_df 85 | -------------------------------------------------------------------------------- /src/models/mmvae_cub_images_sentences.py: -------------------------------------------------------------------------------- 1 | # cub multi-modal model specification 2 | import matplotlib.pyplot as plt 3 | import torch.distributions as dist 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data 7 | from numpy import sqrt, prod 8 | from torch.utils.data import DataLoader 9 | from torchnet.dataset import TensorDataset, ResampleDataset 10 | from torchvision.utils import save_image, make_grid 11 | 12 | from utils import Constants 13 | from vis import plot_embeddings, plot_kls_df 14 | from .mmvae import MMVAE 15 | from .vae_cub_image import CUB_Image 16 | from .vae_cub_sent import CUB_Sentence 17 | 18 | # Constants 19 | maxSentLen = 32 20 | minOccur = 3 21 | 22 | 23 | # This is required because there are 10 captions per image. 24 | # Allows easier reuse of the same image for the corresponding set of captions. 25 | def resampler(dataset, idx): 26 | return idx // 10 27 | 28 | 29 | class CUB_Image_Sentence(MMVAE): 30 | 31 | def __init__(self, params): 32 | super(CUB_Image_Sentence, self).__init__(dist.Laplace, params, CUB_Image, CUB_Sentence) 33 | grad = {'requires_grad': params.learn_prior} 34 | self._pz_params = nn.ParameterList([ 35 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 36 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 37 | ]) 38 | self.vaes[0].llik_scaling = self.vaes[1].maxSentLen / prod(self.vaes[0].dataSize) \ 39 | if params.llik_scaling == 0 else params.llik_scaling 40 | 41 | for vae in self.vaes: 42 | vae._pz_params = self._pz_params 43 | self.modelName = 'cubIS' 44 | 45 | self.i2w = self.vaes[1].load_vocab() 46 | 47 | @property 48 | def pz_params(self): 49 | return self._pz_params[0], \ 50 | F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(1) + Constants.eta 51 | 52 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'): 53 | # load base datasets 54 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device) 55 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device) 56 | 57 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {} 58 | train_loader = DataLoader(TensorDataset([ 59 | ResampleDataset(t1.dataset, resampler, size=len(t1.dataset) * 10), 60 | t2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs) 61 | test_loader = DataLoader(TensorDataset([ 62 | ResampleDataset(s1.dataset, resampler, size=len(s1.dataset) * 10), 63 | s2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs) 64 | return train_loader, test_loader 65 | 66 | def generate(self, runPath, epoch): 67 | N = 8 68 | samples = super(CUB_Image_Sentence, self).generate(N) 69 | images, captions = [sample.data.cpu() for sample in samples] 70 | captions = self._sent_preprocess(captions) 71 | fig = plt.figure(figsize=(8, 6)) 72 | for i, (image, caption) in enumerate(zip(images, captions)): 73 | fig = self._imshow(image, caption, i, fig, N) 74 | 75 | plt.savefig('{}/gen_samples_{:03d}.png'.format(runPath, epoch)) 76 | plt.close() 77 | 78 | def reconstruct(self, raw_data, runPath, epoch): 79 | N = 8 80 | recons_mat = super(CUB_Image_Sentence, self).reconstruct([d[:N] for d in raw_data]) 81 | fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)] 82 | for r, recons_list in enumerate(recons_mat): 83 | for o, recon in enumerate(recons_list): 84 | data = fns[r](raw_data[r][:N]) 85 | recon = fns[o](recon.squeeze()) 86 | if r != o: 87 | fig = plt.figure(figsize=(8, 6)) 88 | for i, (_data, _recon) in enumerate(zip(data, recon)): 89 | image, caption = (_data, _recon) if r == 0 else (_recon, _data) 90 | fig = self._imshow(image, caption, i, fig, N) 91 | plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch)) 92 | plt.close() 93 | else: 94 | if r == 0: 95 | comp = torch.cat([data, recon]) 96 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch)) 97 | else: 98 | with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file: 99 | for r_sent, d_sent in zip(recon, data): 100 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in d_sent))) 101 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(self.i2w[str(i)] for i in r_sent))) 102 | 103 | def analyse(self, data, runPath, epoch): 104 | zemb, zsl, kls_df = super(CUB_Image_Sentence, self).analyse(data, K=10) 105 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]] 106 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 107 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 108 | 109 | def _sent_preprocess(self, sentences): 110 | """make sure raw data is always passed as dim=2 to avoid argmax. 111 | last dimension must always be word embedding.""" 112 | if len(sentences.shape) > 2: 113 | sentences = sentences.argmax(-1).squeeze() 114 | return [self.vaes[1].fn_trun(s) for s in self.vaes[1].fn_2i(sentences)] 115 | 116 | def _imshow(self, image, caption, i, fig, N): 117 | """Imshow for Tensor.""" 118 | ax = fig.add_subplot(N // 2, 4, i * 2 + 1) 119 | ax.axis('off') 120 | image = image.numpy().transpose((1, 2, 0)) # 121 | plt.imshow(image) 122 | ax = fig.add_subplot(N // 2, 4, i * 2 + 2) 123 | pos = ax.get_position() 124 | ax.axis('off') 125 | plt.text( 126 | x=0.5 * (pos.x0 + pos.x1), 127 | y=0.5 * (pos.y0 + pos.y1), 128 | ha='left', 129 | s='{}'.format( 130 | ' '.join(self.i2w[str(i)] + '\n' if (n + 1) % 5 == 0 131 | else self.i2w[str(i)] for n, i in enumerate(caption))), 132 | fontsize=6, 133 | verticalalignment='center', 134 | horizontalalignment='center' 135 | ) 136 | return fig 137 | -------------------------------------------------------------------------------- /src/models/mmvae_cub_images_sentences_ft.py: -------------------------------------------------------------------------------- 1 | # cub multi-modal model specification 2 | import matplotlib.pyplot as plt 3 | import torch.distributions as dist 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data 7 | from numpy import sqrt, prod 8 | from torch.utils.data import DataLoader 9 | from torchnet.dataset import TensorDataset, ResampleDataset 10 | from torchvision.utils import save_image, make_grid 11 | 12 | from utils import Constants 13 | from vis import plot_embeddings, plot_kls_df 14 | from .mmvae import MMVAE 15 | from .vae_cub_image_ft import CUB_Image_ft 16 | from .vae_cub_sent_ft import CUB_Sentence_ft 17 | 18 | # Constants 19 | maxSentLen = 32 20 | minOccur = 3 21 | 22 | 23 | # This is required because there are 10 captions per image. 24 | # Allows easier reuse of the same image for the corresponding set of captions. 25 | def resampler(dataset, idx): 26 | return idx // 10 27 | 28 | 29 | class CUB_Image_Sentence_ft(MMVAE): 30 | 31 | def __init__(self, params): 32 | super(CUB_Image_Sentence_ft, self).__init__(dist.Normal, params, CUB_Image_ft, CUB_Sentence_ft) 33 | grad = {'requires_grad': params.learn_prior} 34 | self._pz_params = nn.ParameterList([ 35 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 36 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 37 | ]) 38 | self.vaes[0].llik_scaling = self.vaes[1].maxSentLen / prod(self.vaes[0].dataSize) \ 39 | if params.llik_scaling == 0 else params.llik_scaling 40 | 41 | for vae in self.vaes: 42 | vae._pz_params = self._pz_params 43 | self.modelName = 'cubISft' 44 | 45 | self.i2w = self.vaes[1].load_vocab() 46 | 47 | @property 48 | def pz_params(self): 49 | return self._pz_params[0], \ 50 | F.softplus(self._pz_params[1]) + Constants.eta 51 | 52 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'): 53 | # load base datasets 54 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device) 55 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device) 56 | 57 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {} 58 | train_loader = DataLoader(TensorDataset([ 59 | ResampleDataset(t1.dataset, resampler, size=len(t1.dataset) * 10), 60 | t2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs) 61 | test_loader = DataLoader(TensorDataset([ 62 | ResampleDataset(s1.dataset, resampler, size=len(s1.dataset) * 10), 63 | s2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs) 64 | return train_loader, test_loader 65 | 66 | def generate(self, runPath, epoch): 67 | N = 8 68 | samples = super(CUB_Image_Sentence_ft, self).generate(N) 69 | samples[0] = self.vaes[0].unproject(samples[0], search_split='train') 70 | images, captions = [sample.data.cpu() for sample in samples] 71 | captions = self._sent_preprocess(captions) 72 | fig = plt.figure(figsize=(8, 6)) 73 | for i, (image, caption) in enumerate(zip(images, captions)): 74 | fig = self._imshow(image, caption, i, fig, N) 75 | 76 | plt.savefig('{}/gen_samples_{:03d}.png'.format(runPath, epoch)) 77 | plt.close() 78 | 79 | def reconstruct(self, raw_data, runPath, epoch): 80 | N = 8 81 | recons_mat = super(CUB_Image_Sentence_ft, self).reconstruct([d[:N] for d in raw_data]) 82 | fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)] 83 | for r, recons_list in enumerate(recons_mat): 84 | for o, recon in enumerate(recons_list): 85 | data = fns[r](raw_data[r][:N]) 86 | recon = fns[o](recon.squeeze()) 87 | if r != o: 88 | fig = plt.figure(figsize=(8, 6)) 89 | for i, (_data, _recon) in enumerate(zip(data, recon)): 90 | image, caption = (_data, _recon) if r == 0 else (_recon, _data) 91 | search_split = 'test' if r == 0 else 'train' 92 | image = self.vaes[0].unproject(image.unsqueeze(0), search_split=search_split) 93 | fig = self._imshow(image, caption, i, fig, N) 94 | plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch)) 95 | plt.close() 96 | else: 97 | if r == 0: 98 | data_ = self.vaes[0].unproject(data, search_split='test') 99 | recon_ = self.vaes[0].unproject(recon, search_split='train') 100 | comp = torch.cat([data_, recon_]) 101 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch)) 102 | else: 103 | with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file: 104 | for r_sent, d_sent in zip(recon, data): 105 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in d_sent))) 106 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(self.i2w[str(i)] for i in r_sent))) 107 | 108 | def analyse(self, data, runPath, epoch): 109 | zemb, zsl, kls_df = super(CUB_Image_Sentence_ft, self).analyse(data, K=10) 110 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]] 111 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 112 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 113 | 114 | def _sent_preprocess(self, sentences): 115 | """make sure raw data is always passed as dim=2 to avoid argmax. 116 | last dimension must always be word embedding.""" 117 | if len(sentences.shape) > 2: 118 | sentences = sentences.argmax(-1).squeeze() 119 | return [self.vaes[1].fn_trun(s) for s in self.vaes[1].fn_2i(sentences)] 120 | 121 | def _imshow(self, image, caption, i, fig, N): 122 | """Imshow for Tensor.""" 123 | ax = fig.add_subplot(N // 2, 4, i * 2 + 1) 124 | ax.axis('off') 125 | image = image.numpy().transpose((1, 2, 0)) # 126 | plt.imshow(image) 127 | ax = fig.add_subplot(N // 2, 4, i * 2 + 2) 128 | pos = ax.get_position() 129 | ax.axis('off') 130 | plt.text( 131 | x=0.5 * (pos.x0 + pos.x1), 132 | y=0.5 * (pos.y0 + pos.y1), 133 | ha='left', 134 | s='{}'.format( 135 | ' '.join(self.i2w[str(i)] + '\n' if (n + 1) % 5 == 0 136 | else self.i2w[str(i)] for n, i in enumerate(caption))), 137 | fontsize=6, 138 | verticalalignment='center', 139 | horizontalalignment='center' 140 | ) 141 | return fig 142 | -------------------------------------------------------------------------------- /src/models/mmvae_mnist_svhn.py: -------------------------------------------------------------------------------- 1 | # MNIST-SVHN multi-modal model specification 2 | import os 3 | 4 | import torch 5 | import torch.distributions as dist 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from numpy import sqrt, prod 9 | from torch.utils.data import DataLoader 10 | from torchnet.dataset import TensorDataset, ResampleDataset 11 | from torchvision.utils import save_image, make_grid 12 | 13 | from vis import plot_embeddings, plot_kls_df 14 | from .mmvae import MMVAE 15 | from .vae_mnist import MNIST 16 | from .vae_svhn import SVHN 17 | 18 | 19 | class MNIST_SVHN(MMVAE): 20 | def __init__(self, params): 21 | super(MNIST_SVHN, self).__init__(dist.Laplace, params, MNIST, SVHN) 22 | grad = {'requires_grad': params.learn_prior} 23 | self._pz_params = nn.ParameterList([ 24 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 25 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 26 | ]) 27 | self.vaes[0].llik_scaling = prod(self.vaes[1].dataSize) / prod(self.vaes[0].dataSize) \ 28 | if params.llik_scaling == 0 else params.llik_scaling 29 | self.modelName = 'mnist-svhn' 30 | 31 | @property 32 | def pz_params(self): 33 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1) 34 | 35 | def getDataLoaders(self, batch_size, shuffle=True, device='cuda'): 36 | if not (os.path.exists('../data/train-ms-mnist-idx.pt') 37 | and os.path.exists('../data/train-ms-svhn-idx.pt') 38 | and os.path.exists('../data/test-ms-mnist-idx.pt') 39 | and os.path.exists('../data/test-ms-svhn-idx.pt')): 40 | raise RuntimeError('Generate transformed indices with the script in bin') 41 | # get transformed indices 42 | t_mnist = torch.load('../data/train-ms-mnist-idx.pt') 43 | t_svhn = torch.load('../data/train-ms-svhn-idx.pt') 44 | s_mnist = torch.load('../data/test-ms-mnist-idx.pt') 45 | s_svhn = torch.load('../data/test-ms-svhn-idx.pt') 46 | 47 | # load base datasets 48 | t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device) 49 | t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device) 50 | 51 | train_mnist_svhn = TensorDataset([ 52 | ResampleDataset(t1.dataset, lambda d, i: t_mnist[i], size=len(t_mnist)), 53 | ResampleDataset(t2.dataset, lambda d, i: t_svhn[i], size=len(t_svhn)) 54 | ]) 55 | test_mnist_svhn = TensorDataset([ 56 | ResampleDataset(s1.dataset, lambda d, i: s_mnist[i], size=len(s_mnist)), 57 | ResampleDataset(s2.dataset, lambda d, i: s_svhn[i], size=len(s_svhn)) 58 | ]) 59 | 60 | kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {} 61 | train = DataLoader(train_mnist_svhn, batch_size=batch_size, shuffle=shuffle, **kwargs) 62 | test = DataLoader(test_mnist_svhn, batch_size=batch_size, shuffle=shuffle, **kwargs) 63 | return train, test 64 | 65 | def generate(self, runPath, epoch): 66 | N = 64 67 | samples_list = super(MNIST_SVHN, self).generate(N) 68 | for i, samples in enumerate(samples_list): 69 | samples = samples.data.cpu() 70 | # wrangle things so they come out tiled 71 | samples = samples.view(N, *samples.size()[1:]) 72 | save_image(samples, 73 | '{}/gen_samples_{}_{:03d}.png'.format(runPath, i, epoch), 74 | nrow=int(sqrt(N))) 75 | 76 | def reconstruct(self, data, runPath, epoch): 77 | recons_mat = super(MNIST_SVHN, self).reconstruct([d[:8] for d in data]) 78 | for r, recons_list in enumerate(recons_mat): 79 | for o, recon in enumerate(recons_list): 80 | _data = data[r][:8].cpu() 81 | recon = recon.squeeze(0).cpu() 82 | # resize mnist to 32 and colour. 0 => mnist, 1 => svhn 83 | _data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize) 84 | recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize) 85 | comp = torch.cat([_data, recon]) 86 | save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch)) 87 | 88 | def analyse(self, data, runPath, epoch): 89 | zemb, zsl, kls_df = super(MNIST_SVHN, self).analyse(data, K=10) 90 | labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]] 91 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 92 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 93 | 94 | 95 | def resize_img(img, refsize): 96 | return F.pad(img, (2, 2, 2, 2)).expand(img.size(0), *refsize) 97 | -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | # Base VAE class definition 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils import get_mean, kl_divergence 7 | from vis import embed_umap, tensors_to_df 8 | 9 | 10 | class VAE(nn.Module): 11 | def __init__(self, prior_dist, likelihood_dist, post_dist, enc, dec, params): 12 | super(VAE, self).__init__() 13 | self.pz = prior_dist 14 | self.px_z = likelihood_dist 15 | self.qz_x = post_dist 16 | self.enc = enc 17 | self.dec = dec 18 | self.modelName = None 19 | self.params = params 20 | self._pz_params = None # defined in subclass 21 | self._qz_x_params = None # populated in `forward` 22 | self.llik_scaling = 1.0 23 | 24 | @property 25 | def pz_params(self): 26 | return self._pz_params 27 | 28 | @property 29 | def qz_x_params(self): 30 | if self._qz_x_params is None: 31 | raise NameError("qz_x params not initalised yet!") 32 | return self._qz_x_params 33 | 34 | @staticmethod 35 | def getDataLoaders(batch_size, shuffle=True, device="cuda"): 36 | # handle merging individual datasets appropriately in sub-class 37 | raise NotImplementedError 38 | 39 | def forward(self, x, K=1): 40 | self._qz_x_params = self.enc(x) 41 | qz_x = self.qz_x(*self._qz_x_params) 42 | zs = qz_x.rsample(torch.Size([K])) 43 | px_z = self.px_z(*self.dec(zs)) 44 | return qz_x, px_z, zs 45 | 46 | def generate(self, N, K): 47 | self.eval() 48 | with torch.no_grad(): 49 | pz = self.pz(*self.pz_params) 50 | latents = pz.rsample(torch.Size([N])) 51 | px_z = self.px_z(*self.dec(latents)) 52 | data = px_z.sample(torch.Size([K])) 53 | return data.view(-1, *data.size()[3:]) 54 | 55 | def reconstruct(self, data): 56 | self.eval() 57 | with torch.no_grad(): 58 | qz_x = self.qz_x(*self.enc(data)) 59 | latents = qz_x.rsample() # no dim expansion 60 | px_z = self.px_z(*self.dec(latents)) 61 | recon = get_mean(px_z) 62 | return recon 63 | 64 | def analyse(self, data, K): 65 | self.eval() 66 | with torch.no_grad(): 67 | qz_x, _, zs = self.forward(data, K=K) 68 | pz = self.pz(*self.pz_params) 69 | zss = [pz.sample(torch.Size([K, data.size(0)])).view(-1, pz.batch_shape[-1]), 70 | zs.view(-1, zs.size(-1))] 71 | zsl = [torch.zeros(zs.size(0)).fill_(i) for i, zs in enumerate(zss)] 72 | kls_df = tensors_to_df( 73 | [kl_divergence(qz_x, pz).cpu().numpy()], 74 | head='KL', 75 | keys=[r'KL$(q(z|x)\,||\,p(z))$'], 76 | ax_names=['Dimensions', r'KL$(q\,||\,p)$'] 77 | ) 78 | return embed_umap(torch.cat(zss, 0).cpu().numpy()), \ 79 | torch.cat(zsl, 0).cpu().numpy(), \ 80 | kls_df 81 | -------------------------------------------------------------------------------- /src/models/vae_cub_image.py: -------------------------------------------------------------------------------- 1 | # CUB Image model specification 2 | 3 | import torch 4 | import torch.distributions as dist 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | from numpy import sqrt 9 | from torchvision import datasets, transforms 10 | from torchvision.utils import make_grid, save_image 11 | 12 | from utils import Constants 13 | from vis import plot_embeddings, plot_kls_df 14 | from .vae import VAE 15 | 16 | # Constants 17 | imgChans = 3 18 | fBase = 64 19 | 20 | 21 | # Classes 22 | class Enc(nn.Module): 23 | """ Generate latent parameters for CUB image data. """ 24 | 25 | def __init__(self, latentDim): 26 | super(Enc, self).__init__() 27 | modules = [ 28 | # input size: 3 x 128 x 128 29 | nn.Conv2d(imgChans, fBase, 4, 2, 1, bias=True), 30 | nn.ReLU(True), 31 | # input size: 1 x 64 x 64 32 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True), 33 | nn.ReLU(True), 34 | # size: (fBase * 2) x 32 x 32 35 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True), 36 | nn.ReLU(True), 37 | # size: (fBase * 4) x 16 x 16 38 | nn.Conv2d(fBase * 4, fBase * 8, 4, 2, 1, bias=True), 39 | nn.ReLU(True)] 40 | # size: (fBase * 8) x 4 x 4 41 | 42 | self.enc = nn.Sequential(*modules) 43 | self.c1 = nn.Conv2d(fBase * 8, latentDim, 4, 1, 0, bias=True) 44 | self.c2 = nn.Conv2d(fBase * 8, latentDim, 4, 1, 0, bias=True) 45 | # c1, c2 size: latentDim x 1 x 1 46 | 47 | def forward(self, x): 48 | e = self.enc(x) 49 | return self.c1(e).squeeze(), F.softplus(self.c2(e)).squeeze() + Constants.eta 50 | 51 | 52 | class Dec(nn.Module): 53 | """ Generate an image given a sample from the latent space. """ 54 | 55 | def __init__(self, latentDim): 56 | super(Dec, self).__init__() 57 | modules = [nn.ConvTranspose2d(latentDim, fBase * 8, 4, 1, 0, bias=True), 58 | nn.ReLU(True), ] 59 | 60 | modules.extend([ 61 | nn.ConvTranspose2d(fBase * 8, fBase * 4, 4, 2, 1, bias=True), 62 | nn.ReLU(True), 63 | # size: (fBase * 4) x 16 x 16 64 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True), 65 | nn.ReLU(True), 66 | # size: (fBase * 2) x 32 x 32 67 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True), 68 | nn.ReLU(True), 69 | # size: (fBase) x 64 x 64 70 | nn.ConvTranspose2d(fBase, imgChans, 4, 2, 1, bias=True), 71 | nn.Sigmoid() 72 | # Output size: 3 x 128 x 128 73 | ]) 74 | self.dec = nn.Sequential(*modules) 75 | 76 | def forward(self, z): 77 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers 78 | out = self.dec(z.view(-1, *z.size()[-3:])) 79 | out = out.view(*z.size()[:-3], *out.size()[1:]) 80 | return out, torch.tensor(0.01).to(z.device) 81 | 82 | 83 | class CUB_Image(VAE): 84 | """ Derive a specific sub-class of a VAE for a CNN sentence model. """ 85 | 86 | def __init__(self, params): 87 | super(CUB_Image, self).__init__( 88 | dist.Laplace, # prior 89 | dist.Laplace, # likelihood 90 | dist.Laplace, # posterior 91 | Enc(params.latent_dim), 92 | Dec(params.latent_dim), 93 | params 94 | ) 95 | grad = {'requires_grad': params.learn_prior} 96 | self._pz_params = nn.ParameterList([ 97 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 98 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 99 | ]) 100 | self.modelName = 'cubI' 101 | self.dataSize = torch.Size([3, 64, 64]) 102 | self.llik_scaling = 1. 103 | 104 | @property 105 | def pz_params(self): 106 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta 107 | 108 | # remember that when combining with captions, this should be x10 109 | def getDataLoaders(self, batch_size, shuffle=True, device="cuda"): 110 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 111 | tx = transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor()]) 112 | train_loader = torch.utils.data.DataLoader( 113 | datasets.ImageFolder('../data/cub/train', transform=tx), 114 | batch_size=batch_size, shuffle=shuffle, **kwargs) 115 | test_loader = torch.utils.data.DataLoader( 116 | datasets.ImageFolder('../data/cub/test', transform=tx), 117 | batch_size=batch_size, shuffle=shuffle, **kwargs) 118 | return train_loader, test_loader 119 | 120 | def generate(self, runPath, epoch): 121 | N, K = 64, 9 122 | samples = super(CUB_Image, self).generate(N, K).data.cpu() 123 | # wrangle things so they come out tiled 124 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1) 125 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples.data.cpu()] 126 | save_image(torch.stack(s), 127 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch), 128 | nrow=int(sqrt(N))) 129 | 130 | def reconstruct(self, data, runPath, epoch): 131 | recon = super(CUB_Image, self).reconstruct(data[:8]) 132 | comp = torch.cat([data[:8], recon]) 133 | save_image(comp.data.cpu(), '{}/recon_{:03d}.png'.format(runPath, epoch)) 134 | 135 | def analyse(self, data, runPath, epoch): 136 | zemb, zsl, kls_df = super(CUB_Image, self).analyse(data, K=10) 137 | labels = ['Prior', self.modelName.lower()] 138 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 139 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 140 | -------------------------------------------------------------------------------- /src/models/vae_cub_image_ft.py: -------------------------------------------------------------------------------- 1 | # CUB Image feature model specification 2 | 3 | import torch 4 | import torch.distributions as dist 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | from numpy import sqrt 9 | from torchvision.utils import make_grid, save_image 10 | 11 | from datasets import CUBImageFt 12 | from utils import Constants, NN_lookup 13 | from vis import plot_embeddings, plot_kls_df 14 | from .vae import VAE 15 | 16 | # Constants 17 | imgChans = 3 18 | fBase = 64 19 | 20 | 21 | class Enc(nn.Module): 22 | """ Generate latent parameters for CUB image feature. """ 23 | 24 | def __init__(self, latent_dim, n_c): 25 | super(Enc, self).__init__() 26 | dim_hidden = 256 27 | self.enc = nn.Sequential() 28 | for i in range(int(torch.tensor(n_c / dim_hidden).log2())): 29 | self.enc.add_module("layer" + str(i), nn.Sequential( 30 | nn.Linear(n_c // (2 ** i), n_c // (2 ** (i + 1))), 31 | nn.ELU(inplace=True), 32 | )) 33 | # relies on above terminating at dim_hidden 34 | self.fc21 = nn.Linear(dim_hidden, latent_dim) 35 | self.fc22 = nn.Linear(dim_hidden, latent_dim) 36 | 37 | def forward(self, x): 38 | e = self.enc(x) 39 | return self.fc21(e), F.softplus(self.fc22(e)) + Constants.eta 40 | 41 | 42 | class Dec(nn.Module): 43 | """ Generate a CUB image feature given a sample from the latent space. """ 44 | 45 | def __init__(self, latent_dim, n_c): 46 | super(Dec, self).__init__() 47 | self.n_c = n_c 48 | dim_hidden = 256 49 | self.dec = nn.Sequential() 50 | for i in range(int(torch.tensor(n_c / dim_hidden).log2())): 51 | indim = latent_dim if i == 0 else dim_hidden * i 52 | outdim = dim_hidden if i == 0 else dim_hidden * (2 * i) 53 | self.dec.add_module("out_t" if i == 0 else "layer" + str(i) + "_t", nn.Sequential( 54 | nn.Linear(indim, outdim), 55 | nn.ELU(inplace=True), 56 | )) 57 | # relies on above terminating at n_c // 2 58 | self.fc31 = nn.Linear(n_c // 2, n_c) 59 | 60 | def forward(self, z): 61 | p = self.dec(z.view(-1, z.size(-1))) 62 | mean = self.fc31(p).view(*z.size()[:-1], -1) 63 | return mean, torch.tensor([0.01]).to(mean.device) 64 | 65 | 66 | class CUB_Image_ft(VAE): 67 | """ Derive a specific sub-class of a VAE for a CNN sentence model. """ 68 | 69 | def __init__(self, params): 70 | super(CUB_Image_ft, self).__init__( 71 | dist.Normal, # prior 72 | dist.Laplace, # likelihood 73 | dist.Normal, # posterior 74 | Enc(params.latent_dim, 2048), 75 | Dec(params.latent_dim, 2048), 76 | params 77 | ) 78 | grad = {'requires_grad': params.learn_prior} 79 | self._pz_params = nn.ParameterList([ 80 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 81 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 82 | ]) 83 | self.modelName = 'cubIft' 84 | self.dataSize = torch.Size([2048]) 85 | 86 | self.llik_scaling = 1. 87 | 88 | @property 89 | def pz_params(self): 90 | return self._pz_params[0], \ 91 | F.softplus(self._pz_params[1]) + Constants.eta 92 | 93 | # remember that when combining with captions, this should be x10 94 | def getDataLoaders(self, batch_size, shuffle=True, device="cuda"): 95 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 96 | 97 | train_dataset = CUBImageFt('../data', 'train', device) 98 | test_dataset = CUBImageFt('../data', 'test', device) 99 | train_loader = torch.utils.data.DataLoader(train_dataset, 100 | batch_size=batch_size, shuffle=shuffle, **kwargs) 101 | test_loader = torch.utils.data.DataLoader(test_dataset, 102 | batch_size=batch_size, shuffle=shuffle, **kwargs) 103 | 104 | train_dataset._load_data() 105 | test_dataset._load_data() 106 | self.unproject = lambda emb_h, search_split='train', \ 107 | te=train_dataset.ft_mat, td=train_dataset.data_mat, \ 108 | se=test_dataset.ft_mat, sd=test_dataset.data_mat: \ 109 | NN_lookup(emb_h, te, td) if search_split == 'train' else NN_lookup(emb_h, se, sd) 110 | 111 | return train_loader, test_loader 112 | 113 | def generate(self, runPath, epoch): 114 | N, K = 64, 9 115 | samples = super(CUB_Image_ft, self).generate(N, K).data.cpu() 116 | samples = self.unproject(samples, search_split='train') 117 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1) 118 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples.data.cpu()] 119 | save_image(torch.stack(s), 120 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch), 121 | nrow=int(sqrt(N))) 122 | 123 | def reconstruct(self, data, runPath, epoch): 124 | recon = super(CUB_Image_ft, self).reconstruct(data[:8]) 125 | data_ = self.unproject(data[:8], search_split='test') 126 | recon_ = self.unproject(recon, search_split='train') 127 | comp = torch.cat([data_, recon_]) 128 | save_image(comp.data.cpu(), '{}/recon_{:03d}.png'.format(runPath, epoch)) 129 | 130 | def analyse(self, data, runPath, epoch): 131 | zemb, zsl, kls_df = super(CUB_Image_ft, self).analyse(data, K=10) 132 | labels = ['Prior', self.modelName.lower()] 133 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 134 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 135 | -------------------------------------------------------------------------------- /src/models/vae_cub_sent.py: -------------------------------------------------------------------------------- 1 | # Sentence model specification - real CUB image version 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributions as dist 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | from torch.utils.data import DataLoader 12 | 13 | from datasets import CUBSentences 14 | from utils import Constants, FakeCategorical 15 | from .vae import VAE 16 | 17 | # Constants 18 | maxSentLen = 32 # max length of any description for birds dataset 19 | minOccur = 3 20 | embeddingDim = 128 21 | lenWindow = 3 22 | fBase = 32 23 | vocabSize = 1590 24 | vocab_path = '../data/cub/oc:{}_sl:{}_s:{}_w:{}/cub.vocab'.format(minOccur, maxSentLen, 300, lenWindow) 25 | 26 | 27 | # Classes 28 | class Enc(nn.Module): 29 | """ Generate latent parameters for sentence data. """ 30 | 31 | def __init__(self, latentDim): 32 | super(Enc, self).__init__() 33 | self.embedding = nn.Embedding(vocabSize, embeddingDim, padding_idx=0) 34 | self.enc = nn.Sequential( 35 | # input size: 1 x 32 x 128 36 | nn.Conv2d(1, fBase, 4, 2, 1, bias=False), 37 | nn.BatchNorm2d(fBase), 38 | nn.ReLU(True), 39 | # size: (fBase) x 16 x 64 40 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=False), 41 | nn.BatchNorm2d(fBase * 2), 42 | nn.ReLU(True), 43 | # size: (fBase * 2) x 8 x 32 44 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=False), 45 | nn.BatchNorm2d(fBase * 4), 46 | nn.ReLU(True), 47 | # # size: (fBase * 4) x 4 x 16 48 | nn.Conv2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False), 49 | nn.BatchNorm2d(fBase * 4), 50 | nn.ReLU(True), 51 | # size: (fBase * 8) x 4 x 8 52 | nn.Conv2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False), 53 | nn.BatchNorm2d(fBase * 4), 54 | nn.ReLU(True), 55 | # size: (fBase * 8) x 4 x 4 56 | ) 57 | self.c1 = nn.Conv2d(fBase * 4, latentDim, 4, 1, 0, bias=False) 58 | self.c2 = nn.Conv2d(fBase * 4, latentDim, 4, 1, 0, bias=False) 59 | # c1, c2 size: latentDim x 1 x 1 60 | 61 | def forward(self, x): 62 | e = self.enc(self.embedding(x.long()).unsqueeze(1)) 63 | mu, logvar = self.c1(e).squeeze(), self.c2(e).squeeze() 64 | return mu, F.softplus(logvar) + Constants.eta 65 | 66 | 67 | class Dec(nn.Module): 68 | """ Generate a sentence given a sample from the latent space. """ 69 | 70 | def __init__(self, latentDim): 71 | super(Dec, self).__init__() 72 | self.dec = nn.Sequential( 73 | nn.ConvTranspose2d(latentDim, fBase * 4, 4, 1, 0, bias=False), 74 | nn.BatchNorm2d(fBase * 4), 75 | nn.ReLU(True), 76 | # size: (fBase * 8) x 4 x 4 77 | nn.ConvTranspose2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False), 78 | nn.BatchNorm2d(fBase * 4), 79 | nn.ReLU(True), 80 | # size: (fBase * 8) x 4 x 8 81 | nn.ConvTranspose2d(fBase * 4, fBase * 4, (1, 4), (1, 2), (0, 1), bias=False), 82 | nn.BatchNorm2d(fBase * 4), 83 | nn.ReLU(True), 84 | # size: (fBase * 4) x 8 x 32 85 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=False), 86 | nn.BatchNorm2d(fBase * 2), 87 | nn.ReLU(True), 88 | # size: (fBase * 2) x 16 x 64 89 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=False), 90 | nn.BatchNorm2d(fBase), 91 | nn.ReLU(True), 92 | # size: (fBase) x 32 x 128 93 | nn.ConvTranspose2d(fBase, 1, 4, 2, 1, bias=False), 94 | nn.ReLU(True) 95 | # Output size: 1 x 64 x 256 96 | ) 97 | # inverts the 'embedding' module upto one-hotness 98 | self.toVocabSize = nn.Linear(embeddingDim, vocabSize) 99 | 100 | def forward(self, z): 101 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers 102 | out = self.dec(z.view(-1, *z.size()[-3:])).view(-1, embeddingDim) 103 | 104 | return self.toVocabSize(out).view(*z.size()[:-3], maxSentLen, vocabSize), 105 | 106 | 107 | class CUB_Sentence(VAE): 108 | """ Derive a specific sub-class of a VAE for a sentence model. """ 109 | 110 | def __init__(self, params): 111 | super(CUB_Sentence, self).__init__( 112 | prior_dist=dist.Normal, 113 | likelihood_dist=FakeCategorical, 114 | post_dist=dist.Normal, 115 | enc=Enc(params.latent_dim), 116 | dec=Dec(params.latent_dim), 117 | params=params) 118 | grad = {'requires_grad': params.learn_prior} 119 | self._pz_params = nn.ParameterList([ 120 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 121 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 122 | ]) 123 | self.modelName = 'cubS' 124 | self.llik_scaling = 1. 125 | 126 | self.tie_modules() 127 | 128 | self.fn_2i = lambda t: t.cpu().numpy().astype(int) 129 | self.fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s 130 | self.vocab_file = vocab_path 131 | 132 | self.maxSentLen = maxSentLen 133 | self.vocabSize = vocabSize 134 | 135 | def tie_modules(self): 136 | # This looks dumb, but is actually dumber than you might realise. 137 | # A linear(a, b) module has a [b x a] weight matrix, but an embedding(a, b) 138 | # module has a [a x b] weight matrix. So when we want the transpose at 139 | # decoding time, we just use the weight matrix as is. 140 | self.dec.toVocabSize.weight = self.enc.embedding.weight 141 | 142 | @property 143 | def pz_params(self): 144 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta 145 | 146 | @staticmethod 147 | def getDataLoaders(batch_size, shuffle=True, device="cuda"): 148 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 149 | tx = lambda data: torch.Tensor(data) 150 | t_data = CUBSentences('../data', split='train', transform=tx, max_sequence_length=maxSentLen) 151 | s_data = CUBSentences('../data', split='test', transform=tx, max_sequence_length=maxSentLen) 152 | 153 | train_loader = DataLoader(t_data, batch_size=batch_size, shuffle=shuffle, **kwargs) 154 | test_loader = DataLoader(s_data, batch_size=batch_size, shuffle=shuffle, **kwargs) 155 | 156 | return train_loader, test_loader 157 | 158 | def reconstruct(self, data, runPath, epoch): 159 | recon = super(CUB_Sentence, self).reconstruct(data[:8]).argmax(dim=-1).squeeze() 160 | recon, data = self.fn_2i(recon), self.fn_2i(data[:8]) 161 | recon, data = [self.fn_trun(r) for r in recon], [self.fn_trun(d) for d in data] 162 | i2w = self.load_vocab() 163 | print("\n Reconstruction examples (excluding ):") 164 | for r_sent, d_sent in zip(recon[:3], data[:3]): 165 | print('[DATA] ==> {}'.format(' '.join(i2w[str(i)] for i in d_sent))) 166 | print('[RECON] ==> {}\n'.format(' '.join(i2w[str(i)] for i in r_sent))) 167 | 168 | with open('{}/recon_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file: 169 | for r_sent, d_sent in zip(recon, data): 170 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(i2w[str(i)] for i in d_sent))) 171 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(i2w[str(i)] for i in r_sent))) 172 | 173 | def generate(self, runPath, epoch): 174 | N, K = 5, 4 175 | i2w = self.load_vocab() 176 | samples = super(CUB_Sentence, self).generate(N, K).argmax(dim=-1).squeeze() 177 | samples = samples.view(K, N, samples.size(-1)).transpose(0, 1) # N x K x 64 178 | samples = [[self.fn_trun(s) for s in ss] for ss in self.fn_2i(samples)] 179 | # samples = [self.fn_trun(s) for s in samples] 180 | print("\n Generated examples (excluding ):") 181 | for s_sent in samples[0][:3]: 182 | print('[GEN] ==> {}'.format(' '.join(i2w[str(i)] for i in s_sent if i != 0))) 183 | 184 | with open('{}/gen_samples_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file: 185 | for s_sents in samples: 186 | for s_sent in s_sents: 187 | txt_file.write('{}\n'.format(' '.join(i2w[str(i)] for i in s_sent))) 188 | txt_file.write('\n') 189 | 190 | def analyse(self, data, runPath, epoch): 191 | pass 192 | 193 | def load_vocab(self): 194 | # call dataloader function to create vocab file 195 | if not os.path.exists(self.vocab_file): 196 | _, _ = self.getDataLoaders(256) 197 | with open(self.vocab_file, 'r') as vocab_file: 198 | vocab = json.load(vocab_file) 199 | return vocab['i2w'] 200 | -------------------------------------------------------------------------------- /src/models/vae_cub_sent_ft.py: -------------------------------------------------------------------------------- 1 | # Sentence model specification - CUB image feature version 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributions as dist 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | from torch.utils.data import DataLoader 12 | 13 | from datasets import CUBSentences 14 | from utils import Constants, FakeCategorical 15 | from .vae import VAE 16 | 17 | maxSentLen = 32 # max length of any description for birds dataset 18 | minOccur = 3 19 | embeddingDim = 128 20 | lenWindow = 3 21 | fBase = 32 22 | vocabSize = 1590 23 | vocab_path = '../data/cub/oc:{}_sl:{}_s:{}_w:{}/cub.vocab'.format(minOccur, maxSentLen, 300, lenWindow) 24 | 25 | 26 | # Classes 27 | class Enc(nn.Module): 28 | """ Generate latent parameters for sentence data. """ 29 | 30 | def __init__(self, latentDim): 31 | super(Enc, self).__init__() 32 | self.embedding = nn.Embedding(vocabSize, embeddingDim, padding_idx=0) 33 | self.enc = nn.Sequential( 34 | # input size: 1 x 32 x 128 35 | nn.Conv2d(1, fBase, 4, 2, 1, bias=True), 36 | nn.BatchNorm2d(fBase), 37 | nn.ReLU(True), 38 | # size: (fBase) x 16 x 64 39 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True), 40 | nn.BatchNorm2d(fBase * 2), 41 | nn.ReLU(True), 42 | # size: (fBase * 2) x 8 x 32 43 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True), 44 | nn.BatchNorm2d(fBase * 4), 45 | nn.ReLU(True), 46 | # # size: (fBase * 4) x 4 x 16 47 | nn.Conv2d(fBase * 4, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True), 48 | nn.BatchNorm2d(fBase * 8), 49 | nn.ReLU(True), 50 | # size: (fBase * 8) x 4 x 8 51 | nn.Conv2d(fBase * 8, fBase * 16, (1, 4), (1, 2), (0, 1), bias=True), 52 | nn.BatchNorm2d(fBase * 16), 53 | nn.ReLU(True), 54 | # size: (fBase * 8) x 4 x 4 55 | ) 56 | self.c1 = nn.Conv2d(fBase * 16, latentDim, 4, 1, 0, bias=True) 57 | self.c2 = nn.Conv2d(fBase * 16, latentDim, 4, 1, 0, bias=True) 58 | # c1, c2 size: latentDim x 1 x 1 59 | 60 | def forward(self, x): 61 | e = self.enc(self.embedding(x.long()).unsqueeze(1)) 62 | mu, logvar = self.c1(e).squeeze(), self.c2(e).squeeze() 63 | return mu, F.softplus(logvar) + Constants.eta 64 | 65 | 66 | class Dec(nn.Module): 67 | """ Generate a sentence given a sample from the latent space. """ 68 | 69 | def __init__(self, latentDim): 70 | super(Dec, self).__init__() 71 | self.dec = nn.Sequential( 72 | nn.ConvTranspose2d(latentDim, fBase * 16, 4, 1, 0, bias=True), 73 | nn.BatchNorm2d(fBase * 16), 74 | nn.ReLU(True), 75 | # size: (fBase * 8) x 4 x 4 76 | nn.ConvTranspose2d(fBase * 16, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True), 77 | nn.BatchNorm2d(fBase * 8), 78 | nn.ReLU(True), 79 | # size: (fBase * 8) x 4 x 8 80 | nn.ConvTranspose2d(fBase * 8, fBase * 4, (1, 4), (1, 2), (0, 1), bias=True), 81 | nn.BatchNorm2d(fBase * 4), 82 | nn.ReLU(True), 83 | # size: (fBase * 4) x 8 x 32 84 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True), 85 | nn.BatchNorm2d(fBase * 2), 86 | nn.ReLU(True), 87 | # size: (fBase * 2) x 16 x 64 88 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True), 89 | nn.BatchNorm2d(fBase), 90 | nn.ReLU(True), 91 | # size: (fBase) x 32 x 128 92 | nn.ConvTranspose2d(fBase, 1, 4, 2, 1, bias=True), 93 | nn.ReLU(True) 94 | # Output size: 1 x 64 x 256 95 | ) 96 | # inverts the 'embedding' module upto one-hotness 97 | self.toVocabSize = nn.Linear(embeddingDim, vocabSize) 98 | 99 | def forward(self, z): 100 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers 101 | out = self.dec(z.view(-1, *z.size()[-3:])).view(-1, embeddingDim) 102 | 103 | return self.toVocabSize(out).view(*z.size()[:-3], maxSentLen, vocabSize), 104 | 105 | 106 | class CUB_Sentence_ft(VAE): 107 | """ Derive a specific sub-class of a VAE for a sentence model. """ 108 | 109 | def __init__(self, params): 110 | super(CUB_Sentence_ft, self).__init__( 111 | prior_dist=dist.Normal, 112 | likelihood_dist=FakeCategorical, 113 | post_dist=dist.Normal, 114 | enc=Enc(params.latent_dim), 115 | dec=Dec(params.latent_dim), 116 | params=params) 117 | grad = {'requires_grad': params.learn_prior} 118 | self._pz_params = nn.ParameterList([ 119 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 120 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 121 | ]) 122 | self.modelName = 'cubSft' 123 | self.llik_scaling = 1. 124 | 125 | self.tie_modules() 126 | 127 | self.fn_2i = lambda t: t.cpu().numpy().astype(int) 128 | self.fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s 129 | self.vocab_file = vocab_path 130 | 131 | self.maxSentLen = maxSentLen 132 | self.vocabSize = vocabSize 133 | 134 | def tie_modules(self): 135 | # This looks dumb, but is actually dumber than you might realise. 136 | # A linear(a, b) module has a [b x a] weight matrix, but an embedding(a, b) 137 | # module has a [a x b] weight matrix. So when we want the transpose at 138 | # decoding time, we just use the weight matrix as is. 139 | self.dec.toVocabSize.weight = self.enc.embedding.weight 140 | 141 | @property 142 | def pz_params(self): 143 | return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta 144 | 145 | @staticmethod 146 | def getDataLoaders(batch_size, shuffle=True, device="cuda"): 147 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 148 | tx = lambda data: torch.Tensor(data) 149 | t_data = CUBSentences('../data', split='train', transform=tx, max_sequence_length=maxSentLen) 150 | s_data = CUBSentences('../data', split='test', transform=tx, max_sequence_length=maxSentLen) 151 | 152 | train_loader = DataLoader(t_data, batch_size=batch_size, shuffle=shuffle, **kwargs) 153 | test_loader = DataLoader(s_data, batch_size=batch_size, shuffle=shuffle, **kwargs) 154 | 155 | return train_loader, test_loader 156 | 157 | def reconstruct(self, data, runPath, epoch): 158 | recon = super(CUB_Sentence_ft, self).reconstruct(data[:8]).argmax(dim=-1).squeeze() 159 | recon, data = self.fn_2i(recon), self.fn_2i(data[:8]) 160 | recon, data = [self.fn_trun(r) for r in recon], [self.fn_trun(d) for d in data] 161 | i2w = self.load_vocab() 162 | print("\n Reconstruction examples (excluding ):") 163 | for r_sent, d_sent in zip(recon[:3], data[:3]): 164 | print('[DATA] ==> {}'.format(' '.join(i2w[str(i)] for i in d_sent))) 165 | print('[RECON] ==> {}\n'.format(' '.join(i2w[str(i)] for i in r_sent))) 166 | 167 | with open('{}/recon_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file: 168 | for r_sent, d_sent in zip(recon, data): 169 | txt_file.write('[DATA] ==> {}\n'.format(' '.join(i2w[str(i)] for i in d_sent))) 170 | txt_file.write('[RECON] ==> {}\n\n'.format(' '.join(i2w[str(i)] for i in r_sent))) 171 | 172 | def generate(self, runPath, epoch): 173 | N, K = 5, 4 174 | i2w = self.load_vocab() 175 | samples = super(CUB_Sentence_ft, self).generate(N, K).argmax(dim=-1).squeeze() 176 | samples = samples.view(K, N, samples.size(-1)).transpose(0, 1) # N x K x 64 177 | samples = [[self.fn_trun(s) for s in ss] for ss in self.fn_2i(samples)] 178 | # samples = [self.fn_trun(s) for s in samples] 179 | print("\n Generated examples (excluding ):") 180 | for s_sent in samples[0][:3]: 181 | print('[GEN] ==> {}'.format(' '.join(i2w[str(i)] for i in s_sent if i != 0))) 182 | 183 | with open('{}/gen_samples_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file: 184 | for s_sents in samples: 185 | for s_sent in s_sents: 186 | txt_file.write('{}\n'.format(' '.join(i2w[str(i)] for i in s_sent))) 187 | txt_file.write('\n') 188 | 189 | def analyse(self, data, runPath, epoch): 190 | pass 191 | 192 | def load_vocab(self): 193 | # call dataloader function to create vocab file 194 | if not os.path.exists(self.vocab_file): 195 | _, _ = self.getDataLoaders(256) 196 | with open(self.vocab_file, 'r') as vocab_file: 197 | vocab = json.load(vocab_file) 198 | return vocab['i2w'] 199 | -------------------------------------------------------------------------------- /src/models/vae_mnist.py: -------------------------------------------------------------------------------- 1 | # MNIST model specification 2 | 3 | import torch 4 | import torch.distributions as dist 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from numpy import prod, sqrt 8 | from torch.utils.data import DataLoader 9 | from torchvision import datasets, transforms 10 | from torchvision.utils import save_image, make_grid 11 | 12 | from utils import Constants 13 | from vis import plot_embeddings, plot_kls_df 14 | from .vae import VAE 15 | 16 | # Constants 17 | dataSize = torch.Size([1, 28, 28]) 18 | data_dim = int(prod(dataSize)) 19 | hidden_dim = 400 20 | 21 | 22 | def extra_hidden_layer(): 23 | return nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True)) 24 | 25 | 26 | # Classes 27 | class Enc(nn.Module): 28 | """ Generate latent parameters for MNIST image data. """ 29 | 30 | def __init__(self, latent_dim, num_hidden_layers=1): 31 | super(Enc, self).__init__() 32 | modules = [] 33 | modules.append(nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.ReLU(True))) 34 | modules.extend([extra_hidden_layer() for _ in range(num_hidden_layers - 1)]) 35 | self.enc = nn.Sequential(*modules) 36 | self.fc21 = nn.Linear(hidden_dim, latent_dim) 37 | self.fc22 = nn.Linear(hidden_dim, latent_dim) 38 | 39 | def forward(self, x): 40 | e = self.enc(x.view(*x.size()[:-3], -1)) # flatten data 41 | lv = self.fc22(e) 42 | return self.fc21(e), F.softmax(lv, dim=-1) * lv.size(-1) + Constants.eta 43 | 44 | 45 | class Dec(nn.Module): 46 | """ Generate an MNIST image given a sample from the latent space. """ 47 | 48 | def __init__(self, latent_dim, num_hidden_layers=1): 49 | super(Dec, self).__init__() 50 | modules = [] 51 | modules.append(nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.ReLU(True))) 52 | modules.extend([extra_hidden_layer() for _ in range(num_hidden_layers - 1)]) 53 | self.dec = nn.Sequential(*modules) 54 | self.fc3 = nn.Linear(hidden_dim, data_dim) 55 | 56 | def forward(self, z): 57 | p = self.fc3(self.dec(z)) 58 | d = torch.sigmoid(p.view(*z.size()[:-1], *dataSize)) # reshape data 59 | d = d.clamp(Constants.eta, 1 - Constants.eta) 60 | 61 | return d, torch.tensor(0.75).to(z.device) # mean, length scale 62 | 63 | 64 | class MNIST(VAE): 65 | """ Derive a specific sub-class of a VAE for MNIST. """ 66 | 67 | def __init__(self, params): 68 | super(MNIST, self).__init__( 69 | dist.Laplace, # prior 70 | dist.Laplace, # likelihood 71 | dist.Laplace, # posterior 72 | Enc(params.latent_dim, params.num_hidden_layers), 73 | Dec(params.latent_dim, params.num_hidden_layers), 74 | params 75 | ) 76 | grad = {'requires_grad': params.learn_prior} 77 | self._pz_params = nn.ParameterList([ 78 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 79 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 80 | ]) 81 | self.modelName = 'mnist' 82 | self.dataSize = dataSize 83 | self.llik_scaling = 1. 84 | 85 | @property 86 | def pz_params(self): 87 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1) 88 | 89 | @staticmethod 90 | def getDataLoaders(batch_size, shuffle=True, device="cuda"): 91 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {} 92 | tx = transforms.ToTensor() 93 | train = DataLoader(datasets.MNIST('../data', train=True, download=True, transform=tx), 94 | batch_size=batch_size, shuffle=shuffle, **kwargs) 95 | test = DataLoader(datasets.MNIST('../data', train=False, download=True, transform=tx), 96 | batch_size=batch_size, shuffle=shuffle, **kwargs) 97 | return train, test 98 | 99 | def generate(self, runPath, epoch): 100 | N, K = 64, 9 101 | samples = super(MNIST, self).generate(N, K).cpu() 102 | # wrangle things so they come out tiled 103 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1) # N x K x 1 x 28 x 28 104 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples] 105 | save_image(torch.stack(s), 106 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch), 107 | nrow=int(sqrt(N))) 108 | 109 | def reconstruct(self, data, runPath, epoch): 110 | recon = super(MNIST, self).reconstruct(data[:8]) 111 | comp = torch.cat([data[:8], recon]).data.cpu() 112 | save_image(comp, '{}/recon_{:03d}.png'.format(runPath, epoch)) 113 | 114 | def analyse(self, data, runPath, epoch): 115 | zemb, zsl, kls_df = super(MNIST, self).analyse(data, K=10) 116 | labels = ['Prior', self.modelName.lower()] 117 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 118 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 119 | -------------------------------------------------------------------------------- /src/models/vae_svhn.py: -------------------------------------------------------------------------------- 1 | # SVHN model specification 2 | 3 | import torch 4 | import torch.distributions as dist 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from numpy import sqrt 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms, datasets 10 | from torchvision.utils import save_image, make_grid 11 | 12 | from utils import Constants 13 | from vis import plot_embeddings, plot_kls_df 14 | from .vae import VAE 15 | 16 | # Constants 17 | dataSize = torch.Size([3, 32, 32]) 18 | imgChans = dataSize[0] 19 | fBase = 32 # base size of filter channels 20 | 21 | 22 | # Classes 23 | class Enc(nn.Module): 24 | """ Generate latent parameters for SVHN image data. """ 25 | 26 | def __init__(self, latent_dim): 27 | super(Enc, self).__init__() 28 | self.enc = nn.Sequential( 29 | # input size: 3 x 32 x 32 30 | nn.Conv2d(imgChans, fBase, 4, 2, 1, bias=True), 31 | nn.ReLU(True), 32 | # size: (fBase) x 16 x 16 33 | nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True), 34 | nn.ReLU(True), 35 | # size: (fBase * 2) x 8 x 8 36 | nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True), 37 | nn.ReLU(True), 38 | # size: (fBase * 4) x 4 x 4 39 | ) 40 | self.c1 = nn.Conv2d(fBase * 4, latent_dim, 4, 1, 0, bias=True) 41 | self.c2 = nn.Conv2d(fBase * 4, latent_dim, 4, 1, 0, bias=True) 42 | # c1, c2 size: latent_dim x 1 x 1 43 | 44 | def forward(self, x): 45 | e = self.enc(x) 46 | lv = self.c2(e).squeeze() 47 | return self.c1(e).squeeze(), F.softmax(lv, dim=-1) * lv.size(-1) + Constants.eta 48 | 49 | 50 | class Dec(nn.Module): 51 | """ Generate a SVHN image given a sample from the latent space. """ 52 | 53 | def __init__(self, latent_dim): 54 | super(Dec, self).__init__() 55 | self.dec = nn.Sequential( 56 | nn.ConvTranspose2d(latent_dim, fBase * 4, 4, 1, 0, bias=True), 57 | nn.ReLU(True), 58 | # size: (fBase * 4) x 4 x 4 59 | nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True), 60 | nn.ReLU(True), 61 | # size: (fBase * 2) x 8 x 8 62 | nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True), 63 | nn.ReLU(True), 64 | # size: (fBase) x 16 x 16 65 | nn.ConvTranspose2d(fBase, imgChans, 4, 2, 1, bias=True), 66 | nn.Sigmoid() 67 | # Output size: 3 x 32 x 32 68 | ) 69 | 70 | def forward(self, z): 71 | z = z.unsqueeze(-1).unsqueeze(-1) # fit deconv layers 72 | out = self.dec(z.view(-1, *z.size()[-3:])) 73 | out = out.view(*z.size()[:-3], *out.size()[1:]) 74 | # consider also predicting the length scale 75 | return out, torch.tensor(0.75).to(z.device) # mean, length scale 76 | 77 | 78 | class SVHN(VAE): 79 | """ Derive a specific sub-class of a VAE for SVHN """ 80 | 81 | def __init__(self, params): 82 | super(SVHN, self).__init__( 83 | dist.Laplace, # prior 84 | dist.Laplace, # likelihood 85 | dist.Laplace, # posterior 86 | Enc(params.latent_dim), 87 | Dec(params.latent_dim), 88 | params 89 | ) 90 | grad = {'requires_grad': params.learn_prior} 91 | self._pz_params = nn.ParameterList([ 92 | nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu 93 | nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar 94 | ]) 95 | self.modelName = 'svhn' 96 | self.dataSize = dataSize 97 | self.llik_scaling = 1. 98 | 99 | @property 100 | def pz_params(self): 101 | return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1) 102 | 103 | @staticmethod 104 | def getDataLoaders(batch_size, shuffle=True, device='cuda'): 105 | kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {} 106 | tx = transforms.ToTensor() 107 | train = DataLoader(datasets.SVHN('../data', split='train', download=True, transform=tx), 108 | batch_size=batch_size, shuffle=shuffle, **kwargs) 109 | test = DataLoader(datasets.SVHN('../data', split='test', download=True, transform=tx), 110 | batch_size=batch_size, shuffle=shuffle, **kwargs) 111 | return train, test 112 | 113 | def generate(self, runPath, epoch): 114 | N, K = 64, 9 115 | samples = super(SVHN, self).generate(N, K).cpu() 116 | # wrangle things so they come out tiled 117 | samples = samples.view(K, N, *samples.size()[1:]).transpose(0, 1) 118 | s = [make_grid(t, nrow=int(sqrt(K)), padding=0) for t in samples] 119 | save_image(torch.stack(s), 120 | '{}/gen_samples_{:03d}.png'.format(runPath, epoch), 121 | nrow=int(sqrt(N))) 122 | 123 | def reconstruct(self, data, runPath, epoch): 124 | recon = super(SVHN, self).reconstruct(data[:8]) 125 | comp = torch.cat([data[:8], recon]).data.cpu() 126 | save_image(comp, '{}/recon_{:03d}.png'.format(runPath, epoch)) 127 | 128 | def analyse(self, data, runPath, epoch): 129 | zemb, zsl, kls_df = super(SVHN, self).analyse(data, K=10) 130 | labels = ['Prior', self.modelName.lower()] 131 | plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch)) 132 | plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch)) 133 | -------------------------------------------------------------------------------- /src/objectives.py: -------------------------------------------------------------------------------- 1 | # objectives of choice 2 | import torch 3 | from numpy import prod 4 | 5 | from utils import log_mean_exp, is_multidata, kl_divergence 6 | 7 | 8 | # helper to vectorise computation 9 | def compute_microbatch_split(x, K): 10 | """ Checks if batch needs to be broken down further to fit in memory. """ 11 | B = x[0].size(0) if is_multidata(x) else x.size(0) 12 | S = sum([1.0 / (K * prod(_x.size()[1:])) for _x in x]) if is_multidata(x) \ 13 | else 1.0 / (K * prod(x.size()[1:])) 14 | S = int(1e8 * S) # float heuristic for 12Gb cuda memory 15 | assert (S > 0), "Cannot fit individual data in memory, consider smaller K" 16 | return min(B, S) 17 | 18 | 19 | def elbo(model, x, K=1): 20 | """Computes E_{p(x)}[ELBO] """ 21 | qz_x, px_z, _ = model(x) 22 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling 23 | kld = kl_divergence(qz_x, model.pz(*model.pz_params)) 24 | return (lpx_z.sum(-1) - kld.sum(-1)).mean(0).sum() 25 | 26 | 27 | def _iwae(model, x, K): 28 | """IWAE estimate for log p_\theta(x) -- fully vectorised.""" 29 | qz_x, px_z, zs = model(x, K) 30 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1) 31 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling 32 | lqz_x = qz_x.log_prob(zs).sum(-1) 33 | return lpz + lpx_z.sum(-1) - lqz_x 34 | 35 | 36 | def iwae(model, x, K): 37 | """Computes an importance-weighted ELBO estimate for log p_\theta(x) 38 | Iterates over the batch as necessary. 39 | """ 40 | S = compute_microbatch_split(x, K) 41 | lw = torch.cat([_iwae(model, _x, K) for _x in x.split(S)], 1) # concat on batch 42 | return log_mean_exp(lw).sum() 43 | 44 | 45 | def _dreg(model, x, K): 46 | """DREG estimate for log p_\theta(x) -- fully vectorised.""" 47 | _, px_z, zs = model(x, K) 48 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1) 49 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling 50 | qz_x = model.qz_x(*[p.detach() for p in model.qz_x_params]) # stop-grad for \phi 51 | lqz_x = qz_x.log_prob(zs).sum(-1) 52 | lw = lpz + lpx_z.sum(-1) - lqz_x 53 | return lw, zs 54 | 55 | 56 | def dreg(model, x, K, regs=None): 57 | """Computes a doubly-reparameterised importance-weighted ELBO estimate for log p_\theta(x) 58 | Iterates over the batch as necessary. 59 | """ 60 | S = compute_microbatch_split(x, K) 61 | lw, zs = zip(*[_dreg(model, _x, K) for _x in x.split(S)]) 62 | lw = torch.cat(lw, 1) # concat on batch 63 | zs = torch.cat(zs, 1) # concat on batch 64 | with torch.no_grad(): 65 | grad_wt = (lw - torch.logsumexp(lw, 0, keepdim=True)).exp() 66 | if zs.requires_grad: 67 | zs.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad) 68 | return (grad_wt * lw).sum() 69 | 70 | 71 | # multi-modal variants 72 | def m_elbo_naive(model, x, K=1): 73 | """Computes E_{p(x)}[ELBO] for multi-modal vae --- NOT EXPOSED""" 74 | qz_xs, px_zs, zss = model(x) 75 | lpx_zs, klds = [], [] 76 | for r, qz_x in enumerate(qz_xs): 77 | kld = kl_divergence(qz_x, model.pz(*model.pz_params)) 78 | klds.append(kld.sum(-1)) 79 | for d, px_z in enumerate(px_zs[r]): 80 | lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling 81 | lpx_zs.append(lpx_z.view(*px_z.batch_shape[:2], -1).sum(-1)) 82 | obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0)) 83 | return obj.mean(0).sum() 84 | 85 | 86 | def m_elbo(model, x, K=1): 87 | """Computes importance-sampled m_elbo (in notes3) for multi-modal vae """ 88 | qz_xs, px_zs, zss = model(x) 89 | lpx_zs, klds = [], [] 90 | for r, qz_x in enumerate(qz_xs): 91 | kld = kl_divergence(qz_x, model.pz(*model.pz_params)) 92 | klds.append(kld.sum(-1)) 93 | for d in range(len(px_zs)): 94 | lpx_z = px_zs[d][d].log_prob(x[d]).view(*px_zs[d][d].batch_shape[:2], -1) 95 | lpx_z = (lpx_z * model.vaes[d].llik_scaling).sum(-1) 96 | if d == r: 97 | lwt = torch.tensor(0.0) 98 | else: 99 | zs = zss[d].detach() 100 | lwt = (qz_x.log_prob(zs) - qz_xs[d].log_prob(zs).detach()).sum(-1) 101 | lpx_zs.append(lwt.exp() * lpx_z) 102 | obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0)) 103 | return obj.mean(0).sum() 104 | 105 | 106 | def _m_iwae(model, x, K=1): 107 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised""" 108 | qz_xs, px_zs, zss = model(x, K) 109 | lws = [] 110 | for r, qz_x in enumerate(qz_xs): 111 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1) 112 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs])) 113 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1) 114 | .mul(model.vaes[d].llik_scaling).sum(-1) 115 | for d, px_z in enumerate(px_zs[r])] 116 | lpx_z = torch.stack(lpx_z).sum(0) 117 | lw = lpz + lpx_z - lqz_x 118 | lws.append(lw) 119 | return torch.cat(lws) # (n_modality * n_samples) x batch_size, batch_size 120 | 121 | 122 | def m_iwae(model, x, K=1): 123 | """Computes iwae estimate for log p_\theta(x) for multi-modal vae """ 124 | S = compute_microbatch_split(x, K) 125 | x_split = zip(*[_x.split(S) for _x in x]) 126 | lw = [_m_iwae(model, _x, K) for _x in x_split] 127 | lw = torch.cat(lw, 1) # concat on batch 128 | return log_mean_exp(lw).sum() 129 | 130 | 131 | def _m_iwae_looser(model, x, K=1): 132 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised 133 | This version is the looser bound---with the average over modalities outside the log 134 | """ 135 | qz_xs, px_zs, zss = model(x, K) 136 | lws = [] 137 | for r, qz_x in enumerate(qz_xs): 138 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1) 139 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs])) 140 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1) 141 | .mul(model.vaes[d].llik_scaling).sum(-1) 142 | for d, px_z in enumerate(px_zs[r])] 143 | lpx_z = torch.stack(lpx_z).sum(0) 144 | lw = lpz + lpx_z - lqz_x 145 | lws.append(lw) 146 | return torch.stack(lws) # (n_modality * n_samples) x batch_size, batch_size 147 | 148 | 149 | def m_iwae_looser(model, x, K=1): 150 | """Computes iwae estimate for log p_\theta(x) for multi-modal vae 151 | This version is the looser bound---with the average over modalities outside the log 152 | """ 153 | S = compute_microbatch_split(x, K) 154 | x_split = zip(*[_x.split(S) for _x in x]) 155 | lw = [_m_iwae_looser(model, _x, K) for _x in x_split] 156 | lw = torch.cat(lw, 2) # concat on batch 157 | return log_mean_exp(lw, dim=1).mean(0).sum() 158 | 159 | 160 | def _m_dreg(model, x, K=1): 161 | """DERG estimate for log p_\theta(x) for multi-modal vae -- fully vectorised""" 162 | qz_xs, px_zs, zss = model(x, K) 163 | qz_xs_ = [vae.qz_x(*[p.detach() for p in vae.qz_x_params]) for vae in model.vaes] 164 | lws = [] 165 | for r, vae in enumerate(model.vaes): 166 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1) 167 | lqz_x = log_mean_exp(torch.stack([qz_x_.log_prob(zss[r]).sum(-1) for qz_x_ in qz_xs_])) 168 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1) 169 | .mul(model.vaes[d].llik_scaling).sum(-1) 170 | for d, px_z in enumerate(px_zs[r])] 171 | lpx_z = torch.stack(lpx_z).sum(0) 172 | lw = lpz + lpx_z - lqz_x 173 | lws.append(lw) 174 | return torch.cat(lws), torch.cat(zss) 175 | 176 | 177 | def m_dreg(model, x, K=1): 178 | """Computes dreg estimate for log p_\theta(x) for multi-modal vae """ 179 | S = compute_microbatch_split(x, K) 180 | x_split = zip(*[_x.split(S) for _x in x]) 181 | lw, zss = zip(*[_m_dreg(model, _x, K) for _x in x_split]) 182 | lw = torch.cat(lw, 1) # concat on batch 183 | zss = torch.cat(zss, 1) # concat on batch 184 | with torch.no_grad(): 185 | grad_wt = (lw - torch.logsumexp(lw, 0, keepdim=True)).exp() 186 | if zss.requires_grad: 187 | zss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad) 188 | return (grad_wt * lw).sum() 189 | 190 | 191 | def _m_dreg_looser(model, x, K=1): 192 | """DERG estimate for log p_\theta(x) for multi-modal vae -- fully vectorised 193 | This version is the looser bound---with the average over modalities outside the log 194 | """ 195 | qz_xs, px_zs, zss = model(x, K) 196 | qz_xs_ = [vae.qz_x(*[p.detach() for p in vae.qz_x_params]) for vae in model.vaes] 197 | lws = [] 198 | for r, vae in enumerate(model.vaes): 199 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1) 200 | lqz_x = log_mean_exp(torch.stack([qz_x_.log_prob(zss[r]).sum(-1) for qz_x_ in qz_xs_])) 201 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1) 202 | .mul(model.vaes[d].llik_scaling).sum(-1) 203 | for d, px_z in enumerate(px_zs[r])] 204 | lpx_z = torch.stack(lpx_z).sum(0) 205 | lw = lpz + lpx_z - lqz_x 206 | lws.append(lw) 207 | return torch.stack(lws), torch.stack(zss) 208 | 209 | 210 | def m_dreg_looser(model, x, K=1): 211 | """Computes dreg estimate for log p_\theta(x) for multi-modal vae 212 | This version is the looser bound---with the average over modalities outside the log 213 | """ 214 | S = compute_microbatch_split(x, K) 215 | x_split = zip(*[_x.split(S) for _x in x]) 216 | lw, zss = zip(*[_m_dreg_looser(model, _x, K) for _x in x_split]) 217 | lw = torch.cat(lw, 2) # concat on batch 218 | zss = torch.cat(zss, 2) # concat on batch 219 | with torch.no_grad(): 220 | grad_wt = (lw - torch.logsumexp(lw, 1, keepdim=True)).exp() 221 | if zss.requires_grad: 222 | zss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad) 223 | return (grad_wt * lw).mean(0).sum() 224 | -------------------------------------------------------------------------------- /src/report/analyse_cub.py: -------------------------------------------------------------------------------- 1 | """Calculate cross and joint coherence of language and image generation on CUB dataset using CCA.""" 2 | import argparse 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | # relative import hack (sorry) 10 | import inspect 11 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 12 | parentdir = os.path.dirname(currentdir) 13 | sys.path.insert(0, parentdir) # for system user 14 | os.chdir(parentdir) # for pycharm user 15 | 16 | import models 17 | from utils import Logger, Timer, unpack_data 18 | from helper import cca, fetch_emb, fetch_weights, fetch_pc, apply_weights, apply_pc 19 | 20 | # variables 21 | RESET = True 22 | USE_PCA = True 23 | maxSentLen = 32 24 | minOccur = 3 25 | lenEmbedding = 300 26 | lenWindow = 3 27 | fBase = 96 28 | vocab_dir = '../data/cub/oc:{}_sl:{}_s:{}_w:{}'.format(minOccur, maxSentLen, lenEmbedding, lenWindow) 29 | batch_size = 256 30 | 31 | # args 32 | torch.backends.cudnn.benchmark = True 33 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results') 34 | parser.add_argument('--save-dir', type=str, default=".", 35 | metavar='N', help='save directory of results') 36 | parser.add_argument('--no-cuda', action='store_true', default=True, 37 | help='disables CUDA use') 38 | cmds = parser.parse_args() 39 | runPath = cmds.save_dir 40 | sys.stdout = Logger('{}/analyse.log'.format(runPath)) 41 | args = torch.load(runPath + '/args.rar') 42 | 43 | # cuda stuff 44 | needs_conversion = cmds.no_cuda and args.cuda 45 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {} 46 | args.cuda = not cmds.no_cuda and torch.cuda.is_available() 47 | device = torch.device("cuda" if args.cuda else "cpu") 48 | torch.manual_seed(args.seed) 49 | 50 | forward_args = {'drop_modality': True} if args.model == 'mcubISft' else {} 51 | 52 | # load trained model 53 | modelC = getattr(models, 'VAE_{}'.format(args.model)) 54 | model = modelC(args) 55 | if args.cuda: 56 | model.cuda() 57 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False) 58 | train_loader, test_loader = model.getDataLoaders(batch_size, device=device) 59 | N = len(test_loader.dataset) 60 | 61 | # generate word embeddings and sentence weighting 62 | emb_path = os.path.join(vocab_dir, 'cub.emb') 63 | weights_path = os.path.join(vocab_dir, 'cub.weights') 64 | vocab_path = os.path.join(vocab_dir, 'cub.vocab') 65 | pc_path = os.path.join(vocab_dir, 'cub.pc') 66 | 67 | emb = fetch_emb(lenWindow, minOccur, emb_path, vocab_path, RESET) 68 | weights = fetch_weights(weights_path, vocab_path, RESET, a=1e-3) 69 | emb = torch.from_numpy(emb).to(device) 70 | weights = torch.from_numpy(weights).to(device).type(emb.dtype) 71 | u = fetch_pc(emb, weights, train_loader, pc_path, RESET) 72 | 73 | # set up word to sentence functions 74 | fn_to_emb = lambda data, emb=emb, weights=weights, u=u: \ 75 | apply_pc(apply_weights(emb, weights, data), u) 76 | 77 | 78 | def calculate_corr(images, embeddings): 79 | global RESET 80 | if not os.path.exists(runPath + '/images_mean.pt') or RESET: 81 | generate_cca_projection() 82 | RESET = False 83 | im_mean = torch.load(runPath + '/images_mean.pt') 84 | emb_mean = torch.load(runPath + '/emb_mean.pt') 85 | im_proj = torch.load(runPath + '/im_proj.pt') 86 | emb_proj = torch.load(runPath + '/emb_proj.pt') 87 | with torch.no_grad(): 88 | corr = F.cosine_similarity((images - im_mean) @ im_proj, 89 | (embeddings - emb_mean) @ emb_proj).mean() 90 | return corr 91 | 92 | 93 | def generate_cca_projection(): 94 | images, sentences = [torch.cat(l) for l in zip(*[(d[0], d[1][0]) for d in train_loader])] 95 | emb = fn_to_emb(sentences.int()) 96 | corr, (im_proj, emb_proj) = cca([images, emb], k=40) 97 | print("Largest eigen value from CCA: {:.3f}".format(corr[0])) 98 | torch.save(images.mean(dim=0), runPath + '/images_mean.pt') 99 | torch.save(emb.mean(dim=0), runPath + '/emb_mean.pt') 100 | torch.save(im_proj, runPath + '/im_proj.pt') 101 | torch.save(emb_proj, runPath + '/emb_proj.pt') 102 | 103 | 104 | def cross_coherence(): 105 | model.eval() 106 | with torch.no_grad(): 107 | i2t = [] 108 | s2i = [] 109 | gt = [] 110 | for i, dataT in enumerate(test_loader): 111 | # get the inputs 112 | images, sentences = unpack_data(dataT, device=device) 113 | if images.shape[0] != batch_size: 114 | break 115 | _, px_zs, _ = model([images, sentences], K=1, **forward_args) 116 | cross_sentences = px_zs[0][1].mean.argmax(dim=-1).squeeze(0) 117 | cross_images = px_zs[1][0].mean.squeeze(0) 118 | # calculate correlation with CCA: 119 | i2t.append(calculate_corr(images, fn_to_emb(cross_sentences))) 120 | s2i.append(calculate_corr(cross_images, fn_to_emb(sentences.int()))) 121 | gt.append(calculate_corr(images, fn_to_emb(sentences.int()))) 122 | print("Coherence score: \nground truth {:10.9f}, \nimage to sentence {:10.9f}, " 123 | "\nsentence to image {:10.9f}".format(sum(gt) / len(gt), 124 | sum(i2t) / len(gt), 125 | sum(s2i) / len(gt))) 126 | 127 | 128 | def joint_coherence(): 129 | model.eval() 130 | with torch.no_grad(): 131 | pzs = model.pz(*model.pz_params).sample([1000]) 132 | gen_images = model.vaes[0].dec(pzs)[0].squeeze(1) 133 | gen_sentences = model.vaes[1].dec(pzs)[0].argmax(dim=-1).squeeze(1) 134 | score = calculate_corr(gen_images, fn_to_emb(gen_sentences)) 135 | print("joint generation {:10.9f}".format(score)) 136 | 137 | 138 | if __name__ == '__main__': 139 | with Timer('MM-VAE analysis') as t: 140 | print('-' * 89) 141 | cross_coherence() 142 | print('-' * 89) 143 | joint_coherence() 144 | -------------------------------------------------------------------------------- /src/report/analyse_ms.py: -------------------------------------------------------------------------------- 1 | """Calculate cross and joint coherence of trained model on MNIST-SVHN dataset. 2 | Train and evaluate a linear model for latent space digit classification.""" 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | # relative import hacks (sorry) 13 | import inspect 14 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 15 | parentdir = os.path.dirname(currentdir) 16 | sys.path.insert(0, parentdir) # for bash user 17 | os.chdir(parentdir) # for pycharm user 18 | 19 | import models 20 | from helper import Latent_Classifier, SVHN_Classifier, MNIST_Classifier 21 | from utils import Logger, Timer 22 | 23 | 24 | torch.backends.cudnn.benchmark = True 25 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results') 26 | parser.add_argument('--save-dir', type=str, default="", 27 | metavar='N', help='save directory of results') 28 | parser.add_argument('--no-cuda', action='store_true', default=False, 29 | help='disables CUDA use') 30 | cmds = parser.parse_args() 31 | runPath = cmds.save_dir 32 | 33 | sys.stdout = Logger('{}/ms_acc.log'.format(runPath)) 34 | args = torch.load(runPath + '/args.rar') 35 | 36 | # cuda stuff 37 | needs_conversion = cmds.no_cuda and args.cuda 38 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {} 39 | args.cuda = not cmds.no_cuda and torch.cuda.is_available() 40 | device = torch.device("cuda" if args.cuda else "cpu") 41 | torch.manual_seed(args.seed) 42 | 43 | modelC = getattr(models, 'VAE_{}'.format(args.model)) 44 | model = modelC(args) 45 | if args.cuda: 46 | model.cuda() 47 | 48 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False) 49 | B = 256 # rough batch size heuristic 50 | train_loader, test_loader = model.getDataLoaders(B, device=device) 51 | N = len(test_loader.dataset) 52 | 53 | 54 | def classify_latents(epochs, option): 55 | model.eval() 56 | vae = unpack_model(option) 57 | if '_' not in args.model: 58 | epochs *= 10 # account for the fact the mnist-svhn has more examples (roughly x10) 59 | classifier = Latent_Classifier(args.latent_dim, 10).to(device) 60 | criterion = nn.CrossEntropyLoss() 61 | optimizer = optim.Adam(classifier.parameters(), lr=0.001) 62 | 63 | for epoch in range(epochs): # loop over the dataset multiple times 64 | running_loss = 0.0 65 | total_iters = len(train_loader) 66 | print('\n====> Epoch: {:03d} '.format(epoch)) 67 | for i, data in enumerate(train_loader): 68 | # get the inputs 69 | x, targets = unpack_data_mlp(data, option) 70 | x, targets = x.to(device), targets.to(device) 71 | with torch.no_grad(): 72 | qz_x_params = vae.enc(x) 73 | zs = vae.qz_x(*qz_x_params).rsample() 74 | optimizer.zero_grad() 75 | outputs = classifier(zs) 76 | loss = criterion(outputs, targets) 77 | loss.backward() 78 | optimizer.step() 79 | # print statistics 80 | running_loss += loss.item() 81 | if (i + 1) % 1000 == 0: 82 | print('iteration {:04d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / 1000)) 83 | running_loss = 0.0 84 | print('Finished Training, calculating test loss...') 85 | 86 | classifier.eval() 87 | total = 0 88 | correct = 0 89 | with torch.no_grad(): 90 | for i, data in enumerate(test_loader): 91 | x, targets = unpack_data_mlp(data, option) 92 | x, targets = x.to(device), targets.to(device) 93 | qz_x_params = vae.enc(x) 94 | zs = vae.qz_x(*qz_x_params).rsample() 95 | outputs = classifier(zs) 96 | _, predicted = torch.max(outputs.data, 1) 97 | total += targets.size(0) 98 | correct += (predicted == targets).sum().item() 99 | print('The classifier correctly classified {} out of {} examples. Accuracy: ' 100 | '{:.2f}%'.format(correct, total, correct / total * 100)) 101 | 102 | 103 | def _maybe_train_or_load_digit_classifier_img(path, epochs): 104 | 105 | options = [o for o in ['mnist', 'svhn'] if not os.path.exists(path.format(o))] 106 | 107 | for option in options: 108 | print("Cannot find trained {} digit classifier in {}, training...". 109 | format(option, path.format(option))) 110 | classifier = globals()['{}_Classifier'.format(option.upper())]().to(device) 111 | criterion = nn.CrossEntropyLoss() 112 | optimizer = optim.Adam(classifier.parameters(), lr=0.001) 113 | for epoch in range(epochs): # loop over the dataset multiple times 114 | running_loss = 0.0 115 | total_iters = len(train_loader) 116 | print('\n====> Epoch: {:03d} '.format(epoch)) 117 | for i, data in enumerate(train_loader): 118 | # get the inputs 119 | x, targets = unpack_data_mlp(data, option) 120 | x, targets = x.to(device), targets.to(device) 121 | 122 | optimizer.zero_grad() 123 | outputs = classifier(x) 124 | loss = criterion(outputs, targets) 125 | loss.backward() 126 | optimizer.step() 127 | # print statistics 128 | running_loss += loss.item() 129 | if (i + 1) % 1000 == 0: 130 | print('iteration {:04d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / 1000)) 131 | running_loss = 0.0 132 | print('Finished Training, calculating test loss...') 133 | 134 | classifier.eval() 135 | total = 0 136 | correct = 0 137 | with torch.no_grad(): 138 | for i, data in enumerate(test_loader): 139 | x, targets = unpack_data_mlp(data, option) 140 | x, targets = x.to(device), targets.to(device) 141 | outputs = classifier(x) 142 | _, predicted = torch.max(outputs.data, 1) 143 | total += targets.size(0) 144 | correct += (predicted == targets).sum().item() 145 | print('The classifier correctly classified {} out of {} examples. Accuracy: ' 146 | '{:.2f}%'.format(correct, total, correct / total * 100)) 147 | 148 | torch.save(classifier.state_dict(), path.format(option)) 149 | 150 | mnist_net, svhn_net = MNIST_Classifier().to(device), SVHN_Classifier().to(device) 151 | mnist_net.load_state_dict(torch.load(path.format('mnist'))) 152 | svhn_net.load_state_dict(torch.load(path.format('svhn'))) 153 | return mnist_net, svhn_net 154 | 155 | def cross_coherence(epochs): 156 | model.eval() 157 | 158 | mnist_net, svhn_net = _maybe_train_or_load_digit_classifier_img("../data/{}_model.pt", epochs=epochs) 159 | mnist_net.eval() 160 | svhn_net.eval() 161 | 162 | total = 0 163 | corr_m = 0 164 | corr_s = 0 165 | with torch.no_grad(): 166 | for i, data in enumerate(test_loader): 167 | mnist, svhn, targets = unpack_data_mlp(data, option='both') 168 | mnist, svhn, targets = mnist.to(device), svhn.to(device), targets.to(device) 169 | _, px_zs, _ = model([mnist, svhn], 1) 170 | mnist_mnist = mnist_net(px_zs[1][0].mean.squeeze(0)) 171 | svhn_svhn = svhn_net(px_zs[0][1].mean.squeeze(0)) 172 | 173 | _, pred_m = torch.max(mnist_mnist.data, 1) 174 | _, pred_s = torch.max(svhn_svhn.data, 1) 175 | total += targets.size(0) 176 | corr_m += (pred_m == targets).sum().item() 177 | corr_s += (pred_s == targets).sum().item() 178 | 179 | print('Cross coherence: \n SVHN -> MNIST {:.2f}% \n MNIST -> SVHN {:.2f}%'.format( 180 | corr_m / total * 100, corr_s / total * 100)) 181 | 182 | 183 | def joint_coherence(): 184 | model.eval() 185 | mnist_net, svhn_net = MNIST_Classifier().to(device), SVHN_Classifier().to(device) 186 | mnist_net.load_state_dict(torch.load('../data/mnist_model.pt')) 187 | svhn_net.load_state_dict(torch.load('../data/svhn_model.pt')) 188 | 189 | mnist_net.eval() 190 | svhn_net.eval() 191 | 192 | total = 0 193 | corr = 0 194 | with torch.no_grad(): 195 | pzs = model.pz(*model.pz_params).sample([10000]) 196 | mnist = model.vaes[0].dec(pzs) 197 | svhn = model.vaes[1].dec(pzs) 198 | 199 | mnist_mnist = mnist_net(mnist[0].squeeze(1)) 200 | svhn_svhn = svhn_net(svhn[0].squeeze(1)) 201 | 202 | _, pred_m = torch.max(mnist_mnist.data, 1) 203 | _, pred_s = torch.max(svhn_svhn.data, 1) 204 | total += pred_m.size(0) 205 | corr += (pred_m == pred_s).sum().item() 206 | 207 | print('Joint coherence: {:.2f}%'.format(corr / total * 100)) 208 | 209 | 210 | def unpack_data_mlp(dataB, option='both'): 211 | if len(dataB[0]) == 2: 212 | if option == 'both': 213 | return dataB[0][0], dataB[1][0], dataB[1][1] 214 | elif option == 'svhn': 215 | return dataB[1][0], dataB[1][1] 216 | elif option == 'mnist': 217 | return dataB[0][0], dataB[0][1] 218 | else: 219 | return dataB 220 | 221 | 222 | def unpack_model(option='svhn'): 223 | if 'mnist_svhn' in args.model: 224 | return model.vaes[1] if option == 'svhn' else model.vaes[0] 225 | else: 226 | return model 227 | 228 | 229 | if __name__ == '__main__': 230 | with Timer('MM-VAE analysis') as t: 231 | print('-' * 25 + 'latent classification accuracy' + '-' * 25) 232 | print("Calculating latent classification accuracy for single MNIST VAE...") 233 | classify_latents(epochs=30, option='mnist') 234 | # # 235 | print("\n Calculating latent classification accuracy for single SVHN VAE...") 236 | classify_latents(epochs=30, option='svhn') 237 | # 238 | print('\n' + '-' * 45 + 'cross coherence' + '-' * 45) 239 | cross_coherence(epochs=30) 240 | # 241 | print('\n' + '-' * 45 + 'joint coherence' + '-' * 45) 242 | joint_coherence() 243 | -------------------------------------------------------------------------------- /src/report/calculate_likelihoods.py: -------------------------------------------------------------------------------- 1 | """Calculate data marginal likelihood p(x) evaluated on the trained generative model.""" 2 | import os 3 | import sys 4 | import argparse 5 | 6 | import numpy as np 7 | import torch 8 | from torchvision.utils import save_image 9 | 10 | # relative import hacks (sorry) 11 | import inspect 12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 13 | parentdir = os.path.dirname(currentdir) 14 | sys.path.insert(0, parentdir) # for bash user 15 | os.chdir(parentdir) # for pycharm user 16 | 17 | import models 18 | from utils import Logger, Timer, unpack_data, log_mean_exp 19 | 20 | torch.backends.cudnn.benchmark = True 21 | parser = argparse.ArgumentParser(description='Analysing MM-DGM results') 22 | parser.add_argument('--save-dir', type=str, default="", 23 | metavar='N', help='save directory of results') 24 | parser.add_argument('--iwae-samples', type=int, default=1000, metavar='I', 25 | help='number of samples to estimate marginal log likelihood (default: 1000)') 26 | parser.add_argument('--no-cuda', action='store_true', default=False, 27 | help='disables CUDA use') 28 | cmds = parser.parse_args() 29 | runPath = cmds.save_dir 30 | 31 | sys.stdout = Logger('{}/llik.log'.format(runPath)) 32 | args = torch.load(runPath + '/args.rar') 33 | 34 | # cuda stuff 35 | needs_conversion = cmds.no_cuda and args.cuda 36 | conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {} 37 | args.cuda = not cmds.no_cuda and torch.cuda.is_available() 38 | device = torch.device("cuda" if args.cuda else "cpu") 39 | torch.manual_seed(args.seed) 40 | 41 | modelC = getattr(models, 'VAE_{}'.format(args.model)) 42 | model = modelC(args) 43 | if args.cuda: 44 | model.cuda() 45 | 46 | model.load_state_dict(torch.load(runPath + '/model.rar', **conversion_kwargs), strict=False) 47 | B = 12000 // cmds.iwae_samples # rough batch size heuristic 48 | train_loader, test_loader = model.getDataLoaders(B, device=device) 49 | N = len(test_loader.dataset) 50 | 51 | 52 | def m_iwae(qz_xs, px_zs, zss, x): 53 | """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised""" 54 | lws = [] 55 | for r, qz_x in enumerate(qz_xs): 56 | lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1) 57 | lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs])) 58 | lpx_z = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1) 59 | .mul(model.vaes[d].llik_scaling).sum(-1) 60 | for d, px_z in enumerate(px_zs[r])] 61 | lpx_z = torch.stack(lpx_z).sum(0) 62 | lw = lpz + lpx_z - lqz_x 63 | lws.append(lw) 64 | return log_mean_exp(torch.cat(lws)).sum() 65 | 66 | 67 | def iwae(qz_x, px_z, zs, x): 68 | """IWAE estimate for log p_\theta(x) -- fully vectorised.""" 69 | lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1) 70 | lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1) * model.llik_scaling 71 | lqz_x = qz_x.log_prob(zs).sum(-1) 72 | return log_mean_exp(lpz + lpx_z.sum(-1) - lqz_x).sum() 73 | 74 | 75 | @torch.no_grad() 76 | def joint_elbo(K): 77 | model.eval() 78 | llik = 0 79 | obj = locals()[('m_' if hasattr(model, 'vaes') else '') + 'iwae']() 80 | for dataT in test_loader: 81 | data = unpack_data(dataT, device=device) 82 | llik += obj(model, data, K).item() 83 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}' 84 | .format(model.modelName, K, llik / N)) 85 | 86 | 87 | def cross_iwaes(qz_xs, px_zs, zss, x): 88 | lws = [] 89 | for e, _px_zs in enumerate(px_zs): # rows are encoders 90 | lpz = model.pz(*model.pz_params).log_prob(zss[e]).sum(-1) 91 | lqz_x = qz_xs[e].log_prob(zss[e]).sum(-1) 92 | _lpx_zs = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1) 93 | for d, px_z in enumerate(_px_zs)] 94 | lws.append([log_mean_exp(_lpx_z + lpz - lqz_x).sum() for _lpx_z in _lpx_zs]) 95 | return lws 96 | 97 | 98 | def individual_iwaes(qz_xs, px_zs, zss, x): 99 | lws = [] 100 | for d, _px_zs in enumerate(np.array(px_zs).T): # rows are decoders now 101 | lw = [px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1) 102 | + model.pz(*model.pz_params).log_prob(zss[e]).sum(-1) 103 | - log_mean_exp(torch.stack([qz_x.log_prob(zss[e]).sum(-1) for qz_x in qz_xs])) 104 | for e, px_z in enumerate(_px_zs)] 105 | lw = torch.cat(lw) 106 | lws.append(log_mean_exp(lw).sum()) 107 | return lws 108 | 109 | 110 | @torch.no_grad() 111 | def m_llik_eval(K): 112 | model.eval() 113 | llik_joint = 0 114 | llik_synergy = np.array([0 for _ in model.vaes]) 115 | lliks_cross = np.array([[0 for _ in model.vaes] for _ in model.vaes]) 116 | for dataT in test_loader: 117 | data = unpack_data(dataT, device=device) 118 | qz_xs, px_zs, zss = model(data, K) 119 | objs = individual_iwaes(qz_xs, px_zs, zss, data) 120 | objs_cross = cross_iwaes(qz_xs, px_zs, zss, data) 121 | llik_joint += m_iwae(qz_xs, px_zs, zss, data) 122 | llik_synergy = llik_synergy + np.array(objs) 123 | lliks_cross = lliks_cross + np.array(objs_cross) 124 | 125 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}' 126 | .format(model.modelName, K, llik_joint / N)) 127 | print('-' * 89) 128 | 129 | for i, llik in enumerate(llik_synergy): 130 | print('Marginal Log Likelihood of {} from {} (IWAE, K = {}): {:.4f}' 131 | .format(model.vaes[i].modelName, model.modelName, K, (llik / N).item())) 132 | print('-' * 89) 133 | 134 | for e, _lliks_cross in enumerate(lliks_cross): 135 | for d, llik_cross in enumerate(_lliks_cross): 136 | print('Marginal Log Likelihood of {} from {} (IWAE, K = {}): {:.4f}' 137 | .format(model.vaes[d].modelName, model.vaes[e].modelName, K, (llik_cross / N).item())) 138 | print('-' * 89) 139 | 140 | 141 | @torch.no_grad() 142 | def llik_eval(K): 143 | model.eval() 144 | llik_joint = 0 145 | for dataT in test_loader: 146 | data = unpack_data(dataT, device=device) 147 | qz_xs, px_zs, zss = model(data, K) 148 | llik_joint += iwae(qz_xs, px_zs, zss, data) 149 | print('Marginal Log Likelihood of joint {} (IWAE, K = {}): {:.4f}' 150 | .format(model.modelName, K, llik_joint / N)) 151 | 152 | 153 | @torch.no_grad() 154 | def generate_sparse(D, steps, J): 155 | """generate `steps` perturbations for all `D` latent dimensions on `J` datapoints. """ 156 | model.eval() 157 | for i, dataT in enumerate(test_loader): 158 | data = unpack_data(dataT, require_length=(args.projection == 'Sft'), device=device) 159 | qz_xs, _, zss = model(data, args.K) 160 | for i, (qz_x, zs) in enumerate(zip(qz_xs, zss)): 161 | embs = [] 162 | # for delta in torch.linspace(0.01, 0.99, steps=steps): 163 | for delta in torch.linspace(-5, 5, steps=steps): 164 | for d in range(D): 165 | mod_emb = qz_x.mean + torch.zeros_like(qz_x.mean) 166 | mod_emb[:, d] += model.vaes[i].pz(*model.vaes[i].pz_params).stddev[:, d] * delta 167 | embs.append(mod_emb) 168 | embs = torch.stack(embs).transpose(0, 1).contiguous() 169 | for r in range(2): 170 | samples = model.vaes[r].px_z(*model.vaes[r].dec(embs.view(-1, D)[:((J) * steps * D)])).mean 171 | save_image(samples.cpu(), os.path.join(runPath, 'latent-traversals-{}x{}.png'.format(i, r)), nrow=D) 172 | break 173 | 174 | 175 | if __name__ == '__main__': 176 | with Timer('MM-VAE analysis') as t: 177 | # likelihood evaluation 178 | print('-' * 89) 179 | eval = locals()[('m_' if hasattr(model, 'vaes') else '') + 'llik_eval'] 180 | eval(cmds.iwae_samples) 181 | print('-' * 89) 182 | -------------------------------------------------------------------------------- /src/report/helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | from collections import Counter, OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from gensim.models import FastText 11 | from nltk.tokenize import sent_tokenize, word_tokenize 12 | from scipy.linalg import eig 13 | from skimage.filters import threshold_yen as threshold 14 | 15 | 16 | class OrderedCounter(Counter, OrderedDict): 17 | """Counter that remembers the order elements are first encountered.""" 18 | 19 | def __repr__(self): 20 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self)) 21 | 22 | def __reduce__(self): 23 | return self.__class__, (OrderedDict(self),) 24 | 25 | 26 | def cca(views, k=None, eps=1e-12): 27 | """Compute (multi-view) CCA 28 | 29 | Args: 30 | views (list): list of views where each view `v_i` is of size `N x o_i` 31 | k (int): joint projection dimension | if None, find using Otsu 32 | eps (float): regulariser [default: 1e-12] 33 | 34 | Returns: 35 | correlations: correlations along each of the k dimensions 36 | projections: projection matrices for each view 37 | """ 38 | V = len(views) # number of views 39 | N = views[0].size(0) # number of observations (same across views) 40 | os = [v.size(1) for v in views] 41 | kmax = np.min(os) 42 | ocum = np.cumsum([0] + os) 43 | os_sum = sum(os) 44 | A, B = np.zeros([os_sum, os_sum]), np.zeros([os_sum, os_sum]) 45 | 46 | for i in range(V): 47 | v_i = views[i] 48 | v_i_bar = v_i - v_i.mean(0).expand_as(v_i) # centered, N x o_i 49 | C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_i_bar) 50 | # A[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij 51 | B[ocum[i]:ocum[i + 1], ocum[i]:ocum[i + 1]] = C_ij 52 | for j in range(i + 1, V): 53 | v_j = views[j] # N x o_j 54 | v_j_bar = v_j - v_j.mean(0).expand_as(v_j) # centered 55 | C_ij = (1.0 / (N - 1)) * torch.mm(v_i_bar.t(), v_j_bar) 56 | A[ocum[i]:ocum[i + 1], ocum[j]:ocum[j + 1]] = C_ij 57 | A[ocum[j]:ocum[j + 1], ocum[i]:ocum[i + 1]] = C_ij.t() 58 | 59 | A[np.diag_indices_from(A)] += eps 60 | B[np.diag_indices_from(B)] += eps 61 | 62 | eigenvalues, eigenvectors = eig(A, B) 63 | # TODO: sanity check to see that all eigenvalues are e+0i 64 | idx = eigenvalues.argsort()[::-1] # sort descending 65 | eigenvalues = eigenvalues[idx] # arrange in descending order 66 | 67 | if k is None: 68 | t = threshold(eigenvalues.real[:kmax]) 69 | k = np.abs(np.asarray(eigenvalues.real[0::10]) - t).argmin() * 10 # closest k % 10 == 0 idx 70 | print('k unspecified, (auto-)choosing:', k) 71 | 72 | eigenvalues = eigenvalues[idx[:k]] 73 | eigenvectors = eigenvectors[:, idx[:k]] 74 | 75 | correlations = torch.from_numpy(eigenvalues.real).type_as(views[0]) 76 | proj_matrices = torch.split(torch.from_numpy(eigenvectors.real).type_as(views[0]), os) 77 | 78 | return correlations, proj_matrices 79 | 80 | 81 | def fetch_emb(lenWindow, minOccur, emb_path, vocab_path, RESET): 82 | if not os.path.exists(emb_path) or RESET: 83 | with open('../data/cub/text_trainvalclasses.txt', 'r') as file: 84 | text = file.read() 85 | sentences = sent_tokenize(text) 86 | 87 | texts = [] 88 | for i, line in enumerate(sentences): 89 | words = word_tokenize(line) 90 | texts.append(words) 91 | 92 | model = FastText(size=300, window=lenWindow, min_count=minOccur) 93 | model.build_vocab(sentences=texts) 94 | model.train(sentences=texts, total_examples=len(texts), epochs=10) 95 | 96 | with open(vocab_path, 'rb') as file: 97 | vocab = json.load(file) 98 | 99 | i2w = vocab['i2w'] 100 | base = np.ones((300,), dtype=np.float32) 101 | emb = [base * (i - 1) for i in range(3)] 102 | for word in list(i2w.values())[3:]: 103 | emb.append(model[word]) 104 | 105 | emb = np.array(emb) 106 | with open(emb_path, 'wb') as file: 107 | pickle.dump(emb, file) 108 | 109 | else: 110 | with open(emb_path, 'rb') as file: 111 | emb = pickle.load(file) 112 | 113 | return emb 114 | 115 | 116 | def fetch_weights(weights_path, vocab_path, RESET, a=1e-3): 117 | if not os.path.exists(weights_path) or RESET: 118 | with open('../data/cub/text_trainvalclasses.txt', 'r') as file: 119 | text = file.read() 120 | sentences = sent_tokenize(text) 121 | occ_register = OrderedCounter() 122 | 123 | for i, line in enumerate(sentences): 124 | words = word_tokenize(line) 125 | occ_register.update(words) 126 | 127 | with open(vocab_path, 'r') as file: 128 | vocab = json.load(file) 129 | w2i = vocab['w2i'] 130 | weights = np.zeros(len(w2i)) 131 | total_occ = sum(list(occ_register.values())) 132 | exc_occ = 0 133 | for w, occ in occ_register.items(): 134 | if w in w2i.keys(): 135 | weights[w2i[w]] = a / (a + occ / total_occ) 136 | else: 137 | exc_occ += occ 138 | weights[0] = a / (a + exc_occ / total_occ) 139 | 140 | with open(weights_path, 'wb') as file: 141 | pickle.dump(weights, file) 142 | else: 143 | with open(weights_path, 'rb') as file: 144 | weights = pickle.load(file) 145 | 146 | return weights 147 | 148 | 149 | def fetch_pc(emb, weights, train_loader, pc_path, RESET): 150 | sentences = torch.cat([d[1][0] for d in train_loader]).int() 151 | emb_dataset = apply_weights(emb, weights, sentences) 152 | 153 | if not os.path.exists(pc_path) or RESET: 154 | _, _, V = torch.svd(emb_dataset - emb_dataset.mean(dim=0), some=True) 155 | v = V[:, 0].unsqueeze(-1) 156 | u = v.mm(v.t()) 157 | with open(pc_path, 'wb') as file: 158 | pickle.dump(u, file) 159 | else: 160 | with open(pc_path, 'rb') as file: 161 | u = pickle.load(file) 162 | return u 163 | 164 | 165 | def apply_weights(emb, weights, data): 166 | fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s 167 | batch_emb = [] 168 | for sent_i in data: 169 | emb_stacked = torch.stack([emb[idx] for idx in fn_trun(sent_i)]) 170 | weights_stacked = torch.stack([weights[idx] for idx in fn_trun(sent_i)]) 171 | batch_emb.append(torch.sum(emb_stacked * weights_stacked.unsqueeze(-1), dim=0) / emb_stacked.shape[0]) 172 | 173 | return torch.stack(batch_emb, dim=0) 174 | 175 | 176 | def apply_pc(weighted_emb, u): 177 | return torch.cat([e - torch.matmul(u, e.unsqueeze(-1)).squeeze() for e in weighted_emb.split(2048, 0)]) 178 | 179 | 180 | class Latent_Classifier(nn.Module): 181 | """ Generate latent parameters for SVHN image data. """ 182 | 183 | def __init__(self, in_n, out_n): 184 | super(Latent_Classifier, self).__init__() 185 | self.mlp = nn.Linear(in_n, out_n) 186 | 187 | def forward(self, x): 188 | return self.mlp(x) 189 | 190 | 191 | class SVHN_Classifier(nn.Module): 192 | def __init__(self): 193 | super(SVHN_Classifier, self).__init__() 194 | self.conv1 = nn.Conv2d(3, 10, kernel_size=5) 195 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 196 | self.conv2_drop = nn.Dropout2d() 197 | self.fc1 = nn.Linear(500, 50) 198 | self.fc2 = nn.Linear(50, 10) 199 | 200 | def forward(self, x): 201 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 202 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 203 | x = x.view(-1, 500) 204 | x = F.relu(self.fc1(x)) 205 | x = F.dropout(x, training=self.training) 206 | x = self.fc2(x) 207 | return F.log_softmax(x, dim=-1) 208 | 209 | 210 | class MNIST_Classifier(nn.Module): 211 | def __init__(self): 212 | super(MNIST_Classifier, self).__init__() 213 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 214 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 215 | self.conv2_drop = nn.Dropout2d() 216 | self.fc1 = nn.Linear(320, 50) 217 | self.fc2 = nn.Linear(50, 10) 218 | 219 | def forward(self, x): 220 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 221 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 222 | x = x.view(-1, 320) 223 | x = F.relu(self.fc1(x)) 224 | x = F.dropout(x, training=self.training) 225 | x = self.fc2(x) 226 | return F.log_softmax(x, dim=-1) 227 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | 7 | import torch 8 | import torch.distributions as dist 9 | import torch.nn.functional as F 10 | 11 | from datasets import CUBImageFt 12 | 13 | 14 | # Classes 15 | class Constants(object): 16 | eta = 1e-6 17 | log2 = math.log(2) 18 | log2pi = math.log(2 * math.pi) 19 | logceilc = 88 # largest cuda v s.t. exp(v) < inf 20 | logfloorc = -104 # smallest cuda v s.t. exp(v) > 0 21 | 22 | 23 | # https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting 24 | class Logger(object): 25 | def __init__(self, filename, mode="a"): 26 | self.terminal = sys.stdout 27 | self.log = open(filename, mode) 28 | 29 | def write(self, message): 30 | self.terminal.write(message) 31 | self.log.write(message) 32 | 33 | def flush(self): 34 | # this flush method is needed for python 3 compatibility. 35 | # this handles the flush command by doing nothing. 36 | # you might want to specify some extra behavior here. 37 | pass 38 | 39 | 40 | class Timer: 41 | def __init__(self, name): 42 | self.name = name 43 | 44 | def __enter__(self): 45 | self.begin = time.time() 46 | return self 47 | 48 | def __exit__(self, *args): 49 | self.end = time.time() 50 | self.elapsed = self.end - self.begin 51 | self.elapsedH = time.gmtime(self.elapsed) 52 | print('====> [{}] Time: {:7.3f}s or {}' 53 | .format(self.name, 54 | self.elapsed, 55 | time.strftime("%H:%M:%S", self.elapsedH))) 56 | 57 | 58 | # Functions 59 | def save_vars(vs, filepath): 60 | """ 61 | Saves variables to the given filepath in a safe manner. 62 | """ 63 | if os.path.exists(filepath): 64 | shutil.copyfile(filepath, '{}.old'.format(filepath)) 65 | torch.save(vs, filepath) 66 | 67 | 68 | def save_model(model, filepath): 69 | """ 70 | To load a saved model, simply use 71 | `model.load_state_dict(torch.load('path-to-saved-model'))`. 72 | """ 73 | save_vars(model.state_dict(), filepath) 74 | if hasattr(model, 'vaes'): 75 | for vae in model.vaes: 76 | fdir, fext = os.path.splitext(filepath) 77 | save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext) 78 | 79 | 80 | def is_multidata(dataB): 81 | return isinstance(dataB, list) or isinstance(dataB, tuple) 82 | 83 | 84 | def unpack_data(dataB, device='cuda'): 85 | # dataB :: (Tensor, Idx) | [(Tensor, Idx)] 86 | """ Unpacks the data batch object in an appropriate manner to extract data """ 87 | if is_multidata(dataB): 88 | if torch.is_tensor(dataB[0]): 89 | if torch.is_tensor(dataB[1]): 90 | return dataB[0].to(device) # mnist, svhn, cubI 91 | elif is_multidata(dataB[1]): 92 | return dataB[0].to(device), dataB[1][0].to(device) # cubISft 93 | else: 94 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1]))) 95 | 96 | elif is_multidata(dataB[0]): 97 | return [d.to(device) for d in list(zip(*dataB))[0]] # mnist-svhn, cubIS 98 | else: 99 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0]))) 100 | elif torch.is_tensor(dataB): 101 | return dataB.to(device) 102 | else: 103 | raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB))) 104 | 105 | 106 | def get_mean(d, K=100): 107 | """ 108 | Extract the `mean` parameter for given distribution. 109 | If attribute not available, estimate from samples. 110 | """ 111 | try: 112 | mean = d.mean 113 | except NotImplementedError: 114 | samples = d.rsample(torch.Size([K])) 115 | mean = samples.mean(0) 116 | return mean 117 | 118 | 119 | def log_mean_exp(value, dim=0, keepdim=False): 120 | return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim)) 121 | 122 | 123 | def kl_divergence(d1, d2, K=100): 124 | """Computes closed-form KL if available, else computes a MC estimate.""" 125 | if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY: 126 | return torch.distributions.kl_divergence(d1, d2) 127 | else: 128 | samples = d1.rsample(torch.Size([K])) 129 | return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0) 130 | 131 | 132 | def pdist(sample_1, sample_2, eps=1e-5): 133 | """Compute the matrix of all squared pairwise distances. Code 134 | adapted from the torch-two-sample library (added batching). 135 | You can find the original implementation of this function here: 136 | https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py 137 | 138 | Arguments 139 | --------- 140 | sample_1 : torch.Tensor or Variable 141 | The first sample, should be of shape ``(batch_size, n_1, d)``. 142 | sample_2 : torch.Tensor or Variable 143 | The second sample, should be of shape ``(batch_size, n_2, d)``. 144 | norm : float 145 | The l_p norm to be used. 146 | batched : bool 147 | whether data is batched 148 | 149 | Returns 150 | ------- 151 | torch.Tensor or Variable 152 | Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to 153 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" 154 | if len(sample_1.shape) == 2: 155 | sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0) 156 | B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1) 157 | norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True) 158 | norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True) 159 | norms = (norms_1.expand(B, n_1, n_2) 160 | + norms_2.transpose(1, 2).expand(B, n_1, n_2)) 161 | distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2)) 162 | return torch.sqrt(eps + torch.abs(distances_squared)).squeeze() # batch x K x latent 163 | 164 | 165 | def NN_lookup(emb_h, emb, data): 166 | indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0) 167 | # indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze() 168 | return data[indices] 169 | 170 | 171 | class FakeCategorical(dist.Distribution): 172 | support = dist.constraints.real 173 | has_rsample = True 174 | 175 | def __init__(self, locs): 176 | self.logits = locs 177 | self._batch_shape = self.logits.shape 178 | 179 | @property 180 | def mean(self): 181 | return self.logits 182 | 183 | def sample(self, sample_shape=torch.Size()): 184 | with torch.no_grad(): 185 | return self.rsample(sample_shape) 186 | 187 | def rsample(self, sample_shape=torch.Size()): 188 | return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous() 189 | 190 | def log_prob(self, value): 191 | # value of shape (K, B, D) 192 | lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)), 193 | target=value.expand(self.logits.size()[:-1]).long().view(-1), 194 | reduction='none', 195 | ignore_index=0) 196 | 197 | return lpx_z.view(*self.logits.shape[:-1]) 198 | # it is inevitable to have the word embedding dimension summed up in 199 | # cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the 200 | # operationally equivalence here, which is summing up the sentence dimension 201 | # in objective. 202 | -------------------------------------------------------------------------------- /src/vis.py: -------------------------------------------------------------------------------- 1 | # visualisation related functions 2 | 3 | import matplotlib.colors as colors 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import seaborn as sns 8 | import torch 9 | from matplotlib.lines import Line2D 10 | from umap import UMAP 11 | 12 | 13 | def custom_cmap(n): 14 | """Create customised colormap for scattered latent plot of n categories. 15 | Returns colormap object and colormap array that contains the RGB value of the colors. 16 | See official matplotlib document for colormap reference: 17 | https://matplotlib.org/examples/color/colormaps_reference.html 18 | """ 19 | # first color is grey from Set1, rest other sensible categorical colourmap 20 | cmap_array = sns.color_palette("Set1", 9)[-1:] + sns.husl_palette(n - 1, h=.6, s=0.7) 21 | cmap = colors.LinearSegmentedColormap.from_list('mmdgm_cmap', cmap_array) 22 | return cmap, cmap_array 23 | 24 | 25 | def embed_umap(data): 26 | """data should be on cpu, numpy""" 27 | embedding = UMAP(metric='euclidean', 28 | n_neighbors=40, 29 | # angular_rp_forest=True, 30 | # random_state=torch.initial_seed(), 31 | transform_seed=torch.initial_seed()) 32 | return embedding.fit_transform(data) 33 | 34 | 35 | def plot_embeddings(emb, emb_l, labels, filepath): 36 | cmap_obj, cmap_arr = custom_cmap(n=len(labels)) 37 | plt.figure() 38 | plt.scatter(emb[:, 0], emb[:, 1], c=emb_l, cmap=cmap_obj, s=25, alpha=0.2, edgecolors='none') 39 | l_elems = [Line2D([0], [0], marker='o', color=cm, label=l, alpha=0.5, linestyle='None') 40 | for (cm, l) in zip(cmap_arr, labels)] 41 | plt.legend(frameon=False, loc=2, handles=l_elems) 42 | plt.savefig(filepath, bbox_inches='tight') 43 | plt.close() 44 | 45 | 46 | def tensor_to_df(tensor, ax_names=None): 47 | assert tensor.ndim == 2, "Can only currently convert 2D tensors to dataframes" 48 | df = pd.DataFrame(data=tensor, columns=np.arange(tensor.shape[1])) 49 | return df.melt(value_vars=df.columns, 50 | var_name=('variable' if ax_names is None else ax_names[0]), 51 | value_name=('value' if ax_names is None else ax_names[1])) 52 | 53 | 54 | def tensors_to_df(tensors, head=None, keys=None, ax_names=None): 55 | dfs = [tensor_to_df(tensor, ax_names=ax_names) for tensor in tensors] 56 | df = pd.concat(dfs, keys=(np.arange(len(tensors)) if keys is None else keys)) 57 | df.reset_index(level=0, inplace=True) 58 | if head is not None: 59 | df.rename(columns={'level_0': head}, inplace=True) 60 | return df 61 | 62 | 63 | def plot_kls_df(df, filepath): 64 | _, cmap_arr = custom_cmap(df[df.columns[0]].nunique() + 1) 65 | with sns.plotting_context("notebook", font_scale=2.0): 66 | g = sns.FacetGrid(df, height=12, aspect=2) 67 | g = g.map(sns.boxplot, df.columns[1], df.columns[2], df.columns[0], palette=cmap_arr[1:], 68 | order=None, hue_order=None) 69 | g = g.set(yscale='log').despine(offset=10) 70 | plt.legend(loc='best', fontsize='22') 71 | plt.savefig(filepath, bbox_inches='tight') 72 | plt.close() 73 | --------------------------------------------------------------------------------