├── .gitignore
├── LICENSE
├── README.md
├── config.yaml
├── datautils.py
├── get-embeddings-unsup.py
├── get-embeddings.py
├── logger.py
├── models
├── __init__.py
├── coupling.py
├── distributions.py
├── flows.py
├── invertconv.py
├── normalization.py
├── realnvp.py
└── utils.py
├── myexman
├── __init__.py
├── index.py
└── parser.py
├── pretrained
└── model.torch
├── train-discriminator.py
├── train-flow-ssl.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semi-Supervised Flows PyTorch
2 | Authors: [Andrei Atanov](https://andrewatanov.github.io/), [Alexandra Volokhova](https://scholar.google.com/citations?user=23LOcyMAAAAJ&hl=en), [Arsenii Ashukha](https://senya-ashukha.github.io/), [Ivan Sosnovik](https://scholar.google.at/citations?user=brUsNccAAAAJ&hl=en), [Dmitry Vetrov](https://scholar.google.ca/citations?user=7HU0UoUAAAAJ&hl=en)
3 |
4 | This repo contains code for our INNF workshop paper [Semi-Conditional Normalizing Flows for Semi-Supervised Learning](https://arxiv.org/abs/1905.00505)
5 |
6 | __Abstract:__
7 | This paper proposes a semi-conditional normalizing flow model for semi-supervised learning. The model uses both labelled and unlabeled data to learn an explicit model of joint distribution over objects and labels. Semi-conditional architecture of the model allows us to efficiently compute a value and gradients of the marginal likelihood for unlabeled objects. The conditional part of the model is based on a proposed conditional coupling layer. We demonstrate performance of the model for semi-supervised classification problem on different datasets. The model outperforms the baseline approach based on variational auto-encoders on MNIST dataset.
8 |
9 | [__Poster__](https://docs.google.com/presentation/d/1wSA6RKG4ko2zI9XuVAsJq0dqd-XTPLOiWlJ17XJ1SoQ/edit?usp=sharing)
10 |
11 | # Semi-Supervised MNIST classification
12 |
13 | Train a Semi-Conditional Normalizing Flows on MNIST with 100 labeled examples:
14 |
15 | `python train-flow-ssl.py --config config.yaml`
16 |
17 | You can then find logs at `/logs/exman-train-flow-ssl.py/runs/`
18 |
19 | For the convenience we also provide pretrained weights `pretrained/model.torch`, use `--pretrained` flag for loading.
20 |
21 | # Credits
22 |
23 | * Credits to https://github.com/ferrine/exman for the exman parser.
24 |
25 | # Citation
26 |
27 | If you found this code useful please cite our paper
28 |
29 | ```
30 | @article{atanov2019semi,
31 | title={Semi-conditional normalizing flows for semi-supervised learning},
32 | author={Atanov, Andrei and Volokhova, Alexandra and Ashukha, Arsenii and Sosnovik, Ivan and Vetrov, Dmitry},
33 | journal={arXiv preprint arXiv:1905.00505},
34 | year={2019}
35 | }
36 | ```
37 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | cl_weight: 0.0
2 | clip_gn: 100.0
3 | config_file: ''
4 | conv: full
5 | data: mnist
6 | data_seed: 0
7 | epochs: 500
8 | hh_factors: 2
9 | hid_dim: []
10 | k: 4
11 | l: 2
12 | log_each: 1
13 | logits: true
14 | lr: 0.001
15 | lr_gamma: 0.5
16 | lr_schedule: linear
17 | lr_steps: []
18 | lr_warmup: 10
19 | model: mnist-masked
20 | name: em vs mle
21 | num_examples: -1
22 | pretrained: ''
23 | root: ''
24 | seed: 0
25 | ssl_conv: full
26 | ssl_dim: 196
27 | ssl_hd: 256
28 | ssl_hh: 2
29 | ssl_k: 4
30 | ssl_l: 2
31 | ssl_model: cond-flow
32 | ssl_nclasses: 10
33 | status: done
34 | sup_ohe: true
35 | sup_sample_weight: 0.5
36 | sup_weight: 1.0
37 | supervised: 100
38 | test_bs: 512
39 | tmp: false
40 | train_bs: 256
41 | weight_decay: 1.0e-05
42 |
43 | time: '2019-04-20T13:33:10'
44 | id: 118
45 |
--------------------------------------------------------------------------------
/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn import datasets
4 | import os
5 | import torchvision
6 | from torchvision import transforms
7 | from models.distributions import GMM
8 | import torch.nn.functional as F
9 | from sklearn.model_selection import train_test_split
10 | from torch.utils.data import Dataset
11 | import torch.utils.data
12 | import PIL
13 | from torchvision.datasets import ImageFolder
14 | import warnings
15 | import utils
16 |
17 | DATA_ROOT = './'
18 |
19 |
20 | mean = {
21 | 'mnist': (0.1307,),
22 | 'cifar10': (0.4914, 0.4822, 0.4465)
23 | }
24 |
25 | std = {
26 | 'mnist': (0.3081,),
27 | 'cifar10': (0.2470, 0.2435, 0.2616)
28 | }
29 |
30 |
31 | class UniformNoise(object):
32 | def __init__(self, bits=256):
33 | self.bits = bits
34 |
35 | def __call__(self, x):
36 | with torch.no_grad():
37 | noise = torch.rand_like(x)
38 | # TODO: generalize. x assumed to be normalized to [0, 1]
39 | return (x * (self.bits - 1) + noise) / self.bits
40 |
41 | def __repr__(self):
42 | return "UniformNoise"
43 |
44 |
45 | def load_dataset(data, train_bs, test_bs, num_examples=None, data_root=DATA_ROOT, shuffle=True,
46 | seed=42, supervised=-1, logs_root='', sup_sample_weight=-1, sup_only=False, device=None):
47 | bits = None
48 | sampler = None
49 | if data in ['moons', 'circles']:
50 | if data == 'moons':
51 | x, y = datasets.make_moons(n_samples=int(num_examples * 1.5), noise=0.1, random_state=seed)
52 | train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=num_examples, random_state=seed)
53 | elif data == 'circles':
54 | x, y = datasets.make_circles(n_samples=int(num_examples * 1.5), noise=0.1, factor=0.2, random_state=seed)
55 | train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=num_examples, random_state=seed)
56 |
57 | if supervised not in [-1, len(train_y), 0]:
58 | unsupervised_idxs, _ = train_test_split(np.arange(len(train_y)), test_size=supervised, stratify=train_y)
59 | train_y[unsupervised_idxs] = -1
60 | elif supervised == 0:
61 | train_y[:] = -1
62 |
63 | torch.save({
64 | 'train_x': train_x,
65 | 'train_y': train_y,
66 | 'test_x': test_x,
67 | 'test_y': test_y,
68 | }, os.path.join(logs_root, 'data.torch'))
69 |
70 | trainset = torch.utils.data.TensorDataset(torch.FloatTensor(train_x[..., None, None]),
71 | torch.LongTensor(train_y))
72 | testset = torch.utils.data.TensorDataset(torch.FloatTensor(test_x[..., None, None]),
73 | torch.LongTensor(test_y))
74 | data_shape = [2, 1, 1]
75 | bits = np.nan
76 | elif data == 'mnist':
77 | train_transform = transforms.Compose([
78 | transforms.ToTensor(),
79 | UniformNoise(),
80 | ])
81 |
82 | test_transform = transforms.Compose([
83 | transforms.ToTensor(),
84 | UniformNoise(),
85 | ])
86 | trainset = torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform=train_transform)
87 | testset = torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform=test_transform)
88 |
89 | if num_examples != -1 and num_examples != len(trainset) and num_examples is not None:
90 | idxs, _ = train_test_split(np.arange(len(trainset)), train_size=num_examples, random_state=seed,
91 | stratify=utils.tonp(trainset.targets))
92 | trainset.data = trainset.data[idxs]
93 | trainset.targets = trainset.targets[idxs]
94 |
95 | if supervised == 0:
96 | trainset.targets[:] = -1
97 | elif supervised != -1:
98 | unsupervised_idxs, _ = train_test_split(np.arange(len(trainset.targets)),
99 | test_size=supervised, stratify=trainset.targets)
100 | trainset.targets[unsupervised_idxs] = -1
101 |
102 | if sup_only:
103 | mask = trainset.targets != -1
104 | trainset.targets = trainset.targets[mask]
105 | trainset.data = trainset.data[mask]
106 |
107 | data_shape = (1, 28, 28)
108 | bits = 256
109 | else:
110 | raise NotImplementedError
111 |
112 | nw = 2
113 | if sup_sample_weight == -1:
114 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=shuffle,
115 | num_workers=nw, pin_memory=True)
116 | else:
117 | sampler = ImbalancedDatasetSampler(trainset, sup_weight=sup_sample_weight)
118 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, sampler=sampler,
119 | num_workers=nw, pin_memory=True)
120 |
121 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False,
122 | num_workers=nw, pin_memory=True)
123 | return trainloader, testloader, data_shape, bits
124 |
125 |
126 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
127 | """Samples elements randomly from a given list of indices for imbalanced dataset
128 | https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/sampler.py
129 | Arguments:
130 | indices (list, optional): a list of indices
131 | num_samples (int, optional): number of samples to draw
132 | """
133 |
134 | def __init__(self, dataset, indices=None, num_samples=None, sup_weight=1.):
135 |
136 | # if indices is not provided,
137 | # all elements in the dataset will be considered
138 | self.indices = list(range(len(dataset))) \
139 | if indices is None else indices
140 |
141 | # if num_samples is not provided,
142 | # draw `len(indices)` samples in each iteration
143 | self.num_samples = len(self.indices) \
144 | if num_samples is None else num_samples
145 |
146 | # distribution of classes in the dataset
147 | label_to_count = {
148 | -1: 0
149 | }
150 | sup = 0
151 | for idx in self.indices:
152 | label = self._get_label(dataset, idx)
153 | if label == -1:
154 | label_to_count[-1] += 1
155 | else:
156 | sup += 1
157 | label_to_count[label] = sup
158 | for k in label_to_count:
159 | if k != -1:
160 | label_to_count[k] = sup
161 |
162 | # weight for each sample
163 | weights = []
164 | for idx in self.indices:
165 | label = self._get_label(dataset, idx)
166 | w = 1 if label == -1 else sup_weight
167 | weights.append(w / label_to_count[label])
168 |
169 | self.weights = torch.DoubleTensor(weights)
170 |
171 | def _get_label(self, dataset, idx):
172 | return dataset.targets[idx].item()
173 |
174 | def __iter__(self):
175 | return (self.indices[i] for i in torch.multinomial(
176 | self.weights, self.num_samples, replacement=True))
177 |
178 | def __len__(self):
179 | return self.num_samples
180 |
181 |
182 | class FastMNIST(torchvision.datasets.MNIST):
183 | def __init__(self, device, *args, **kwargs):
184 | super().__init__(*args, **kwargs)
185 |
186 | # Scale data to [0,1]
187 | self.data = self.data.unsqueeze(1).float().div(255)
188 |
189 | # Put both data and targets on GPU in advance
190 | self.data, self.targets = self.data.to(device), self.targets.to(device)
191 |
192 | def __getitem__(self, index):
193 | """
194 | Args:
195 | index (int): Index
196 |
197 | Returns:
198 | tuple: (image, target) where target is index of the target class.
199 | """
200 | img, target = self.data[index], self.targets[index]
201 |
202 | return img, target
203 |
--------------------------------------------------------------------------------
/get-embeddings-unsup.py:
--------------------------------------------------------------------------------
1 | import myexman
2 | import torch
3 | import utils
4 | import datautils
5 | import os
6 | from logger import Logger
7 | import time
8 | import numpy as np
9 | from models import flows
10 | import matplotlib.pyplot as plt
11 | from models import distributions
12 | import sys
13 | from tqdm import tqdm
14 |
15 |
16 | def get_logp(model, loader):
17 | logp = []
18 | for x, _ in loader:
19 | x = x.to(device)
20 | logp.append(utils.tonp(model.log_prob(x)))
21 | return np.concatenate(logp)
22 |
23 |
24 | parser = myexman.ExParser(file=os.path.basename(__file__))
25 | parser.add_argument('--name', default='')
26 | parser.add_argument('--save_dir', default='')
27 | # Data
28 | parser.add_argument('--data', default='mnist')
29 | parser.add_argument('--data_seed', default=0, type=int)
30 | parser.add_argument('--aug', dest='aug', action='store_true')
31 | parser.add_argument('--no_aug', dest='aug', action='store_false')
32 | parser.set_defaults(aug=False)
33 | # Optimization
34 | parser.add_argument('--epochs', default=100, type=int)
35 | parser.add_argument('--train_bs', default=256, type=int)
36 | parser.add_argument('--test_bs', default=512, type=int)
37 | parser.add_argument('--lr', default=1e-3, type=float)
38 | parser.add_argument('--lr_schedule', default='linear')
39 | parser.add_argument('--lr_warmup', default=10, type=int)
40 | parser.add_argument('--lr_gamma', default=0.5, type=float)
41 | parser.add_argument('--lr_steps', type=int, nargs='*', default=[])
42 | parser.add_argument('--log_each', default=1, type=int)
43 | parser.add_argument('--seed', default=0, type=int)
44 | parser.add_argument('--pretrained', default='')
45 | parser.add_argument('--weight_decay', default=0., type=float)
46 | parser.add_argument('--clip_gv', default=1e9, type=float)
47 | parser.add_argument('--clip_gn', default=100., type=float)
48 | # Model
49 | parser.add_argument('--model', default='flow')
50 | parser.add_argument('--logits', dest='logits', action='store_true')
51 | parser.add_argument('--no-logits', dest='logits', action='store_false')
52 | parser.set_defaults(logits=True)
53 | parser.add_argument('--conv', default='full')
54 | parser.add_argument('--hh_factors', default=2, type=int)
55 | parser.add_argument('--k', default=4, type=int)
56 | parser.add_argument('--l', default=2, type=int)
57 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[])
58 | args = parser.parse_args()
59 |
60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61 |
62 | fmt = {
63 | 'time': '.3f',
64 | 'lr': '.1e',
65 | }
66 | logger = Logger('logs', base=args.root, fmt=fmt)
67 |
68 | # Load data
69 | np.random.seed(args.data_seed)
70 | torch.manual_seed(args.data_seed)
71 | torch.cuda.manual_seed_all(args.data_seed)
72 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs,
73 | seed=args.data_seed, shuffle=False)
74 |
75 | # Seed for training process
76 | np.random.seed(args.seed)
77 | torch.manual_seed(args.seed)
78 | torch.cuda.manual_seed_all(args.seed)
79 |
80 | # Create model
81 | dim = int(np.prod(data_shape))
82 | prior = distributions.GaussianDiag(dim).to(device)
83 |
84 | flow = utils.create_flow(args, data_shape)
85 | flow = torch.nn.DataParallel(flow.to(device))
86 | model = flows.FlowPDF(flow, prior).to(device)
87 |
88 | if args.pretrained is not None and args.pretrained != '':
89 | model.load_state_dict(torch.load(args.pretrained))
90 |
91 |
92 | def get_embeddings(loader, model):
93 | zf, zh, labels = [], [], []
94 | for x, y in tqdm(loader):
95 | z_ = model.flow(x)[1]
96 | zf.append(utils.tonp(z_))
97 | labels.append(utils.tonp(y))
98 | return np.concatenate(zf), np.concatenate(labels)
99 |
100 |
101 | with torch.no_grad():
102 | zf_train, y_train = get_embeddings(trainloader, model)
103 | zf_test, y_test = get_embeddings(testloader, model)
104 |
105 |
106 | np.save(os.path.join(args.save_dir, 'zf_train'), zf_train)
107 | np.save(os.path.join(args.save_dir, 'y_train'), y_train)
108 |
109 | np.save(os.path.join(args.save_dir, 'zf_test'), zf_test)
110 | np.save(os.path.join(args.save_dir, 'y_test'), y_test)
111 |
--------------------------------------------------------------------------------
/get-embeddings.py:
--------------------------------------------------------------------------------
1 | import myexman
2 | import torch
3 | import utils
4 | import datautils
5 | import os
6 | from logger import Logger
7 | import time
8 | import numpy as np
9 | from models import flows, distributions
10 | import matplotlib.pyplot as plt
11 | from algo.em import init_kmeans2plus_mu
12 | import warnings
13 | from sklearn.mixture import GaussianMixture
14 | import torch.nn.functional as F
15 | from tqdm import tqdm
16 | import sys
17 |
18 |
19 | def get_metrics(model, loader):
20 | logp, acc = [], []
21 | for x, y in loader:
22 | x = x.to(device)
23 | log_det, z = model.flow(x)
24 | log_prior_full = model.prior.log_prob_full(z)
25 | pred = torch.softmax(log_prior_full, dim=1).argmax(1)
26 | logp.append(utils.tonp(log_det + model.prior.log_prob(z)))
27 | acc.append(utils.tonp(pred) == utils.tonp(y))
28 | return np.mean(np.concatenate(logp)), np.mean(np.concatenate(acc))
29 |
30 |
31 | parser = myexman.ExParser(file=os.path.basename(__file__))
32 | parser.add_argument('--name', default='')
33 | parser.add_argument('--verbose', default=0, type=int)
34 | parser.add_argument('--save_dir', default='')
35 | parser.add_argument('--test_mode', default='')
36 | # Data
37 | parser.add_argument('--data', default='mnist')
38 | parser.add_argument('--num_examples', default=-1, type=int)
39 | parser.add_argument('--data_seed', default=0, type=int)
40 | parser.add_argument('--sup_sample_weight', default=-1, type=float)
41 | # parser.add_argument('--aug', dest='aug', action='store_true')
42 | # parser.add_argument('--no_aug', dest='aug', action='store_false')
43 | # parser.set_defaults(aug=True)
44 | # Optimization
45 | parser.add_argument('--opt', default='adam')
46 | parser.add_argument('--ssl_alg', default='em')
47 | parser.add_argument('--lr', default=1e-3, type=float)
48 | parser.add_argument('--epochs', default=100, type=int)
49 | parser.add_argument('--train_bs', default=256, type=int)
50 | parser.add_argument('--test_bs', default=512, type=int)
51 | parser.add_argument('--lr_schedule', default='linear')
52 | parser.add_argument('--lr_warmup', default=10, type=int)
53 | parser.add_argument('--lr_gamma', default=0.5, type=float)
54 | parser.add_argument('--lr_steps', type=int, nargs='*', default=[])
55 | parser.add_argument('--log_each', default=1, type=int)
56 | parser.add_argument('--seed', default=0, type=int)
57 | parser.add_argument('--pretrained', default='')
58 | parser.add_argument('--weight_decay', default=0., type=float)
59 | parser.add_argument('--sup_ohe', dest='sup_ohe', action='store_true')
60 | parser.add_argument('--no_sup_ohe', dest='sup_ohe', action='store_false')
61 | parser.set_defaults(sup_ohe=True)
62 | parser.add_argument('--clip_gn', default=100., type=float)
63 | # Model
64 | parser.add_argument('--model', default='flow')
65 | parser.add_argument('--logits', dest='logits', action='store_true')
66 | parser.add_argument('--no_logits', dest='logits', action='store_false')
67 | parser.set_defaults(logits=True)
68 | parser.add_argument('--conv', default='full')
69 | parser.add_argument('--hh_factors', default=2, type=int)
70 | parser.add_argument('--k', default=4, type=int)
71 | parser.add_argument('--l', default=2, type=int)
72 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[])
73 | # Prior
74 | parser.add_argument('--ssl_model', default='cond-flow')
75 | parser.add_argument('--ssl_dim', default=-1, type=int)
76 | parser.add_argument('--ssl_l', default=2, type=int)
77 | parser.add_argument('--ssl_k', default=4, type=int)
78 | parser.add_argument('--ssl_hd', default=256, type=int)
79 | parser.add_argument('--ssl_conv', default='full')
80 | parser.add_argument('--ssl_hh', default=2, type=int)
81 | parser.add_argument('--ssl_nclasses', default=10, type=int)
82 | # SSL
83 | parser.add_argument('--supervised', default=0, type=int)
84 | parser.add_argument('--sup_weight', default=1., type=float)
85 | parser.add_argument('--cl_weight', default=0, type=float)
86 | args = parser.parse_args()
87 |
88 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89 |
90 | # Load data
91 | np.random.seed(args.data_seed)
92 | torch.manual_seed(args.data_seed)
93 | torch.cuda.manual_seed_all(args.data_seed)
94 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs,
95 | seed=args.data_seed, shuffle=False)
96 |
97 | # Seed for training process
98 | np.random.seed(args.seed)
99 | torch.manual_seed(args.seed)
100 | torch.cuda.manual_seed_all(args.seed)
101 |
102 | # Create model
103 | dim = int(np.prod(data_shape))
104 | if args.ssl_dim == -1:
105 | args.ssl_dim = dim
106 | deep_prior = distributions.GaussianDiag(args.ssl_dim)
107 | shallow_prior = distributions.GaussianDiag(dim - args.ssl_dim)
108 | yprior = torch.distributions.Categorical(logits=torch.zeros((args.ssl_nclasses,)).to(device))
109 | ssl_flow = flows.get_flow_cond(args.ssl_l, args.ssl_k, in_channels=args.ssl_dim, hid_dim=args.ssl_hd,
110 | conv=args.ssl_conv, hh_factors=args.ssl_hh, num_cat=args.ssl_nclasses)
111 | ssl_flow = torch.nn.DataParallel(ssl_flow.to(device))
112 | prior = flows.DiscreteConditionalFlowPDF(ssl_flow, deep_prior, yprior, deep_dim=args.ssl_dim,
113 | shallow_prior=shallow_prior)
114 |
115 | flow = utils.create_flow(args, data_shape)
116 | flow = torch.nn.DataParallel(flow.to(device))
117 |
118 | model = flows.FlowPDF(flow, prior).to(device)
119 |
120 | if args.pretrained != '':
121 | model.load_state_dict(torch.load(os.path.join(args.pretrained, 'model.torch'), map_location=device))
122 |
123 |
124 | # def get_embeddings(loader, model):
125 | # zf, zh, labels = [], [], []
126 | # for x, y in loader:
127 | # x = x.to(device)
128 | # print(model.log_prob(x).mean())
129 | # z_ = flow(x)[1]
130 | # zf.append(utils.tonp(z_).mean())
131 | # # z_ = z_[:, -args.ssl_dim:, None, None]
132 | # print(model.prior.log_prob(z_))
133 | # print(torch.zeros((z_.shape[0],)).to(z_.device))
134 | # print(model.prior.flow(z_, y))
135 | # # print(x.device, z_.device)
136 | # # print(torch.zeros((z_.shape[0],)).to(z_.device))
137 | # # print(z_.shape)
138 | # sys.exit(0)
139 | # log_det_jac = torch.zeros((x.shape[0],)).to(x.device)
140 | # ssl_flow.module.f(z_, y.to(device))
141 | # zh.append(utils.tonp())
142 | # labels.append(utils.tonp(y))
143 | # return np.concatenate(zf), np.concatenate(zh), np.concatenate(labels)
144 |
145 | def get_embeddings(loader, model):
146 | zf, zh, labels = [], [], []
147 | for x, y in tqdm(loader):
148 | z_ = model.flow(x)[1]
149 | zf.append(utils.tonp(z_))
150 | zh.append(utils.tonp(model.prior.flow(z_[:, -args.ssl_dim:, None, None], y)[1]))
151 | labels.append(utils.tonp(y))
152 | return np.concatenate(zf), np.concatenate(zh), np.concatenate(labels)
153 |
154 |
155 | y_test = np.array(testloader.dataset.targets)
156 | y_train = np.array(trainloader.dataset.targets)
157 |
158 | if args.test_mode == 'perm':
159 | idxs = np.random.permutation(10000)[:5000]
160 | testloader.dataset.data[idxs] = 255 - testloader.dataset.data[idxs]
161 | testloader.dataset.targets[idxs] = 1 - testloader.dataset.targets[idxs]
162 | elif args.test_mode == '':
163 | pass
164 | elif args.test_mode == 'inv':
165 | testloader.dataset.data = 255 - testloader.dataset.data
166 | testloader.dataset.targets = 1 - testloader.dataset.targets
167 | else:
168 | raise NotImplementedError
169 |
170 | with torch.no_grad():
171 | zf_train, zh_train, _ = get_embeddings(trainloader, model)
172 | zf_test, zh_test, _ = get_embeddings(testloader, model)
173 |
174 |
175 | np.save(os.path.join(args.save_dir, 'zf_train'), zf_train)
176 | np.save(os.path.join(args.save_dir, 'zh_train'), zh_train)
177 | np.save(os.path.join(args.save_dir, 'y_train'), y_train)
178 |
179 | np.save(os.path.join(args.save_dir, 'zf_test'), zf_test)
180 | np.save(os.path.join(args.save_dir, 'zh_test'), zh_test)
181 | np.save(os.path.join(args.save_dir, 'y_test'), y_test)
182 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import numpy as np
5 |
6 | from collections import OrderedDict
7 | from tabulate import tabulate
8 | from pandas import DataFrame
9 | from time import gmtime, strftime
10 | import time
11 |
12 |
13 | class Logger:
14 | def __init__(self, name='name', fmt=None, base='./logs'):
15 | self.handler = True
16 | self.scalar_metrics = OrderedDict()
17 | self.fmt = fmt if fmt else dict()
18 |
19 | if not os.path.exists(base):
20 | os.makedirs(base)
21 |
22 | time = gmtime()
23 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)])
24 | fname = '-'.join(sys.argv[0].split('/')[-3:])
25 | # self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time))
26 | self.path = '%s/%s-%s' % (base, fname, name)
27 |
28 | self.logs = self.path + '.csv'
29 | self.output = self.path + '.out'
30 |
31 | def prin(*args):
32 | str_to_write = ' '.join(map(str, args))
33 | with open(self.output, 'a') as f:
34 | f.write(str_to_write + '\n')
35 | f.flush()
36 |
37 | print(str_to_write)
38 | sys.stdout.flush()
39 |
40 | self.print = prin
41 |
42 | def add_scalar(self, t, key, value):
43 | if key not in self.scalar_metrics:
44 | self.scalar_metrics[key] = []
45 | self.scalar_metrics[key] += [(t, value)]
46 |
47 | def iter_info(self, order=None):
48 | names = list(self.scalar_metrics.keys())
49 | if order:
50 | names = order
51 | values = [self.scalar_metrics[name][-1][1] for name in names]
52 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names]))
53 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.3f' for name in names]
54 |
55 | if self.handler:
56 | self.handler = False
57 | self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt))
58 | else:
59 | self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1])
60 |
61 | def save(self):
62 | result = None
63 | for key in self.scalar_metrics.keys():
64 | if result is None:
65 | result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
66 | else:
67 | df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
68 | result = result.join(df, how='outer')
69 | result.to_csv(self.logs)
70 | # self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt')
71 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndrewAtanov/semi-supervised-flow-pytorch/d1748decaccbc59e6bce014e1cb84527173c6b54/models/__init__.py
--------------------------------------------------------------------------------
/models/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | def coupling(x1, x2, net, log_det_jac=None, eps=1e-6):
8 | scale, shift = net(x2).split(x1.size(1), dim=1)
9 | # TODO: deal with scale
10 | # scale = torch.tanh(scale)
11 | # scale = torch.exp(scale)
12 | scale = torch.sigmoid(scale + 2.) + eps
13 | x1 = (x1 + shift) * scale
14 | if log_det_jac is not None:
15 | if scale.dim() == 4:
16 | log_det_jac += torch.log(scale).sum((1, 2, 3))
17 | else:
18 | log_det_jac += torch.log(scale).sum(1)
19 | return x1
20 |
21 |
22 | def coupling_inv(x1, x2, net, eps=1e-6):
23 | scale, shift = net(x2).split(x1.size(1), dim=1)
24 | # TODO: deal with scale
25 | # scale = torch.tanh(scale)
26 | # scale = torch.exp(scale)
27 | scale = torch.sigmoid(scale + 2.) + eps
28 | x1 = x1 / scale - shift
29 | return x1
30 |
31 |
32 | def get_mask(xs, mask_type):
33 | if 'checkerboard' in mask_type:
34 | unit0 = np.array([[0.0, 1.0], [1.0, 0.0]])
35 | unit1 = -unit0 + 1.0
36 | unit = unit0 if mask_type == 'checkerboard0' else unit1
37 | unit = np.reshape(unit, [1, 2, 2])
38 | b = np.tile(unit, [xs[0], xs[1]//2, xs[2]//2])
39 | elif 'channel' in mask_type:
40 | white = np.ones([xs[0]//2, xs[1], xs[2]])
41 | black = np.zeros([xs[0]//2, xs[1], xs[2]])
42 | if mask_type == 'channel0':
43 | b = np.concatenate([white, black], 0)
44 | else:
45 | b = np.concatenate([black, white], 0)
46 |
47 | if list(b.shape) != list(xs):
48 | b = np.tile(b, (1, 2, 2))[:, :xs[1], :xs[2]]
49 |
50 | return b
51 |
52 |
53 | class MaskedCouplingLayer(nn.Module):
54 | def __init__(self, shape, mask_type, net):
55 | super().__init__()
56 | mask = torch.FloatTensor(get_mask(shape, mask_type))
57 | self.mask = nn.Parameter(mask[None], requires_grad=False)
58 | self.net = net
59 | self.eps = 1e-6
60 |
61 | def extra_repr(self):
62 | return 'MaskedCouplingLayer(mask=checkerboard)'
63 |
64 | def forward(self, x, log_det_jac, z):
65 | return self.f(x, log_det_jac, z)
66 |
67 | def f(self, x, log_det_jac, z):
68 | x1 = self.mask * x
69 | s, t = self.net(x1).split(x1.size(1), dim=1)
70 | logs = -F.softplus(-s-2.)
71 | logs *= (1 - self.mask)
72 | s = torch.sigmoid(s + 2.)
73 | s = (1 - self.mask) * s
74 | s += self.mask
75 | t = (1 - self.mask) * t
76 | x = x1 + (1 - self.mask) * (x * s + t)
77 | log_det_jac += torch.sum(logs, dim=(1, 2, 3))
78 | return x, log_det_jac, z
79 |
80 | def g(self, x, z):
81 | x1 = self.mask * x
82 | s, t = self.net(x1).split(x1.size(1), dim=1)
83 | s = torch.sigmoid(s + 2.) + self.eps
84 | x = x1 + (1 - self.mask) * (x - t) / s
85 | return x, z
86 |
87 |
88 | class ConditionalMaskedCouplingLayer(nn.Module):
89 | def __init__(self, shape, mask_type, net):
90 | super().__init__()
91 | mask = torch.FloatTensor(get_mask(shape, mask_type))
92 | self.mask = nn.Parameter(mask[None], requires_grad=False)
93 | self.net = net
94 | self.eps = 1e-6
95 |
96 | def forward(self, x, y, log_det_jac, z):
97 | return self.f(x, y, log_det_jac, z)
98 |
99 | def f(self, x, y, log_det_jac, z):
100 | x1 = self.mask * x
101 |
102 | assert y.dim() == 2
103 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3]))
104 |
105 | s, t = self.net(torch.cat([x1, y], dim=1)).split(x1.size(1), dim=1)
106 | logs = -F.softplus(-s-2.)
107 | logs *= (1 - self.mask)
108 | s = torch.sigmoid(s + 2.)
109 | s = (1 - self.mask) * s
110 | s += self.mask
111 | t = (1 - self.mask) * t
112 | x = x1 + (1 - self.mask) * (x * s + t)
113 | log_det_jac += torch.sum(logs, dim=(1, 2, 3))
114 | return x, log_det_jac, z
115 |
116 | def g(self, x, y, z):
117 | x1 = self.mask * x
118 |
119 | assert y.dim() == 2
120 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3]))
121 |
122 | s, t = self.net(torch.cat([x1, y], dim=1)).split(x1.size(1), dim=1)
123 | s = torch.sigmoid(s + 2.)
124 | x = x1 + (1 - self.mask) * (x - t) / s
125 | return x, z
126 |
127 |
128 | class CouplingLayer(nn.Module):
129 | """
130 | Coupling layer with channelwise mask applied twice.
131 | (e.g. see RealNVP https://arxiv.org/pdf/1605.08803.pdf for details)
132 | """
133 | def __init__(self, netfunc):
134 | super().__init__()
135 | self.net1 = netfunc()
136 | self.net2 = netfunc()
137 |
138 | def extra_repr(self):
139 | return 'CouplingLayer(mask=channel)'
140 |
141 | def forward(self, x, log_det_jac, z):
142 | return self.f(x, log_det_jac, z)
143 |
144 | def f(self, x, log_det_jac, z):
145 | C = x.size(1) // 2
146 | x1, x2 = x.split(C, dim=1)
147 | x1, x2 = x1.contiguous(), x2.contiguous()
148 |
149 | x1 = coupling(x1, x2, self.net1, log_det_jac)
150 | x2 = coupling(x2, x1, self.net2, log_det_jac)
151 |
152 | return torch.cat([x1, x2], dim=1), log_det_jac, z
153 |
154 | def g(self, x, z):
155 | C = x.size(1) // 2
156 | x1, x2 = x.split(C, dim=1)
157 | x1, x2 = x1.contiguous(), x2.contiguous()
158 |
159 | x2 = coupling_inv(x2, x1, self.net2)
160 | x1 = coupling_inv(x1, x2, self.net1)
161 |
162 | return torch.cat([x1, x2], dim=1), z
163 |
164 |
165 | class ConditionalCouplingLayer(CouplingLayer):
166 | def forward(self, x, y, log_det_jac, z):
167 | return self.f(x, y, log_det_jac, z)
168 |
169 | def extra_repr(self):
170 | return 'ConditionalCouplingLayer(mask=channel)'
171 |
172 | def f(self, x, y, log_det_jac, z):
173 | C = x.size(1) // 2
174 | x1, x2 = x.split(C, dim=1)
175 | x1, x2 = x1.contiguous(), x2.contiguous()
176 |
177 | assert y.dim() == 2
178 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3]))
179 |
180 | x1 = coupling(x1, torch.cat([x2, y], dim=1), self.net1, log_det_jac)
181 | x2 = coupling(x2, torch.cat([x1, y], dim=1), self.net2, log_det_jac)
182 |
183 | return torch.cat([x1, x2], dim=1), log_det_jac, z
184 |
185 | def g(self, x, y, z):
186 | C = x.size(1) // 2
187 | x1, x2 = x.split(C, dim=1)
188 | x1, x2 = x1.contiguous(), x2.contiguous()
189 |
190 | assert y.dim() == 2
191 | y = y[..., None, None].repeat((1, 1, x1.shape[2], x1.shape[3]))
192 |
193 | x2 = coupling_inv(x2, torch.cat([x1, y], dim=1), self.net2)
194 | x1 = coupling_inv(x1, torch.cat([x2, y], dim=1), self.net1)
195 |
196 | return torch.cat([x1, x2], dim=1), z
197 |
198 |
199 | class ConditionalShift(nn.Module):
200 | def __init__(self, channels, nfactors):
201 | super().__init__()
202 | self.factors = nn.Embedding(nfactors, channels)
203 |
204 | def forward(self, x, y, log_det_jac, z):
205 | return self.f(x, y, log_det_jac, z)
206 |
207 | def f(self, x, y, log_det_jac, z):
208 | shift = self.factors(y)
209 | return x - shift.view((x.size(0), -1, 1, 1)), log_det_jac, z
210 |
211 | def g(self, x, y, z):
212 | shift = self.factors(y)
213 | return x + shift.view((x.size(0), -1, 1, 1)), z
214 |
--------------------------------------------------------------------------------
/models/distributions.py:
--------------------------------------------------------------------------------
1 | import torch.distributions as dist
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from numbers import Number
6 | import numpy as np
7 | import utils
8 |
9 |
10 | class Mixture(dist.Distribution):
11 | def __init__(self, base_ditributions, weights=None):
12 | super(Mixture, self).__init__(batch_shape=base_ditributions[0].batch_shape,
13 | event_shape=base_ditributions[0].event_shape)
14 | self.base_ditributions = base_ditributions
15 | self.weights = weights
16 | if self.weights is None:
17 | k = len(self.base_ditributions)
18 | self.weights = torch.ones((k,)) / float(k)
19 |
20 | def log_prob(self, x, detach=False):
21 | ens = None
22 | for prob, w in zip(self.base_ditributions, self.weights):
23 | logp = prob.log_prob(x) + torch.log(w)
24 | if detach:
25 | logp = logp.detach()
26 | if ens is None:
27 | ens = logp
28 | else:
29 | _t = torch.stack([ens, logp])
30 | ens = torch.logsumexp(_t, dim=0)
31 | return ens
32 |
33 | def one_sample(self, labels=False):
34 | k = np.random.choice(len(self.weights), p=utils.tonp(self.weights))
35 | if labels:
36 | return self.base_ditributions[k].sample(), k
37 | return self.base_ditributions[k].sample()
38 |
39 | def sample(self, sample_shape=torch.Size(), labels=False):
40 | if len(sample_shape) == 0:
41 | return self.one_sample()
42 | elif len(sample_shape) == 1:
43 | res, ys = [], []
44 | for i in range(sample_shape[0]):
45 | if labels:
46 | samples, y = self.one_sample(labels=labels)
47 | res.append(samples)
48 | ys.append(y)
49 | else:
50 | res.append(self.one_sample())
51 | if labels:
52 | return torch.stack(res), np.stack(ys)
53 | else:
54 | return torch.stack(res)
55 | elif len(sample_shape) == 2:
56 | res, y = [], []
57 | for _ in range(sample_shape[0]):
58 | res.append([])
59 | y.append([])
60 | for _ in range(sample_shape[1]):
61 | if labels:
62 | samples, y = self.one_sample(labels=labels)
63 | res[-1].append(samples)
64 | ys[-1].append(y)
65 | else:
66 | res[-1].append(self.one_sample())
67 | res[-1] = torch.stack(res[-1])
68 | if labels:
69 | return torch.stack(res), np.stack(ys)
70 | else:
71 | return torch.stack(res)
72 | else:
73 | raise NotImplementedError
74 |
75 |
76 | class GMM(nn.Module):
77 | def __init__(self, k=None, dim=None, means=None, covariances=None, weights=None, normalize=False):
78 | super(GMM, self).__init__()
79 | if k is None and means is None:
80 | raise NotImplementedError
81 |
82 | if means is None:
83 | covars = torch.rand(k, dim, dim)
84 | covars = torch.matmul(covars, covars.transpose(1, 2))
85 | self.means = nn.ParameterList([nn.Parameter(m) for m in torch.randn(k, dim)])
86 | self.cov_factors = nn.ParameterList([nn.Parameter(torch.cholesky(cov)) for cov in covars])
87 | self.weights = nn.Parameter(torch.FloatTensor([1./k] * k))
88 | self.k = k
89 | else:
90 | self.means = nn.ParameterList([nn.Parameter(m) for m in means])
91 | self.cov_factors = nn.ParameterList([nn.Parameter(torch.cholesky(cov)) for cov in covariances])
92 | self.weights = nn.Parameter(weights)
93 | self.k = weights.shape[0]
94 |
95 | self.normalize = normalize
96 |
97 | def get_weights(self):
98 | if self.normalize:
99 | return F.softmax(self.weights, dim=0)
100 | return self.weights
101 |
102 | def get_dist(self):
103 | base_distributions = []
104 | for m, covf in zip(self.means, self.cov_factors):
105 | if covf.dim() == 1:
106 | covf = torch.diag(covf)
107 | cov = torch.mm(covf, covf.t())
108 | base_distributions.append(dist.MultivariateNormal(m, covariance_matrix=cov))
109 | return Mixture(base_distributions, weights=self.get_weights())
110 |
111 | def set_covariance(self, k, cov):
112 | self.cov_factors[k].data = torch.cholesky(cov)
113 |
114 | def set_params(self, means=None, covars=None, pi=None):
115 | if pi is not None:
116 | self.weights.data = torch.log(pi) if self.normalize else pi
117 | for k in range(self.k):
118 | if means is not None:
119 | self.means[k].data = means[k]
120 | if covars is not None:
121 | self.set_covariance(k, covars[k])
122 |
123 | @property
124 | def covariances(self):
125 | return torch.stack([torch.mm(covf, covf.t()) for covf in self.cov_factors])
126 |
127 | def log_prob(self, x, k=None):
128 | if k is None:
129 | p = self.get_dist()
130 | return p.log_prob(x)
131 | else:
132 | p = self.get_dist()
133 | return p.base_ditributions[k].log_prob(x) + torch.log(p.weights[k])
134 |
135 | def sample(self, sample_shape=torch.Size(), labels=False):
136 | p = self.get_dist()
137 | return p.sample(sample_shape, labels=labels)
138 |
139 |
140 | class MultivariateNormalDiag(torch.distributions.Normal):
141 | def log_prob(self, x):
142 | logp = super().log_prob(x)
143 | return logp.sum(1)
144 |
145 |
146 | class GaussianDiag(nn.Module):
147 | def __init__(self, dim):
148 | super().__init__()
149 | self.dim = dim
150 | self.logsigma = nn.Parameter(torch.zeros((dim,)))
151 | self.mean = nn.Parameter(torch.zeros((dim,)), requires_grad=False)
152 |
153 | def _get_dist(self):
154 | scale = F.softplus(self.logsigma)
155 | return MultivariateNormalDiag(self.mean, scale)
156 |
157 | def log_prob(self, x):
158 | p = self._get_dist()
159 | return p.log_prob(x)
160 |
161 | def log_prob_full(self, x):
162 | p = self._get_dist()
163 | return p.log_prob(x)[:, None]
164 |
165 |
166 | class GmmPrior(nn.Module):
167 | def __init__(self, k=None, dim=None, full_dim=None, means=None, covariances=None, weights=None, cov_type='diag'):
168 | super().__init__()
169 | if k is None and means is None:
170 | raise NotImplementedError
171 |
172 | if cov_type not in ['diag']:
173 | raise NotImplementedError
174 | self.cov_type = cov_type
175 |
176 | if full_dim is None:
177 | full_dim = dim
178 |
179 | if means is None:
180 | means = torch.randn(k, dim) * np.sqrt(2)
181 | if cov_type == 'diag':
182 | # covariances = torch.log(torch.rand(k, dim) * 0.5)
183 | covariances = torch.zeros((k, dim))
184 | weights = torch.FloatTensor([1./k] * k)
185 |
186 | self.means = nn.Parameter(means)
187 | self.cov_factors = nn.Parameter(covariances)
188 | self.weights = nn.Parameter(weights)
189 | self.k = self.weights.shape[0]
190 | self.dim = self.means.shape[1]
191 | self.full_dim = full_dim
192 | self.sn_dim = self.full_dim - self.dim
193 |
194 | def get_logpi(self):
195 | return F.log_softmax(self.weights, dim=0)
196 |
197 | def get_dist(self):
198 | base_distributions = []
199 | for m, covf in zip(self.means, self.cov_factors):
200 | m = torch.cat([torch.zeros((self.sn_dim,)).to(m.device), m])
201 | if self.cov_type == 'diag':
202 | covf = torch.cat([torch.zeros((self.sn_dim,)).to(covf.device), covf])
203 | # TODO: softplus seems to be more stable
204 | scale = torch.exp(covf * 0.5)
205 | base_distributions.append(MultivariateNormalDiag(m, scale))
206 |
207 | pi = torch.exp(self.get_logpi())
208 | return Mixture(base_distributions, weights=pi)
209 |
210 | def log_prob(self, x, k=None):
211 | logpi = self.get_logpi()
212 | if k is None:
213 | p = self.get_dist()
214 | return p.log_prob(x)
215 | else:
216 | p = self.get_dist()
217 | return p.base_ditributions[k].log_prob(x) + logpi[k]
218 |
219 | def log_prob_full(self, x):
220 | return torch.stack([self.log_prob(x, k=k) for k in range(self.k)]).transpose(0, 1)
221 |
222 | def log_prob_full_fast(self, x):
223 | if self.cov_type != 'diag':
224 | raise NotImplementedError
225 | var = torch.exp(self.cov_factors)
226 | logp = -(x[:, None] - self.means[None])**2 / (2. * var[None]) - 0.5 * self.cov_factors
227 | return logp.sum(2) + self.get_logpi()[None] - 0.5 * np.log(2 * np.pi) * self.dim
228 |
229 | def sample(self, sample_shape=torch.Size(), labels=False):
230 | p = self.get_dist()
231 | return p.sample(sample_shape, labels=labels)
232 |
233 | def set_params(self, means=None, covars=None, pi=None):
234 | if pi is not None:
235 | self.weights.data = torch.log(pi) if self.normalize else pi
236 | if means is not None:
237 | self.means.data = torch.tensor(means)
238 | if covars is not None:
239 | self.cov_factors.data = torch.log(covars)
240 |
--------------------------------------------------------------------------------
/models/flows.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from models.normalization import ActNorm, DummyCondActNorm
6 | from models.invertconv import InvertibleConv2d, HHConv2d_1x1, QRInvertibleConv2d, DummyCondInvertibleConv2d
7 | from models.coupling import CouplingLayer, MaskedCouplingLayer, ConditionalCouplingLayer, ConditionalMaskedCouplingLayer
8 | from models.utils import Conv2dZeros, SpaceToDepth, FactorOut, ToLogits, CondToLogits, CondFactorOut, CondSpaceToDepth
9 | from models.utils import DummyCond, IdFunction
10 | import warnings
11 |
12 |
13 | class Flow(nn.Module):
14 | def __init__(self, modules):
15 | super().__init__()
16 | self.modules_ = nn.ModuleList(modules)
17 | self.latent_len = -1
18 | self.x_shape = -1
19 |
20 | def f(self, x):
21 | z = None
22 | log_det_jac = torch.zeros((x.shape[0],)).to(x.device)
23 | for m in self.modules_:
24 | x, log_det_jac, z = m(x, log_det_jac, z)
25 | if z is None:
26 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device)
27 | self.x_shape = list(x.shape)[1:]
28 | self.latent_len = z.shape[1]
29 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1)
30 | return log_det_jac, z
31 |
32 | def forward(self, x):
33 | return self.f(x)
34 |
35 | def g(self, z):
36 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape)
37 | z = z[:, :self.latent_len]
38 | for m in reversed(self.modules_):
39 | x, z = m.g(x, z)
40 | return x
41 |
42 |
43 | class ConditionalFlow(Flow):
44 | def f(self, x, y):
45 | z = None
46 | log_det_jac = torch.zeros((x.shape[0],)).to(x.device)
47 | for m in self.modules_:
48 | x, log_det_jac, z = m(x, y, log_det_jac, z)
49 | if z is None:
50 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device)
51 | self.x_shape = list(x.shape)[1:]
52 | self.latent_len = z.shape[1]
53 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1)
54 | return log_det_jac, z
55 |
56 | def g(self, z, y):
57 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape)
58 | z = z[:, :self.latent_len]
59 | for m in reversed(self.modules_):
60 | x, z = m.g(x, y, z)
61 | return x
62 |
63 | def forward(self, x, y):
64 | return self.f(x, y)
65 |
66 |
67 | class DiscreteConditionalFlow(ConditionalFlow):
68 | def __init__(self, modules, num_cat, emb_dim):
69 | super().__init__(modules)
70 | self.embeddings = nn.Embedding(num_cat, emb_dim)
71 |
72 | self.embeddings.weight.data.zero_()
73 | l = torch.arange(self.embeddings.weight.data.shape[0])
74 | self.embeddings.weight.data[l, l] = 1.
75 |
76 | def f(self, x, y):
77 | return super().f(x, self.embeddings(y))
78 |
79 | def g(self, z, y):
80 | return super().g(z, self.embeddings(y))
81 |
82 |
83 | class FlowPDF(nn.Module):
84 | def __init__(self, flow, prior):
85 | super().__init__()
86 | self.flow = flow
87 | self.prior = prior
88 |
89 | def log_prob(self, x):
90 | log_det, z = self.flow(x)
91 | return log_det + self.prior.log_prob(z)
92 |
93 |
94 | class DeepConditionalFlowPDF(nn.Module):
95 | def __init__(self, flow, deep_prior, yprior, deep_dim, shallow_prior=None):
96 | super().__init__()
97 | self.flow = flow
98 | self.shallow_prior = shallow_prior
99 | self.deep_prior = deep_prior
100 | self.yprior = yprior
101 | self.deep_dim = deep_dim
102 |
103 | def log_prob(self, x, y):
104 | if x.dim() == 2:
105 | x = x[..., None, None]
106 | if self.deep_dim == x.shape[1]:
107 | log_det, z = self.flow(x, y)
108 | return log_det + self.deep_prior.log_prob(z)
109 | else:
110 | log_det, z = self.flow(x[:, -self.deep_dim:], y)
111 | return log_det + self.deep_prior.log_prob(z) + self.shallow_prior.log_prob(x[:, :-self.deep_dim].squeeze())
112 |
113 | def log_prob_joint(self, x, y):
114 | return self.log_prob(x, y) + self.yprior.log_prob(y)
115 |
116 |
117 | class ConditionalFlowPDF(nn.Module):
118 | def __init__(self, flow, prior, emb=True):
119 | super().__init__()
120 | self.flow = flow
121 | self.prior = prior
122 |
123 | def log_prob(self, x, y):
124 | log_det, z = self.flow(x, y)
125 | return log_det + self.prior.log_prob(z)
126 |
127 |
128 | class DiscreteConditionalFlowPDF(DeepConditionalFlowPDF):
129 | def log_prob_full(self, x):
130 | sup = self.yprior.enumerate_support().to(x.device)
131 | logp = []
132 |
133 | n_uniq = sup.size(0)
134 | y = sup.repeat((x.size(0), 1)).t().reshape((1, -1)).t()[:, 0]
135 | logp = self.log_prob(x.repeat([n_uniq] + [1]*(len(x.shape)-1)), y)
136 | return logp.reshape((n_uniq, x.size(0))).t() + self.yprior.log_prob(sup[None])
137 |
138 | def log_prob(self, x, y=None):
139 | if y is not None:
140 | return super().log_prob(x, y)
141 | else:
142 | logp_joint = self.log_prob_full(x)
143 | return torch.logsumexp(logp_joint, dim=1)
144 |
145 | def log_prob_posterior(self, x):
146 | logp_joint = self.log_prob_full(x)
147 | return logp_joint - torch.logsumexp(logp_joint, dim=1)[:, None]
148 |
149 |
150 | class ResNetBlock(nn.Module):
151 | def __init__(self, channels, use_bn=False):
152 | super().__init__()
153 | modules = []
154 | if use_bn:
155 | # modules.append(nn.BatchNorm2d(channels))
156 | ActNorm(channels, flow=False)
157 | modules += [
158 | nn.ReLU(),
159 | nn.ReflectionPad2d(1),
160 | nn.Conv2d(channels, channels, 3)]
161 | if use_bn:
162 | # modules.append(nn.BatchNorm2d(channels))
163 | ActNorm(channels, flow=False)
164 | modules += [
165 | nn.ReLU(),
166 | nn.ReflectionPad2d(1),
167 | nn.Conv2d(channels, channels, 3)]
168 |
169 | self.net = nn.Sequential(*modules)
170 |
171 | def forward(self, x):
172 | return self.net(x) + x
173 |
174 |
175 | class ResNetBlock1x1(nn.Module):
176 | def __init__(self, channels, use_bn=False):
177 | super().__init__()
178 | modules = []
179 | if use_bn:
180 | ActNorm(channels, flow=False)
181 | modules += [
182 | nn.ReLU(),
183 | nn.Conv2d(channels, channels, 1)]
184 | if use_bn:
185 | ActNorm(channels, flow=False)
186 | modules += [
187 | nn.ReLU(),
188 | nn.Conv2d(channels, channels, 1)]
189 |
190 | self.net = nn.Sequential(*modules)
191 |
192 | def forward(self, x):
193 | return self.net(x) + x
194 |
195 |
196 | def get_resnet1x1(in_channels, channels, out_channels=None):
197 | if out_channels is None:
198 | out_channels = in_channels * 2
199 | net = nn.Sequential(
200 | nn.Conv2d(in_channels, channels, 1, padding=0),
201 | ResNetBlock1x1(channels, use_bn=True),
202 | ResNetBlock1x1(channels, use_bn=True),
203 | ResNetBlock1x1(channels, use_bn=True),
204 | ResNetBlock1x1(channels, use_bn=True),
205 | ActNorm(channels, flow=False),
206 | nn.ReLU(),
207 | Conv2dZeros(channels, out_channels, 1, padding=0),
208 | )
209 | return net
210 |
211 |
212 | def get_resnet(in_channels, channels, out_channels=None):
213 | if out_channels is None:
214 | out_channels = in_channels * 2
215 | net = nn.Sequential(
216 | nn.ReflectionPad2d(1),
217 | nn.Conv2d(in_channels, channels, 3),
218 | ResNetBlock(channels, use_bn=True),
219 | ResNetBlock(channels, use_bn=True),
220 | ResNetBlock(channels, use_bn=True),
221 | ResNetBlock(channels, use_bn=True),
222 | ActNorm(channels, flow=False),
223 | nn.ReLU(),
224 | nn.ReflectionPad2d(1),
225 | Conv2dZeros(channels, out_channels, 3, 0),
226 | )
227 | return net
228 |
229 |
230 | def get_resnet8(in_channels, channels, out_channels=None):
231 | if out_channels is None:
232 | out_channels = in_channels * 2
233 | net = nn.Sequential(
234 | nn.ReflectionPad2d(1),
235 | nn.Conv2d(in_channels, channels, 3),
236 | ResNetBlock(channels, use_bn=True),
237 | ResNetBlock(channels, use_bn=True),
238 | ResNetBlock(channels, use_bn=True),
239 | ResNetBlock(channels, use_bn=True),
240 | ResNetBlock(channels, use_bn=True),
241 | ResNetBlock(channels, use_bn=True),
242 | ResNetBlock(channels, use_bn=True),
243 | ResNetBlock(channels, use_bn=True),
244 | ActNorm(channels, flow=False),
245 | nn.ReLU(),
246 | nn.ReflectionPad2d(1),
247 | Conv2dZeros(channels, out_channels, 3, 0),
248 | )
249 | return net
250 |
251 |
252 | def netfunc_for_coupling(in_channels, hidden_channels, out_channels, k=3):
253 | def foo():
254 | return nn.Sequential(
255 | nn.Conv2d(in_channels, hidden_channels, k, padding=int(k == 3)),
256 | nn.ReLU(False),
257 | nn.Conv2d(hidden_channels, hidden_channels, 1),
258 | nn.ReLU(False),
259 | Conv2dZeros(hidden_channels, out_channels, k, padding=int(k == 3))
260 | )
261 |
262 | return foo
263 |
264 |
265 | def get_flow(num_layers, k_factor, in_channels=1, hid_dim=[256], conv='full', hh_factors=2,
266 | cond=False, emb_dim=10, n_cat=10, net='shallow'):
267 | modules = [
268 | DummyCond(ToLogits()) if cond else ToLogits(),
269 | ]
270 | channels = in_channels
271 |
272 | if conv == 'full':
273 | convf = lambda x: InvertibleConv2d(x)
274 | elif conv == 'hh':
275 | convf = lambda x: HHConv2d_1x1(x, factors=[x]*hh_factors)
276 | elif conv == 'qr':
277 | convf = lambda x: QRInvertibleConv2d(x, factors=[x]*hh_factors)
278 | elif conv == 'qr-abs':
279 | convf = lambda x: QRInvertibleConv2d(x, factors=[x]*hh_factors, act='no')
280 | elif conv == 'no':
281 | convf = lambda x: IdFunction()
282 | else:
283 | raise NotImplementedError
284 |
285 | if net == 'shallow':
286 | couplingnetf = lambda x, y: netfunc_for_coupling(x, hid_dim[0], y)
287 | elif net == 'resnet':
288 | couplingnetf = lambda x, y: lambda: get_resnet(x, hid_dim[0], out_channels=y)
289 | else:
290 | raise NotImplementedError
291 |
292 | for l in range(num_layers):
293 | # TODO: FIX
294 | warnings.warn('==== "get_flow" reduce spatial dimensions only 4 times!!! ====')
295 | if l != 4:
296 | if cond:
297 | modules.append(DummyCond(SpaceToDepth(2)))
298 | else:
299 | modules.append(SpaceToDepth(2))
300 | channels *= 4
301 | for k in range(k_factor):
302 | if cond:
303 | modules.append(DummyCond(ActNorm(channels)))
304 | modules.append(DummyCond(convf(channels)))
305 | modules.append(ConditionalCouplingLayer(couplingnetf(channels // 2 + emb_dim, channels)))
306 | else:
307 | modules.append(ActNorm(channels))
308 | modules.append(convf(channels))
309 | modules.append(CouplingLayer(couplingnetf(channels // 2, channels)))
310 |
311 | if l != num_layers - 1:
312 | if cond:
313 | modules.append(DummyCond(FactorOut()))
314 | else:
315 | modules.append(FactorOut())
316 |
317 | channels //= 2
318 | channels -= channels % 2
319 |
320 | return DiscreteConditionalFlow(modules, n_cat, emb_dim) if cond else Flow(modules)
321 |
322 |
323 | def get_flow_cond(num_layers, k_factor, in_channels=1, hid_dim=256, conv='full', hh_factors=2, num_cat=10, emb_dim=10):
324 | modules = []
325 | channels = in_channels
326 | for l in range(num_layers):
327 | for k in range(k_factor):
328 | modules.append(DummyCondActNorm(channels))
329 | if conv == 'full':
330 | modules.append(DummyCondInvertibleConv2d(channels))
331 | elif conv == 'hh':
332 | modules.append(DummyCond(HHConv2d_1x1(channels, factors=[channels]*hh_factors)))
333 | elif conv == 'qr':
334 | modules.append(DummyCond(QRInvertibleConv2d(channels, factors=[channels]*hh_factors)))
335 | elif conv == 'qr-abs':
336 | modules.append(DummyCond(QRInvertibleConv2d(channels, factors=[channels]*hh_factors, act='no')))
337 | else:
338 | raise NotImplementedError
339 |
340 | netf = lambda: get_resnet1x1(channels//2 + emb_dim, hid_dim, channels)
341 | modules.append(ConditionalCouplingLayer(netf))
342 |
343 | if l != num_layers - 1:
344 | modules.append(CondFactorOut())
345 | channels //= 2
346 | channels -= channels % 2
347 |
348 | return DiscreteConditionalFlow(modules, num_cat, emb_dim)
349 |
350 |
351 | def mnist_flow(num_layers=5, k_factor=4, logits=True, conv='full', hh_factors=2, hid_dim=[32, 784]):
352 | modules = []
353 | if logits:
354 | modules.append(ToLogits())
355 |
356 | channels = 1
357 | hd = hid_dim[0]
358 | kernel = 3
359 | for l in range(num_layers):
360 | if l < 2:
361 | modules.append(SpaceToDepth(2))
362 | channels *= 4
363 | elif l == 2:
364 | modules.append(SpaceToDepth(7))
365 | channels *= 49
366 | hd = hid_dim[1]
367 | kernel = 1
368 |
369 | for k in range(k_factor):
370 | modules.append(ActNorm(channels))
371 | if conv == 'full':
372 | modules.append(InvertibleConv2d(channels))
373 | elif conv == 'hh':
374 | modules.append(HHConv2d_1x1(channels, factors=[channels]*hh_factors))
375 | elif conv == 'qr':
376 | modules.append(QRInvertibleConv2d(channels, factors=[channels]*hh_factors))
377 | elif conv == 'qr-abs':
378 | modules.append(QRInvertibleConv2d(channels, factors=[channels]*hh_factors, act='no'))
379 | else:
380 | raise NotImplementedError
381 | modules.append(CouplingLayer(netfunc_for_coupling(channels, hd, k=kernel)))
382 |
383 | if l != num_layers - 1:
384 | modules.append(FactorOut())
385 | channels //= 2
386 | channels -= channels % 2
387 |
388 | return Flow(modules)
389 |
390 |
391 | def mnist_masked_glow(conv='full', hh_factors=2):
392 | def get_net(in_channels, channels):
393 | net = nn.Sequential(
394 | nn.ReflectionPad2d(1),
395 | nn.Conv2d(in_channels, channels, 3),
396 | ResNetBlock(channels, use_bn=True),
397 | ResNetBlock(channels, use_bn=True),
398 | ResNetBlock(channels, use_bn=True),
399 | ResNetBlock(channels, use_bn=True),
400 | ActNorm(channels, flow=False),
401 | nn.ReLU(),
402 | nn.ReflectionPad2d(1),
403 | Conv2dZeros(channels, in_channels * 2, 3, 0),
404 | )
405 | return net
406 |
407 | if conv == 'full':
408 | convf = lambda x: InvertibleConv2d(x)
409 | elif conv == 'qr':
410 | convf = lambda x: QRInvertibleConv2d(x, [x]*hh_factors)
411 | elif conv == 'hh':
412 | convf = lambda x: HHConv2d_1x1(x, [x]*hh_factors)
413 | else:
414 | raise NotImplementedError
415 |
416 | modules = [
417 | ToLogits(),
418 | convf(1),
419 | MaskedCouplingLayer([1, 28, 28], 'checkerboard0', get_net(1, 64)),
420 | ActNorm(1),
421 | convf(1),
422 | MaskedCouplingLayer([1, 28, 28], 'checkerboard1', get_net(1, 64)),
423 | ActNorm(1),
424 | convf(1),
425 | MaskedCouplingLayer([1, 28, 28], 'checkerboard0', get_net(1, 64)),
426 | ActNorm(1),
427 | SpaceToDepth(2),
428 | convf(4),
429 | CouplingLayer(lambda: get_net(2, 64)),
430 | ActNorm(4),
431 | convf(4),
432 | CouplingLayer(lambda: get_net(2, 64)),
433 | ActNorm(4),
434 |
435 | FactorOut(),
436 |
437 | convf(2),
438 | MaskedCouplingLayer([2, 14, 14], 'checkerboard0', get_net(2, 64)),
439 | ActNorm(2),
440 | convf(2),
441 | MaskedCouplingLayer([2, 14, 14], 'checkerboard1', get_net(2, 64)),
442 | ActNorm(2),
443 | convf(2),
444 | MaskedCouplingLayer([2, 14, 14], 'checkerboard0', get_net(2, 64)),
445 | ActNorm(2),
446 | SpaceToDepth(2),
447 | convf(8),
448 | CouplingLayer(lambda: get_net(4, 64)),
449 | ActNorm(8),
450 | convf(8),
451 | CouplingLayer(lambda: get_net(4, 64)),
452 | ActNorm(8),
453 |
454 | FactorOut(),
455 |
456 | convf(4),
457 | MaskedCouplingLayer([4, 7, 7], 'checkerboard0', get_net(4, 64)),
458 | ActNorm(4),
459 | convf(4),
460 | MaskedCouplingLayer([4, 7, 7], 'checkerboard1', get_net(4, 64)),
461 | ActNorm(4),
462 | convf(4),
463 | MaskedCouplingLayer([4, 7, 7], 'checkerboard0', get_net(4, 64)),
464 | ActNorm(4),
465 | convf(4),
466 | CouplingLayer(lambda: get_net(2, 64)),
467 | ActNorm(4),
468 | convf(4),
469 | CouplingLayer(lambda: get_net(2, 64)),
470 | ActNorm(4),
471 | ]
472 |
473 | return Flow(modules)
474 |
475 |
476 | def toy2d_flow(conv='full', hh_factors=2, l=5):
477 | def netf():
478 | return nn.Sequential(
479 | nn.Conv2d(1, 64, 1),
480 | nn.LeakyReLU(),
481 | nn.Conv2d(64, 64, 1),
482 | nn.LeakyReLU(),
483 | nn.Conv2d(64, 2, 1)
484 | )
485 |
486 | if conv == 'full':
487 | convf = lambda x: InvertibleConv2d(x)
488 | elif conv == 'qr':
489 | convf = lambda x: QRInvertibleConv2d(x, [x]*hh_factors)
490 | elif conv == 'hh':
491 | convf = lambda x: HHConv2d_1x1(x, [x]*hh_factors)
492 | else:
493 | raise NotImplementedError
494 |
495 | modules = []
496 | for _ in range(l):
497 | modules.append(convf(2))
498 | modules.append(CouplingLayer(netf))
499 | modules.append(ActNorm(2))
500 | return Flow(modules)
501 |
--------------------------------------------------------------------------------
/models/invertconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | def num_pixels(tensor):
8 | assert tensor.dim() == 4
9 | return tensor.size(2) * tensor.size(3)
10 |
11 |
12 | class _BaseInvertibleConv2d(nn.Module):
13 | def _get_w(self):
14 | raise NotImplementedError
15 |
16 | def _get_w_inv(self):
17 | raise NotImplementedError
18 |
19 | def _log_det(self, x=None):
20 | raise NotImplementedError
21 |
22 | def forward(self, x, log_det_jac, z):
23 | log_det_jac += self._log_det(x)
24 | W = self._get_w().to(x)
25 | W = W.unsqueeze(-1).unsqueeze(-1)
26 | return F.conv2d(x, W), log_det_jac, z
27 |
28 | def g(self, x, z):
29 | W = self._get_w_inv()
30 | W = W.unsqueeze(-1).unsqueeze(-1)
31 | return F.conv2d(x, W), z
32 |
33 |
34 | class InvertibleConv2d(_BaseInvertibleConv2d):
35 | '''
36 | Diederik P. Kingma, Prafulla Dhariwal
37 | "Glow: Generative Flow with Invertible 1x1 Convolutions"
38 | https://arxiv.org/pdf/1807.03039.pdf
39 | '''
40 | def __init__(self, features):
41 | super().__init__()
42 | self.features = features
43 | self.W = nn.Parameter(torch.Tensor(features, features))
44 | self.reset_parameters()
45 |
46 | def reset_parameters(self):
47 | # nn.init.orthogonal_(self.W)
48 | self.W.data = torch.eye(self.features).to(self.W.device)
49 |
50 | def _get_w(self):
51 | return self.W
52 |
53 | def _get_w_inv(self):
54 | return torch.inverse(self.W.double()).float()
55 |
56 | def _log_det(self, x):
57 | return torch.slogdet(self.W.double())[1].float() * num_pixels(x)
58 |
59 | def extra_repr(self):
60 | return 'InvertibleConv2d({:d})'.format(self.features)
61 |
62 |
63 | def householder_matrix(v, size=None):
64 | """
65 | householder_matrix(Tensor, size=None) -> Tensor
66 |
67 | Arguments
68 | v: Tensor of size [Any,]
69 | size: `int` or `None`. The size of the resulting matrix.
70 | size >= v.size(0)
71 | Output
72 | I - 2 v^T * v / v*v^T: Tensor of size [size, size]
73 | """
74 | size = size or v.size(0)
75 | assert size >= v.size(0)
76 | v = torch.cat([torch.ones(size - v.size(0), device=v.device), v])
77 | I = torch.eye(size, device=v.device)
78 | outer = torch.ger(v, v)
79 | inner = torch.dot(v, v) + 1e-16
80 | return I - 2 * outer / inner
81 |
82 |
83 | def naive_cascade(vectors, size=None):
84 | """
85 | naive_cascade([Tensor, Tensor, ...], size=None) -> Tensor
86 | naive implementation
87 |
88 | Arguments
89 | vectors: list of Tensors of size [Any,]
90 | size: `int` or `None`. The size of the resulting matrix.
91 | size >= max(v.size(0) for v in vectors)
92 | Output
93 | Q: `torch.Tensor` of size [size, size]
94 | """
95 | size = size or max(v.size(0) for v in vectors)
96 | assert size >= max(v.size(0) for v in vectors)
97 | device = vectors[0].device
98 | Q = torch.eye(size, device=device)
99 | for v in vectors:
100 | Q = torch.mm(Q, householder_matrix(v, size=size))
101 | return Q
102 |
103 |
104 | class HHConv2d_1x1(_BaseInvertibleConv2d):
105 | def __init__(self, features, factors=None):
106 | super().__init__()
107 | self.features = features
108 | self.factors = factors or range(2, features + 1)
109 |
110 | # init vectors
111 | self.vectors = []
112 | for i, f in enumerate(self.factors):
113 | vec = nn.Parameter(torch.Tensor(f))
114 | self.register_parameter('vec_{}'.format(i), vec)
115 | self.vectors.append(vec)
116 |
117 | self.reset_parameters()
118 | self.cascade = naive_cascade
119 |
120 | def reset_parameters(self):
121 | for v in self.vectors:
122 | v.data.uniform_(-1, 1)
123 | with torch.no_grad():
124 | v /= (torch.norm(v) + 1e-16)
125 |
126 | def _get_w(self):
127 | return self.cascade(self.vectors, self.features)
128 |
129 | def _get_w_inv(self):
130 | return self._get_w().t()
131 |
132 | def _log_det(self, x):
133 | return 0.
134 |
135 |
136 | class QRInvertibleConv2d(HHConv2d_1x1):
137 | """
138 | Hoogeboom, Emiel and Berg, Rianne van den and Welling, Max
139 | "Emerging Convolutions for Generative Normalizing Flows"
140 | https://arxiv.org/pdf/1901.11137.pdf
141 | """
142 | def __init__(self, features, factors=None, act='softplus'):
143 | super().__init__(features, factors=factors)
144 | self.act = act
145 | if act == 'softplus':
146 | self.s_factor = nn.Parameter(torch.zeros((features,)))
147 | elif act == 'no':
148 | self.s_factor = nn.Parameter(torch.ones((features,)))
149 | else:
150 | raise NotImplementedError
151 |
152 | self.r = nn.Parameter(torch.zeros((features, features)))
153 |
154 | def _get_w(self):
155 | Q = super()._get_w()
156 | if self.act == 'softplus':
157 | R = torch.diag(F.softplus(self.s_factor))
158 | elif self.act == 'no':
159 | R = torch.diag(self.s_factor)
160 |
161 | R += torch.triu(self.r, diagonal=1)
162 | return Q.to(R) @ R
163 |
164 | def _log_det(self, x=None):
165 | if self.act == 'softplus':
166 | return torch.log(F.softplus(self.s_factor)).sum() * num_pixels(x)
167 | elif self.act == 'no':
168 | return torch.log(torch.abs(self.s_factor)).sum() * num_pixels(x)
169 |
170 | def _get_w_inv(self):
171 | Q = super()._get_w().to(self.s_factor)
172 | if self.act == 'softplus':
173 | R = torch.diag(F.softplus(self.s_factor))
174 | elif self.act == 'no':
175 | R = torch.diag(self.s_factor)
176 |
177 | R += torch.triu(self.r, diagonal=1)
178 | return torch.inverse(R.double()).float() @ Q.t()
179 |
180 |
181 | class DummyCondInvertibleConv2d(InvertibleConv2d):
182 | def forward(self, x, y, log_det_jac, z):
183 | return super().forward(x, log_det_jac, z)
184 |
185 | def g(self, x, y, z):
186 | return super().g(x, z)
187 |
--------------------------------------------------------------------------------
/models/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def num_pixels(tensor):
7 | assert tensor.dim() == 4
8 | return tensor.size(2) * tensor.size(3)
9 |
10 |
11 | class ActNorm(nn.Module):
12 | def __init__(self, num_features, eps=1e-5, flow=True):
13 | super().__init__()
14 | self.num_features = num_features
15 | self.eps = eps
16 | self.logs = nn.Parameter(torch.Tensor(num_features))
17 | self.bias = nn.Parameter(torch.Tensor(num_features))
18 | self.requires_init = nn.Parameter(torch.ByteTensor(1), requires_grad=False)
19 | self.reset_parameters()
20 | self.flow = flow
21 |
22 | def reset_parameters(self):
23 | self.logs.data.zero_()
24 | self.bias.data.zero_()
25 | self.requires_init.data.fill_(True)
26 |
27 | def init_data_dependent(self, x):
28 | with torch.no_grad():
29 | x_ = x.transpose(0, 1).contiguous().view(self.num_features, -1)
30 | mean = x_.mean(1)
31 | var = x_.var(1)
32 | logs = -torch.log(torch.sqrt(var) + 1e-6)
33 | self.logs.data.copy_(logs.data)
34 | self.bias.data.copy_(mean.data)
35 |
36 | def forward(self, x, log_det_jac=None, z=None):
37 | assert x.size(1) == self.num_features
38 | if self.requires_init:
39 | self.requires_init.data.fill_(False)
40 | self.init_data_dependent(x)
41 |
42 | size = [1] * x.ndimension()
43 | size[1] = self.num_features
44 | x = (x - self.bias.view(*size)) * torch.exp(self.logs.view(*size))
45 | if not self.flow:
46 | return x
47 | log_det_jac += self.logs.sum() * num_pixels(x)
48 | return x, log_det_jac, z
49 |
50 | def g(self, x, z):
51 | size = [1] * x.ndimension()
52 | size[1] = self.num_features
53 | x = x * torch.exp(-self.logs.view(*size)) + self.bias.view(*size)
54 | return x, z
55 |
56 | def inverse(self, x):
57 | size = [1] * x.ndimension()
58 | size[1] = self.num_features
59 | x = x * torch.exp(-self.logs.view(*size)) + self.bias.view(*size)
60 | return x
61 |
62 | def log_det(self):
63 | return self._log_det
64 |
65 | def extra_repr(self):
66 | return 'ActNorm({}, requires_init={})'.format(self.num_features, bool(self.requires_init.item()))
67 |
68 |
69 | class DummyCondActNorm(ActNorm):
70 | def forward(self, x, y, log_det_jac=None, z=None):
71 | return super().forward(x, log_det_jac, z)
72 |
73 | def g(self, x, y, z):
74 | return super().g(x, z)
75 |
--------------------------------------------------------------------------------
/models/realnvp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import distributions
4 | from torch.nn.parameter import Parameter
5 | import numpy as np
6 | from torch.nn import functional as F
7 | from models.normalization import ActNorm
8 |
9 |
10 | class RealNVPold(nn.Module):
11 | def __init__(self, nets, nett, masks, prior, device=None):
12 | super().__init__()
13 |
14 | self.prior = prior
15 | self.mask = nn.Parameter(masks, requires_grad=False)
16 | self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))])
17 | self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))])
18 |
19 | self.to(device)
20 | self.device = device
21 |
22 | def g(self, z):
23 | x = z
24 | for i in range(len(self.t)):
25 | x_ = x*self.mask[i]
26 | s = self.s[i](x_)*(1 - self.mask[i])
27 | t = self.t[i](x_)*(1 - self.mask[i])
28 | x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
29 | return x
30 |
31 | def f(self, x):
32 | log_det_J, z = x.new_zeros(x.shape[0]), x
33 | for i in reversed(range(len(self.t))):
34 | z_ = self.mask[i] * z
35 | s = self.s[i](z_) * (1-self.mask[i])
36 | t = self.t[i](z_) * (1-self.mask[i])
37 | z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
38 | if x.dim() == 2:
39 | log_det_J -= s.sum(dim=1)
40 | else:
41 | log_det_J -= s.sum(dim=(1, 2, 3))
42 |
43 | return z, log_det_J
44 |
45 | def log_prob(self, x):
46 | z, logp = self.f(x)
47 | return self.prior.log_prob(z) + logp
48 |
49 | def sample(self, batchSize):
50 | z = self.prior.sample((batchSize, 1))
51 | logp = self.prior.log_prob(z)
52 | x = self.g(z)
53 | return x
54 |
55 |
56 | def get_toy_nvp(prior=None, device=None):
57 | def nets():
58 | return nn.Sequential(nn.Linear(2, 32),
59 | nn.LeakyReLU(),
60 | nn.Linear(32, 2),
61 | nn.Tanh()
62 | )
63 |
64 | def nett():
65 | return nn.Sequential(nn.Linear(2, 32),
66 | nn.LeakyReLU(),
67 | nn.Linear(32, 2)
68 | )
69 |
70 | if prior is None:
71 | prior = distributions.MultivariateNormal(torch.zeros(2).to(device),
72 | torch.eye(2).to(device))
73 |
74 | masks = torch.from_numpy(np.array([[0, 1], [1, 0]] * 3).astype(np.float32))
75 | return RealNVPold(nets, nett, masks, prior, device=device)
76 |
77 |
78 | class NFGMM(RealNVPold):
79 | def log_prob(self, x, k=None):
80 | if k is None:
81 | z, logp = self.f(x)
82 | return self.prior.log_prob(z) + logp
83 | else:
84 | z, logp = self.f(x)
85 | return self.prior.log_prob(z, k=k) + logp
86 |
87 |
88 | def gmm_prior(k):
89 | covars = torch.rand(args.gmm_k, 2, 2)
90 | covars = torch.matmul(covars, covars.transpose(1, 2))
91 | prior = distributions.GMM(torch.randn(args.gmm_k, 2), covars, torch.FloatTensor([0.5] * args.gmm_k),
92 | normalize=args.prior_train_algo == 'GD')
93 |
94 |
95 | class WNConv2d(nn.Conv2d):
96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
97 | padding=0, dilation=1, groups=1, bias=True):
98 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
99 | dilation=dilation, groups=groups, bias=bias)
100 | self.scale = nn.Parameter(torch.ones((1,)))
101 | self.scale.reg = True
102 | self.eps = 1e-6
103 |
104 | def forward(self, input):
105 | w = self.weight / (torch.norm(self.weight) + self.eps) * self.scale
106 | return F.conv2d(input, w, self.bias, self.stride,
107 | self.padding, self.dilation, self.groups)
108 |
109 |
110 | class MABatchNorm2d(nn.BatchNorm2d):
111 | def forward(self, input):
112 | # Code from PyTroch repo:
113 | self._check_input_dim(input)
114 |
115 | exponential_average_factor = 0.0
116 |
117 | if self.training and self.track_running_stats:
118 | # TODO: if statement only here to tell the jit to skip emitting this when it is None
119 | if self.num_batches_tracked is not None:
120 | self.num_batches_tracked += 1
121 | if self.momentum is None: # use cumulative moving average
122 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
123 | else: # use exponential moving average
124 | exponential_average_factor = self.momentum
125 |
126 | # --- My code ---
127 |
128 | if self.training:
129 | mean = torch.mean(input, (0, 2, 3), keepdim=True)
130 | var = torch.mean((input - mean)**2, (0, 2, 3), keepdim=True)
131 | self.running_mean.data = self.running_mean.data * (1 - self.momentum) + self.momentum * mean.data.squeeze()
132 | self.running_var.data = self.running_var.data * (1 - self.momentum) + self.momentum * var.data.squeeze()
133 | mean = mean * self.momentum + (1 - self.momentum) * self.running_mean[None, :, None, None]
134 | var = var * self.momentum + (1 - self.momentum) * self.running_var[None, :, None, None]
135 | else:
136 | mean = self.running_mean[None, :, None, None]
137 | var = self.running_var[None, :, None, None]
138 |
139 | input = (input - mean) / (var + self.eps)
140 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
141 |
142 | return input
143 |
144 |
145 | class ResNetBlock(nn.Module):
146 | def __init__(self, channels, use_bn=True):
147 | super().__init__()
148 | modules = []
149 | if use_bn:
150 | modules.append(nn.BatchNorm2d(channels))
151 | modules += [
152 | nn.ReLU(),
153 | nn.ReflectionPad2d(1),
154 | WNConv2d(channels, channels, 3)]
155 | if use_bn:
156 | modules.append(nn.BatchNorm2d(channels))
157 | modules += [
158 | nn.ReLU(),
159 | nn.ReflectionPad2d(1),
160 | WNConv2d(channels, channels, 3)]
161 |
162 | self.net = nn.Sequential(*modules)
163 |
164 | def forward(self, x):
165 | return self.net(x) + x
166 |
167 |
168 | class SplitAndNorm(nn.Module):
169 | def __init__(self):
170 | super().__init__()
171 | self.scale = nn.Parameter(data=torch.FloatTensor([1.]))
172 | self.scale.reg = True
173 |
174 | def forward(self, x):
175 | k = x.shape[1] // 2
176 | s, t = x[:, :k], x[:, k:]
177 | return torch.tanh(s) * self.scale, t
178 |
179 |
180 | def get_mask(xs, mask_type):
181 | if 'checkerboard' in mask_type:
182 | unit0 = np.array([[0.0, 1.0], [1.0, 0.0]])
183 | unit1 = -unit0 + 1.0
184 | unit = unit0 if mask_type == 'checkerboard0' else unit1
185 | unit = np.reshape(unit, [1, 2, 2])
186 | b = np.tile(unit, [xs[0], xs[1]//2, xs[2]//2])
187 | elif 'channel' in mask_type:
188 | white = np.ones([xs[0]//2, xs[1], xs[2]])
189 | black = np.zeros([xs[0]//2, xs[1], xs[2]])
190 | if mask_type == 'channel0':
191 | b = np.concatenate([white, black], 0)
192 | else:
193 | b = np.concatenate([black, white], 0)
194 |
195 | assert list(b.shape) == list(xs)
196 |
197 | return b
198 |
199 |
200 | class CouplingLayer(nn.Module):
201 | def __init__(self, mask_type, shape, net):
202 | super().__init__()
203 | mask = torch.FloatTensor(get_mask(shape, mask_type))
204 | self.mask = nn.Parameter(mask[None], requires_grad=False)
205 | self.net = net
206 |
207 | def forward(self, x, log_det_jac, z):
208 | return self.f(x, log_det_jac, z)
209 |
210 | def f(self, x, log_det_jac, z):
211 | x1 = self.mask * x
212 | s, t = self.net(x1)
213 | s = (1 - self.mask) * s
214 | t = (1 - self.mask) * t
215 | x = x1 + (1 - self.mask) * (x * torch.exp(s) + t)
216 | log_det_jac += torch.sum(s, dim=(1, 2, 3))
217 | return x, log_det_jac, z
218 |
219 | def g(self, x, z):
220 | x1 = self.mask * x
221 | s, t = self.net(x1)
222 | x = x1 + (1 - self.mask) * (x - t) * torch.exp(-s)
223 | return x, z
224 |
225 |
226 | class Invertable1x1Conv(nn.Module):
227 | # reference https://github.com/openai/glow/blob/eaff2177693a5d84a1cf8ae19e8e0441715b82f8/model.py#L438
228 | def __init__(self, channels):
229 | super().__init__()
230 | # Sample a random orthogonal matrix
231 | w_init = np.linalg.qr(np.random.randn(channels, channels))[0]
232 | self.weight = nn.Parameter(torch.FloatTensor(w_init))
233 |
234 | def forward(self, x, log_det_jac, z):
235 | x = F.conv2d(x, self.weight[:, :, None, None])
236 | log_det_jac += torch.logdet(self.weight) * np.prod(x.shape[2:])
237 | return x, log_det_jac, z
238 |
239 | def g(self, x, z):
240 | x = F.conv2d(x, torch.inverse(self.weight)[:, :, None, None])
241 | return x, z
242 |
243 |
244 | class Housholder1x1Conv(nn.Module):
245 | def __init__(self, channels):
246 | super().__init__()
247 | self.v = nn.Parameter(torch.ones((channels,)))
248 | self.id = nn.Parameter(torch.eye(channels), requires_grad=False)
249 | self.channels = channels
250 |
251 | def forward(self, x, log_det_jac, z):
252 | v = self.v
253 | w = self.id - 2 * v[:, None] @ v[None] / (v @ v)
254 | x = F.conv2d(x, w[..., None, None])
255 | # w is unitary so log_det = 0
256 | return x, log_det_jac, z
257 |
258 | def g(self, x, z):
259 | v = self.v
260 | w = self.id - 2 * v[:, None] @ v[None] / (v @ v)
261 | x = F.conv2d(x, w[..., None, None])
262 | return x, z
263 |
264 |
265 | class Prior(nn.Module):
266 | def __init__(self, dim):
267 | super().__init__()
268 | self.mean = nn.Parameter(torch.zeros((dim,)), requires_grad=False)
269 | self.cov = nn.Parameter(torch.eye(dim), requires_grad=False)
270 |
271 | def log_prob(self, x):
272 | p = torch.distributions.MultivariateNormal(self.mean, self.cov)
273 | return p.log_prob(x)
274 |
275 |
276 | class RealNVP(nn.Module):
277 | def __init__(self, modules, dim):
278 | super().__init__()
279 | self.modules_ = nn.ModuleList(modules)
280 | self.latent_len = -1
281 | self.x_shape = -1
282 | self.prior = Prior(dim)
283 | self.alpha = 0.05
284 |
285 | def f(self, x):
286 | x = x * (1 - self.alpha) + self.alpha * 0.5
287 | log_det_jac = torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3])
288 | x = torch.log(x) - torch.log(1-x)
289 |
290 | z = None
291 | for m in self.modules_:
292 | x, log_det_jac, z = m(x, log_det_jac, z)
293 | if z is None:
294 | z = torch.zeros((x.shape[0], 1))[:, :0].to(x.device)
295 | self.x_shape = list(x.shape)[1:]
296 | self.latent_len = z.shape[1]
297 | z = torch.cat([z, x.reshape((x.shape[0], -1))], dim=1)
298 | return x, log_det_jac, z
299 |
300 | def forward(self, x):
301 | return self.log_prob(x)
302 |
303 | def g(self, z):
304 | x = z[:, self.latent_len:].view([z.shape[0]] + self.x_shape)
305 | z = z[:, :self.latent_len]
306 | for m in reversed(self.modules_):
307 | x, z = m.g(x, z)
308 | x = torch.sigmoid(x)
309 | x = (x - self.alpha * 0.5) / (1. - self.alpha)
310 | return x
311 |
312 | def log_prob(self, x):
313 | x, log_det_jac, z = self.f(x)
314 | logp = self.prior.log_prob(z) + log_det_jac
315 | return logp
316 |
317 |
318 | def get_cifar_realnvp():
319 | dim = 32**2 * 3
320 | channels = 64
321 |
322 | def get_net(in_channels, channels):
323 | net = nn.Sequential(
324 | nn.ReflectionPad2d(1),
325 | WNConv2d(in_channels, channels, 3),
326 | ResNetBlock(channels),
327 | ResNetBlock(channels),
328 | ResNetBlock(channels),
329 | ResNetBlock(channels),
330 | ResNetBlock(channels),
331 | ResNetBlock(channels),
332 | ResNetBlock(channels),
333 | ResNetBlock(channels),
334 | nn.BatchNorm2d(channels),
335 | nn.ReLU(),
336 | nn.ReflectionPad2d(1),
337 | WNConv2d(channels, in_channels * 2, 3),
338 | SplitAndNorm()
339 | )
340 | for m in net.modules():
341 | if isinstance(m, nn.Conv2d):
342 | m.weight.data.normal_(std=1e-6)
343 | m.scale.data.fill_(1e-5)
344 | return net
345 |
346 | model = [
347 | CouplingLayer('checkerboard0', [3, 32, 32], get_net(3, channels)),
348 | CouplingLayer('checkerboard1', [3, 32, 32], get_net(3, channels)),
349 | CouplingLayer('checkerboard0', [3, 32, 32], get_net(3, channels)),
350 | SpaceToDepth(2),
351 | CouplingLayer('channel0', [12, 16, 16], get_net(12, channels)),
352 | CouplingLayer('channel1', [12, 16, 16], get_net(12, channels)),
353 | CouplingLayer('channel0', [12, 16, 16], get_net(12, channels)),
354 | FactorOut([12, 16, 16]),
355 | CouplingLayer('checkerboard0', [6, 16, 16], get_net(6, channels)),
356 | CouplingLayer('checkerboard1', [6, 16, 16], get_net(6, channels)),
357 | CouplingLayer('checkerboard0', [6, 16, 16], get_net(6, channels)),
358 | CouplingLayer('checkerboard1', [6, 16, 16], get_net(6, channels)),
359 | ]
360 | realnvp = RealNVP(model, dim)
361 | return realnvp
362 |
363 |
364 | def get_mnist_realnvp():
365 | dim = 28**2
366 | channels = 32
367 |
368 | def get_net(in_channels, channels):
369 | net = nn.Sequential(
370 | nn.ReflectionPad2d(1),
371 | WNConv2d(in_channels, channels, 3),
372 | ResNetBlock(channels),
373 | ResNetBlock(channels),
374 | ResNetBlock(channels),
375 | nn.BatchNorm2d(channels),
376 | nn.ReLU(),
377 | nn.ReflectionPad2d(1),
378 | WNConv2d(channels, in_channels * 2, 3),
379 | SplitAndNorm()
380 | )
381 | for m in net.modules():
382 | if isinstance(m, nn.Conv2d):
383 | m.weight.data.normal_(std=1e-6)
384 | m.scale.data.fill_(1e-5)
385 | return net
386 |
387 | model = [
388 | CouplingLayer('checkerboard0', [1, 28, 28], get_net(1, channels)),
389 | CouplingLayer('checkerboard1', [1, 28, 28], get_net(1, channels)),
390 | CouplingLayer('checkerboard0', [1, 28, 28], get_net(1, channels)),
391 | SpaceToDepth(2),
392 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
393 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
394 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
395 | FactorOut([4, 14, 14]),
396 | CouplingLayer('checkerboard0', [2, 14, 14], get_net(2, channels)),
397 | CouplingLayer('checkerboard1', [2, 14, 14], get_net(2, channels)),
398 | CouplingLayer('checkerboard0', [2, 14, 14], get_net(2, channels)),
399 | CouplingLayer('checkerboard1', [2, 14, 14], get_net(2, channels)),
400 | ]
401 | realnvp = RealNVP(model, dim)
402 | return realnvp
403 |
404 |
405 | class ConcatNet(nn.Module):
406 | def __init__(self, in_channels, channels):
407 | super().__init__()
408 | self.net1 = nn.Sequential(
409 | nn.Conv2d(in_channels, channels, 3, padding=1),
410 | nn.ReLU(True),
411 | nn.Conv2d(channels, channels, 1),
412 | nn.ReLU(True),
413 | nn.Conv2d(channels, in_channels, 3, padding=1),
414 | )
415 | self.net2 = nn.Sequential(
416 | nn.Conv2d(in_channels, channels, 3, padding=1),
417 | nn.ReLU(True),
418 | nn.Conv2d(channels, channels, 1),
419 | nn.ReLU(True),
420 | nn.Conv2d(channels, in_channels, 3, padding=1),
421 | )
422 | self.split = SplitAndNorm()
423 |
424 | def forward(self, x):
425 | x = torch.cat([self.net1(x), self.net2(x)], dim=1)
426 | return self.split(x)
427 |
428 |
429 | def get_pie(channels=32):
430 | def get_net(in_channels, channels):
431 | net = ConcatNet(in_channels, channels)
432 | for m in net.modules():
433 | if isinstance(m, nn.Conv2d):
434 | m.weight.data.normal_(std=1e-6)
435 | m.bias.data.fill_(0.)
436 | return net
437 |
438 | model = [
439 | SpaceToDepth(2),
440 | Housholder1x1Conv(4),
441 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
442 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
443 | ActNorm(4),
444 | Housholder1x1Conv(4),
445 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
446 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
447 | ActNorm(4),
448 | Housholder1x1Conv(4),
449 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
450 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
451 | ActNorm(4),
452 | Housholder1x1Conv(4),
453 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
454 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
455 | ActNorm(4),
456 | Housholder1x1Conv(4),
457 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
458 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
459 | ActNorm(4),
460 | Housholder1x1Conv(4),
461 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
462 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
463 | ActNorm(4),
464 | Housholder1x1Conv(4),
465 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
466 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
467 | ActNorm(4),
468 | Housholder1x1Conv(4),
469 | CouplingLayer('channel0', [4, 14, 14], get_net(4, channels)),
470 | CouplingLayer('channel1', [4, 14, 14], get_net(4, channels)),
471 | ActNorm(4),
472 | Housholder1x1Conv(4),
473 | SpaceToDepth(14),
474 | Housholder1x1Conv(784),
475 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
476 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
477 | ActNorm(784),
478 | Housholder1x1Conv(784),
479 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
480 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
481 | ActNorm(784),
482 | Housholder1x1Conv(784),
483 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
484 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
485 | ActNorm(784),
486 | Housholder1x1Conv(784),
487 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
488 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
489 | ActNorm(784),
490 | Housholder1x1Conv(784),
491 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
492 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
493 | ActNorm(784),
494 | Housholder1x1Conv(784),
495 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
496 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
497 | ActNorm(784),
498 | Housholder1x1Conv(784),
499 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
500 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
501 | ActNorm(784),
502 | Housholder1x1Conv(784),
503 | CouplingLayer('channel0', [784, 1, 1], get_net(784, channels)),
504 | CouplingLayer('channel1', [784, 1, 1], get_net(784, channels)),
505 | ActNorm(784),
506 | Housholder1x1Conv(784),
507 | ]
508 |
509 | dim = 784
510 | realnvp = RealNVP(model, dim)
511 | return realnvp
512 |
513 |
514 | def get_realnvp(k, l, in_shape, channels, use_bn=False):
515 | dim = int(np.prod(in_shape))
516 |
517 | def get_net(in_channels, channels):
518 | net = nn.Sequential(
519 | nn.ReflectionPad2d(1),
520 | WNConv2d(in_channels, channels, 3),
521 | ResNetBlock(channels, use_bn),
522 | ResNetBlock(channels, use_bn),
523 | ResNetBlock(channels, use_bn),
524 | nn.BatchNorm2d(channels),
525 | nn.ReLU(),
526 | nn.ReflectionPad2d(1),
527 | WNConv2d(channels, in_channels * 2, 3),
528 | SplitAndNorm()
529 | )
530 | for m in net.modules():
531 | if isinstance(m, nn.Conv2d):
532 | m.weight.data.normal_(std=1e-6)
533 | m.scale.data.fill_(1e-5)
534 | return net
535 |
536 | shape = tuple(in_shape)
537 | model = []
538 | for _ in range(l):
539 | for i in range(k):
540 | model.append(Housholder1x1Conv(shape[0]))
541 | model.append(CouplingLayer('checkerboard{}'.format(i % 2), shape, get_net(shape[0], channels)))
542 | model += [SpaceToDepth(2)]
543 | shape = (shape[0] * 4, shape[1] // 2, shape[2] // 2)
544 | for i in range(k):
545 | model.append(Housholder1x1Conv(shape[0]))
546 | model.append(CouplingLayer('channel{}'.format(i % 2), shape, get_net(shape[0], channels)))
547 | model += [FactorOut(list(shape))]
548 | shape = (shape[0] // 2, shape[1], shape[2])
549 |
550 | model += [
551 | CouplingLayer('checkerboard0', shape, get_net(shape[0], channels)),
552 | CouplingLayer('checkerboard1', shape, get_net(shape[0], channels)),
553 | CouplingLayer('checkerboard0', shape, get_net(shape[0], channels)),
554 | CouplingLayer('checkerboard1', shape, get_net(shape[0], channels)),
555 | ]
556 | realnvp = RealNVP(model, dim)
557 | return realnvp
558 |
559 |
560 | class MyModel(nn.Module):
561 | def __init__(self, flow, prior):
562 | super().__init__()
563 | self.flow = flow
564 | self.prior = prior
565 |
566 | def _flow_term(self, x):
567 | _, log_det, z = self.flow([x, None, None])
568 |
569 | # TODO: get rid of this
570 | if z.numel() != x.numel():
571 | log_det += self.flow.pie.residual()
572 |
573 | logp = log_det + self.prior.log_prob(z)
574 | return logp, z
575 |
576 | def log_prob(self, x):
577 | logp, z = self._flow_term(x)
578 | return logp + self.prior.log_prob(z)
579 |
580 | def log_prob_full(self, x):
581 | logp, z = self._flow_term(x)
582 | log_prior = torch.stack([self.prior.log_prob(z, k=k) for k in range(self.prior.k)])
583 | return logp[:, None] + log_prior.transpose(0, 1)
584 |
585 |
586 | class MyPie(nn.Module):
587 | def __init__(self, pie):
588 | super().__init__()
589 | self.pie = pie
590 |
591 | def forward(self, x):
592 | x, _, _ = x
593 | z = self.pie(x)
594 | return None, self.pie.log_det(), z
595 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.distributions as dist
5 | import torch.nn.functional as F
6 |
7 |
8 | class Conv2dZeros(nn.Conv2d):
9 | def __init__(self, in_channels, out_channels, kernel_size, padding):
10 | super().__init__(in_channels, out_channels, kernel_size, padding=padding)
11 | self.weight.data.zero_()
12 | self.bias.data.zero_()
13 |
14 |
15 | class SpaceToDepth(nn.Module):
16 | def __init__(self, block_size):
17 | super().__init__()
18 | self.block_size = block_size
19 | self.block_size_sq = block_size*block_size
20 |
21 | def forward(self, input, *inputs):
22 | output = input.permute(0, 2, 3, 1)
23 | (batch_size, s_height, s_width, s_depth) = output.size()
24 | d_depth = s_depth * self.block_size_sq
25 | d_width = int(s_width / self.block_size)
26 | d_height = int(s_height / self.block_size)
27 | t_1 = output.split(self.block_size, 2)
28 | stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
29 | output = torch.stack(stack, 1)
30 | output = output.permute(0, 2, 1, 3)
31 | output = output.permute(0, 3, 1, 2)
32 | return [output] + list(inputs)
33 |
34 | def g(self, input, *inputs):
35 | output = input.permute(0, 2, 3, 1)
36 | (batch_size, input_height, input_width, input_depth) = output.size()
37 | output_depth = int(input_depth / self.block_size_sq)
38 | output_width = int(input_width * self.block_size)
39 | output_height = int(input_height * self.block_size)
40 | t_1 = output.reshape(batch_size, input_height, input_width, self.block_size_sq, output_depth)
41 | spl = t_1.split(self.block_size, 3)
42 | stacks = [t_t.reshape(batch_size, input_height, output_width, output_depth) for t_t in spl]
43 | output = torch.stack(stacks, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).reshape(batch_size,
44 | output_height,
45 | output_width,
46 | output_depth)
47 | output = output.permute(0, 3, 1, 2)
48 | return [output] + list(inputs)
49 |
50 | def extra_repr(self):
51 | return 'SpaceToDepth({0:d}x{0:d})'.format(self.block_size)
52 |
53 |
54 | class CondSpaceToDepth(SpaceToDepth):
55 | def forward(self, x, y, log_det, z):
56 | return super().forward(x, log_det, z)
57 |
58 | def g(self, x, y, z):
59 | return super().g(x, z)
60 |
61 |
62 | class FactorOut(nn.Module):
63 | def __init__(self):
64 | super().__init__()
65 | self.out_shape = None
66 |
67 | def forward(self, x, log_det_jac, z):
68 | self.out_shape = list(x.shape)[1:]
69 | self.inp_shape = list(x.shape)[1:]
70 | self.out_shape[0] = self.out_shape[0] // 2
71 | self.out_shape[0] += self.out_shape[0] % 2
72 |
73 | k = self.out_shape[0]
74 | if z is None:
75 | return x[:, k:], log_det_jac, x[:, :k].reshape((x.shape[0], -1))
76 | z = torch.cat([z, x[:, :k].view((x.shape[0], -1))], dim=1)
77 | return x[:, k:], log_det_jac, z
78 |
79 | def g(self, x, z):
80 | k = np.prod(self.out_shape)
81 | x = torch.cat([z[:, -k:].view([x.shape[0]] + self.out_shape), x], dim=1)
82 | z = z[:, :-k]
83 | return x, z
84 |
85 | def extra_repr(self):
86 | return 'FactorOut({:s} -> {:s})'.format(str(self.inp_shape), str(self.out_shape))
87 |
88 |
89 | class CondFactorOut(nn.Module):
90 | def __init__(self):
91 | super().__init__()
92 | self.out_shape = None
93 |
94 | def extra_repr(self):
95 | return 'FactorOut({:s} -> {:s})'.format(str(self.inp_shape), str(self.out_shape))
96 |
97 | def forward(self, x, y, log_det_jac, z):
98 | self.out_shape = list(x.shape)[1:]
99 | self.inp_shape = list(x.shape)[1:]
100 | self.out_shape[0] = self.out_shape[0] // 2
101 | self.out_shape[0] += self.out_shape[0] % 2
102 |
103 | k = self.out_shape[0]
104 | if z is None:
105 | return x[:, k:], log_det_jac, x[:, :k].reshape((x.shape[0], -1))
106 | z = torch.cat([z, x[:, :k].view((x.shape[0], -1))], dim=1)
107 | return x[:, k:], log_det_jac, z
108 |
109 | def g(self, x, y, z):
110 | k = np.prod(self.out_shape)
111 | x = torch.cat([z[:, -k:].view([x.shape[0]] + self.out_shape), x], dim=1)
112 | z = z[:, :-k]
113 | return x, z
114 |
115 |
116 | class ToLogits(nn.Module):
117 | '''
118 | Maps interval [0, 1] to (-inf, +inf) via inversion of sigmoid
119 | '''
120 | alpha = 0.05
121 |
122 | def forward(self, x, log_det_jac, z):
123 | # [0, 1] -> [alpha/2, 1 - alpha/2]
124 | x = x * (1 - self.alpha) + self.alpha * 0.5
125 | log_det_jac += torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3])
126 | x = torch.log(x) - torch.log(1-x)
127 | return x, log_det_jac, z
128 |
129 | def g(self, x, z):
130 | x = torch.sigmoid(x)
131 | x = (x - self.alpha * 0.5) / (1. - self.alpha)
132 | return x, z
133 |
134 | def extra_repr(self):
135 | return 'ToLogits()'
136 |
137 |
138 | class InverseLogits(nn.Module):
139 | def forward(self, x, log_det_jac, z):
140 | log_det_jac += torch.sum(-F.softplus(-x) - F.softplus(x), dim=[1, 2, 3])
141 | x = torch.sigmoid(x)
142 | return x, log_det_jac, z
143 |
144 | def g(self, x, z):
145 | x = torch.log(x) - torch.log(1 - x)
146 | return x, z
147 |
148 | def extra_repr(self):
149 | return 'InverseLogits()'
150 |
151 |
152 | class CondToLogits(nn.Module):
153 | '''
154 | Maps interval [0, 1] to (-inf, +inf) via inversion of sigmoid
155 | '''
156 | alpha = 0.05
157 |
158 | def forward(self, x, y, log_det_jac, z):
159 | # [0, 1] -> [alpha/2, 1 - alpha/2]
160 | x = x * (1 - self.alpha) + self.alpha * 0.5
161 | log_det_jac += torch.sum(-torch.log(x) - torch.log(1-x) + np.log(1 - self.alpha), dim=[1, 2, 3])
162 | x = torch.log(x) - torch.log(1-x)
163 | return x, log_det_jac, z
164 |
165 | def g(self, x, y, z):
166 | x = torch.sigmoid(x)
167 | x = (x - self.alpha * 0.5) / (1. - self.alpha)
168 | return x, z
169 |
170 |
171 | class DummyCond(nn.Module):
172 | def __init__(self, module):
173 | super().__init__()
174 | self.module = module
175 |
176 | def forward(self, x, y, log_det_jac, z):
177 | return self.module.forward(x, log_det_jac, z)
178 |
179 | def g(self, x, y, z):
180 | return self.module.g(x, z)
181 |
182 |
183 | class IdFunction(nn.Module):
184 | def forward(self, *inputs):
185 | return inputs
186 |
187 | def g(self, *inputs):
188 | return inputs
189 |
190 |
191 | class UniformWithLogits(dist.Distribution):
192 | def __init__(self, dim):
193 | super().__init__()
194 | self.dim = dim
195 |
196 | def log_prob(self, x):
197 | return torch.sum(-F.softplus(-x) - F.softplus(x), dim=1)
198 |
199 | def sample(self, shape):
200 | x = torch.rand(list(shape) + [self.dim])
201 | return torch.log(x) - torch.log(1 - x)
202 |
--------------------------------------------------------------------------------
/myexman/__init__.py:
--------------------------------------------------------------------------------
1 | from .parser import (
2 | ExParser,
3 | simpleroot
4 | )
5 | from .index import (
6 | Index
7 | )
8 | from . import index
9 | from . import parser
10 | __version__ = '0.0.2'
11 |
--------------------------------------------------------------------------------
/myexman/index.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import pandas as pd
3 | import pathlib
4 | import strconv
5 | import json
6 | import functools
7 | import datetime
8 | from . import parser
9 | import yaml
10 | from argparse import Namespace
11 | __all__ = [
12 | 'Index'
13 | ]
14 |
15 |
16 | def only_value_error(conv):
17 | @functools.wraps(conv)
18 | def new_conv(value):
19 | try:
20 | return conv(value)
21 | except Exception as e:
22 | raise ValueError from e
23 | return new_conv
24 |
25 |
26 | def none2none(none):
27 | if none is None:
28 | return None
29 | else:
30 | raise ValueError
31 |
32 |
33 | converter = strconv.Strconv(converters=[
34 | ('int', strconv.convert_int),
35 | ('float', strconv.convert_float),
36 | ('bool', only_value_error(parser.str2bool)),
37 | ('time', strconv.convert_time),
38 | ('datetime', strconv.convert_datetime),
39 | ('datetime1', lambda time: datetime.datetime.strptime(time, parser.TIME_FORMAT)),
40 | ('date', strconv.convert_date),
41 | ('json', only_value_error(json.loads)),
42 | ])
43 |
44 |
45 | def get_args(path):
46 | with open(path, 'rb') as f:
47 | return Namespace(**yaml.load(f))
48 |
49 |
50 | class Index(object):
51 | def __init__(self, root):
52 | self.root = pathlib.Path(root)
53 |
54 | @property
55 | def index(self):
56 | return self.root / 'index'
57 |
58 | @property
59 | def marked(self):
60 | return self.root / 'marked'
61 |
62 | def info(self, source=None):
63 | if source is None:
64 | source = self.index
65 | files = source.iterdir()
66 | else:
67 | source = self.marked / source
68 | files = source.glob('**/*/'+parser.PARAMS_FILE)
69 |
70 | def get_dict(cfg):
71 | return configargparse.YAMLConfigFileParser().parse(cfg.open('r'))
72 |
73 | def convert_column(col):
74 | if any(isinstance(v, str) for v in converter.convert_series(col)):
75 | return col
76 | else:
77 | return pd.Series(converter.convert_series(col), name=col.name, index=col.index)
78 | try:
79 | df = (pd.DataFrame
80 | .from_records((get_dict(c) for c in files))
81 | .apply(lambda s: convert_column(s))
82 | .sort_values('id')
83 | .assign(root=lambda _: _.root.apply(self.root.__truediv__))
84 | .reset_index(drop=True))
85 | cols = df.columns.tolist()
86 | cols.insert(0, cols.pop(cols.index('id')))
87 | return df.reindex(columns=cols)
88 | except FileNotFoundError as e:
89 | raise KeyError(source.name) from e
90 |
--------------------------------------------------------------------------------
/myexman/parser.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import argparse
3 | import pathlib
4 | import datetime
5 | import yaml
6 | import yaml.representer
7 | import os
8 | import functools
9 | import itertools
10 | from filelock import FileLock
11 | __all__ = [
12 | 'ExParser',
13 | 'simpleroot',
14 | ]
15 |
16 |
17 | TIME_FORMAT_DIR = '%Y-%m-%d-%H-%M-%S'
18 | TIME_FORMAT = '%Y-%m-%dT%H:%M:%S'
19 | DIR_FORMAT = '{num}-{time}'
20 | EXT = 'yaml'
21 | PARAMS_FILE = 'params.'+EXT
22 | FOLDER_DEFAULT = 'exman'
23 | RESERVED_DIRECTORIES = {
24 | 'runs', 'index',
25 | 'tmp', 'marked'
26 | }
27 |
28 |
29 | def yaml_file(name):
30 | return name + '.' + EXT
31 |
32 |
33 | def simpleroot(__file__):
34 | return pathlib.Path(os.path.dirname(os.path.abspath(__file__)))/FOLDER_DEFAULT
35 |
36 |
37 | def represent_as_str(self, data, tostr=str):
38 | return yaml.representer.Representer.represent_str(self, tostr(data))
39 |
40 |
41 | def register_str_converter(*types, tostr=str):
42 | for T in types:
43 | yaml.add_representer(T, functools.partial(represent_as_str, tostr=tostr))
44 |
45 |
46 | register_str_converter(pathlib.PosixPath, pathlib.WindowsPath)
47 |
48 |
49 | def str2bool(s):
50 | true = ('true', 't', 'yes', 'y', 'on', '1')
51 | false = ('false', 'f', 'no', 'n', 'off', '0')
52 |
53 | if s.lower() in true:
54 | return True
55 | elif s.lower() in false:
56 | return False
57 | else:
58 | raise argparse.ArgumentTypeError(s, 'bool argument should be one of {}'.format(str(true + false)))
59 |
60 |
61 | class ParserWithRoot(configargparse.ArgumentParser):
62 | def __init__(self, *args, root=None, zfill=6,
63 | **kwargs):
64 | super().__init__(*args, **kwargs)
65 | if root is None:
66 | raise ValueError('Root directory is not specified')
67 | root = pathlib.Path(root)
68 | if not root.is_absolute():
69 | raise ValueError(root, 'Root directory is not absolute path')
70 | if not root.exists():
71 | raise ValueError(root, 'Root directory does not exist')
72 | self.root = pathlib.Path(root)
73 | self.zfill = zfill
74 | self.register('type', bool, str2bool)
75 | for directory in RESERVED_DIRECTORIES:
76 | getattr(self, directory).mkdir(exist_ok=True)
77 | self.lock = FileLock(str(self.root/'lock'))
78 |
79 | @property
80 | def runs(self):
81 | return self.root / 'runs'
82 |
83 | @property
84 | def marked(self):
85 | return self.root / 'marked'
86 |
87 | @property
88 | def index(self):
89 | return self.root / 'index'
90 |
91 | @property
92 | def tmp(self):
93 | return self.root / 'tmp'
94 |
95 | def max_ex(self):
96 | max_num = 0
97 | for directory in itertools.chain(self.runs.iterdir(), self.tmp.iterdir()):
98 | num = int(directory.name.split('-', 1)[0])
99 | if num > max_num:
100 | max_num = num
101 | return max_num
102 |
103 | def num_ex(self):
104 | return len(list(self.runs.iterdir()))
105 |
106 | def next_ex(self):
107 | return self.max_ex() + 1
108 |
109 | def next_ex_str(self):
110 | return str(self.next_ex()).zfill(self.zfill)
111 |
112 |
113 | class ExParser(ParserWithRoot):
114 | """
115 | Parser responsible for creating the following structure of experiments
116 | ```
117 | root
118 | |-- runs
119 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS
120 | | |-- params.yaml
121 | | `-- ...
122 | |-- index
123 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS.yaml (symlink)
124 | |-- marked
125 | | `--
126 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS (symlink)
127 | | |-- params.yaml
128 | | `-- ...
129 | `-- tmp
130 | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS
131 | |-- params.yaml
132 | `-- ...
133 | ```
134 | """
135 | def __init__(self, *args, zfill=6, file=None,
136 | args_for_setting_config_path=('--config', ),
137 | automark=(),
138 | **kwargs):
139 | root = os.path.join(os.getcwd(), 'logs', ('exman-' + str(file)))
140 | if not os.path.exists(root):
141 | os.makedirs(root)
142 | super().__init__(*args, root=root, zfill=zfill,
143 | args_for_setting_config_path=args_for_setting_config_path,
144 | config_file_parser_class=configargparse.YAMLConfigFileParser,
145 | ignore_unknown_config_file_keys=True,
146 | **kwargs)
147 | self.automark = automark
148 | self.add_argument('--tmp', action='store_true')
149 |
150 | def _initialize_dir(self, tmp):
151 | try:
152 | # with self.lock: # different processes can make it same time, this is needed to avoid collision
153 | time = datetime.datetime.now()
154 | num = self.next_ex_str()
155 | name = DIR_FORMAT.format(num=num, time=time.strftime(TIME_FORMAT_DIR))
156 | if tmp:
157 | absroot = self.tmp / name
158 | relroot = pathlib.Path('tmp') / name
159 | else:
160 | absroot = self.runs / name
161 | relroot = pathlib.Path('runs') / name
162 | # this process now safely owns root directory
163 | # raises FileExistsError on fail
164 | absroot.mkdir()
165 | except FileExistsError: # shit still happens
166 | return self._initialize_dir(tmp)
167 | return absroot, relroot, name, time, num
168 |
169 | def parse_known_args(self, *args, **kwargs):
170 | args, argv = super().parse_known_args(*args, **kwargs)
171 | absroot, relroot, name, time, num = self._initialize_dir(args.tmp)
172 | args.root = absroot
173 | self.yaml_params_path = args.root / PARAMS_FILE
174 | rel_yaml_params_path = pathlib.Path('..', 'runs', name, PARAMS_FILE)
175 | with self.yaml_params_path.open('a') as f:
176 | self.dumpd = args.__dict__.copy()
177 | # dumpd['root'] = relroot
178 | yaml.dump(self.dumpd, f, default_flow_style=False)
179 | print("\ntime: '{}'".format(time.strftime(TIME_FORMAT)), file=f)
180 | print("id:", int(num), file=f)
181 | print(self.yaml_params_path.read_text())
182 | symlink = self.index / yaml_file(name)
183 | if not args.tmp:
184 | symlink.symlink_to(rel_yaml_params_path)
185 | print('Created symlink from', symlink, '->', rel_yaml_params_path)
186 | if self.automark and not args.tmp:
187 | automark_path_part = pathlib.Path(*itertools.chain.from_iterable(
188 | (mark, str(getattr(args, mark, '')))
189 | for mark in self.automark))
190 | markpath = pathlib.Path(self.marked, automark_path_part)
191 | markpath.mkdir(exist_ok=True, parents=True)
192 | relpathmark = pathlib.Path('..', *(['..']*len(automark_path_part.parts))) / 'runs' / name
193 | (markpath / name).symlink_to(relpathmark, target_is_directory=True)
194 | print('Created symlink from', markpath / name, '->', relpathmark)
195 | return args, argv
196 |
--------------------------------------------------------------------------------
/pretrained/model.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndrewAtanov/semi-supervised-flow-pytorch/d1748decaccbc59e6bce014e1cb84527173c6b54/pretrained/model.torch
--------------------------------------------------------------------------------
/train-discriminator.py:
--------------------------------------------------------------------------------
1 | import myexman
2 | import torch
3 | from logger import Logger
4 | import torchvision
5 | import os
6 | from torch import nn
7 | import torch.nn.functional as F
8 | import utils
9 | import warnings
10 | import numpy as np
11 |
12 |
13 | parser = myexman.ExParser(file=os.path.basename(__file__))
14 | parser.add_argument('--name', default='')
15 | # Data
16 | parser.add_argument('--data', default='')
17 | parser.add_argument('--emb')
18 | parser.add_argument('--dim', default=196, type=int)
19 | # Optimization
20 | parser.add_argument('--epochs', default=100, type=int)
21 | parser.add_argument('--lr', default=1e-3, type=float)
22 | args = parser.parse_args()
23 |
24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25 |
26 | logger = Logger('logs', base=args.root)
27 |
28 | if args.emb == "f":
29 | emb_train = np.load(os.path.join(args.data, 'zf_train.npy'))[:, -args.dim:]
30 | emb_test = np.load(os.path.join(args.data, 'zf_test.npy'))[:, -args.dim:]
31 | elif args.emb == 'h':
32 | emb_train = np.load(os.path.join(args.data, 'zh_train.npy'))
33 | emb_test = np.load(os.path.join(args.data, 'zh_test.npy'))
34 | else:
35 | raise NotImplementedError
36 |
37 | y_train = np.load(os.path.join(args.data, 'y_train.npy'))
38 | y_test = np.load(os.path.join(args.data, 'y_test.npy'))
39 |
40 |
41 | trainset = torch.utils.data.TensorDataset(torch.FloatTensor(emb_train - emb_train.mean(0)[None]), torch.LongTensor(y_train))
42 | testset = torch.utils.data.TensorDataset(torch.FloatTensor(emb_test - emb_test.mean(0)[None]), torch.LongTensor(y_test))
43 |
44 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256)
45 | testloader = torch.utils.data.DataLoader(testset, batch_size=256)
46 |
47 |
48 | net = nn.Sequential(
49 | nn.Linear(args.dim, 256),
50 | nn.LeakyReLU(),
51 |
52 | nn.Linear(256, 256),
53 | nn.Dropout(0.3),
54 | nn.LeakyReLU(),
55 |
56 | nn.Linear(256, 256),
57 | nn.LeakyReLU(),
58 |
59 | nn.Linear(256, 256),
60 | nn.LeakyReLU(),
61 |
62 | nn.Linear(256, len(np.unique(y_train))),
63 | ).to(device)
64 |
65 | opt = torch.optim.Adam(net.parameters(), lr=args.lr)
66 | lr_schedule = utils.LinearLR(opt, args.epochs)
67 |
68 | for e in range(1, 1 + args.epochs):
69 | net.train()
70 | train_loss = 0.
71 | train_acc = 0.
72 | for x, y in trainloader:
73 | x, y = x.to(device), y.to(device)
74 | p = net(x)
75 | loss = F.cross_entropy(p, y)
76 | opt.zero_grad()
77 | loss.backward()
78 | opt.step()
79 | train_loss += loss.item() * x.size(0)
80 | train_acc += (p.argmax(1) == y).sum().item()
81 |
82 | train_loss /= len(trainloader.dataset)
83 | train_acc /= len(trainloader.dataset)
84 |
85 | net.eval()
86 | x, y = map(lambda a: a.to(device), next(iter(testloader)))
87 | p = net(x)
88 | test_acc = (p.argmax(1) == y).float().mean().item()
89 |
90 | logger.add_scalar(e, 'train.loss', train_loss)
91 | logger.add_scalar(e, 'train.acc', train_acc)
92 | logger.add_scalar(e, 'test.acc', test_acc)
93 | logger.iter_info()
94 | logger.save()
95 |
--------------------------------------------------------------------------------
/train-flow-ssl.py:
--------------------------------------------------------------------------------
1 | import myexman
2 | import torch
3 | import utils
4 | import datautils
5 | import os
6 | from logger import Logger
7 | import time
8 | import numpy as np
9 | from models import flows, distributions
10 | import warnings
11 | import torch.nn.functional as F
12 | import argparse
13 |
14 |
15 | def get_metrics(model, loader):
16 | logp, acc = [], []
17 | for x, y in loader:
18 | x = x.to(device)
19 | log_det, z = model.flow(x)
20 | log_prior_full = model.prior.log_prob_full(z)
21 | pred = torch.softmax(log_prior_full, dim=1).argmax(1)
22 | logp.append(utils.tonp(log_det + model.prior.log_prob(z)))
23 | acc.append(utils.tonp(pred) == utils.tonp(y))
24 | return np.mean(np.concatenate(logp)), np.mean(np.concatenate(acc))
25 |
26 |
27 | parser = myexman.ExParser(file=os.path.basename(__file__))
28 | parser.add_argument('--name', default='')
29 | parser.add_argument('--seed', default=0, type=int)
30 | # Data
31 | parser.add_argument('--data', default='mnist')
32 | parser.add_argument('--num_examples', default=-1, type=int)
33 | parser.add_argument('--data_seed', default=0, type=int)
34 | parser.add_argument('--sup_sample_weight', default=-1, type=float)
35 | # Optimization
36 | parser.add_argument('--lr', default=1e-3, type=float)
37 | parser.add_argument('--epochs', default=500, type=int)
38 | parser.add_argument('--train_bs', default=256, type=int)
39 | parser.add_argument('--test_bs', default=512, type=int)
40 | parser.add_argument('--lr_schedule', default='hat')
41 | parser.add_argument('--lr_warmup', default=10, type=int)
42 | parser.add_argument('--log_each', default=1, type=int)
43 | parser.add_argument('--pretrained', default='')
44 | parser.add_argument('--weight_decay', default=0., type=float)
45 | # Model
46 | parser.add_argument('--model', default='mnist-masked')
47 | parser.add_argument('--conv', default='full')
48 | parser.add_argument('--hh_factors', default=2, type=int)
49 | parser.add_argument('--k', default=4, type=int)
50 | parser.add_argument('--l', default=2, type=int)
51 | parser.add_argument('--hid_dim', type=int, nargs='*', default=[])
52 | # Prior
53 | parser.add_argument('--ssl_model', default='cond-flow')
54 | parser.add_argument('--ssl_dim', default=-1, type=int)
55 | parser.add_argument('--ssl_l', default=2, type=int)
56 | parser.add_argument('--ssl_k', default=3, type=int)
57 | parser.add_argument('--ssl_hd', default=256, type=int)
58 | parser.add_argument('--ssl_conv', default='full')
59 | parser.add_argument('--ssl_hh', default=2, type=int)
60 | parser.add_argument('--ssl_nclasses', default=10, type=int)
61 | # SSL
62 | parser.add_argument('--supervised', default=0, type=int)
63 | parser.add_argument('--sup_weight', default=1., type=float)
64 | parser.add_argument('--cl_weight', default=0, type=float)
65 | args = parser.parse_args()
66 |
67 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68 |
69 | # TODO: make it changable
70 | torch.set_num_threads(1)
71 |
72 | fmt = {
73 | 'time': '.3f',
74 | }
75 | logger = Logger('logs', base=args.root, fmt=fmt)
76 |
77 | # Load data
78 | np.random.seed(args.data_seed)
79 | torch.manual_seed(args.data_seed)
80 | torch.cuda.manual_seed_all(args.data_seed)
81 | trainloader, testloader, data_shape, bits = datautils.load_dataset(args.data, args.train_bs, args.test_bs,
82 | seed=args.data_seed, num_examples=args.num_examples,
83 | supervised=args.supervised, logs_root=args.root,
84 | sup_sample_weight=args.sup_sample_weight)
85 | # Seed for training process
86 | np.random.seed(args.seed)
87 | torch.manual_seed(args.seed)
88 | torch.cuda.manual_seed_all(args.seed)
89 |
90 | # Create model
91 | dim = int(np.prod(data_shape))
92 | if args.ssl_dim == -1:
93 | args.ssl_dim = dim
94 | deep_prior = distributions.GaussianDiag(args.ssl_dim)
95 | shallow_prior = distributions.GaussianDiag(dim - args.ssl_dim)
96 |
97 | _, c = np.unique(trainloader.dataset.targets[trainloader.dataset.targets != -1], return_counts=True)
98 | yprior = torch.distributions.Categorical(probs=torch.FloatTensor(c/c.sum()).to(device))
99 | ssl_flow = utils.create_cond_flow(args)
100 | # ssl_flow = torch.nn.DataParallel(ssl_flow.to(device))
101 | ssl_flow.to(device)
102 | prior = flows.DiscreteConditionalFlowPDF(ssl_flow, deep_prior, yprior, deep_dim=args.ssl_dim,
103 | shallow_prior=shallow_prior)
104 |
105 | flow = utils.create_flow(args, data_shape)
106 | flow.to(device)
107 | flow = torch.nn.DataParallel(flow.to(device))
108 |
109 | model = flows.FlowPDF(flow, prior).to(device)
110 |
111 | torch.save(model.state_dict(), os.path.join(args.root, 'model_init.torch'))
112 |
113 | parameters = [
114 | {'params': [p for p in model.parameters() if p.requires_grad], 'weight_decay': args.weight_decay},
115 | ]
116 | optimizer = torch.optim.Adamax(parameters, lr=args.lr)
117 | if args.lr_schedule == 'no':
118 | lr_scheduler = utils.BaseLR(optimizer)
119 | elif args.lr_schedule == 'linear':
120 | lr_scheduler = utils.LinearLR(optimizer, args.epochs)
121 | elif args.lr_schedule == 'hat':
122 | lr_scheduler = utils.HatLR(optimizer, args.lr_warmup, args.epochs)
123 | else:
124 | raise NotImplementedError
125 |
126 | if args.pretrained != '':
127 | model.load_state_dict(torch.load(args.pretrained))
128 | # model.load_state_dict(torch.load(os.path.join(args.pretrained, 'model.torch')))
129 | # optimizer.load_state_dict(torch.load(os.path.join(args.pretrained, 'optimizer.torch')))
130 |
131 | t0 = time.time()
132 | for epoch in range(1, args.epochs + 1):
133 | train_loss = 0.
134 | train_acc = utils.MovingMetric()
135 | train_elbo = utils.MovingMetric()
136 | train_cl = utils.MovingMetric()
137 |
138 | for x, y in trainloader:
139 | x = x.to(device)
140 | n_sup = (y != -1).sum().item()
141 |
142 | log_det, z = model.flow(x)
143 |
144 | log_prior = torch.ones((x.size(0),)).to(x.device)
145 | if n_sup != z.shape[0]:
146 | log_prior[y == -1] = model.prior.log_prob(z[y == -1])
147 | if n_sup != 0:
148 | log_prior[y != -1] = model.prior.log_prob(z[y != -1], y=y[y != -1].to(x.device))
149 | elbo = log_det + log_prior
150 |
151 | weights = torch.ones((elbo.size(0),)).to(elbo)
152 | weights[y != -1] = args.sup_weight
153 | weights /= weights.sum()
154 |
155 | gen_loss = -(elbo * weights.detach()).sum()
156 |
157 | cl_loss = 0
158 | if n_sup != 0:
159 | logp_full = model.prior.log_prob_full(z[y != -1])
160 | prediction = logp_full
161 | train_acc.add(utils.tonp(prediction.argmax(1).to(y) == y[y != -1]))
162 | if args.cl_weight != 0:
163 | cl_loss = F.cross_entropy(prediction, y[y != -1].to(prediction.device), reduction='none')
164 | train_cl.add(utils.tonp(cl_loss))
165 | cl_loss = cl_loss.mean()
166 |
167 | loss = gen_loss + args.cl_weight * cl_loss
168 |
169 | optimizer.zero_grad()
170 | loss.backward()
171 | optimizer.step()
172 |
173 | train_elbo.add(utils.tonp(elbo))
174 | train_loss += loss.item() * x.size(0)
175 |
176 | train_loss /= len(trainloader.dataset)
177 | lr_scheduler.step()
178 |
179 | if epoch % args.log_each == 0 or epoch == 1:
180 | with torch.no_grad():
181 | test_logp, test_acc = get_metrics(model, testloader)
182 | logger.add_scalar(epoch, 'train.loss', train_loss)
183 | logger.add_scalar(epoch, 'train.elbo', train_elbo.avg())
184 | logger.add_scalar(epoch, 'train.cl', train_cl.avg())
185 | logger.add_scalar(epoch, 'train.acc', train_acc.avg())
186 | logger.add_scalar(epoch, 'test.logp', test_logp)
187 | logger.add_scalar(epoch, 'test.acc', test_acc)
188 | logger.add_scalar(epoch, 'test.bits/dim', utils.bits_dim(test_logp, dim, bits))
189 | logger.add_scalar(epoch, 'time', time.time() - t0)
190 | t0 = time.time()
191 | logger.iter_info()
192 | logger.save()
193 |
194 | torch.save(model.state_dict(), os.path.join(args.root, 'model.torch'))
195 | torch.save(optimizer.state_dict(), os.path.join(args.root, 'optimizer.torch'))
196 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn import datasets
4 | import os
5 | import torchvision
6 | from torchvision import transforms
7 | from sklearn.cluster import MiniBatchKMeans
8 | import matplotlib.pyplot as plt
9 | import warnings
10 | from models import flows, coupling
11 |
12 |
13 | def viz_array_grid(array, rows, cols, padding=0, channels_last=False, normalize=False, **kwargs):
14 | # normalization
15 | '''
16 | Args:
17 | array: (N_images, N_channels, H, W) or (N_images, H, W, N_channels)
18 | rows, cols: rows and columns of the plot. rows * cols == array.shape[0]
19 | padding: padding between cells of plot
20 | channels_last: for Tensorflow = True, for PyTorch = False
21 | normalize: `False`, `mean_std`, or `min_max`
22 | Kwargs:
23 | if normalize == 'mean_std':
24 | mean: mean of the distribution. Default 0.5
25 | std: std of the distribution. Default 0.5
26 | if normalize == 'min_max':
27 | min: min of the distribution. Default array.min()
28 | max: max if the distribution. Default array.max()
29 | '''
30 | if not channels_last:
31 | array = np.transpose(array, (0, 2, 3, 1))
32 |
33 | array = array.astype('float32')
34 |
35 | if normalize:
36 | if normalize == 'mean_std':
37 | mean = kwargs.get('mean', 0.5)
38 | mean = np.array(mean).reshape((1, 1, 1, -1))
39 | std = kwargs.get('std', 0.5)
40 | std = np.array(std).reshape((1, 1, 1, -1))
41 | array = array * std + mean
42 | elif normalize == 'min_max':
43 | min_ = kwargs.get('min', array.min())
44 | min_ = np.array(min_).reshape((1, 1, 1, -1))
45 | max_ = kwargs.get('max', array.max())
46 | max_ = np.array(max_).reshape((1, 1, 1, -1))
47 | array -= min_
48 | array /= max_ + 1e-9
49 |
50 | batch_size, H, W, channels = array.shape
51 | assert rows * cols == batch_size
52 |
53 | if channels == 1:
54 | canvas = np.ones((H * rows + padding * (rows - 1),
55 | W * cols + padding * (cols - 1)))
56 | array = array[:, :, :, 0]
57 | elif channels == 3:
58 | canvas = np.ones((H * rows + padding * (rows - 1),
59 | W * cols + padding * (cols - 1),
60 | 3))
61 | else:
62 | raise TypeError('number of channels is either 1 of 3')
63 |
64 | for i in range(rows):
65 | for j in range(cols):
66 | img = array[i * cols + j]
67 | start_h = i * padding + i * H
68 | start_w = j * padding + j * W
69 | canvas[start_h: start_h + H, start_w: start_w + W] = img
70 |
71 | canvas = np.clip(canvas, 0, 1)
72 | canvas *= 255.0
73 | canvas = canvas.astype('uint8')
74 | return canvas
75 |
76 |
77 | def params_norm(parameters):
78 | sq = 0.
79 | n = 0
80 | for p in parameters:
81 | sq += (p**2).sum()
82 | n += torch.numel(p)
83 | return np.sqrt(sq.item() / float(n))
84 |
85 |
86 | def tonp(x):
87 | if isinstance(x, np.ndarray):
88 | return x
89 | return x.detach().cpu().numpy()
90 |
91 |
92 | def batch_eval(f, loader):
93 | res = []
94 | for x in loader:
95 | res.append(f(x))
96 | return res
97 |
98 |
99 | def bits_dim(ll, dim, bits=256):
100 | return np.log2(bits) - ll / dim / np.log(2)
101 |
102 |
103 | class LinearLR(torch.optim.lr_scheduler._LRScheduler):
104 | def __init__(self, optimizer, num_epochs, last_epoch=-1):
105 | self.num_epochs = max(num_epochs, 1)
106 | super(LinearLR, self).__init__(optimizer, last_epoch)
107 |
108 | def get_lr(self):
109 | res = []
110 | for lr in self.base_lrs:
111 | res.append(np.maximum(lr * np.minimum(-(self.last_epoch + 1) * 1. / self.num_epochs + 1., 1.), 0.))
112 | return res
113 |
114 |
115 | class HatLR(torch.optim.lr_scheduler._LRScheduler):
116 | def __init__(self, optimizer, warm_up, num_epochs, last_epoch=-1):
117 | if warm_up == 0:
118 | warnings.warn('====> HatLR with warm_up=0 !!! <====')
119 |
120 | self.num_epochs = max(num_epochs, 1)
121 | self.warm_up = warm_up
122 | self.warm_schedule = LinearLR(optimizer, warm_up + 1)
123 | self.warm_schedule.step()
124 | self.anneal_schedule = LinearLR(optimizer, num_epochs - warm_up)
125 | super().__init__(optimizer, last_epoch)
126 |
127 | def get_lr(self):
128 | if self.last_epoch + 1 < self.warm_up:
129 | return [lr - x for lr, x in zip(self.base_lrs, self.warm_schedule.get_lr())]
130 | return self.anneal_schedule.get_lr()
131 |
132 | def step(self, epoch=None):
133 | super().step(epoch=epoch)
134 | if self.last_epoch + 1 < self.warm_up:
135 | self.warm_schedule.step()
136 | else:
137 | self.anneal_schedule.step()
138 |
139 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
140 | param_group['lr'] = lr
141 |
142 |
143 | class BaseLR(torch.optim.lr_scheduler._LRScheduler):
144 | def get_lr(self):
145 | return [group['lr'] for group in self.optimizer.param_groups]
146 |
147 |
148 | def init_kmeans(k, dataloader, model=None, epochs=1, device=None):
149 | kmeans = MiniBatchKMeans(k, batch_size=dataloader.batch_size)
150 | for _ in range(epochs):
151 | for x, _ in dataloader:
152 | if model:
153 | x = model(x.to(device))
154 | x = tonp(x)
155 | kmeans.partial_fit(x)
156 |
157 | mu = kmeans.cluster_centers_
158 | dim = mu.shape[1]
159 | cov = np.zeros((k, dim, dim))
160 | n = np.zeros((k,))
161 | for x, _ in dataloader:
162 | if model:
163 | x = model(x.to(device))
164 | x = tonp(x)
165 | labels = kmeans.predict(x)
166 | for k in range(k):
167 | c = labels == k
168 | n[k] += np.sum(c)
169 | d = x[c] - mu[None, k]
170 | cov[k] += np.matmul(d[..., None], d[:, None]).sum(0)
171 | cov /= n[:, None, None]
172 | pi = n / n.sum()
173 | return mu, cov, pi
174 |
175 |
176 | def create_flow(args, data_shape):
177 | if args.model == 'toy':
178 | flow = flows.toy2d_flow(args.conv, args.hh_factors, args.l)
179 | elif args.model == 'id':
180 | flow = flows.Flow([])
181 | elif args.model == 'mnist':
182 | flow = flows.mnist_flow(num_layers=args.l, k_factor=args.k, logits=args.logits,
183 | conv=args.conv, hh_factors=args.hh_factors, hid_dim=args.hid_dim)
184 | elif args.model == 'mnist-masked':
185 | flow = flows.mnist_masked_glow(conv=args.conv, hh_factors=args.hh_factors)
186 | elif args.model == 'ffjord':
187 | # TODO: add FFJORD model
188 | raise NotImplementedError
189 | else:
190 | raise NotImplementedError
191 |
192 | return flow
193 |
194 |
195 | def create_cond_flow(args):
196 | if args.ssl_model == 'cond-flow':
197 | flow = flows.get_flow_cond(args.ssl_l, args.ssl_k, in_channels=args.ssl_dim, hid_dim=args.ssl_hd,
198 | conv=args.ssl_conv, hh_factors=args.ssl_hh, num_cat=args.ssl_nclasses)
199 | elif args.ssl_model == 'cond-shift':
200 | flow = flows.ConditionalFlow([
201 | coupling.ConditionalShift(args.ssl_dim, args.ssl_nclasses)
202 | ])
203 | return flow
204 |
205 |
206 | class MovingMetric(object):
207 | def __init__(self):
208 | self.n = 0
209 | self.sum = 0.
210 |
211 | def add(self, x):
212 | assert np.ndim(x) == 1
213 | self.n += len(x)
214 | self.sum += np.sum(x)
215 |
216 | def avg(self):
217 | return self.sum / self.n if self.n != 0 else np.nan
218 |
--------------------------------------------------------------------------------