├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── datautils.py ├── get-embeddings-unsup.py ├── get-embeddings.py ├── logger.py ├── models ├── __init__.py ├── coupling.py ├── distributions.py ├── flows.py ├── invertconv.py ├── normalization.py ├── realnvp.py └── utils.py ├── myexman ├── __init__.py ├── index.py └── parser.py ├── pretrained └── model.torch ├── train-discriminator.py ├── train-flow-ssl.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Supervised Flows PyTorch 2 | Authors: [Andrei Atanov](https://andrewatanov.github.io/), [Alexandra Volokhova](https://scholar.google.com/citations?user=23LOcyMAAAAJ&hl=en), [Arsenii Ashukha](https://senya-ashukha.github.io/), [Ivan Sosnovik](https://scholar.google.at/citations?user=brUsNccAAAAJ&hl=en), [Dmitry Vetrov](https://scholar.google.ca/citations?user=7HU0UoUAAAAJ&hl=en) 3 | 4 | This repo contains code for our INNF workshop paper [Semi-Conditional Normalizing Flows for Semi-Supervised Learning](https://arxiv.org/abs/1905.00505) 5 | 6 | __Abstract:__ 7 | This paper proposes a semi-conditional normalizing flow model for semi-supervised learning. The model uses both labelled and unlabeled data to learn an explicit model of joint distribution over objects and labels. Semi-conditional architecture of the model allows us to efficiently compute a value and gradients of the marginal likelihood for unlabeled objects. The conditional part of the model is based on a proposed conditional coupling layer. We demonstrate performance of the model for semi-supervised classification problem on different datasets. The model outperforms the baseline approach based on variational auto-encoders on MNIST dataset. 8 | 9 | [__Poster__](https://docs.google.com/presentation/d/1wSA6RKG4ko2zI9XuVAsJq0dqd-XTPLOiWlJ17XJ1SoQ/edit?usp=sharing) 10 | 11 | # Semi-Supervised MNIST classification 12 | 13 | Train a Semi-Conditional Normalizing Flows on MNIST with 100 labeled examples: 14 | 15 | `python train-flow-ssl.py --config config.yaml` 16 | 17 | You can then find logs at `/logs/exman-train-flow-ssl.py/runs/` 18 | 19 | For the convenience we also provide pretrained weights `pretrained/model.torch`, use `--pretrained` flag for loading. 20 | 21 | # Credits 22 | 23 | * Credits to https://github.com/ferrine/exman for the exman parser. 24 | 25 | # Citation 26 | 27 | If you found this code useful please cite our paper 28 | 29 | ``` 30 | @article{atanov2019semi, 31 | title={Semi-conditional normalizing flows for semi-supervised learning}, 32 | author={Atanov, Andrei and Volokhova, Alexandra and Ashukha, Arsenii and Sosnovik, Ivan and Vetrov, Dmitry}, 33 | journal={arXiv preprint arXiv:1905.00505}, 34 | year={2019} 35 | } 36 | ``` 37 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | cl_weight: 0.0 2 | clip_gn: 100.0 3 | config_file: '' 4 | conv: full 5 | data: mnist 6 | data_seed: 0 7 | epochs: 500 8 | hh_factors: 2 9 | hid_dim: [] 10 | k: 4 11 | l: 2 12 | log_each: 1 13 | logits: true 14 | lr: 0.001 15 | lr_gamma: 0.5 16 | lr_schedule: linear 17 | lr_steps: [] 18 | lr_warmup: 10 19 | model: mnist-masked 20 | name: em vs mle 21 | num_examples: -1 22 | pretrained: '' 23 | root: '' 24 | seed: 0 25 | ssl_conv: full 26 | ssl_dim: 196 27 | ssl_hd: 256 28 | ssl_hh: 2 29 | ssl_k: 4 30 | ssl_l: 2 31 | ssl_model: cond-flow 32 | ssl_nclasses: 10 33 | status: done 34 | sup_ohe: true 35 | sup_sample_weight: 0.5 36 | sup_weight: 1.0 37 | supervised: 100 38 | test_bs: 512 39 | tmp: false 40 | train_bs: 256 41 | weight_decay: 1.0e-05 42 | 43 | time: '2019-04-20T13:33:10' 44 | id: 118 45 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import datasets 4 | import os 5 | import torchvision 6 | from torchvision import transforms 7 | from models.distributions import GMM 8 | import torch.nn.functional as F 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import Dataset 11 | import torch.utils.data 12 | import PIL 13 | from torchvision.datasets import ImageFolder 14 | import warnings 15 | import utils 16 | 17 | DATA_ROOT = './' 18 | 19 | 20 | mean = { 21 | 'mnist': (0.1307,), 22 | 'cifar10': (0.4914, 0.4822, 0.4465) 23 | } 24 | 25 | std = { 26 | 'mnist': (0.3081,), 27 | 'cifar10': (0.2470, 0.2435, 0.2616) 28 | } 29 | 30 | 31 | class UniformNoise(object): 32 | def __init__(self, bits=256): 33 | self.bits = bits 34 | 35 | def __call__(self, x): 36 | with torch.no_grad(): 37 | noise = torch.rand_like(x) 38 | # TODO: generalize. x assumed to be normalized to [0, 1] 39 | return (x * (self.bits - 1) + noise) / self.bits 40 | 41 | def __repr__(self): 42 | return "UniformNoise" 43 | 44 | 45 | def load_dataset(data, train_bs, test_bs, num_examples=None, data_root=DATA_ROOT, shuffle=True, 46 | seed=42, supervised=-1, logs_root='', sup_sample_weight=-1, sup_only=False, device=None): 47 | bits = None 48 | sampler = None 49 | if data in ['moons', 'circles']: 50 | if data == 'moons': 51 | x, y = datasets.make_moons(n_samples=int(num_examples * 1.5), noise=0.1, random_state=seed) 52 | train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=num_examples, random_state=seed) 53 | elif data == 'circles': 54 | x, y = datasets.make_circles(n_samples=int(num_examples * 1.5), noise=0.1, factor=0.2, random_state=seed) 55 | train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=num_examples, random_state=seed) 56 | 57 | if supervised not in [-1, len(train_y), 0]: 58 | unsupervised_idxs, _ = train_test_split(np.arange(len(train_y)), test_size=supervised, stratify=train_y) 59 | train_y[unsupervised_idxs] = -1 60 | elif supervised == 0: 61 | train_y[:] = -1 62 | 63 | torch.save({ 64 | 'train_x': train_x, 65 | 'train_y': train_y, 66 | 'test_x': test_x, 67 | 'test_y': test_y, 68 | }, os.path.join(logs_root, 'data.torch')) 69 | 70 | trainset = torch.utils.data.TensorDataset(torch.FloatTensor(train_x[..., None, None]), 71 | torch.LongTensor(train_y)) 72 | testset = torch.utils.data.TensorDataset(torch.FloatTensor(test_x[..., None, None]), 73 | torch.LongTensor(test_y)) 74 | data_shape = [2, 1, 1] 75 | bits = np.nan 76 | elif data == 'mnist': 77 | train_transform = transforms.Compose([ 78 | transforms.ToTensor(), 79 | UniformNoise(), 80 | ]) 81 | 82 | test_transform = transforms.Compose([ 83 | transforms.ToTensor(), 84 | UniformNoise(), 85 | ]) 86 | trainset = torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform=train_transform) 87 | testset = torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform=test_transform) 88 | 89 | if num_examples != -1 and num_examples != len(trainset) and num_examples is not None: 90 | idxs, _ = train_test_split(np.arange(len(trainset)), train_size=num_examples, random_state=seed, 91 | stratify=utils.tonp(trainset.targets)) 92 | trainset.data = trainset.data[idxs] 93 | trainset.targets = trainset.targets[idxs] 94 | 95 | if supervised == 0: 96 | trainset.targets[:] = -1 97 | elif supervised != -1: 98 | unsupervised_idxs, _ = train_test_split(np.arange(len(trainset.targets)), 99 | test_size=supervised, stratify=trainset.targets) 100 | trainset.targets[unsupervised_idxs] = -1 101 | 102 | if sup_only: 103 | mask = trainset.targets != -1 104 | trainset.targets = trainset.targets[mask] 105 | trainset.data = trainset.data[mask] 106 | 107 | data_shape = (1, 28, 28) 108 | bits = 256 109 | else: 110 | raise NotImplementedError 111 | 112 | nw = 2 113 | if sup_sample_weight == -1: 114 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=shuffle, 115 | num_workers=nw, pin_memory=True) 116 | else: 117 | sampler = ImbalancedDatasetSampler(trainset, sup_weight=sup_sample_weight) 118 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, sampler=sampler, 119 | num_workers=nw, pin_memory=True) 120 | 121 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False, 122 | num_workers=nw, pin_memory=True) 123 | return trainloader, testloader, data_shape, bits 124 | 125 | 126 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 127 | """Samples elements randomly from a given list of indices for imbalanced dataset 128 | https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/sampler.py 129 | Arguments: 130 | indices (list, optional): a list of indices 131 | num_samples (int, optional): number of samples to draw 132 | """ 133 | 134 | def __init__(self, dataset, indices=None, num_samples=None, sup_weight=1.): 135 | 136 | # if indices is not provided, 137 | # all elements in the dataset will be considered 138 | self.indices = list(range(len(dataset))) \ 139 | if indices is None else indices 140 | 141 | # if num_samples is not provided, 142 | # draw `len(indices)` samples in each iteration 143 | self.num_samples = len(self.indices) \ 144 | if num_samples is None else num_samples 145 | 146 | # distribution of classes in the dataset 147 | label_to_count = { 148 | -1: 0 149 | } 150 | sup = 0 151 | for idx in self.indices: 152 | label = self._get_label(dataset, idx) 153 | if label == -1: 154 | label_to_count[-1] += 1 155 | else: 156 | sup += 1 157 | label_to_count[label] = sup 158 | for k in label_to_count: 159 | if k != -1: 160 | label_to_count[k] = sup 161 | 162 | # weight for each sample 163 | weights = [] 164 | for idx in self.indices: 165 | label = self._get_label(dataset, idx) 166 | w = 1 if label == -1 else sup_weight 167 | weights.append(w / label_to_count[label]) 168 | 169 | self.weights = torch.DoubleTensor(weights) 170 | 171 | def _get_label(self, dataset, idx): 172 | return dataset.targets[idx].item() 173 | 174 | def __iter__(self): 175 | return (self.indices[i] for i in torch.multinomial( 176 | self.weights, self.num_samples, replacement=True)) 177 | 178 | def __len__(self): 179 | return self.num_samples 180 | 181 | 182 | class FastMNIST(torchvision.datasets.MNIST): 183 | def __init__(self, device, *args, **kwargs): 184 | super().__init__(*args, **kwargs) 185 | 186 | # Scale data to [0,1] 187 | self.data = self.data.unsqueeze(1).float().div(255) 188 | 189 | # Put both data and targets on GPU in advance 190 | self.data, self.targets = self.data.to(device), self.targets.to(device) 191 | 192 | def __getitem__(self, index): 193 | """ 194 | Args: 195 | index (int): Index 196 | 197 | Returns: 198 | tuple: (image, target) where target is index of the target class. 199 | """ 200 | img, target = self.data[index], self.targets[index] 201 | 202 | return img, target 203 | -------------------------------------------------------------------------------- /get-embeddings-unsup.py: -------------------------------------------------------------------------------- 1 | import myexman 2 | import torch 3 | import utils 4 | import datautils 5 | import os 6 | from logger import Logger 7 | import time 8 | import numpy as np 9 | from models import flows 10 | import matplotlib.pyplot as plt 11 | from models import distributions 12 | import sys 13 | from tqdm import tqdm 14 | 15 | 16 | def get_logp(model, loader): 17 | logp = [] 18 | for x, _ in loader: 19 | x = x.to(device) 20 | logp.append(utils.tonp(model.log_prob(x))) 21 | return np.concatenate(logp) 22 | 23 | 24 | parser = myexman.ExParser(file=os.path.basename(__file__)) 25 | parser.add_argument('--name', default='') 26 | parser.add_argument('--save_dir', default='') 27 | # Data 28 | parser.add_argument('--data', default='mnist') 29 | parser.add_argument('--data_seed', default=0, type=int) 30 | parser.add_argument('--aug', dest='aug', action='store_true') 31 | parser.add_argument('--no_aug', dest='aug', action='store_false') 32 | parser.set_defaults(aug=False) 33 | # Optimization 34 | parser.add_argument('--epochs', default=100, type=int) 35 | parser.add_argument('--train_bs', default=256, type=int) 36 | parser.add_argument('--test_bs', default=512, type=int) 37 | parser.add_argument('--lr', default=1e-3, type=float) 38 | parser.add_argument('--lr_schedule', default='linear') 39 | parser.add_argument('--lr_warmup', default=10, type=int) 40 | parser.add_argument('--lr_gamma', default=0.5, type=float) 41 | parser.add_argument('--lr_steps', type=int, nargs='*', default=[]) 42 | parser.add_argument('--log_each', default=1, type=int) 43 | parser.add_argument('--seed', default=0, type=int) 44 | parser.add_argument('--pretrained', default='') 45 | parser.add_argument('--weight_decay', default=0., type=float) 46 | parser.add_argument('--clip_gv', default=1e9, type=float) 47 | parser.add_argument('--clip_gn', default=100., type=float) 48 | # Model 49 | parser.add_argument('--model', default='flow') 50 | parser.add_argument('--logits', dest='logits', action='store_true') 51 | parser.add_argument('--no-logits', dest='logits', action='store_false') 52 | parser.set_defaults(logits=True) 53 | parser.add_argument('--conv', default='full') 54 | parser.add_argument('--hh_factors', default=2, type=int) 55 | parser.add_argument('--k', default=4, type=int) 56 | parser.add_argument('--l', default=2, type=int) 57 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[]) 58 | args = parser.parse_args() 59 | 60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 61 | 62 | fmt = { 63 | 'time': '.3f', 64 | 'lr': '.1e', 65 | } 66 | logger = Logger('logs', base=args.root, fmt=fmt) 67 | 68 | # Load data 69 | np.random.seed(args.data_seed) 70 | torch.manual_seed(args.data_seed) 71 | torch.cuda.manual_seed_all(args.data_seed) 72 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs, 73 | seed=args.data_seed, shuffle=False) 74 | 75 | # Seed for training process 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | torch.cuda.manual_seed_all(args.seed) 79 | 80 | # Create model 81 | dim = int(np.prod(data_shape)) 82 | prior = distributions.GaussianDiag(dim).to(device) 83 | 84 | flow = utils.create_flow(args, data_shape) 85 | flow = torch.nn.DataParallel(flow.to(device)) 86 | model = flows.FlowPDF(flow, prior).to(device) 87 | 88 | if args.pretrained is not None and args.pretrained != '': 89 | model.load_state_dict(torch.load(args.pretrained)) 90 | 91 | 92 | def get_embeddings(loader, model): 93 | zf, zh, labels = [], [], [] 94 | for x, y in tqdm(loader): 95 | z_ = model.flow(x)[1] 96 | zf.append(utils.tonp(z_)) 97 | labels.append(utils.tonp(y)) 98 | return np.concatenate(zf), np.concatenate(labels) 99 | 100 | 101 | with torch.no_grad(): 102 | zf_train, y_train = get_embeddings(trainloader, model) 103 | zf_test, y_test = get_embeddings(testloader, model) 104 | 105 | 106 | np.save(os.path.join(args.save_dir, 'zf_train'), zf_train) 107 | np.save(os.path.join(args.save_dir, 'y_train'), y_train) 108 | 109 | np.save(os.path.join(args.save_dir, 'zf_test'), zf_test) 110 | np.save(os.path.join(args.save_dir, 'y_test'), y_test) 111 | -------------------------------------------------------------------------------- /get-embeddings.py: -------------------------------------------------------------------------------- 1 | import myexman 2 | import torch 3 | import utils 4 | import datautils 5 | import os 6 | from logger import Logger 7 | import time 8 | import numpy as np 9 | from models import flows, distributions 10 | import matplotlib.pyplot as plt 11 | from algo.em import init_kmeans2plus_mu 12 | import warnings 13 | from sklearn.mixture import GaussianMixture 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | import sys 17 | 18 | 19 | def get_metrics(model, loader): 20 | logp, acc = [], [] 21 | for x, y in loader: 22 | x = x.to(device) 23 | log_det, z = model.flow(x) 24 | log_prior_full = model.prior.log_prob_full(z) 25 | pred = torch.softmax(log_prior_full, dim=1).argmax(1) 26 | logp.append(utils.tonp(log_det + model.prior.log_prob(z))) 27 | acc.append(utils.tonp(pred) == utils.tonp(y)) 28 | return np.mean(np.concatenate(logp)), np.mean(np.concatenate(acc)) 29 | 30 | 31 | parser = myexman.ExParser(file=os.path.basename(__file__)) 32 | parser.add_argument('--name', default='') 33 | parser.add_argument('--verbose', default=0, type=int) 34 | parser.add_argument('--save_dir', default='') 35 | parser.add_argument('--test_mode', default='') 36 | # Data 37 | parser.add_argument('--data', default='mnist') 38 | parser.add_argument('--num_examples', default=-1, type=int) 39 | parser.add_argument('--data_seed', default=0, type=int) 40 | parser.add_argument('--sup_sample_weight', default=-1, type=float) 41 | # parser.add_argument('--aug', dest='aug', action='store_true') 42 | # parser.add_argument('--no_aug', dest='aug', action='store_false') 43 | # parser.set_defaults(aug=True) 44 | # Optimization 45 | parser.add_argument('--opt', default='adam') 46 | parser.add_argument('--ssl_alg', default='em') 47 | parser.add_argument('--lr', default=1e-3, type=float) 48 | parser.add_argument('--epochs', default=100, type=int) 49 | parser.add_argument('--train_bs', default=256, type=int) 50 | parser.add_argument('--test_bs', default=512, type=int) 51 | parser.add_argument('--lr_schedule', default='linear') 52 | parser.add_argument('--lr_warmup', default=10, type=int) 53 | parser.add_argument('--lr_gamma', default=0.5, type=float) 54 | parser.add_argument('--lr_steps', type=int, nargs='*', default=[]) 55 | parser.add_argument('--log_each', default=1, type=int) 56 | parser.add_argument('--seed', default=0, type=int) 57 | parser.add_argument('--pretrained', default='') 58 | parser.add_argument('--weight_decay', default=0., type=float) 59 | parser.add_argument('--sup_ohe', dest='sup_ohe', action='store_true') 60 | parser.add_argument('--no_sup_ohe', dest='sup_ohe', action='store_false') 61 | parser.set_defaults(sup_ohe=True) 62 | parser.add_argument('--clip_gn', default=100., type=float) 63 | # Model 64 | parser.add_argument('--model', default='flow') 65 | parser.add_argument('--logits', dest='logits', action='store_true') 66 | parser.add_argument('--no_logits', dest='logits', action='store_false') 67 | parser.set_defaults(logits=True) 68 | parser.add_argument('--conv', default='full') 69 | parser.add_argument('--hh_factors', default=2, type=int) 70 | parser.add_argument('--k', default=4, type=int) 71 | parser.add_argument('--l', default=2, type=int) 72 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[]) 73 | # Prior 74 | parser.add_argument('--ssl_model', default='cond-flow') 75 | parser.add_argument('--ssl_dim', default=-1, type=int) 76 | parser.add_argument('--ssl_l', default=2, type=int) 77 | parser.add_argument('--ssl_k', default=4, type=int) 78 | parser.add_argument('--ssl_hd', default=256, type=int) 79 | parser.add_argument('--ssl_conv', default='full') 80 | parser.add_argument('--ssl_hh', default=2, type=int) 81 | parser.add_argument('--ssl_nclasses', default=10, type=int) 82 | # SSL 83 | parser.add_argument('--supervised', default=0, type=int) 84 | parser.add_argument('--sup_weight', default=1., type=float) 85 | parser.add_argument('--cl_weight', default=0, type=float) 86 | args = parser.parse_args() 87 | 88 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 89 | 90 | # Load data 91 | np.random.seed(args.data_seed) 92 | torch.manual_seed(args.data_seed) 93 | torch.cuda.manual_seed_all(args.data_seed) 94 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs, 95 | seed=args.data_seed, shuffle=False) 96 | 97 | # Seed for training process 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed_all(args.seed) 101 | 102 | # Create model 103 | dim = int(np.prod(data_shape)) 104 | if args.ssl_dim == -1: 105 | args.ssl_dim = dim 106 | deep_prior = distributions.GaussianDiag(args.ssl_dim) 107 | shallow_prior = distributions.GaussianDiag(dim - args.ssl_dim) 108 | yprior = torch.distributions.Categorical(logits=torch.zeros((args.ssl_nclasses,)).to(device)) 109 | ssl_flow = flows.get_flow_cond(args.ssl_l, args.ssl_k, in_channels=args.ssl_dim, hid_dim=args.ssl_hd, 110 | conv=args.ssl_conv, hh_factors=args.ssl_hh, num_cat=args.ssl_nclasses) 111 | ssl_flow = torch.nn.DataParallel(ssl_flow.to(device)) 112 | prior = flows.DiscreteConditionalFlowPDF(ssl_flow, deep_prior, yprior, deep_dim=args.ssl_dim, 113 | shallow_prior=shallow_prior) 114 | 115 | flow = utils.create_flow(args, data_shape) 116 | flow = torch.nn.DataParallel(flow.to(device)) 117 | 118 | model = flows.FlowPDF(flow, prior).to(device) 119 | 120 | if args.pretrained != '': 121 | model.load_state_dict(torch.load(os.path.join(args.pretrained, 'model.torch'), map_location=device)) 122 | 123 | 124 | # def get_embeddings(loader, model): 125 | # zf, zh, labels = [], [], [] 126 | # for x, y in loader: 127 | # x = x.to(device) 128 | # print(model.log_prob(x).mean()) 129 | # z_ = flow(x)[1] 130 | # zf.append(utils.tonp(z_).mean()) 131 | # # z_ = z_[:, -args.ssl_dim:, None, None] 132 | # print(model.prior.log_prob(z_)) 133 | # print(torch.zeros((z_.shape[0],)).to(z_.device)) 134 | # print(model.prior.flow(z_, y)) 135 | # # print(x.device, z_.device) 136 | # # print(torch.zeros((z_.shape[0],)).to(z_.device)) 137 | # # print(z_.shape) 138 | # sys.exit(0) 139 | # log_det_jac = torch.zeros((x.shape[0],)).to(x.device) 140 | # ssl_flow.module.f(z_, y.to(device)) 141 | # zh.append(utils.tonp()) 142 | # labels.append(utils.tonp(y)) 143 | # return np.concatenate(zf), np.concatenate(zh), np.concatenate(labels) 144 | 145 | def get_embeddings(loader, model): 146 | zf, zh, labels = [], [], [] 147 | for x, y in tqdm(loader): 148 | z_ = model.flow(x)[1] 149 | zf.append(utils.tonp(z_)) 150 | zh.append(utils.tonp(model.prior.flow(z_[:, -args.ssl_dim:, None, None], y)[1])) 151 | labels.append(utils.tonp(y)) 152 | return np.concatenate(zf), np.concatenate(zh), np.concatenate(labels) 153 | 154 | 155 | y_test = np.array(testloader.dataset.targets) 156 | y_train = np.array(trainloader.dataset.targets) 157 | 158 | if args.test_mode == 'perm': 159 | idxs = np.random.permutation(10000)[:5000] 160 | testloader.dataset.data[idxs] = 255 - testloader.dataset.data[idxs] 161 | testloader.dataset.targets[idxs] = 1 - testloader.dataset.targets[idxs] 162 | elif args.test_mode == '': 163 | pass 164 | elif args.test_mode == 'inv': 165 | testloader.dataset.data = 255 - testloader.dataset.data 166 | testloader.dataset.targets = 1 - testloader.dataset.targets 167 | else: 168 | raise NotImplementedError 169 | 170 | with torch.no_grad(): 171 | zf_train, zh_train, _ = get_embeddings(trainloader, model) 172 | zf_test, zh_test, _ = get_embeddings(testloader, model) 173 | 174 | 175 | np.save(os.path.join(args.save_dir, 'zf_train'), zf_train) 176 | np.save(os.path.join(args.save_dir, 'zh_train'), zh_train) 177 | np.save(os.path.join(args.save_dir, 'y_train'), y_train) 178 | 179 | np.save(os.path.join(args.save_dir, 'zf_test'), zf_test) 180 | np.save(os.path.join(args.save_dir, 'zh_test'), zh_test) 181 | np.save(os.path.join(args.save_dir, 'y_test'), y_test) 182 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | 6 | from collections import OrderedDict 7 | from tabulate import tabulate 8 | from pandas import DataFrame 9 | from time import gmtime, strftime 10 | import time 11 | 12 | 13 | class Logger: 14 | def __init__(self, name='name', fmt=None, base='./logs'): 15 | self.handler = True 16 | self.scalar_metrics = OrderedDict() 17 | self.fmt = fmt if fmt else dict() 18 | 19 | if not os.path.exists(base): 20 | os.makedirs(base) 21 | 22 | time = gmtime() 23 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)]) 24 | fname = '-'.join(sys.argv[0].split('/')[-3:]) 25 | # self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time)) 26 | self.path = '%s/%s-%s' % (base, fname, name) 27 | 28 | self.logs = self.path + '.csv' 29 | self.output = self.path + '.out' 30 | 31 | def prin(*args): 32 | str_to_write = ' '.join(map(str, args)) 33 | with open(self.output, 'a') as f: 34 | f.write(str_to_write + '\n') 35 | f.flush() 36 | 37 | print(str_to_write) 38 | sys.stdout.flush() 39 | 40 | self.print = prin 41 | 42 | def add_scalar(self, t, key, value): 43 | if key not in self.scalar_metrics: 44 | self.scalar_metrics[key] = [] 45 | self.scalar_metrics[key] += [(t, value)] 46 | 47 | def iter_info(self, order=None): 48 | names = list(self.scalar_metrics.keys()) 49 | if order: 50 | names = order 51 | values = [self.scalar_metrics[name][-1][1] for name in names] 52 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names])) 53 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.3f' for name in names] 54 | 55 | if self.handler: 56 | self.handler = False 57 | self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt)) 58 | else: 59 | self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1]) 60 | 61 | def save(self): 62 | result = None 63 | for key in self.scalar_metrics.keys(): 64 | if result is None: 65 | result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 66 | else: 67 | df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 68 | result = result.join(df, how='outer') 69 | result.to_csv(self.logs) 70 | # self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt') 71 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewAtanov/semi-supervised-flow-pytorch/d1748decaccbc59e6bce014e1cb84527173c6b54/models/__init__.py -------------------------------------------------------------------------------- /models/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def coupling(x1, x2, net, log_det_jac=None, eps=1e-6): 8 | scale, shift = net(x2).split(x1.size(1), dim=1) 9 | # TODO: deal with scale 10 | # scale = torch.tanh(scale) 11 | # scale = torch.exp(scale) 12 | scale = torch.sigmoid(scale + 2.) + eps 13 | x1 = (x1 + shift) * scale 14 | if log_det_jac is not None: 15 | if scale.dim() == 4: 16 | log_det_jac += torch.log(scale).sum((1, 2, 3)) 17 | else: 18 | log_det_jac += torch.log(scale).sum(1) 19 | return x1 20 | 21 | 22 | def coupling_inv(x1, x2, net, eps=1e-6): 23 | scale, shift = net(x2).split(x1.size(1), dim=1) 24 | # TODO: deal with scale 25 | # scale = torch.tanh(scale) 26 | # scale = torch.exp(scale) 27 | scale = torch.sigmoid(scale + 2.) + eps 28 | x1 = x1 / scale - shift 29 | return x1 30 | 31 | 32 | def get_mask(xs, mask_type): 33 | if 'checkerboard' in mask_type: 34 | unit0 = np.array([[0.0, 1.0], [1.0, 0.0]]) 35 | unit1 = -unit0 + 1.0 36 | unit = unit0 if mask_type == 'checkerboard0' else unit1 37 | unit = np.reshape(unit, [1, 2, 2]) 38 | b = np.tile(unit, [xs[0], xs[1]//2, xs[2]//2]) 39 | elif 'channel' in mask_type: 40 | white = np.ones([xs[0]//2, xs[1], xs[2]]) 41 | black = np.zeros([xs[0]//2, xs[1], xs[2]]) 42 | if mask_type == 'channel0': 43 | b = np.concatenate([white, black], 0) 44 | else: 45 | b = np.concatenate([black, white], 0) 46 | 47 | if list(b.shape) != list(xs): 48 | b = np.tile(b, (1, 2, 2))[:, :xs[1], :xs[2]] 49 | 50 | return b 51 | 52 | 53 | class MaskedCouplingLayer(nn.Module): 54 | def __init__(self, shape, mask_type, net): 55 | super().__init__() 56 | mask = torch.FloatTensor(get_mask(shape, mask_type)) 57 | self.mask = nn.Parameter(mask[None], requires_grad=False) 58 | self.net = net 59 | self.eps = 1e-6 60 | 61 | def extra_repr(self): 62 | return 'MaskedCouplingLayer(mask=checkerboard)' 63 | 64 | def forward(self, x, log_det_jac, z): 65 | return self.f(x, log_det_jac, z) 66 | 67 | def f(self, x, log_det_jac, z): 68 | x1 = self.mask * x 69 | s, t = self.net(x1).split(x1.size(1), dim=1) 70 | logs = -F.softplus(-s-2.) 71 | logs *= (1 - self.mask) 72 | s = torch.sigmoid(s + 2.) 73 | s = (1 - self.mask) * s 74 | s += self.mask 75 | t = (1 - self.mask) * t 76 | x = x1 + (1 - self.mask) * (x * s + t) 77 | log_det_jac += torch.sum(logs, dim=(1, 2, 3)) 78 | return x, log_det_jac, z 79 | 80 | def g(self, x, z): 81 | x1 = self.mask * x 82 | s, t = self.net(x1).split(x1.size(1), dim=1) 83 | s = torch.sigmoid(s + 2.) + self.eps 84 | x = x1 + (1 - self.mask) * (x - t) / s 85 | return x, z 86 | 87 | 88 | class ConditionalMaskedCouplingLayer(nn.Module): 89 | def __init__(self, shape, mask_type, net): 90 | super().__init__() 91 | mask = torch.FloatTensor(get_mask(shape, mask_type)) 92 | self.mask = nn.Parameter(mask[None], requires_grad=False) 93 | self.net = net 94 | self.eps = 1e-6 95 | 96 | def forward(self, x, y, log_det_jac, z): 97 | return self.f(x, y, log_det_jac, z) 98 | 99 | def f(self, x, y, log_det_jac, z): 100 | x1 = self.mask * x 101 | 102 | assert y.dim() == 2 103 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3])) 104 | 105 | s, t = self.net(torch.cat([x1, y], dim=1)).split(x1.size(1), dim=1) 106 | logs = -F.softplus(-s-2.) 107 | logs *= (1 - self.mask) 108 | s = torch.sigmoid(s + 2.) 109 | s = (1 - self.mask) * s 110 | s += self.mask 111 | t = (1 - self.mask) * t 112 | x = x1 + (1 - self.mask) * (x * s + t) 113 | log_det_jac += torch.sum(logs, dim=(1, 2, 3)) 114 | return x, log_det_jac, z 115 | 116 | def g(self, x, y, z): 117 | x1 = self.mask * x 118 | 119 | assert y.dim() == 2 120 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3])) 121 | 122 | s, t = self.net(torch.cat([x1, y], dim=1)).split(x1.size(1), dim=1) 123 | s = torch.sigmoid(s + 2.) 124 | x = x1 + (1 - self.mask) * (x - t) / s 125 | return x, z 126 | 127 | 128 | class CouplingLayer(nn.Module): 129 | """ 130 | Coupling layer with channelwise mask applied twice. 131 | (e.g. see RealNVP https://arxiv.org/pdf/1605.08803.pdf for details) 132 | """ 133 | def __init__(self, netfunc): 134 | super().__init__() 135 | self.net1 = netfunc() 136 | self.net2 = netfunc() 137 | 138 | def extra_repr(self): 139 | return 'CouplingLayer(mask=channel)' 140 | 141 | def forward(self, x, log_det_jac, z): 142 | return self.f(x, log_det_jac, z) 143 | 144 | def f(self, x, log_det_jac, z): 145 | C = x.size(1) // 2 146 | x1, x2 = x.split(C, dim=1) 147 | x1, x2 = x1.contiguous(), x2.contiguous() 148 | 149 | x1 = coupling(x1, x2, self.net1, log_det_jac) 150 | x2 = coupling(x2, x1, self.net2, log_det_jac) 151 | 152 | return torch.cat([x1, x2], dim=1), log_det_jac, z 153 | 154 | def g(self, x, z): 155 | C = x.size(1) // 2 156 | x1, x2 = x.split(C, dim=1) 157 | x1, x2 = x1.contiguous(), x2.contiguous() 158 | 159 | x2 = coupling_inv(x2, x1, self.net2) 160 | x1 = coupling_inv(x1, x2, self.net1) 161 | 162 | return torch.cat([x1, x2], dim=1), z 163 | 164 | 165 | class ConditionalCouplingLayer(CouplingLayer): 166 | def forward(self, x, y, log_det_jac, z): 167 | return self.f(x, y, log_det_jac, z) 168 | 169 | def extra_repr(self): 170 | return 'ConditionalCouplingLayer(mask=channel)' 171 | 172 | def f(self, x, y, log_det_jac, z): 173 | C = x.size(1) // 2 174 | x1, x2 = x.split(C, dim=1) 175 | x1, x2 = x1.contiguous(), x2.contiguous() 176 | 177 | assert y.dim() == 2 178 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3])) 179 | 180 | x1 = coupling(x1, torch.cat([x2, y], dim=1), self.net1, log_det_jac) 181 | x2 = coupling(x2, torch.cat([x1, y], dim=1), self.net2, log_det_jac) 182 | 183 | return torch.cat([x1, x2], dim=1), log_det_jac, z 184 | 185 | def g(self, x, y, z): 186 | C = x.size(1) // 2 187 | x1, x2 = x.split(C, dim=1) 188 | x1, x2 = x1.contiguous(), x2.contiguous() 189 | 190 | assert y.dim() == 2 191 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3])) 192 | 193 | x2 = coupling_inv(x2, torch.cat([x1, y], dim=1), self.net2) 194 | x1 = coupling_inv(x1, torch.cat([x2, y], dim=1), self.net1) 195 | 196 | return torch.cat([x1, x2], dim=1), z 197 | 198 | 199 | class ConditionalShift(nn.Module): 200 | def __init__(self, channels, nfactors): 201 | super().__init__() 202 | self.factors = nn.Embedding(nfactors, channels) 203 | 204 | def forward(self, x, y, log_det_jac, z): 205 | return self.f(x, y, log_det_jac, z) 206 | 207 | def f(self, x, y, log_det_jac, z): 208 | shift = self.factors(y) 209 | return x - shift.view((x.size(0), -1, 1, 1)), log_det_jac, z 210 | 211 | def g(self, x, y, z): 212 | shift = self.factors(y) 213 | return x + shift.view((x.size(0), -1, 1, 1)), z 214 | -------------------------------------------------------------------------------- /models/distributions.py: -------------------------------------------------------------------------------- 1 | import torch.distributions as dist 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from numbers import Number 6 | import numpy as np 7 | import utils 8 | 9 | 10 | class Mixture(dist.Distribution): 11 | def __init__(self, base_ditributions, weights=None): 12 | super(Mixture, self).__init__(batch_shape=base_ditributions[0].batch_shape, 13 | event_shape=base_ditributions[0].event_shape) 14 | self.base_ditributions = base_ditributions 15 | self.weights = weights 16 | if self.weights is None: 17 | k = len(self.base_ditributions) 18 | self.weights = torch.ones((k,)) / float(k) 19 | 20 | def log_prob(self, x, detach=False): 21 | ens = None 22 | for prob, w in zip(self.base_ditributions, self.weights): 23 | logp = prob.log_prob(x) + torch.log(w) 24 | if detach: 25 | logp = logp.detach() 26 | if ens is None: 27 | ens = logp 28 | else: 29 | _t = torch.stack([ens, logp]) 30 | ens = torch.logsumexp(_t, dim=0) 31 | return ens 32 | 33 | def one_sample(self, labels=False): 34 | k = np.random.choice(len(self.weights), p=utils.tonp(self.weights)) 35 | if labels: 36 | return self.base_ditributions[k].sample(), k 37 | return self.base_ditributions[k].sample() 38 | 39 | def sample(self, sample_shape=torch.Size(), labels=False): 40 | if len(sample_shape) == 0: 41 | return self.one_sample() 42 | elif len(sample_shape) == 1: 43 | res, ys = [], [] 44 | for i in range(sample_shape[0]): 45 | if labels: 46 | samples, y = self.one_sample(labels=labels) 47 | res.append(samples) 48 | ys.append(y) 49 | else: 50 | res.append(self.one_sample()) 51 | if labels: 52 | return torch.stack(res), np.stack(ys) 53 | else: 54 | return torch.stack(res) 55 | elif len(sample_shape) == 2: 56 | res, y = [], [] 57 | for _ in range(sample_shape[0]): 58 | res.append([]) 59 | y.append([]) 60 | for _ in range(sample_shape[1]): 61 | if labels: 62 | samples, y = self.one_sample(labels=labels) 63 | res[-1].append(samples) 64 | ys[-1].append(y) 65 | else: 66 | res[-1].append(self.one_sample()) 67 | res[-1] = torch.stack(res[-1]) 68 | if labels: 69 | return torch.stack(res), np.stack(ys) 70 | else: 71 | return torch.stack(res) 72 | else: 73 | raise NotImplementedError 74 | 75 | 76 | class GMM(nn.Module): 77 | def __init__(self, k=None, dim=None, means=None, covariances=None, weights=None, normalize=False): 78 | super(GMM, self).__init__() 79 | if k is None and means is None: 80 | raise NotImplementedError 81 | 82 | if means is None: 83 | covars = torch.rand(k, dim, dim) 84 | covars = torch.matmul(covars, covars.transpose(1, 2)) 85 | self.means = nn.ParameterList([nn.Parameter(m) for m in torch.randn(k, dim)]) 86 | self.cov_factors = nn.ParameterList([nn.Parameter(torch.cholesky(cov)) for cov in covars]) 87 | self.weights = nn.Parameter(torch.FloatTensor([1./k] * k)) 88 | self.k = k 89 | else: 90 | self.means = nn.ParameterList([nn.Parameter(m) for m in means]) 91 | self.cov_factors = nn.ParameterList([nn.Parameter(torch.cholesky(cov)) for cov in covariances]) 92 | self.weights = nn.Parameter(weights) 93 | self.k = weights.shape[0] 94 | 95 | self.normalize = normalize 96 | 97 | def get_weights(self): 98 | if self.normalize: 99 | return F.softmax(self.weights, dim=0) 100 | return self.weights 101 | 102 | def get_dist(self): 103 | base_distributions = [] 104 | for m, covf in zip(self.means, self.cov_factors): 105 | if covf.dim() == 1: 106 | covf = torch.diag(covf) 107 | cov = torch.mm(covf, covf.t()) 108 | base_distributions.append(dist.MultivariateNormal(m, covariance_matrix=cov)) 109 | return Mixture(base_distributions, weights=self.get_weights()) 110 | 111 | def set_covariance(self, k, cov): 112 | self.cov_factors[k].data = torch.cholesky(cov) 113 | 114 | def set_params(self, means=None, covars=None, pi=None): 115 | if pi is not None: 116 | self.weights.data = torch.log(pi) if self.normalize else pi 117 | for k in range(self.k): 118 | if means is not None: 119 | self.means[k].data = means[k] 120 | if covars is not None: 121 | self.set_covariance(k, covars[k]) 122 | 123 | @property 124 | def covariances(self): 125 | return torch.stack([torch.mm(covf, covf.t()) for covf in self.cov_factors]) 126 | 127 | def log_prob(self, x, k=None): 128 | if k is None: 129 | p = self.get_dist() 130 | return p.log_prob(x) 131 | else: 132 | p = self.get_dist() 133 | return p.base_ditributions[k].log_prob(x) + torch.log(p.weights[k]) 134 | 135 | def sample(self, sample_shape=torch.Size(), labels=False): 136 | p = self.get_dist() 137 | return p.sample(sample_shape, labels=labels) 138 | 139 | 140 | class MultivariateNormalDiag(torch.distributions.Normal): 141 | def log_prob(self, x): 142 | logp = super().log_prob(x) 143 | return logp.sum(1) 144 | 145 | 146 | class GaussianDiag(nn.Module): 147 | def __init__(self, dim): 148 | super().__init__() 149 | self.dim = dim 150 | self.logsigma = nn.Parameter(torch.zeros((dim,))) 151 | self.mean = nn.Parameter(torch.zeros((dim,)), requires_grad=False) 152 | 153 | def _get_dist(self): 154 | scale = F.softplus(self.logsigma) 155 | return MultivariateNormalDiag(self.mean, scale) 156 | 157 | def log_prob(self, x): 158 | p = self._get_dist() 159 | return p.log_prob(x) 160 | 161 | def log_prob_full(self, x): 162 | p = self._get_dist() 163 | return p.log_prob(x)[:, None] 164 | 165 | 166 | class GmmPrior(nn.Module): 167 | def __init__(self, k=None, dim=None, full_dim=None, means=None, covariances=None, weights=None, cov_type='diag'): 168 | super().__init__() 169 | if k is None and means is None: 170 | raise NotImplementedError 171 | 172 | if cov_type not in ['diag']: 173 | raise NotImplementedError 174 | self.cov_type = cov_type 175 | 176 | if full_dim is None: 177 | full_dim = dim 178 | 179 | if means is None: 180 | means = torch.randn(k, dim) * np.sqrt(2) 181 | if cov_type == 'diag': 182 | # covariances = torch.log(torch.rand(k, dim) * 0.5) 183 | covariances = torch.zeros((k, dim)) 184 | weights = torch.FloatTensor([1./k] * k) 185 | 186 | self.means = nn.Parameter(means) 187 | self.cov_factors = nn.Parameter(covariances) 188 | self.weights = nn.Parameter(weights) 189 | self.k = self.weights.shape[0] 190 | self.dim = self.means.shape[1] 191 | self.full_dim = full_dim 192 | self.sn_dim = self.full_dim - self.dim 193 | 194 | def get_logpi(self): 195 | return F.log_softmax(self.weights, dim=0) 196 | 197 | def get_dist(self): 198 | base_distributions = [] 199 | for m, covf in zip(self.means, self.cov_factors): 200 | m = torch.cat([torch.zeros((self.sn_dim,)).to(m.device), m]) 201 | if self.cov_type == 'diag': 202 | covf = torch.cat([torch.zeros((self.sn_dim,)).to(covf.device), covf]) 203 | # TODO: softplus seems to be more stable 204 | scale = torch.exp(covf * 0.5) 205 | base_distributions.append(MultivariateNormalDiag(m, scale)) 206 | 207 | pi = torch.exp(self.get_logpi()) 208 | return Mixture(base_distributions, weights=pi) 209 | 210 | def log_prob(self, x, k=None): 211 | logpi = self.get_logpi() 212 | if k is None: 213 | p = self.get_dist() 214 | return p.log_prob(x) 215 | else: 216 | p = self.get_dist() 217 | return p.base_ditributions[k].log_prob(x) + logpi[k] 218 | 219 | def log_prob_full(self, x): 220 | return torch.stack([self.log_prob(x, k=k) for k in range(self.k)]).transpose(0, 1) 221 | 222 | def log_prob_full_fast(self, x): 223 | if self.cov_type != 'diag': 224 | raise NotImplementedError 225 | var = torch.exp(self.cov_factors) 226 | logp = -(x[:, None] - self.means[None])**2 / (2. * var[None]) - 0.5 * self.cov_factors 227 | return logp.sum(2) + self.get_logpi()[None] - 0.5 * np.log(2 * np.pi) * self.dim 228 | 229 | def sample(self, sample_shape=torch.Size(), labels=False): 230 | p = self.get_dist() 231 | return p.sample(sample_shape, labels=labels) 232 | 233 | def set_params(self, means=None, covars=None, pi=None): 234 | if pi is not None: 235 | self.weights.data = torch.log(pi) if self.normalize else pi 236 | if means is not None: 237 | self.means.data = torch.tensor(means) 238 | if covars is not None: 239 | self.cov_factors.data = torch.log(covars) 240 | -------------------------------------------------------------------------------- /models/flows.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from models.normalization import ActNorm, DummyCondActNorm 6 | from models.invertconv import InvertibleConv2d, HHConv2d_1x1, QRInvertibleConv2d, DummyCondInvertibleConv2d 7 | from models.coupling import CouplingLayer, MaskedCouplingLayer, ConditionalCouplingLayer, ConditionalMaskedCouplingLayer 8 | from models.utils import Conv2dZeros, SpaceToDepth, FactorOut, ToLogits, CondToLogits, CondFactorOut, CondSpaceToDepth 9 | from models.utils import DummyCond, IdFunction 10 | import warnings 11 | 12 | 13 | class Flow(nn.Module): 14 | def __init__(self, modules): 15 | super().__init__() 16 | self.modules_ = nn.ModuleList(modules) 17 | self.latent_len = -1 18 | self.x_shape = -1 19 | 20 | def f(self, x): 21 | z = None 22 | log_det_jac = torch.zeros((x.shape[0],)).to(x.device) 23 | for m in self.modules_: 24 | x, log_det_jac, z = m(x, log_det_jac, z) 25 | if z is None: 26 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device) 27 | self.x_shape = list(x.shape)[1:] 28 | self.latent_len = z.shape[1] 29 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1) 30 | return log_det_jac, z 31 | 32 | def forward(self, x): 33 | return self.f(x) 34 | 35 | def g(self, z): 36 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape) 37 | z = z[:, :self.latent_len] 38 | for m in reversed(self.modules_): 39 | x, z = m.g(x, z) 40 | return x 41 | 42 | 43 | class ConditionalFlow(Flow): 44 | def f(self, x, y): 45 | z = None 46 | log_det_jac = torch.zeros((x.shape[0],)).to(x.device) 47 | for m in self.modules_: 48 | x, log_det_jac, z = m(x, y, log_det_jac, z) 49 | if z is None: 50 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device) 51 | self.x_shape = list(x.shape)[1:] 52 | self.latent_len = z.shape[1] 53 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1) 54 | return log_det_jac, z 55 | 56 | def g(self, z, y): 57 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape) 58 | z = z[:, :self.latent_len] 59 | for m in reversed(self.modules_): 60 | x, z = m.g(x, y, z) 61 | return x 62 | 63 | def forward(self, x, y): 64 | return self.f(x, y) 65 | 66 | 67 | class DiscreteConditionalFlow(ConditionalFlow): 68 | def __init__(self, modules, num_cat, emb_dim): 69 | super().__init__(modules) 70 | self.embeddings = nn.Embedding(num_cat, emb_dim) 71 | 72 | self.embeddings.weight.data.zero_() 73 | l = torch.arange(self.embeddings.weight.data.shape[0]) 74 | self.embeddings.weight.data[l, l] = 1. 75 | 76 | def f(self, x, y): 77 | return super().f(x, self.embeddings(y)) 78 | 79 | def g(self, z, y): 80 | return super().g(z, self.embeddings(y)) 81 | 82 | 83 | class FlowPDF(nn.Module): 84 | def __init__(self, flow, prior): 85 | super().__init__() 86 | self.flow = flow 87 | self.prior = prior 88 | 89 | def log_prob(self, x): 90 | log_det, z = self.flow(x) 91 | return log_det + self.prior.log_prob(z) 92 | 93 | 94 | class DeepConditionalFlowPDF(nn.Module): 95 | def __init__(self, flow, deep_prior, yprior, deep_dim, shallow_prior=None): 96 | super().__init__() 97 | self.flow = flow 98 | self.shallow_prior = shallow_prior 99 | self.deep_prior = deep_prior 100 | self.yprior = yprior 101 | self.deep_dim = deep_dim 102 | 103 | def log_prob(self, x, y): 104 | if x.dim() == 2: 105 | x = x[..., None, None] 106 | if self.deep_dim == x.shape[1]: 107 | log_det, z = self.flow(x, y) 108 | return log_det + self.deep_prior.log_prob(z) 109 | else: 110 | log_det, z = self.flow(x[:, -self.deep_dim:], y) 111 | return log_det + self.deep_prior.log_prob(z) + self.shallow_prior.log_prob(x[:, :-self.deep_dim].squeeze()) 112 | 113 | def log_prob_joint(self, x, y): 114 | return self.log_prob(x, y) + self.yprior.log_prob(y) 115 | 116 | 117 | class ConditionalFlowPDF(nn.Module): 118 | def __init__(self, flow, prior, emb=True): 119 | super().__init__() 120 | self.flow = flow 121 | self.prior = prior 122 | 123 | def log_prob(self, x, y): 124 | log_det, z = self.flow(x, y) 125 | return log_det + self.prior.log_prob(z) 126 | 127 | 128 | class DiscreteConditionalFlowPDF(DeepConditionalFlowPDF): 129 | def log_prob_full(self, x): 130 | sup = self.yprior.enumerate_support().to(x.device) 131 | logp = [] 132 | 133 | n_uniq = sup.size(0) 134 | y = sup.repeat((x.size(0), 1)).t().reshape((1, -1)).t()[:, 0] 135 | logp = self.log_prob(x.repeat([n_uniq] + [1]*(len(x.shape)-1)), y) 136 | return logp.reshape((n_uniq, x.size(0))).t() + self.yprior.log_prob(sup[None]) 137 | 138 | def log_prob(self, x, y=None): 139 | if y is not None: 140 | return super().log_prob(x, y) 141 | else: 142 | logp_joint = self.log_prob_full(x) 143 | return torch.logsumexp(logp_joint, dim=1) 144 | 145 | def log_prob_posterior(self, x): 146 | logp_joint = self.log_prob_full(x) 147 | return logp_joint - torch.logsumexp(logp_joint, dim=1)[:, None] 148 | 149 | 150 | class ResNetBlock(nn.Module): 151 | def __init__(self, channels, use_bn=False): 152 | super().__init__() 153 | modules = [] 154 | if use_bn: 155 | # modules.append(nn.BatchNorm2d(channels)) 156 | ActNorm(channels, flow=False) 157 | modules += [ 158 | nn.ReLU(), 159 | nn.ReflectionPad2d(1), 160 | nn.Conv2d(channels, channels, 3)] 161 | if use_bn: 162 | # modules.append(nn.BatchNorm2d(channels)) 163 | ActNorm(channels, flow=False) 164 | modules += [ 165 | nn.ReLU(), 166 | nn.ReflectionPad2d(1), 167 | nn.Conv2d(channels, channels, 3)] 168 | 169 | self.net = nn.Sequential(*modules) 170 | 171 | def forward(self, x): 172 | return self.net(x) + x 173 | 174 | 175 | class ResNetBlock1x1(nn.Module): 176 | def __init__(self, channels, use_bn=False): 177 | super().__init__() 178 | modules = [] 179 | if use_bn: 180 | ActNorm(channels, flow=False) 181 | modules += [ 182 | nn.ReLU(), 183 | nn.Conv2d(channels, channels, 1)] 184 | if use_bn: 185 | ActNorm(channels, flow=False) 186 | modules += [ 187 | nn.ReLU(), 188 | nn.Conv2d(channels, channels, 1)] 189 | 190 | self.net = nn.Sequential(*modules) 191 | 192 | def forward(self, x): 193 | return self.net(x) + x 194 | 195 | 196 | def get_resnet1x1(in_channels, channels, out_channels=None): 197 | if out_channels is None: 198 | out_channels = in_channels * 2 199 | net = nn.Sequential( 200 | nn.Conv2d(in_channels, channels, 1, padding=0), 201 | ResNetBlock1x1(channels, use_bn=True), 202 | ResNetBlock1x1(channels, use_bn=True), 203 | ResNetBlock1x1(channels, use_bn=True), 204 | ResNetBlock1x1(channels, use_bn=True), 205 | ActNorm(channels, flow=False), 206 | nn.ReLU(), 207 | Conv2dZeros(channels, out_channels, 1, padding=0), 208 | ) 209 | return net 210 | 211 | 212 | def get_resnet(in_channels, channels, out_channels=None): 213 | if out_channels is None: 214 | out_channels = in_channels * 2 215 | net = nn.Sequential( 216 | nn.ReflectionPad2d(1), 217 | nn.Conv2d(in_channels, channels, 3), 218 | ResNetBlock(channels, use_bn=True), 219 | ResNetBlock(channels, use_bn=True), 220 | ResNetBlock(channels, use_bn=True), 221 | ResNetBlock(channels, use_bn=True), 222 | ActNorm(channels, flow=False), 223 | nn.ReLU(), 224 | nn.ReflectionPad2d(1), 225 | Conv2dZeros(channels, out_channels, 3, 0), 226 | ) 227 | return net 228 | 229 | 230 | def get_resnet8(in_channels, channels, out_channels=None): 231 | if out_channels is None: 232 | out_channels = in_channels * 2 233 | net = nn.Sequential( 234 | nn.ReflectionPad2d(1), 235 | nn.Conv2d(in_channels, channels, 3), 236 | ResNetBlock(channels, use_bn=True), 237 | ResNetBlock(channels, use_bn=True), 238 | ResNetBlock(channels, use_bn=True), 239 | ResNetBlock(channels, use_bn=True), 240 | ResNetBlock(channels, use_bn=True), 241 | ResNetBlock(channels, use_bn=True), 242 | ResNetBlock(channels, use_bn=True), 243 | ResNetBlock(channels, use_bn=True), 244 | ActNorm(channels, flow=False), 245 | nn.ReLU(), 246 | nn.ReflectionPad2d(1), 247 | Conv2dZeros(channels, out_channels, 3, 0), 248 | ) 249 | return net 250 | 251 | 252 | def netfunc_for_coupling(in_channels, hidden_channels, out_channels, k=3): 253 | def foo(): 254 | return nn.Sequential( 255 | nn.Conv2d(in_channels, hidden_channels, k, padding=int(k == 3)), 256 | nn.ReLU(False), 257 | nn.Conv2d(hidden_channels, hidden_channels, 1), 258 | nn.ReLU(False), 259 | Conv2dZeros(hidden_channels, out_channels, k, padding=int(k == 3)) 260 | ) 261 | 262 | return foo 263 | 264 | 265 | def get_flow(num_layers, k_factor, in_channels=1, hid_dim=[256], conv='full', hh_factors=2, 266 | cond=False, emb_dim=10, n_cat=10, net='shallow'): 267 | modules = [ 268 | DummyCond(ToLogits()) if cond else ToLogits(), 269 | ] 270 | channels = in_channels 271 | 272 | if conv == 'full': 273 | convf = lambda x: InvertibleConv2d(x) 274 | elif conv == 'hh': 275 | convf = lambda x: HHConv2d_1x1(x, factors=[x]*hh_factors) 276 | elif conv == 'qr': 277 | convf = lambda x: QRInvertibleConv2d(x, factors=[x]*hh_factors) 278 | elif conv == 'qr-abs': 279 | convf = lambda x: QRInvertibleConv2d(x, factors=[x]*hh_factors, act='no') 280 | elif conv == 'no': 281 | convf = lambda x: IdFunction() 282 | else: 283 | raise NotImplementedError 284 | 285 | if net == 'shallow': 286 | couplingnetf = lambda x, y: netfunc_for_coupling(x, hid_dim[0], y) 287 | elif net == 'resnet': 288 | couplingnetf = lambda x, y: lambda: get_resnet(x, hid_dim[0], out_channels=y) 289 | else: 290 | raise NotImplementedError 291 | 292 | for l in range(num_layers): 293 | # TODO: FIX 294 | warnings.warn('==== "get_flow" reduce spatial dimensions only 4 times!!! ====') 295 | if l != 4: 296 | if cond: 297 | modules.append(DummyCond(SpaceToDepth(2))) 298 | else: 299 | modules.append(SpaceToDepth(2)) 300 | channels *= 4 301 | for k in range(k_factor): 302 | if cond: 303 | modules.append(DummyCond(ActNorm(channels))) 304 | modules.append(DummyCond(convf(channels))) 305 | modules.append(ConditionalCouplingLayer(couplingnetf(channels // 2 + emb_dim, channels))) 306 | else: 307 | modules.append(ActNorm(channels)) 308 | modules.append(convf(channels)) 309 | modules.append(CouplingLayer(couplingnetf(channels // 2, channels))) 310 | 311 | if l != num_layers - 1: 312 | if cond: 313 | modules.append(DummyCond(FactorOut())) 314 | else: 315 | modules.append(FactorOut()) 316 | 317 | channels //= 2 318 | channels -= channels % 2 319 | 320 | return DiscreteConditionalFlow(modules, n_cat, emb_dim) if cond else Flow(modules) 321 | 322 | 323 | def get_flow_cond(num_layers, k_factor, in_channels=1, hid_dim=256, conv='full', hh_factors=2, num_cat=10, emb_dim=10): 324 | modules = [] 325 | channels = in_channels 326 | for l in range(num_layers): 327 | for k in range(k_factor): 328 | modules.append(DummyCondActNorm(channels)) 329 | if conv == 'full': 330 | modules.append(DummyCondInvertibleConv2d(channels)) 331 | elif conv == 'hh': 332 | modules.append(DummyCond(HHConv2d_1x1(channels, factors=[channels]*hh_factors))) 333 | elif conv == 'qr': 334 | modules.append(DummyCond(QRInvertibleConv2d(channels, factors=[channels]*hh_factors))) 335 | elif conv == 'qr-abs': 336 | modules.append(DummyCond(QRInvertibleConv2d(channels, factors=[channels]*hh_factors, act='no'))) 337 | else: 338 | raise NotImplementedError 339 | 340 | netf = lambda: get_resnet1x1(channels//2 + emb_dim, hid_dim, channels) 341 | modules.append(ConditionalCouplingLayer(netf)) 342 | 343 | if l != num_layers - 1: 344 | modules.append(CondFactorOut()) 345 | channels //= 2 346 | channels -= channels % 2 347 | 348 | return DiscreteConditionalFlow(modules, num_cat, emb_dim) 349 | 350 | 351 | def mnist_flow(num_layers=5, k_factor=4, logits=True, conv='full', hh_factors=2, hid_dim=[32, 784]): 352 | modules = [] 353 | if logits: 354 | modules.append(ToLogits()) 355 | 356 | channels = 1 357 | hd = hid_dim[0] 358 | kernel = 3 359 | for l in range(num_layers): 360 | if l < 2: 361 | modules.append(SpaceToDepth(2)) 362 | channels *= 4 363 | elif l == 2: 364 | modules.append(SpaceToDepth(7)) 365 | channels *= 49 366 | hd = hid_dim[1] 367 | kernel = 1 368 | 369 | for k in range(k_factor): 370 | modules.append(ActNorm(channels)) 371 | if conv == 'full': 372 | modules.append(InvertibleConv2d(channels)) 373 | elif conv == 'hh': 374 | modules.append(HHConv2d_1x1(channels, factors=[channels]*hh_factors)) 375 | elif conv == 'qr': 376 | modules.append(QRInvertibleConv2d(channels, factors=[channels]*hh_factors)) 377 | elif conv == 'qr-abs': 378 | modules.append(QRInvertibleConv2d(channels, factors=[channels]*hh_factors, act='no')) 379 | else: 380 | raise NotImplementedError 381 | modules.append(CouplingLayer(netfunc_for_coupling(channels, hd, k=kernel))) 382 | 383 | if l != num_layers - 1: 384 | modules.append(FactorOut()) 385 | channels //= 2 386 | channels -= channels % 2 387 | 388 | return Flow(modules) 389 | 390 | 391 | def mnist_masked_glow(conv='full', hh_factors=2): 392 | def get_net(in_channels, channels): 393 | net = nn.Sequential( 394 | nn.ReflectionPad2d(1), 395 | nn.Conv2d(in_channels, channels, 3), 396 | ResNetBlock(channels, use_bn=True), 397 | ResNetBlock(channels, use_bn=True), 398 | ResNetBlock(channels, use_bn=True), 399 | ResNetBlock(channels, use_bn=True), 400 | ActNorm(channels, flow=False), 401 | nn.ReLU(), 402 | nn.ReflectionPad2d(1), 403 | Conv2dZeros(channels, in_channels * 2, 3, 0), 404 | ) 405 | return net 406 | 407 | if conv == 'full': 408 | convf = lambda x: InvertibleConv2d(x) 409 | elif conv == 'qr': 410 | convf = lambda x: QRInvertibleConv2d(x, [x]*hh_factors) 411 | elif conv == 'hh': 412 | convf = lambda x: HHConv2d_1x1(x, [x]*hh_factors) 413 | else: 414 | raise NotImplementedError 415 | 416 | modules = [ 417 | ToLogits(), 418 | convf(1), 419 | MaskedCouplingLayer([1, 28, 28], 'checkerboard0', get_net(1, 64)), 420 | ActNorm(1), 421 | convf(1), 422 | MaskedCouplingLayer([1, 28, 28], 'checkerboard1', get_net(1, 64)), 423 | ActNorm(1), 424 | convf(1), 425 | MaskedCouplingLayer([1, 28, 28], 'checkerboard0', get_net(1, 64)), 426 | ActNorm(1), 427 | SpaceToDepth(2), 428 | convf(4), 429 | CouplingLayer(lambda: get_net(2, 64)), 430 | ActNorm(4), 431 | convf(4), 432 | CouplingLayer(lambda: get_net(2, 64)), 433 | ActNorm(4), 434 | 435 | FactorOut(), 436 | 437 | convf(2), 438 | MaskedCouplingLayer([2, 14, 14], 'checkerboard0', get_net(2, 64)), 439 | ActNorm(2), 440 | convf(2), 441 | MaskedCouplingLayer([2, 14, 14], 'checkerboard1', get_net(2, 64)), 442 | ActNorm(2), 443 | convf(2), 444 | MaskedCouplingLayer([2, 14, 14], 'checkerboard0', get_net(2, 64)), 445 | ActNorm(2), 446 | SpaceToDepth(2), 447 | convf(8), 448 | CouplingLayer(lambda: get_net(4, 64)), 449 | ActNorm(8), 450 | convf(8), 451 | CouplingLayer(lambda: get_net(4, 64)), 452 | ActNorm(8), 453 | 454 | FactorOut(), 455 | 456 | convf(4), 457 | MaskedCouplingLayer([4, 7, 7], 'checkerboard0', get_net(4, 64)), 458 | ActNorm(4), 459 | convf(4), 460 | MaskedCouplingLayer([4, 7, 7], 'checkerboard1', get_net(4, 64)), 461 | ActNorm(4), 462 | convf(4), 463 | MaskedCouplingLayer([4, 7, 7], 'checkerboard0', get_net(4, 64)), 464 | ActNorm(4), 465 | convf(4), 466 | CouplingLayer(lambda: get_net(2, 64)), 467 | ActNorm(4), 468 | convf(4), 469 | CouplingLayer(lambda: get_net(2, 64)), 470 | ActNorm(4), 471 | ] 472 | 473 | return Flow(modules) 474 | 475 | 476 | def toy2d_flow(conv='full', hh_factors=2, l=5): 477 | def netf(): 478 | return nn.Sequential( 479 | nn.Conv2d(1, 64, 1), 480 | nn.LeakyReLU(), 481 | nn.Conv2d(64, 64, 1), 482 | nn.LeakyReLU(), 483 | nn.Conv2d(64, 2, 1) 484 | ) 485 | 486 | if conv == 'full': 487 | convf = lambda x: InvertibleConv2d(x) 488 | elif conv == 'qr': 489 | convf = lambda x: QRInvertibleConv2d(x, [x]*hh_factors) 490 | elif conv == 'hh': 491 | convf = lambda x: HHConv2d_1x1(x, [x]*hh_factors) 492 | else: 493 | raise NotImplementedError 494 | 495 | modules = [] 496 | for _ in range(l): 497 | modules.append(convf(2)) 498 | modules.append(CouplingLayer(netf)) 499 | modules.append(ActNorm(2)) 500 | return Flow(modules) 501 | -------------------------------------------------------------------------------- /models/invertconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def num_pixels(tensor): 8 | assert tensor.dim() == 4 9 | return tensor.size(2) * tensor.size(3) 10 | 11 | 12 | class _BaseInvertibleConv2d(nn.Module): 13 | def _get_w(self): 14 | raise NotImplementedError 15 | 16 | def _get_w_inv(self): 17 | raise NotImplementedError 18 | 19 | def _log_det(self, x=None): 20 | raise NotImplementedError 21 | 22 | def forward(self, x, log_det_jac, z): 23 | log_det_jac += self._log_det(x) 24 | W = self._get_w().to(x) 25 | W = W.unsqueeze(-1).unsqueeze(-1) 26 | return F.conv2d(x, W), log_det_jac, z 27 | 28 | def g(self, x, z): 29 | W = self._get_w_inv() 30 | W = W.unsqueeze(-1).unsqueeze(-1) 31 | return F.conv2d(x, W), z 32 | 33 | 34 | class InvertibleConv2d(_BaseInvertibleConv2d): 35 | ''' 36 | Diederik P. Kingma, Prafulla Dhariwal 37 | "Glow: Generative Flow with Invertible 1x1 Convolutions" 38 | https://arxiv.org/pdf/1807.03039.pdf 39 | ''' 40 | def __init__(self, features): 41 | super().__init__() 42 | self.features = features 43 | self.W = nn.Parameter(torch.Tensor(features, features)) 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | # nn.init.orthogonal_(self.W) 48 | self.W.data = torch.eye(self.features).to(self.W.device) 49 | 50 | def _get_w(self): 51 | return self.W 52 | 53 | def _get_w_inv(self): 54 | return torch.inverse(self.W.double()).float() 55 | 56 | def _log_det(self, x): 57 | return torch.slogdet(self.W.double())[1].float() * num_pixels(x) 58 | 59 | def extra_repr(self): 60 | return 'InvertibleConv2d({:d})'.format(self.features) 61 | 62 | 63 | def householder_matrix(v, size=None): 64 | """ 65 | householder_matrix(Tensor, size=None) -> Tensor 66 | 67 | Arguments 68 | v: Tensor of size [Any,] 69 | size: `int` or `None`. The size of the resulting matrix. 70 | size >= v.size(0) 71 | Output 72 | I - 2 v^T * v / v*v^T: Tensor of size [size, size] 73 | """ 74 | size = size or v.size(0) 75 | assert size >= v.size(0) 76 | v = torch.cat([torch.ones(size - v.size(0), device=v.device), v]) 77 | I = torch.eye(size, device=v.device) 78 | outer = torch.ger(v, v) 79 | inner = torch.dot(v, v) + 1e-16 80 | return I - 2 * outer / inner 81 | 82 | 83 | def naive_cascade(vectors, size=None): 84 | """ 85 | naive_cascade([Tensor, Tensor, ...], size=None) -> Tensor 86 | naive implementation 87 | 88 | Arguments 89 | vectors: list of Tensors of size [Any,] 90 | size: `int` or `None`. The size of the resulting matrix. 91 | size >= max(v.size(0) for v in vectors) 92 | Output 93 | Q: `torch.Tensor` of size [size, size] 94 | """ 95 | size = size or max(v.size(0) for v in vectors) 96 | assert size >= max(v.size(0) for v in vectors) 97 | device = vectors[0].device 98 | Q = torch.eye(size, device=device) 99 | for v in vectors: 100 | Q = torch.mm(Q, householder_matrix(v, size=size)) 101 | return Q 102 | 103 | 104 | class HHConv2d_1x1(_BaseInvertibleConv2d): 105 | def __init__(self, features, factors=None): 106 | super().__init__() 107 | self.features = features 108 | self.factors = factors or range(2, features + 1) 109 | 110 | # init vectors 111 | self.vectors = [] 112 | for i, f in enumerate(self.factors): 113 | vec = nn.Parameter(torch.Tensor(f)) 114 | self.register_parameter('vec_{}'.format(i), vec) 115 | self.vectors.append(vec) 116 | 117 | self.reset_parameters() 118 | self.cascade = naive_cascade 119 | 120 | def reset_parameters(self): 121 | for v in self.vectors: 122 | v.data.uniform_(-1, 1) 123 | with torch.no_grad(): 124 | v /= (torch.norm(v) + 1e-16) 125 | 126 | def _get_w(self): 127 | return self.cascade(self.vectors, self.features) 128 | 129 | def _get_w_inv(self): 130 | return self._get_w().t() 131 | 132 | def _log_det(self, x): 133 | return 0. 134 | 135 | 136 | class QRInvertibleConv2d(HHConv2d_1x1): 137 | """ 138 | Hoogeboom, Emiel and Berg, Rianne van den and Welling, Max 139 | "Emerging Convolutions for Generative Normalizing Flows" 140 | https://arxiv.org/pdf/1901.11137.pdf 141 | """ 142 | def __init__(self, features, factors=None, act='softplus'): 143 | super().__init__(features, factors=factors) 144 | self.act = act 145 | if act == 'softplus': 146 | self.s_factor = nn.Parameter(torch.zeros((features,))) 147 | elif act == 'no': 148 | self.s_factor = nn.Parameter(torch.ones((features,))) 149 | else: 150 | raise NotImplementedError 151 | 152 | self.r = nn.Parameter(torch.zeros((features, features))) 153 | 154 | def _get_w(self): 155 | Q = super()._get_w() 156 | if self.act == 'softplus': 157 | R = torch.diag(F.softplus(self.s_factor)) 158 | elif self.act == 'no': 159 | R = torch.diag(self.s_factor) 160 | 161 | R += torch.triu(self.r, diagonal=1) 162 | return Q.to(R) @ R 163 | 164 | def _log_det(self, x=None): 165 | if self.act == 'softplus': 166 | return torch.log(F.softplus(self.s_factor)).sum() * num_pixels(x) 167 | elif self.act == 'no': 168 | return torch.log(torch.abs(self.s_factor)).sum() * num_pixels(x) 169 | 170 | def _get_w_inv(self): 171 | Q = super()._get_w().to(self.s_factor) 172 | if self.act == 'softplus': 173 | R = torch.diag(F.softplus(self.s_factor)) 174 | elif self.act == 'no': 175 | R = torch.diag(self.s_factor) 176 | 177 | R += torch.triu(self.r, diagonal=1) 178 | return torch.inverse(R.double()).float() @ Q.t() 179 | 180 | 181 | class DummyCondInvertibleConv2d(InvertibleConv2d): 182 | def forward(self, x, y, log_det_jac, z): 183 | return super().forward(x, log_det_jac, z) 184 | 185 | def g(self, x, y, z): 186 | return super().g(x, z) 187 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def num_pixels(tensor): 7 | assert tensor.dim() == 4 8 | return tensor.size(2) * tensor.size(3) 9 | 10 | 11 | class ActNorm(nn.Module): 12 | def __init__(self, num_features, eps=1e-5, flow=True): 13 | super().__init__() 14 | self.num_features = num_features 15 | self.eps = eps 16 | self.logs = nn.Parameter(torch.Tensor(num_features)) 17 | self.bias = nn.Parameter(torch.Tensor(num_features)) 18 | self.requires_init = nn.Parameter(torch.ByteTensor(1), requires_grad=False) 19 | self.reset_parameters() 20 | self.flow = flow 21 | 22 | def reset_parameters(self): 23 | self.logs.data.zero_() 24 | self.bias.data.zero_() 25 | self.requires_init.data.fill_(True) 26 | 27 | def init_data_dependent(self, x): 28 | with torch.no_grad(): 29 | x_ = x.transpose(0, 1).contiguous().view(self.num_features, -1) 30 | mean = x_.mean(1) 31 | var = x_.var(1) 32 | logs = -torch.log(torch.sqrt(var) + 1e-6) 33 | self.logs.data.copy_(logs.data) 34 | self.bias.data.copy_(mean.data) 35 | 36 | def forward(self, x, log_det_jac=None, z=None): 37 | assert x.size(1) == self.num_features 38 | if self.requires_init: 39 | self.requires_init.data.fill_(False) 40 | self.init_data_dependent(x) 41 | 42 | size = [1] * x.ndimension() 43 | size[1] = self.num_features 44 | x = (x - self.bias.view(*size)) * torch.exp(self.logs.view(*size)) 45 | if not self.flow: 46 | return x 47 | log_det_jac += self.logs.sum() * num_pixels(x) 48 | return x, log_det_jac, z 49 | 50 | def g(self, x, z): 51 | size = [1] * x.ndimension() 52 | size[1] = self.num_features 53 | x = x * torch.exp(-self.logs.view(*size)) + self.bias.view(*size) 54 | return x, z 55 | 56 | def inverse(self, x): 57 | size = [1] * x.ndimension() 58 | size[1] = self.num_features 59 | x = x * torch.exp(-self.logs.view(*size)) + self.bias.view(*size) 60 | return x 61 | 62 | def log_det(self): 63 | return self._log_det 64 | 65 | def extra_repr(self): 66 | return 'ActNorm({}, requires_init={})'.format(self.num_features, bool(self.requires_init.item())) 67 | 68 | 69 | class DummyCondActNorm(ActNorm): 70 | def forward(self, x, y, log_det_jac=None, z=None): 71 | return super().forward(x, log_det_jac, z) 72 | 73 | def g(self, x, y, z): 74 | return super().g(x, z) 75 | -------------------------------------------------------------------------------- /models/realnvp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import distributions 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from models.normalization import ActNorm 8 | 9 | 10 | class RealNVPold(nn.Module): 11 | def __init__(self, nets, nett, masks, prior, device=None): 12 | super().__init__() 13 | 14 | self.prior = prior 15 | self.mask = nn.Parameter(masks, requires_grad=False) 16 | self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))]) 17 | self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))]) 18 | 19 | self.to(device) 20 | self.device = device 21 | 22 | def g(self, z): 23 | x = z 24 | for i in range(len(self.t)): 25 | x_ = x*self.mask[i] 26 | s = self.s[i](x_)*(1 - self.mask[i]) 27 | t = self.t[i](x_)*(1 - self.mask[i]) 28 | x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t) 29 | return x 30 | 31 | def f(self, x): 32 | log_det_J, z = x.new_zeros(x.shape[0]), x 33 | for i in reversed(range(len(self.t))): 34 | z_ = self.mask[i] * z 35 | s = self.s[i](z_) * (1-self.mask[i]) 36 | t = self.t[i](z_) * (1-self.mask[i]) 37 | z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_ 38 | if x.dim() == 2: 39 | log_det_J -= s.sum(dim=1) 40 | else: 41 | log_det_J -= s.sum(dim=(1, 2, 3)) 42 | 43 | return z, log_det_J 44 | 45 | def log_prob(self, x): 46 | z, logp = self.f(x) 47 | return self.prior.log_prob(z) + logp 48 | 49 | def sample(self, batchSize): 50 | z = self.prior.sample((batchSize, 1)) 51 | logp = self.prior.log_prob(z) 52 | x = self.g(z) 53 | return x 54 | 55 | 56 | def get_toy_nvp(prior=None, device=None): 57 | def nets(): 58 | return nn.Sequential(nn.Linear(2, 32), 59 | nn.LeakyReLU(), 60 | nn.Linear(32, 2), 61 | nn.Tanh() 62 | ) 63 | 64 | def nett(): 65 | return nn.Sequential(nn.Linear(2, 32), 66 | nn.LeakyReLU(), 67 | nn.Linear(32, 2) 68 | ) 69 | 70 | if prior is None: 71 | prior = distributions.MultivariateNormal(torch.zeros(2).to(device), 72 | torch.eye(2).to(device)) 73 | 74 | masks = torch.from_numpy(np.array([[0, 1], [1, 0]] * 3).astype(np.float32)) 75 | return RealNVPold(nets, nett, masks, prior, device=device) 76 | 77 | 78 | class NFGMM(RealNVPold): 79 | def log_prob(self, x, k=None): 80 | if k is None: 81 | z, logp = self.f(x) 82 | return self.prior.log_prob(z) + logp 83 | else: 84 | z, logp = self.f(x) 85 | return self.prior.log_prob(z, k=k) + logp 86 | 87 | 88 | def gmm_prior(k): 89 | covars = torch.rand(args.gmm_k, 2, 2) 90 | covars = torch.matmul(covars, covars.transpose(1, 2)) 91 | prior = distributions.GMM(torch.randn(args.gmm_k, 2), covars, torch.FloatTensor([0.5] * args.gmm_k), 92 | normalize=args.prior_train_algo == 'GD') 93 | 94 | 95 | class WNConv2d(nn.Conv2d): 96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 97 | padding=0, dilation=1, groups=1, bias=True): 98 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 99 | dilation=dilation, groups=groups, bias=bias) 100 | self.scale = nn.Parameter(torch.ones((1,))) 101 | self.scale.reg = True 102 | self.eps = 1e-6 103 | 104 | def forward(self, input): 105 | w = self.weight / (torch.norm(self.weight) + self.eps) * self.scale 106 | return F.conv2d(input, w, self.bias, self.stride, 107 | self.padding, self.dilation, self.groups) 108 | 109 | 110 | class MABatchNorm2d(nn.BatchNorm2d): 111 | def forward(self, input): 112 | # Code from PyTroch repo: 113 | self._check_input_dim(input) 114 | 115 | exponential_average_factor = 0.0 116 | 117 | if self.training and self.track_running_stats: 118 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 119 | if self.num_batches_tracked is not None: 120 | self.num_batches_tracked += 1 121 | if self.momentum is None: # use cumulative moving average 122 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 123 | else: # use exponential moving average 124 | exponential_average_factor = self.momentum 125 | 126 | # --- My code --- 127 | 128 | if self.training: 129 | mean = torch.mean(input, (0, 2, 3), keepdim=True) 130 | var = torch.mean((input - mean)**2, (0, 2, 3), keepdim=True) 131 | self.running_mean.data = self.running_mean.data * (1 - self.momentum) + self.momentum * mean.data.squeeze() 132 | self.running_var.data = self.running_var.data * (1 - self.momentum) + self.momentum * var.data.squeeze() 133 | mean = mean * self.momentum + (1 - self.momentum) * self.running_mean[None, :, None, None] 134 | var = var * self.momentum + (1 - self.momentum) * self.running_var[None, :, None, None] 135 | else: 136 | mean = self.running_mean[None, :, None, None] 137 | var = self.running_var[None, :, None, None] 138 | 139 | input = (input - mean) / (var + self.eps) 140 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] 141 | 142 | return input 143 | 144 | 145 | class ResNetBlock(nn.Module): 146 | def __init__(self, channels, use_bn=True): 147 | super().__init__() 148 | modules = [] 149 | if use_bn: 150 | modules.append(nn.BatchNorm2d(channels)) 151 | modules += [ 152 | nn.ReLU(), 153 | nn.ReflectionPad2d(1), 154 | WNConv2d(channels, channels, 3)] 155 | if use_bn: 156 | modules.append(nn.BatchNorm2d(channels)) 157 | modules += [ 158 | nn.ReLU(), 159 | nn.ReflectionPad2d(1), 160 | WNConv2d(channels, channels, 3)] 161 | 162 | self.net = nn.Sequential(*modules) 163 | 164 | def forward(self, x): 165 | return self.net(x) + x 166 | 167 | 168 | class SplitAndNorm(nn.Module): 169 | def __init__(self): 170 | super().__init__() 171 | self.scale = nn.Parameter(data=torch.FloatTensor([1.])) 172 | self.scale.reg = True 173 | 174 | def forward(self, x): 175 | k = x.shape[1] // 2 176 | s, t = x[:, :k], x[:, k:] 177 | return torch.tanh(s) * self.scale, t 178 | 179 | 180 | def get_mask(xs, mask_type): 181 | if 'checkerboard' in mask_type: 182 | unit0 = np.array([[0.0, 1.0], [1.0, 0.0]]) 183 | unit1 = -unit0 + 1.0 184 | unit = unit0 if mask_type == 'checkerboard0' else unit1 185 | unit = np.reshape(unit, [1, 2, 2]) 186 | b = np.tile(unit, [xs[0], xs[1]//2, xs[2]//2]) 187 | elif 'channel' in mask_type: 188 | white = np.ones([xs[0]//2, xs[1], xs[2]]) 189 | black = np.zeros([xs[0]//2, xs[1], xs[2]]) 190 | if mask_type == 'channel0': 191 | b = np.concatenate([white, black], 0) 192 | else: 193 | b = np.concatenate([black, white], 0) 194 | 195 | assert list(b.shape) == list(xs) 196 | 197 | return b 198 | 199 | 200 | class CouplingLayer(nn.Module): 201 | def __init__(self, mask_type, shape, net): 202 | super().__init__() 203 | mask = torch.FloatTensor(get_mask(shape, mask_type)) 204 | self.mask = nn.Parameter(mask[None], requires_grad=False) 205 | self.net = net 206 | 207 | def forward(self, x, log_det_jac, z): 208 | return self.f(x, log_det_jac, z) 209 | 210 | def f(self, x, log_det_jac, z): 211 | x1 = self.mask * x 212 | s, t = self.net(x1) 213 | s = (1 - self.mask) * s 214 | t = (1 - self.mask) * t 215 | x = x1 + (1 - self.mask) * (x * torch.exp(s) + t) 216 | log_det_jac += torch.sum(s, dim=(1, 2, 3)) 217 | return x, log_det_jac, z 218 | 219 | def g(self, x, z): 220 | x1 = self.mask * x 221 | s, t = self.net(x1) 222 | x = x1 + (1 - self.mask) * (x - t) * torch.exp(-s) 223 | return x, z 224 | 225 | 226 | class Invertable1x1Conv(nn.Module): 227 | # reference https://github.com/openai/glow/blob/eaff2177693a5d84a1cf8ae19e8e0441715b82f8/model.py#L438 228 | def __init__(self, channels): 229 | super().__init__() 230 | # Sample a random orthogonal matrix 231 | w_init = np.linalg.qr(np.random.randn(channels, channels))[0] 232 | self.weight = nn.Parameter(torch.FloatTensor(w_init)) 233 | 234 | def forward(self, x, log_det_jac, z): 235 | x = F.conv2d(x, self.weight[:, :, None, None]) 236 | log_det_jac += torch.logdet(self.weight) * np.prod(x.shape[2:]) 237 | return x, log_det_jac, z 238 | 239 | def g(self, x, z): 240 | x = F.conv2d(x, torch.inverse(self.weight)[:, :, None, None]) 241 | return x, z 242 | 243 | 244 | class Housholder1x1Conv(nn.Module): 245 | def __init__(self, channels): 246 | super().__init__() 247 | self.v = nn.Parameter(torch.ones((channels,))) 248 | self.id = nn.Parameter(torch.eye(channels), requires_grad=False) 249 | self.channels = channels 250 | 251 | def forward(self, x, log_det_jac, z): 252 | v = self.v 253 | w = self.id - 2 * v[:, None] @ v[None] / (v @ v) 254 | x = F.conv2d(x, w[..., None, None]) 255 | # w is unitary so log_det = 0 256 | return x, log_det_jac, z 257 | 258 | def g(self, x, z): 259 | v = self.v 260 | w = self.id - 2 * v[:, None] @ v[None] / (v @ v) 261 | x = F.conv2d(x, w[..., None, None]) 262 | return x, z 263 | 264 | 265 | class Prior(nn.Module): 266 | def __init__(self, dim): 267 | super().__init__() 268 | self.mean = nn.Parameter(torch.zeros((dim,)), requires_grad=False) 269 | self.cov = nn.Parameter(torch.eye(dim), requires_grad=False) 270 | 271 | def log_prob(self, x): 272 | p = torch.distributions.MultivariateNormal(self.mean, self.cov) 273 | return p.log_prob(x) 274 | 275 | 276 | class RealNVP(nn.Module): 277 | def __init__(self, modules, dim): 278 | super().__init__() 279 | self.modules_ = nn.ModuleList(modules) 280 | self.latent_len = -1 281 | self.x_shape = -1 282 | self.prior = Prior(dim) 283 | self.alpha = 0.05 284 | 285 | def f(self, x): 286 | x = x * (1 - self.alpha) + self.alpha * 0.5 287 | log_det_jac = torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3]) 288 | x = torch.log(x) - torch.log(1-x) 289 | 290 | z = None 291 | for m in self.modules_: 292 | x, log_det_jac, z = m(x, log_det_jac, z) 293 | if z is None: 294 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device) 295 | self.x_shape = list(x.shape)[1:] 296 | self.latent_len = z.shape[1] 297 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1) 298 | return x, log_det_jac, z 299 | 300 | def forward(self, x): 301 | return self.log_prob(x) 302 | 303 | def g(self, z): 304 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape) 305 | z = z[:, :self.latent_len] 306 | for m in reversed(self.modules_): 307 | x, z = m.g(x, z) 308 | x = torch.sigmoid(x) 309 | x = (x - self.alpha * 0.5) / (1. - self.alpha) 310 | return x 311 | 312 | def log_prob(self, x): 313 | x, log_det_jac, z = self.f(x) 314 | logp = self.prior.log_prob(z) + log_det_jac 315 | return logp 316 | 317 | 318 | def get_cifar_realnvp(): 319 | dim = 32**2 * 3 320 | channels = 64 321 | 322 | def get_net(in_channels, channels): 323 | net = nn.Sequential( 324 | nn.ReflectionPad2d(1), 325 | WNConv2d(in_channels, channels, 3), 326 | ResNetBlock(channels), 327 | ResNetBlock(channels), 328 | ResNetBlock(channels), 329 | ResNetBlock(channels), 330 | ResNetBlock(channels), 331 | ResNetBlock(channels), 332 | ResNetBlock(channels), 333 | ResNetBlock(channels), 334 | nn.BatchNorm2d(channels), 335 | nn.ReLU(), 336 | nn.ReflectionPad2d(1), 337 | WNConv2d(channels, in_channels * 2, 3), 338 | SplitAndNorm() 339 | ) 340 | for m in net.modules(): 341 | if isinstance(m, nn.Conv2d): 342 | m.weight.data.normal_(std=1e-6) 343 | m.scale.data.fill_(1e-5) 344 | return net 345 | 346 | model = [ 347 | CouplingLayer('checkerboard0', [3, 32, 32], get_net(3, channels)), 348 | CouplingLayer('checkerboard1', [3, 32, 32], get_net(3, channels)), 349 | CouplingLayer('checkerboard0', [3, 32, 32], get_net(3, channels)), 350 | SpaceToDepth(2), 351 | CouplingLayer('channel0', [12, 16, 16], get_net(12, channels)), 352 | CouplingLayer('channel1', [12, 16, 16], get_net(12, channels)), 353 | CouplingLayer('channel0', [12, 16, 16], get_net(12, channels)), 354 | FactorOut([12, 16, 16]), 355 | CouplingLayer('checkerboard0', [6, 16, 16], get_net(6, channels)), 356 | CouplingLayer('checkerboard1', [6, 16, 16], get_net(6, channels)), 357 | CouplingLayer('checkerboard0', [6, 16, 16], get_net(6, channels)), 358 | CouplingLayer('checkerboard1', [6, 16, 16], get_net(6, channels)), 359 | ] 360 | realnvp = RealNVP(model, dim) 361 | return realnvp 362 | 363 | 364 | def get_mnist_realnvp(): 365 | dim = 28**2 366 | channels = 32 367 | 368 | def get_net(in_channels, channels): 369 | net = nn.Sequential( 370 | nn.ReflectionPad2d(1), 371 | WNConv2d(in_channels, channels, 3), 372 | ResNetBlock(channels), 373 | ResNetBlock(channels), 374 | ResNetBlock(channels), 375 | nn.BatchNorm2d(channels), 376 | nn.ReLU(), 377 | nn.ReflectionPad2d(1), 378 | WNConv2d(channels, in_channels * 2, 3), 379 | SplitAndNorm() 380 | ) 381 | for m in net.modules(): 382 | if isinstance(m, nn.Conv2d): 383 | m.weight.data.normal_(std=1e-6) 384 | m.scale.data.fill_(1e-5) 385 | return net 386 | 387 | model = [ 388 | CouplingLayer('checkerboard0', [1, 28, 28], get_net(1, channels)), 389 | CouplingLayer('checkerboard1', [1, 28, 28], get_net(1, channels)), 390 | CouplingLayer('checkerboard0', [1, 28, 28], get_net(1, channels)), 391 | SpaceToDepth(2), 392 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 393 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 394 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 395 | FactorOut([4, 14, 14]), 396 | CouplingLayer('checkerboard0', [2, 14, 14], get_net(2, channels)), 397 | CouplingLayer('checkerboard1', [2, 14, 14], get_net(2, channels)), 398 | CouplingLayer('checkerboard0', [2, 14, 14], get_net(2, channels)), 399 | CouplingLayer('checkerboard1', [2, 14, 14], get_net(2, channels)), 400 | ] 401 | realnvp = RealNVP(model, dim) 402 | return realnvp 403 | 404 | 405 | class ConcatNet(nn.Module): 406 | def __init__(self, in_channels, channels): 407 | super().__init__() 408 | self.net1 = nn.Sequential( 409 | nn.Conv2d(in_channels, channels, 3, padding=1), 410 | nn.ReLU(True), 411 | nn.Conv2d(channels, channels, 1), 412 | nn.ReLU(True), 413 | nn.Conv2d(channels, in_channels, 3, padding=1), 414 | ) 415 | self.net2 = nn.Sequential( 416 | nn.Conv2d(in_channels, channels, 3, padding=1), 417 | nn.ReLU(True), 418 | nn.Conv2d(channels, channels, 1), 419 | nn.ReLU(True), 420 | nn.Conv2d(channels, in_channels, 3, padding=1), 421 | ) 422 | self.split = SplitAndNorm() 423 | 424 | def forward(self, x): 425 | x = torch.cat([self.net1(x), self.net2(x)], dim=1) 426 | return self.split(x) 427 | 428 | 429 | def get_pie(channels=32): 430 | def get_net(in_channels, channels): 431 | net = ConcatNet(in_channels, channels) 432 | for m in net.modules(): 433 | if isinstance(m, nn.Conv2d): 434 | m.weight.data.normal_(std=1e-6) 435 | m.bias.data.fill_(0.) 436 | return net 437 | 438 | model = [ 439 | SpaceToDepth(2), 440 | Housholder1x1Conv(4), 441 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 442 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 443 | ActNorm(4), 444 | Housholder1x1Conv(4), 445 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 446 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 447 | ActNorm(4), 448 | Housholder1x1Conv(4), 449 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 450 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 451 | ActNorm(4), 452 | Housholder1x1Conv(4), 453 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 454 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 455 | ActNorm(4), 456 | Housholder1x1Conv(4), 457 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 458 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 459 | ActNorm(4), 460 | Housholder1x1Conv(4), 461 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 462 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 463 | ActNorm(4), 464 | Housholder1x1Conv(4), 465 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 466 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 467 | ActNorm(4), 468 | Housholder1x1Conv(4), 469 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)), 470 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)), 471 | ActNorm(4), 472 | Housholder1x1Conv(4), 473 | SpaceToDepth(14), 474 | Housholder1x1Conv(784), 475 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 476 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 477 | ActNorm(784), 478 | Housholder1x1Conv(784), 479 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 480 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 481 | ActNorm(784), 482 | Housholder1x1Conv(784), 483 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 484 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 485 | ActNorm(784), 486 | Housholder1x1Conv(784), 487 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 488 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 489 | ActNorm(784), 490 | Housholder1x1Conv(784), 491 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 492 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 493 | ActNorm(784), 494 | Housholder1x1Conv(784), 495 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 496 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 497 | ActNorm(784), 498 | Housholder1x1Conv(784), 499 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 500 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 501 | ActNorm(784), 502 | Housholder1x1Conv(784), 503 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)), 504 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)), 505 | ActNorm(784), 506 | Housholder1x1Conv(784), 507 | ] 508 | 509 | dim = 784 510 | realnvp = RealNVP(model, dim) 511 | return realnvp 512 | 513 | 514 | def get_realnvp(k, l, in_shape, channels, use_bn=False): 515 | dim = int(np.prod(in_shape)) 516 | 517 | def get_net(in_channels, channels): 518 | net = nn.Sequential( 519 | nn.ReflectionPad2d(1), 520 | WNConv2d(in_channels, channels, 3), 521 | ResNetBlock(channels, use_bn), 522 | ResNetBlock(channels, use_bn), 523 | ResNetBlock(channels, use_bn), 524 | nn.BatchNorm2d(channels), 525 | nn.ReLU(), 526 | nn.ReflectionPad2d(1), 527 | WNConv2d(channels, in_channels * 2, 3), 528 | SplitAndNorm() 529 | ) 530 | for m in net.modules(): 531 | if isinstance(m, nn.Conv2d): 532 | m.weight.data.normal_(std=1e-6) 533 | m.scale.data.fill_(1e-5) 534 | return net 535 | 536 | shape = tuple(in_shape) 537 | model = [] 538 | for _ in range(l): 539 | for i in range(k): 540 | model.append(Housholder1x1Conv(shape[0])) 541 | model.append(CouplingLayer('checkerboard{}'.format(i % 2), shape, get_net(shape[0], channels))) 542 | model += [SpaceToDepth(2)] 543 | shape = (shape[0] * 4, shape[1] // 2, shape[2] // 2) 544 | for i in range(k): 545 | model.append(Housholder1x1Conv(shape[0])) 546 | model.append(CouplingLayer('channel{}'.format(i % 2), shape, get_net(shape[0], channels))) 547 | model += [FactorOut(list(shape))] 548 | shape = (shape[0] // 2, shape[1], shape[2]) 549 | 550 | model += [ 551 | CouplingLayer('checkerboard0', shape, get_net(shape[0], channels)), 552 | CouplingLayer('checkerboard1', shape, get_net(shape[0], channels)), 553 | CouplingLayer('checkerboard0', shape, get_net(shape[0], channels)), 554 | CouplingLayer('checkerboard1', shape, get_net(shape[0], channels)), 555 | ] 556 | realnvp = RealNVP(model, dim) 557 | return realnvp 558 | 559 | 560 | class MyModel(nn.Module): 561 | def __init__(self, flow, prior): 562 | super().__init__() 563 | self.flow = flow 564 | self.prior = prior 565 | 566 | def _flow_term(self, x): 567 | _, log_det, z = self.flow([x, None, None]) 568 | 569 | # TODO: get rid of this 570 | if z.numel() != x.numel(): 571 | log_det += self.flow.pie.residual() 572 | 573 | logp = log_det + self.prior.log_prob(z) 574 | return logp, z 575 | 576 | def log_prob(self, x): 577 | logp, z = self._flow_term(x) 578 | return logp + self.prior.log_prob(z) 579 | 580 | def log_prob_full(self, x): 581 | logp, z = self._flow_term(x) 582 | log_prior = torch.stack([self.prior.log_prob(z, k=k) for k in range(self.prior.k)]) 583 | return logp[:, None] + log_prior.transpose(0, 1) 584 | 585 | 586 | class MyPie(nn.Module): 587 | def __init__(self, pie): 588 | super().__init__() 589 | self.pie = pie 590 | 591 | def forward(self, x): 592 | x, _, _ = x 593 | z = self.pie(x) 594 | return None, self.pie.log_det(), z 595 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.distributions as dist 5 | import torch.nn.functional as F 6 | 7 | 8 | class Conv2dZeros(nn.Conv2d): 9 | def __init__(self, in_channels, out_channels, kernel_size, padding): 10 | super().__init__(in_channels, out_channels, kernel_size, padding=padding) 11 | self.weight.data.zero_() 12 | self.bias.data.zero_() 13 | 14 | 15 | class SpaceToDepth(nn.Module): 16 | def __init__(self, block_size): 17 | super().__init__() 18 | self.block_size = block_size 19 | self.block_size_sq = block_size*block_size 20 | 21 | def forward(self, input, *inputs): 22 | output = input.permute(0, 2, 3, 1) 23 | (batch_size, s_height, s_width, s_depth) = output.size() 24 | d_depth = s_depth * self.block_size_sq 25 | d_width = int(s_width / self.block_size) 26 | d_height = int(s_height / self.block_size) 27 | t_1 = output.split(self.block_size, 2) 28 | stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] 29 | output = torch.stack(stack, 1) 30 | output = output.permute(0, 2, 1, 3) 31 | output = output.permute(0, 3, 1, 2) 32 | return [output] + list(inputs) 33 | 34 | def g(self, input, *inputs): 35 | output = input.permute(0, 2, 3, 1) 36 | (batch_size, input_height, input_width, input_depth) = output.size() 37 | output_depth = int(input_depth / self.block_size_sq) 38 | output_width = int(input_width * self.block_size) 39 | output_height = int(input_height * self.block_size) 40 | t_1 = output.reshape(batch_size, input_height, input_width, self.block_size_sq, output_depth) 41 | spl = t_1.split(self.block_size, 3) 42 | stacks = [t_t.reshape(batch_size, input_height, output_width, output_depth) for t_t in spl] 43 | output = torch.stack(stacks, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).reshape(batch_size, 44 | output_height, 45 | output_width, 46 | output_depth) 47 | output = output.permute(0, 3, 1, 2) 48 | return [output] + list(inputs) 49 | 50 | def extra_repr(self): 51 | return 'SpaceToDepth({0:d}x{0:d})'.format(self.block_size) 52 | 53 | 54 | class CondSpaceToDepth(SpaceToDepth): 55 | def forward(self, x, y, log_det, z): 56 | return super().forward(x, log_det, z) 57 | 58 | def g(self, x, y, z): 59 | return super().g(x, z) 60 | 61 | 62 | class FactorOut(nn.Module): 63 | def __init__(self): 64 | super().__init__() 65 | self.out_shape = None 66 | 67 | def forward(self, x, log_det_jac, z): 68 | self.out_shape = list(x.shape)[1:] 69 | self.inp_shape = list(x.shape)[1:] 70 | self.out_shape[0] = self.out_shape[0] // 2 71 | self.out_shape[0] += self.out_shape[0] % 2 72 | 73 | k = self.out_shape[0] 74 | if z is None: 75 | return x[:, k:], log_det_jac, x[:, :k].reshape((x.shape[0], -1)) 76 | z = torch.cat([z, x[:, :k].view((x.shape[0], -1))], dim=1) 77 | return x[:, k:], log_det_jac, z 78 | 79 | def g(self, x, z): 80 | k = np.prod(self.out_shape) 81 | x = torch.cat([z[:, -k:].view([x.shape[0]] + self.out_shape), x], dim=1) 82 | z = z[:, :-k] 83 | return x, z 84 | 85 | def extra_repr(self): 86 | return 'FactorOut({:s} -> {:s})'.format(str(self.inp_shape), str(self.out_shape)) 87 | 88 | 89 | class CondFactorOut(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | self.out_shape = None 93 | 94 | def extra_repr(self): 95 | return 'FactorOut({:s} -> {:s})'.format(str(self.inp_shape), str(self.out_shape)) 96 | 97 | def forward(self, x, y, log_det_jac, z): 98 | self.out_shape = list(x.shape)[1:] 99 | self.inp_shape = list(x.shape)[1:] 100 | self.out_shape[0] = self.out_shape[0] // 2 101 | self.out_shape[0] += self.out_shape[0] % 2 102 | 103 | k = self.out_shape[0] 104 | if z is None: 105 | return x[:, k:], log_det_jac, x[:, :k].reshape((x.shape[0], -1)) 106 | z = torch.cat([z, x[:, :k].view((x.shape[0], -1))], dim=1) 107 | return x[:, k:], log_det_jac, z 108 | 109 | def g(self, x, y, z): 110 | k = np.prod(self.out_shape) 111 | x = torch.cat([z[:, -k:].view([x.shape[0]] + self.out_shape), x], dim=1) 112 | z = z[:, :-k] 113 | return x, z 114 | 115 | 116 | class ToLogits(nn.Module): 117 | ''' 118 | Maps interval [0, 1] to (-inf, +inf) via inversion of sigmoid 119 | ''' 120 | alpha = 0.05 121 | 122 | def forward(self, x, log_det_jac, z): 123 | # [0, 1] -> [alpha/2, 1 - alpha/2] 124 | x = x * (1 - self.alpha) + self.alpha * 0.5 125 | log_det_jac += torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3]) 126 | x = torch.log(x) - torch.log(1-x) 127 | return x, log_det_jac, z 128 | 129 | def g(self, x, z): 130 | x = torch.sigmoid(x) 131 | x = (x - self.alpha * 0.5) / (1. - self.alpha) 132 | return x, z 133 | 134 | def extra_repr(self): 135 | return 'ToLogits()' 136 | 137 | 138 | class InverseLogits(nn.Module): 139 | def forward(self, x, log_det_jac, z): 140 | log_det_jac += torch.sum(-F.softplus(-x) - F.softplus(x), dim=[1, 2, 3]) 141 | x = torch.sigmoid(x) 142 | return x, log_det_jac, z 143 | 144 | def g(self, x, z): 145 | x = torch.log(x) - torch.log(1 - x) 146 | return x, z 147 | 148 | def extra_repr(self): 149 | return 'InverseLogits()' 150 | 151 | 152 | class CondToLogits(nn.Module): 153 | ''' 154 | Maps interval [0, 1] to (-inf, +inf) via inversion of sigmoid 155 | ''' 156 | alpha = 0.05 157 | 158 | def forward(self, x, y, log_det_jac, z): 159 | # [0, 1] -> [alpha/2, 1 - alpha/2] 160 | x = x * (1 - self.alpha) + self.alpha * 0.5 161 | log_det_jac += torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3]) 162 | x = torch.log(x) - torch.log(1-x) 163 | return x, log_det_jac, z 164 | 165 | def g(self, x, y, z): 166 | x = torch.sigmoid(x) 167 | x = (x - self.alpha * 0.5) / (1. - self.alpha) 168 | return x, z 169 | 170 | 171 | class DummyCond(nn.Module): 172 | def __init__(self, module): 173 | super().__init__() 174 | self.module = module 175 | 176 | def forward(self, x, y, log_det_jac, z): 177 | return self.module.forward(x, log_det_jac, z) 178 | 179 | def g(self, x, y, z): 180 | return self.module.g(x, z) 181 | 182 | 183 | class IdFunction(nn.Module): 184 | def forward(self, *inputs): 185 | return inputs 186 | 187 | def g(self, *inputs): 188 | return inputs 189 | 190 | 191 | class UniformWithLogits(dist.Distribution): 192 | def __init__(self, dim): 193 | super().__init__() 194 | self.dim = dim 195 | 196 | def log_prob(self, x): 197 | return torch.sum(-F.softplus(-x) - F.softplus(x), dim=1) 198 | 199 | def sample(self, shape): 200 | x = torch.rand(list(shape) + [self.dim]) 201 | return torch.log(x) - torch.log(1 - x) 202 | -------------------------------------------------------------------------------- /myexman/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import ( 2 | ExParser, 3 | simpleroot 4 | ) 5 | from .index import ( 6 | Index 7 | ) 8 | from . import index 9 | from . import parser 10 | __version__ = '0.0.2' 11 | -------------------------------------------------------------------------------- /myexman/index.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import pandas as pd 3 | import pathlib 4 | import strconv 5 | import json 6 | import functools 7 | import datetime 8 | from . import parser 9 | import yaml 10 | from argparse import Namespace 11 | __all__ = [ 12 | 'Index' 13 | ] 14 | 15 | 16 | def only_value_error(conv): 17 | @functools.wraps(conv) 18 | def new_conv(value): 19 | try: 20 | return conv(value) 21 | except Exception as e: 22 | raise ValueError from e 23 | return new_conv 24 | 25 | 26 | def none2none(none): 27 | if none is None: 28 | return None 29 | else: 30 | raise ValueError 31 | 32 | 33 | converter = strconv.Strconv(converters=[ 34 | ('int', strconv.convert_int), 35 | ('float', strconv.convert_float), 36 | ('bool', only_value_error(parser.str2bool)), 37 | ('time', strconv.convert_time), 38 | ('datetime', strconv.convert_datetime), 39 | ('datetime1', lambda time: datetime.datetime.strptime(time, parser.TIME_FORMAT)), 40 | ('date', strconv.convert_date), 41 | ('json', only_value_error(json.loads)), 42 | ]) 43 | 44 | 45 | def get_args(path): 46 | with open(path, 'rb') as f: 47 | return Namespace(**yaml.load(f)) 48 | 49 | 50 | class Index(object): 51 | def __init__(self, root): 52 | self.root = pathlib.Path(root) 53 | 54 | @property 55 | def index(self): 56 | return self.root / 'index' 57 | 58 | @property 59 | def marked(self): 60 | return self.root / 'marked' 61 | 62 | def info(self, source=None): 63 | if source is None: 64 | source = self.index 65 | files = source.iterdir() 66 | else: 67 | source = self.marked / source 68 | files = source.glob('**/*/'+parser.PARAMS_FILE) 69 | 70 | def get_dict(cfg): 71 | return configargparse.YAMLConfigFileParser().parse(cfg.open('r')) 72 | 73 | def convert_column(col): 74 | if any(isinstance(v, str) for v in converter.convert_series(col)): 75 | return col 76 | else: 77 | return pd.Series(converter.convert_series(col), name=col.name, index=col.index) 78 | try: 79 | df = (pd.DataFrame 80 | .from_records((get_dict(c) for c in files)) 81 | .apply(lambda s: convert_column(s)) 82 | .sort_values('id') 83 | .assign(root=lambda _: _.root.apply(self.root.__truediv__)) 84 | .reset_index(drop=True)) 85 | cols = df.columns.tolist() 86 | cols.insert(0, cols.pop(cols.index('id'))) 87 | return df.reindex(columns=cols) 88 | except FileNotFoundError as e: 89 | raise KeyError(source.name) from e 90 | -------------------------------------------------------------------------------- /myexman/parser.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import argparse 3 | import pathlib 4 | import datetime 5 | import yaml 6 | import yaml.representer 7 | import os 8 | import functools 9 | import itertools 10 | from filelock import FileLock 11 | __all__ = [ 12 | 'ExParser', 13 | 'simpleroot', 14 | ] 15 | 16 | 17 | TIME_FORMAT_DIR = '%Y-%m-%d-%H-%M-%S' 18 | TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' 19 | DIR_FORMAT = '{num}-{time}' 20 | EXT = 'yaml' 21 | PARAMS_FILE = 'params.'+EXT 22 | FOLDER_DEFAULT = 'exman' 23 | RESERVED_DIRECTORIES = { 24 | 'runs', 'index', 25 | 'tmp', 'marked' 26 | } 27 | 28 | 29 | def yaml_file(name): 30 | return name + '.' + EXT 31 | 32 | 33 | def simpleroot(__file__): 34 | return pathlib.Path(os.path.dirname(os.path.abspath(__file__)))/FOLDER_DEFAULT 35 | 36 | 37 | def represent_as_str(self, data, tostr=str): 38 | return yaml.representer.Representer.represent_str(self, tostr(data)) 39 | 40 | 41 | def register_str_converter(*types, tostr=str): 42 | for T in types: 43 | yaml.add_representer(T, functools.partial(represent_as_str, tostr=tostr)) 44 | 45 | 46 | register_str_converter(pathlib.PosixPath, pathlib.WindowsPath) 47 | 48 | 49 | def str2bool(s): 50 | true = ('true', 't', 'yes', 'y', 'on', '1') 51 | false = ('false', 'f', 'no', 'n', 'off', '0') 52 | 53 | if s.lower() in true: 54 | return True 55 | elif s.lower() in false: 56 | return False 57 | else: 58 | raise argparse.ArgumentTypeError(s, 'bool argument should be one of {}'.format(str(true + false))) 59 | 60 | 61 | class ParserWithRoot(configargparse.ArgumentParser): 62 | def __init__(self, *args, root=None, zfill=6, 63 | **kwargs): 64 | super().__init__(*args, **kwargs) 65 | if root is None: 66 | raise ValueError('Root directory is not specified') 67 | root = pathlib.Path(root) 68 | if not root.is_absolute(): 69 | raise ValueError(root, 'Root directory is not absolute path') 70 | if not root.exists(): 71 | raise ValueError(root, 'Root directory does not exist') 72 | self.root = pathlib.Path(root) 73 | self.zfill = zfill 74 | self.register('type', bool, str2bool) 75 | for directory in RESERVED_DIRECTORIES: 76 | getattr(self, directory).mkdir(exist_ok=True) 77 | self.lock = FileLock(str(self.root/'lock')) 78 | 79 | @property 80 | def runs(self): 81 | return self.root / 'runs' 82 | 83 | @property 84 | def marked(self): 85 | return self.root / 'marked' 86 | 87 | @property 88 | def index(self): 89 | return self.root / 'index' 90 | 91 | @property 92 | def tmp(self): 93 | return self.root / 'tmp' 94 | 95 | def max_ex(self): 96 | max_num = 0 97 | for directory in itertools.chain(self.runs.iterdir(), self.tmp.iterdir()): 98 | num = int(directory.name.split('-', 1)[0]) 99 | if num > max_num: 100 | max_num = num 101 | return max_num 102 | 103 | def num_ex(self): 104 | return len(list(self.runs.iterdir())) 105 | 106 | def next_ex(self): 107 | return self.max_ex() + 1 108 | 109 | def next_ex_str(self): 110 | return str(self.next_ex()).zfill(self.zfill) 111 | 112 | 113 | class ExParser(ParserWithRoot): 114 | """ 115 | Parser responsible for creating the following structure of experiments 116 | ``` 117 | root 118 | |-- runs 119 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 120 | | |-- params.yaml 121 | | `-- ... 122 | |-- index 123 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS.yaml (symlink) 124 | |-- marked 125 | | `-- 126 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS (symlink) 127 | | |-- params.yaml 128 | | `-- ... 129 | `-- tmp 130 | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 131 | |-- params.yaml 132 | `-- ... 133 | ``` 134 | """ 135 | def __init__(self, *args, zfill=6, file=None, 136 | args_for_setting_config_path=('--config', ), 137 | automark=(), 138 | **kwargs): 139 | root = os.path.join(os.getcwd(), 'logs', ('exman-' + str(file))) 140 | if not os.path.exists(root): 141 | os.makedirs(root) 142 | super().__init__(*args, root=root, zfill=zfill, 143 | args_for_setting_config_path=args_for_setting_config_path, 144 | config_file_parser_class=configargparse.YAMLConfigFileParser, 145 | ignore_unknown_config_file_keys=True, 146 | **kwargs) 147 | self.automark = automark 148 | self.add_argument('--tmp', action='store_true') 149 | 150 | def _initialize_dir(self, tmp): 151 | try: 152 | # with self.lock: # different processes can make it same time, this is needed to avoid collision 153 | time = datetime.datetime.now() 154 | num = self.next_ex_str() 155 | name = DIR_FORMAT.format(num=num, time=time.strftime(TIME_FORMAT_DIR)) 156 | if tmp: 157 | absroot = self.tmp / name 158 | relroot = pathlib.Path('tmp') / name 159 | else: 160 | absroot = self.runs / name 161 | relroot = pathlib.Path('runs') / name 162 | # this process now safely owns root directory 163 | # raises FileExistsError on fail 164 | absroot.mkdir() 165 | except FileExistsError: # shit still happens 166 | return self._initialize_dir(tmp) 167 | return absroot, relroot, name, time, num 168 | 169 | def parse_known_args(self, *args, **kwargs): 170 | args, argv = super().parse_known_args(*args, **kwargs) 171 | absroot, relroot, name, time, num = self._initialize_dir(args.tmp) 172 | args.root = absroot 173 | self.yaml_params_path = args.root / PARAMS_FILE 174 | rel_yaml_params_path = pathlib.Path('..', 'runs', name, PARAMS_FILE) 175 | with self.yaml_params_path.open('a') as f: 176 | self.dumpd = args.__dict__.copy() 177 | # dumpd['root'] = relroot 178 | yaml.dump(self.dumpd, f, default_flow_style=False) 179 | print("\ntime: '{}'".format(time.strftime(TIME_FORMAT)), file=f) 180 | print("id:", int(num), file=f) 181 | print(self.yaml_params_path.read_text()) 182 | symlink = self.index / yaml_file(name) 183 | if not args.tmp: 184 | symlink.symlink_to(rel_yaml_params_path) 185 | print('Created symlink from', symlink, '->', rel_yaml_params_path) 186 | if self.automark and not args.tmp: 187 | automark_path_part = pathlib.Path(*itertools.chain.from_iterable( 188 | (mark, str(getattr(args, mark, ''))) 189 | for mark in self.automark)) 190 | markpath = pathlib.Path(self.marked, automark_path_part) 191 | markpath.mkdir(exist_ok=True, parents=True) 192 | relpathmark = pathlib.Path('..', *(['..']*len(automark_path_part.parts))) / 'runs' / name 193 | (markpath / name).symlink_to(relpathmark, target_is_directory=True) 194 | print('Created symlink from', markpath / name, '->', relpathmark) 195 | return args, argv 196 | -------------------------------------------------------------------------------- /pretrained/model.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewAtanov/semi-supervised-flow-pytorch/d1748decaccbc59e6bce014e1cb84527173c6b54/pretrained/model.torch -------------------------------------------------------------------------------- /train-discriminator.py: -------------------------------------------------------------------------------- 1 | import myexman 2 | import torch 3 | from logger import Logger 4 | import torchvision 5 | import os 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import utils 9 | import warnings 10 | import numpy as np 11 | 12 | 13 | parser = myexman.ExParser(file=os.path.basename(__file__)) 14 | parser.add_argument('--name', default='') 15 | # Data 16 | parser.add_argument('--data', default='') 17 | parser.add_argument('--emb') 18 | parser.add_argument('--dim', default=196, type=int) 19 | # Optimization 20 | parser.add_argument('--epochs', default=100, type=int) 21 | parser.add_argument('--lr', default=1e-3, type=float) 22 | args = parser.parse_args() 23 | 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | logger = Logger('logs', base=args.root) 27 | 28 | if args.emb == "f": 29 | emb_train = np.load(os.path.join(args.data, 'zf_train.npy'))[:, -args.dim:] 30 | emb_test = np.load(os.path.join(args.data, 'zf_test.npy'))[:, -args.dim:] 31 | elif args.emb == 'h': 32 | emb_train = np.load(os.path.join(args.data, 'zh_train.npy')) 33 | emb_test = np.load(os.path.join(args.data, 'zh_test.npy')) 34 | else: 35 | raise NotImplementedError 36 | 37 | y_train = np.load(os.path.join(args.data, 'y_train.npy')) 38 | y_test = np.load(os.path.join(args.data, 'y_test.npy')) 39 | 40 | 41 | trainset = torch.utils.data.TensorDataset(torch.FloatTensor(emb_train - emb_train.mean(0)[None]), torch.LongTensor(y_train)) 42 | testset = torch.utils.data.TensorDataset(torch.FloatTensor(emb_test - emb_test.mean(0)[None]), torch.LongTensor(y_test)) 43 | 44 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256) 45 | testloader = torch.utils.data.DataLoader(testset, batch_size=256) 46 | 47 | 48 | net = nn.Sequential( 49 | nn.Linear(args.dim, 256), 50 | nn.LeakyReLU(), 51 | 52 | nn.Linear(256, 256), 53 | nn.Dropout(0.3), 54 | nn.LeakyReLU(), 55 | 56 | nn.Linear(256, 256), 57 | nn.LeakyReLU(), 58 | 59 | nn.Linear(256, 256), 60 | nn.LeakyReLU(), 61 | 62 | nn.Linear(256, len(np.unique(y_train))), 63 | ).to(device) 64 | 65 | opt = torch.optim.Adam(net.parameters(), lr=args.lr) 66 | lr_schedule = utils.LinearLR(opt, args.epochs) 67 | 68 | for e in range(1, 1 + args.epochs): 69 | net.train() 70 | train_loss = 0. 71 | train_acc = 0. 72 | for x, y in trainloader: 73 | x, y = x.to(device), y.to(device) 74 | p = net(x) 75 | loss = F.cross_entropy(p, y) 76 | opt.zero_grad() 77 | loss.backward() 78 | opt.step() 79 | train_loss += loss.item() * x.size(0) 80 | train_acc += (p.argmax(1) == y).sum().item() 81 | 82 | train_loss /= len(trainloader.dataset) 83 | train_acc /= len(trainloader.dataset) 84 | 85 | net.eval() 86 | x, y = map(lambda a: a.to(device), next(iter(testloader))) 87 | p = net(x) 88 | test_acc = (p.argmax(1) == y).float().mean().item() 89 | 90 | logger.add_scalar(e, 'train.loss', train_loss) 91 | logger.add_scalar(e, 'train.acc', train_acc) 92 | logger.add_scalar(e, 'test.acc', test_acc) 93 | logger.iter_info() 94 | logger.save() 95 | -------------------------------------------------------------------------------- /train-flow-ssl.py: -------------------------------------------------------------------------------- 1 | import myexman 2 | import torch 3 | import utils 4 | import datautils 5 | import os 6 | from logger import Logger 7 | import time 8 | import numpy as np 9 | from models import flows, distributions 10 | import warnings 11 | import torch.nn.functional as F 12 | import argparse 13 | 14 | 15 | def get_metrics(model, loader): 16 | logp, acc = [], [] 17 | for x, y in loader: 18 | x = x.to(device) 19 | log_det, z = model.flow(x) 20 | log_prior_full = model.prior.log_prob_full(z) 21 | pred = torch.softmax(log_prior_full, dim=1).argmax(1) 22 | logp.append(utils.tonp(log_det + model.prior.log_prob(z))) 23 | acc.append(utils.tonp(pred) == utils.tonp(y)) 24 | return np.mean(np.concatenate(logp)), np.mean(np.concatenate(acc)) 25 | 26 | 27 | parser = myexman.ExParser(file=os.path.basename(__file__)) 28 | parser.add_argument('--name', default='') 29 | parser.add_argument('--seed', default=0, type=int) 30 | # Data 31 | parser.add_argument('--data', default='mnist') 32 | parser.add_argument('--num_examples', default=-1, type=int) 33 | parser.add_argument('--data_seed', default=0, type=int) 34 | parser.add_argument('--sup_sample_weight', default=-1, type=float) 35 | # Optimization 36 | parser.add_argument('--lr', default=1e-3, type=float) 37 | parser.add_argument('--epochs', default=500, type=int) 38 | parser.add_argument('--train_bs', default=256, type=int) 39 | parser.add_argument('--test_bs', default=512, type=int) 40 | parser.add_argument('--lr_schedule', default='hat') 41 | parser.add_argument('--lr_warmup', default=10, type=int) 42 | parser.add_argument('--log_each', default=1, type=int) 43 | parser.add_argument('--pretrained', default='') 44 | parser.add_argument('--weight_decay', default=0., type=float) 45 | # Model 46 | parser.add_argument('--model', default='mnist-masked') 47 | parser.add_argument('--conv', default='full') 48 | parser.add_argument('--hh_factors', default=2, type=int) 49 | parser.add_argument('--k', default=4, type=int) 50 | parser.add_argument('--l', default=2, type=int) 51 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[]) 52 | # Prior 53 | parser.add_argument('--ssl_model', default='cond-flow') 54 | parser.add_argument('--ssl_dim', default=-1, type=int) 55 | parser.add_argument('--ssl_l', default=2, type=int) 56 | parser.add_argument('--ssl_k', default=3, type=int) 57 | parser.add_argument('--ssl_hd', default=256, type=int) 58 | parser.add_argument('--ssl_conv', default='full') 59 | parser.add_argument('--ssl_hh', default=2, type=int) 60 | parser.add_argument('--ssl_nclasses', default=10, type=int) 61 | # SSL 62 | parser.add_argument('--supervised', default=0, type=int) 63 | parser.add_argument('--sup_weight', default=1., type=float) 64 | parser.add_argument('--cl_weight', default=0, type=float) 65 | args = parser.parse_args() 66 | 67 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 68 | 69 | # TODO: make it changable 70 | torch.set_num_threads(1) 71 | 72 | fmt = { 73 | 'time': '.3f', 74 | } 75 | logger = Logger('logs', base=args.root, fmt=fmt) 76 | 77 | # Load data 78 | np.random.seed(args.data_seed) 79 | torch.manual_seed(args.data_seed) 80 | torch.cuda.manual_seed_all(args.data_seed) 81 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs, 82 | seed=args.data_seed, num_examples=args.num_examples, 83 | supervised=args.supervised, logs_root=args.root, 84 | sup_sample_weight=args.sup_sample_weight) 85 | # Seed for training process 86 | np.random.seed(args.seed) 87 | torch.manual_seed(args.seed) 88 | torch.cuda.manual_seed_all(args.seed) 89 | 90 | # Create model 91 | dim = int(np.prod(data_shape)) 92 | if args.ssl_dim == -1: 93 | args.ssl_dim = dim 94 | deep_prior = distributions.GaussianDiag(args.ssl_dim) 95 | shallow_prior = distributions.GaussianDiag(dim - args.ssl_dim) 96 | 97 | _, c = np.unique(trainloader.dataset.targets[trainloader.dataset.targets != -1], return_counts=True) 98 | yprior = torch.distributions.Categorical(probs=torch.FloatTensor(c/c.sum()).to(device)) 99 | ssl_flow = utils.create_cond_flow(args) 100 | # ssl_flow = torch.nn.DataParallel(ssl_flow.to(device)) 101 | ssl_flow.to(device) 102 | prior = flows.DiscreteConditionalFlowPDF(ssl_flow, deep_prior, yprior, deep_dim=args.ssl_dim, 103 | shallow_prior=shallow_prior) 104 | 105 | flow = utils.create_flow(args, data_shape) 106 | flow.to(device) 107 | flow = torch.nn.DataParallel(flow.to(device)) 108 | 109 | model = flows.FlowPDF(flow, prior).to(device) 110 | 111 | torch.save(model.state_dict(), os.path.join(args.root, 'model_init.torch')) 112 | 113 | parameters = [ 114 | {'params': [p for p in model.parameters() if p.requires_grad], 'weight_decay': args.weight_decay}, 115 | ] 116 | optimizer = torch.optim.Adamax(parameters, lr=args.lr) 117 | if args.lr_schedule == 'no': 118 | lr_scheduler = utils.BaseLR(optimizer) 119 | elif args.lr_schedule == 'linear': 120 | lr_scheduler = utils.LinearLR(optimizer, args.epochs) 121 | elif args.lr_schedule == 'hat': 122 | lr_scheduler = utils.HatLR(optimizer, args.lr_warmup, args.epochs) 123 | else: 124 | raise NotImplementedError 125 | 126 | if args.pretrained != '': 127 | model.load_state_dict(torch.load(args.pretrained)) 128 | # model.load_state_dict(torch.load(os.path.join(args.pretrained, 'model.torch'))) 129 | # optimizer.load_state_dict(torch.load(os.path.join(args.pretrained, 'optimizer.torch'))) 130 | 131 | t0 = time.time() 132 | for epoch in range(1, args.epochs + 1): 133 | train_loss = 0. 134 | train_acc = utils.MovingMetric() 135 | train_elbo = utils.MovingMetric() 136 | train_cl = utils.MovingMetric() 137 | 138 | for x, y in trainloader: 139 | x = x.to(device) 140 | n_sup = (y != -1).sum().item() 141 | 142 | log_det, z = model.flow(x) 143 | 144 | log_prior = torch.ones((x.size(0),)).to(x.device) 145 | if n_sup != z.shape[0]: 146 | log_prior[y == -1] = model.prior.log_prob(z[y == -1]) 147 | if n_sup != 0: 148 | log_prior[y != -1] = model.prior.log_prob(z[y != -1], y=y[y != -1].to(x.device)) 149 | elbo = log_det + log_prior 150 | 151 | weights = torch.ones((elbo.size(0),)).to(elbo) 152 | weights[y != -1] = args.sup_weight 153 | weights /= weights.sum() 154 | 155 | gen_loss = -(elbo * weights.detach()).sum() 156 | 157 | cl_loss = 0 158 | if n_sup != 0: 159 | logp_full = model.prior.log_prob_full(z[y != -1]) 160 | prediction = logp_full 161 | train_acc.add(utils.tonp(prediction.argmax(1).to(y) == y[y != -1])) 162 | if args.cl_weight != 0: 163 | cl_loss = F.cross_entropy(prediction, y[y != -1].to(prediction.device), reduction='none') 164 | train_cl.add(utils.tonp(cl_loss)) 165 | cl_loss = cl_loss.mean() 166 | 167 | loss = gen_loss + args.cl_weight * cl_loss 168 | 169 | optimizer.zero_grad() 170 | loss.backward() 171 | optimizer.step() 172 | 173 | train_elbo.add(utils.tonp(elbo)) 174 | train_loss += loss.item() * x.size(0) 175 | 176 | train_loss /= len(trainloader.dataset) 177 | lr_scheduler.step() 178 | 179 | if epoch % args.log_each == 0 or epoch == 1: 180 | with torch.no_grad(): 181 | test_logp, test_acc = get_metrics(model, testloader) 182 | logger.add_scalar(epoch, 'train.loss', train_loss) 183 | logger.add_scalar(epoch, 'train.elbo', train_elbo.avg()) 184 | logger.add_scalar(epoch, 'train.cl', train_cl.avg()) 185 | logger.add_scalar(epoch, 'train.acc', train_acc.avg()) 186 | logger.add_scalar(epoch, 'test.logp', test_logp) 187 | logger.add_scalar(epoch, 'test.acc', test_acc) 188 | logger.add_scalar(epoch, 'test.bits/dim', utils.bits_dim(test_logp, dim, bits)) 189 | logger.add_scalar(epoch, 'time', time.time() - t0) 190 | t0 = time.time() 191 | logger.iter_info() 192 | logger.save() 193 | 194 | torch.save(model.state_dict(), os.path.join(args.root, 'model.torch')) 195 | torch.save(optimizer.state_dict(), os.path.join(args.root, 'optimizer.torch')) 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import datasets 4 | import os 5 | import torchvision 6 | from torchvision import transforms 7 | from sklearn.cluster import MiniBatchKMeans 8 | import matplotlib.pyplot as plt 9 | import warnings 10 | from models import flows, coupling 11 | 12 | 13 | def viz_array_grid(array, rows, cols, padding=0, channels_last=False, normalize=False, **kwargs): 14 | # normalization 15 | ''' 16 | Args: 17 | array: (N_images, N_channels, H, W) or (N_images, H, W, N_channels) 18 | rows, cols: rows and columns of the plot. rows * cols == array.shape[0] 19 | padding: padding between cells of plot 20 | channels_last: for Tensorflow = True, for PyTorch = False 21 | normalize: `False`, `mean_std`, or `min_max` 22 | Kwargs: 23 | if normalize == 'mean_std': 24 | mean: mean of the distribution. Default 0.5 25 | std: std of the distribution. Default 0.5 26 | if normalize == 'min_max': 27 | min: min of the distribution. Default array.min() 28 | max: max if the distribution. Default array.max() 29 | ''' 30 | if not channels_last: 31 | array = np.transpose(array, (0, 2, 3, 1)) 32 | 33 | array = array.astype('float32') 34 | 35 | if normalize: 36 | if normalize == 'mean_std': 37 | mean = kwargs.get('mean', 0.5) 38 | mean = np.array(mean).reshape((1, 1, 1, -1)) 39 | std = kwargs.get('std', 0.5) 40 | std = np.array(std).reshape((1, 1, 1, -1)) 41 | array = array * std + mean 42 | elif normalize == 'min_max': 43 | min_ = kwargs.get('min', array.min()) 44 | min_ = np.array(min_).reshape((1, 1, 1, -1)) 45 | max_ = kwargs.get('max', array.max()) 46 | max_ = np.array(max_).reshape((1, 1, 1, -1)) 47 | array -= min_ 48 | array /= max_ + 1e-9 49 | 50 | batch_size, H, W, channels = array.shape 51 | assert rows * cols == batch_size 52 | 53 | if channels == 1: 54 | canvas = np.ones((H * rows + padding * (rows - 1), 55 | W * cols + padding * (cols - 1))) 56 | array = array[:, :, :, 0] 57 | elif channels == 3: 58 | canvas = np.ones((H * rows + padding * (rows - 1), 59 | W * cols + padding * (cols - 1), 60 | 3)) 61 | else: 62 | raise TypeError('number of channels is either 1 of 3') 63 | 64 | for i in range(rows): 65 | for j in range(cols): 66 | img = array[i * cols + j] 67 | start_h = i * padding + i * H 68 | start_w = j * padding + j * W 69 | canvas[start_h: start_h + H, start_w: start_w + W] = img 70 | 71 | canvas = np.clip(canvas, 0, 1) 72 | canvas *= 255.0 73 | canvas = canvas.astype('uint8') 74 | return canvas 75 | 76 | 77 | def params_norm(parameters): 78 | sq = 0. 79 | n = 0 80 | for p in parameters: 81 | sq += (p**2).sum() 82 | n += torch.numel(p) 83 | return np.sqrt(sq.item() / float(n)) 84 | 85 | 86 | def tonp(x): 87 | if isinstance(x, np.ndarray): 88 | return x 89 | return x.detach().cpu().numpy() 90 | 91 | 92 | def batch_eval(f, loader): 93 | res = [] 94 | for x in loader: 95 | res.append(f(x)) 96 | return res 97 | 98 | 99 | def bits_dim(ll, dim, bits=256): 100 | return np.log2(bits) - ll / dim / np.log(2) 101 | 102 | 103 | class LinearLR(torch.optim.lr_scheduler._LRScheduler): 104 | def __init__(self, optimizer, num_epochs, last_epoch=-1): 105 | self.num_epochs = max(num_epochs, 1) 106 | super(LinearLR, self).__init__(optimizer, last_epoch) 107 | 108 | def get_lr(self): 109 | res = [] 110 | for lr in self.base_lrs: 111 | res.append(np.maximum(lr * np.minimum(-(self.last_epoch + 1) * 1. / self.num_epochs + 1., 1.), 0.)) 112 | return res 113 | 114 | 115 | class HatLR(torch.optim.lr_scheduler._LRScheduler): 116 | def __init__(self, optimizer, warm_up, num_epochs, last_epoch=-1): 117 | if warm_up == 0: 118 | warnings.warn('====> HatLR with warm_up=0 !!! <====') 119 | 120 | self.num_epochs = max(num_epochs, 1) 121 | self.warm_up = warm_up 122 | self.warm_schedule = LinearLR(optimizer, warm_up + 1) 123 | self.warm_schedule.step() 124 | self.anneal_schedule = LinearLR(optimizer, num_epochs - warm_up) 125 | super().__init__(optimizer, last_epoch) 126 | 127 | def get_lr(self): 128 | if self.last_epoch + 1 < self.warm_up: 129 | return [lr - x for lr, x in zip(self.base_lrs, self.warm_schedule.get_lr())] 130 | return self.anneal_schedule.get_lr() 131 | 132 | def step(self, epoch=None): 133 | super().step(epoch=epoch) 134 | if self.last_epoch + 1 < self.warm_up: 135 | self.warm_schedule.step() 136 | else: 137 | self.anneal_schedule.step() 138 | 139 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 140 | param_group['lr'] = lr 141 | 142 | 143 | class BaseLR(torch.optim.lr_scheduler._LRScheduler): 144 | def get_lr(self): 145 | return [group['lr'] for group in self.optimizer.param_groups] 146 | 147 | 148 | def init_kmeans(k, dataloader, model=None, epochs=1, device=None): 149 | kmeans = MiniBatchKMeans(k, batch_size=dataloader.batch_size) 150 | for _ in range(epochs): 151 | for x, _ in dataloader: 152 | if model: 153 | x = model(x.to(device)) 154 | x = tonp(x) 155 | kmeans.partial_fit(x) 156 | 157 | mu = kmeans.cluster_centers_ 158 | dim = mu.shape[1] 159 | cov = np.zeros((k, dim, dim)) 160 | n = np.zeros((k,)) 161 | for x, _ in dataloader: 162 | if model: 163 | x = model(x.to(device)) 164 | x = tonp(x) 165 | labels = kmeans.predict(x) 166 | for k in range(k): 167 | c = labels == k 168 | n[k] += np.sum(c) 169 | d = x[c] - mu[None, k] 170 | cov[k] += np.matmul(d[..., None], d[:, None]).sum(0) 171 | cov /= n[:, None, None] 172 | pi = n / n.sum() 173 | return mu, cov, pi 174 | 175 | 176 | def create_flow(args, data_shape): 177 | if args.model == 'toy': 178 | flow = flows.toy2d_flow(args.conv, args.hh_factors, args.l) 179 | elif args.model == 'id': 180 | flow = flows.Flow([]) 181 | elif args.model == 'mnist': 182 | flow = flows.mnist_flow(num_layers=args.l, k_factor=args.k, logits=args.logits, 183 | conv=args.conv, hh_factors=args.hh_factors, hid_dim=args.hid_dim) 184 | elif args.model == 'mnist-masked': 185 | flow = flows.mnist_masked_glow(conv=args.conv, hh_factors=args.hh_factors) 186 | elif args.model == 'ffjord': 187 | # TODO: add FFJORD model 188 | raise NotImplementedError 189 | else: 190 | raise NotImplementedError 191 | 192 | return flow 193 | 194 | 195 | def create_cond_flow(args): 196 | if args.ssl_model == 'cond-flow': 197 | flow = flows.get_flow_cond(args.ssl_l, args.ssl_k, in_channels=args.ssl_dim, hid_dim=args.ssl_hd, 198 | conv=args.ssl_conv, hh_factors=args.ssl_hh, num_cat=args.ssl_nclasses) 199 | elif args.ssl_model == 'cond-shift': 200 | flow = flows.ConditionalFlow([ 201 | coupling.ConditionalShift(args.ssl_dim, args.ssl_nclasses) 202 | ]) 203 | return flow 204 | 205 | 206 | class MovingMetric(object): 207 | def __init__(self): 208 | self.n = 0 209 | self.sum = 0. 210 | 211 | def add(self, x): 212 | assert np.ndim(x) == 1 213 | self.n += len(x) 214 | self.sum += np.sum(x) 215 | 216 | def avg(self): 217 | return self.sum / self.n if self.n != 0 else np.nan 218 | --------------------------------------------------------------------------------