├── .gitignore ├── LICENSE ├── README.md ├── bm_sequential.py ├── ctfp_tools.py ├── data ├── gbm_2.pkl └── ou_2.pkl ├── eval_ctfp.py ├── eval_latent_ctfp.py ├── experiments ├── ctfp_gbm │ └── pretrained.pth ├── ctfp_mix │ └── pretrained.pth ├── ctfp_ou │ └── pretrained.pth ├── latent_ctfp_gbm │ └── pretrained.pth ├── latent_ctfp_mix │ └── pretrained.pth └── latent_ctfp_ou │ └── pretrained.pth ├── figure.png ├── lib ├── diffeq_solver.py ├── encoder_decoder.py ├── layers │ ├── __init__.py │ ├── cnf.py │ ├── container.py │ ├── coupling.py │ ├── diffeq_layers │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── container.py │ │ ├── resnet.py │ │ └── wrappers.py │ ├── elemwise.py │ ├── glow.py │ ├── norm_flows.py │ ├── normalization.py │ ├── odefunc.py │ ├── odefunc_aug.py │ ├── resnet.py │ ├── squeeze.py │ └── wrappers │ │ └── cnf_regularization.py ├── ode_func.py ├── spectral_norm.py └── utils.py ├── ode_rnn_encoder.py ├── page1.png ├── requirements.txt ├── train_ctfp.py ├── train_latent_ctfp.py └── train_misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | .idea/ 131 | 132 | .DS_Store 133 | *.pyc 134 | data/ou_20.pkl 135 | data/gbm_20.pkl 136 | data/ou_mix.pkl 137 | experiments/*/logs 138 | experiments/*/tb_logs 139 | experiments/*/checkpt*.pth 140 | scratch.txt 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 2 | 3 | Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 60 | Public License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 65 | ("Public License"). To the extent this Public License may be 66 | interpreted as a contract, You are granted the Licensed Rights in 67 | consideration of Your acceptance of these terms and conditions, and the 68 | Licensor grants You such rights in consideration of benefits the 69 | Licensor receives from making the Licensed Material available under 70 | these terms and conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. BY-NC-SA Compatible License means a license listed at 90 | creativecommons.org/compatiblelicenses, approved by Creative 91 | Commons as essentially the equivalent of this Public License. 92 | 93 | d. Copyright and Similar Rights means copyright and/or similar rights 94 | closely related to copyright including, without limitation, 95 | performance, broadcast, sound recording, and Sui Generis Database 96 | Rights, without regard to how the rights are labeled or 97 | categorized. For purposes of this Public License, the rights 98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 99 | Rights. 100 | 101 | e. Effective Technological Measures means those measures that, in the 102 | absence of proper authority, may not be circumvented under laws 103 | fulfilling obligations under Article 11 of the WIPO Copyright 104 | Treaty adopted on December 20, 1996, and/or similar international 105 | agreements. 106 | 107 | f. Exceptions and Limitations means fair use, fair dealing, and/or 108 | any other exception or limitation to Copyright and Similar Rights 109 | that applies to Your use of the Licensed Material. 110 | 111 | g. License Elements means the license attributes listed in the name 112 | of a Creative Commons Public License. The License Elements of this 113 | Public License are Attribution, NonCommercial, and ShareAlike. 114 | 115 | h. Licensed Material means the artistic or literary work, database, 116 | or other material to which the Licensor applied this Public 117 | License. 118 | 119 | i. Licensed Rights means the rights granted to You subject to the 120 | terms and conditions of this Public License, which are limited to 121 | all Copyright and Similar Rights that apply to Your use of the 122 | Licensed Material and that the Licensor has authority to license. 123 | 124 | j. Licensor means the individual(s) or entity(ies) granting rights 125 | under this Public License. 126 | 127 | k. NonCommercial means not primarily intended for or directed towards 128 | commercial advantage or monetary compensation. For purposes of 129 | this Public License, the exchange of the Licensed Material for 130 | other material subject to Copyright and Similar Rights by digital 131 | file-sharing or similar means is NonCommercial provided there is 132 | no payment of monetary compensation in connection with the 133 | exchange. 134 | 135 | l. Share means to provide material to the public by any means or 136 | process that requires permission under the Licensed Rights, such 137 | as reproduction, public display, public performance, distribution, 138 | dissemination, communication, or importation, and to make material 139 | available to the public including in ways that members of the 140 | public may access the material from a place and at a time 141 | individually chosen by them. 142 | 143 | m. Sui Generis Database Rights means rights other than copyright 144 | resulting from Directive 96/9/EC of the European Parliament and of 145 | the Council of 11 March 1996 on the legal protection of databases, 146 | as amended and/or succeeded, as well as other essentially 147 | equivalent rights anywhere in the world. 148 | 149 | n. You means the individual or entity exercising the Licensed Rights 150 | under this Public License. Your has a corresponding meaning. 151 | 152 | 153 | Section 2 -- Scope. 154 | 155 | a. License grant. 156 | 157 | 1. Subject to the terms and conditions of this Public License, 158 | the Licensor hereby grants You a worldwide, royalty-free, 159 | non-sublicensable, non-exclusive, irrevocable license to 160 | exercise the Licensed Rights in the Licensed Material to: 161 | 162 | a. reproduce and Share the Licensed Material, in whole or 163 | in part, for NonCommercial purposes only; and 164 | 165 | b. produce, reproduce, and Share Adapted Material for 166 | NonCommercial purposes only. 167 | 168 | 2. Exceptions and Limitations. For the avoidance of doubt, where 169 | Exceptions and Limitations apply to Your use, this Public 170 | License does not apply, and You do not need to comply with 171 | its terms and conditions. 172 | 173 | 3. Term. The term of this Public License is specified in Section 174 | 6(a). 175 | 176 | 4. Media and formats; technical modifications allowed. The 177 | Licensor authorizes You to exercise the Licensed Rights in 178 | all media and formats whether now known or hereafter created, 179 | and to make technical modifications necessary to do so. The 180 | Licensor waives and/or agrees not to assert any right or 181 | authority to forbid You from making technical modifications 182 | necessary to exercise the Licensed Rights, including 183 | technical modifications necessary to circumvent Effective 184 | Technological Measures. For purposes of this Public License, 185 | simply making modifications authorized by this Section 2(a) 186 | (4) never produces Adapted Material. 187 | 188 | 5. Downstream recipients. 189 | 190 | a. Offer from the Licensor -- Licensed Material. Every 191 | recipient of the Licensed Material automatically 192 | receives an offer from the Licensor to exercise the 193 | Licensed Rights under the terms and conditions of this 194 | Public License. 195 | 196 | b. Additional offer from the Licensor -- Adapted Material. 197 | Every recipient of Adapted Material from You 198 | automatically receives an offer from the Licensor to 199 | exercise the Licensed Rights in the Adapted Material 200 | under the conditions of the Adapter's License You apply. 201 | 202 | c. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 3. If requested by the Licensor, You must remove any of the 282 | information required by Section 3(a)(1)(A) to the extent 283 | reasonably practicable. 284 | 285 | b. ShareAlike. 286 | 287 | In addition to the conditions in Section 3(a), if You Share 288 | Adapted Material You produce, the following conditions also apply. 289 | 290 | 1. The Adapter's License You apply must be a Creative Commons 291 | license with the same License Elements, this version or 292 | later, or a BY-NC-SA Compatible License. 293 | 294 | 2. You must include the text of, or the URI or hyperlink to, the 295 | Adapter's License You apply. You may satisfy this condition 296 | in any reasonable manner based on the medium, means, and 297 | context in which You Share Adapted Material. 298 | 299 | 3. You may not offer or impose any additional or different terms 300 | or conditions on, or apply any Effective Technological 301 | Measures to, Adapted Material that restrict exercise of the 302 | rights granted under the Adapter's License You apply. 303 | 304 | 305 | Section 4 -- Sui Generis Database Rights. 306 | 307 | Where the Licensed Rights include Sui Generis Database Rights that 308 | apply to Your use of the Licensed Material: 309 | 310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 311 | to extract, reuse, reproduce, and Share all or a substantial 312 | portion of the contents of the database for NonCommercial purposes 313 | only; 314 | 315 | b. if You include all or a substantial portion of the database 316 | contents in a database in which You have Sui Generis Database 317 | Rights, then the database in which You have Sui Generis Database 318 | Rights (but not its individual contents) is Adapted Material, 319 | including for purposes of Section 3(b); and 320 | 321 | c. You must comply with the conditions in Section 3(a) if You Share 322 | all or a substantial portion of the contents of the database. 323 | 324 | For the avoidance of doubt, this Section 4 supplements and does not 325 | replace Your obligations under this Public License where the Licensed 326 | Rights include other Copyright and Similar Rights. 327 | 328 | 329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 330 | 331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 341 | 342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 351 | 352 | c. The disclaimer of warranties and limitation of liability provided 353 | above shall be interpreted in a manner that, to the extent 354 | possible, most closely approximates an absolute disclaimer and 355 | waiver of all liability. 356 | 357 | 358 | Section 6 -- Term and Termination. 359 | 360 | a. This Public License applies for the term of the Copyright and 361 | Similar Rights licensed here. However, if You fail to comply with 362 | this Public License, then Your rights under this Public License 363 | terminate automatically. 364 | 365 | b. Where Your right to use the Licensed Material has terminated under 366 | Section 6(a), it reinstates: 367 | 368 | 1. automatically as of the date the violation is cured, provided 369 | it is cured within 30 days of Your discovery of the 370 | violation; or 371 | 372 | 2. upon express reinstatement by the Licensor. 373 | 374 | For the avoidance of doubt, this Section 6(b) does not affect any 375 | right the Licensor may have to seek remedies for Your violations 376 | of this Public License. 377 | 378 | c. For the avoidance of doubt, the Licensor may also offer the 379 | Licensed Material under separate terms or conditions or stop 380 | distributing the Licensed Material at any time; however, doing so 381 | will not terminate this Public License. 382 | 383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 384 | License. 385 | 386 | 387 | Section 7 -- Other Terms and Conditions. 388 | 389 | a. The Licensor shall not be bound by any additional or different 390 | terms or conditions communicated by You unless expressly agreed. 391 | 392 | b. Any arrangements, understandings, or agreements regarding the 393 | Licensed Material not stated herein are separate from and 394 | independent of the terms and conditions of this Public License. 395 | 396 | 397 | Section 8 -- Interpretation. 398 | 399 | a. For the avoidance of doubt, this Public License does not, and 400 | shall not be interpreted to, reduce, limit, restrict, or impose 401 | conditions on any use of the Licensed Material that could lawfully 402 | be made without permission under this Public License. 403 | 404 | b. To the extent possible, if any provision of this Public License is 405 | deemed unenforceable, it shall be automatically reformed to the 406 | minimum extent necessary to make it enforceable. If the provision 407 | cannot be reformed, it shall be severed from this Public License 408 | without affecting the enforceability of the remaining terms and 409 | conditions. 410 | 411 | c. No term or condition of this Public License will be waived and no 412 | failure to comply consented to unless expressly agreed to by the 413 | Licensor. 414 | 415 | d. Nothing in this Public License constitutes or may be interpreted 416 | as a limitation upon, or waiver of, any privileges and immunities 417 | that apply to the Licensor or You, including from the legal 418 | processes of any jurisdiction or authority. 419 | 420 | ======================================================================= 421 | 422 | Creative Commons is not a party to its public 423 | licenses. Notwithstanding, Creative Commons may elect to apply one of 424 | its public licenses to material it publishes and in those instances 425 | will be considered the “Licensor.” The text of the Creative Commons 426 | public licenses is dedicated to the public domain under the CC0 Public 427 | Domain Dedication. Except for the limited purpose of indicating that 428 | material is shared under a Creative Commons public license or as 429 | otherwise permitted by the Creative Commons policies published at 430 | creativecommons.org/policies, Creative Commons does not authorize the 431 | use of the trademark "Creative Commons" or any other trademark or logo 432 | of Creative Commons without its prior written consent including, 433 | without limitation, in connection with any unauthorized modifications 434 | to any of its public licenses or any other arrangements, 435 | understandings, or agreements concerning use of licensed material. For 436 | the avoidance of doubt, this paragraph does not form part of the 437 | public licenses. 438 | 439 | Creative Commons may be contacted at creativecommons.org. 440 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modeling Continuous Stochastic Process with Dynamic Normalizing Flow 2 | 3 | Code for the paper 4 | 5 | > Ruizhi Deng, Bo Chang, Marcus Brubaker, Greg Mori, Andreas Lehrmann. "Modeling Continuous Stochastic Process with Dynamic Normalizing Flow" (2020) 6 | [[arxiv]](https://arxiv.org/pdf/2002.10516.pdf) 7 | 8 | ![](page1.png) 9 | ## Dependency installment 10 | 11 | Install the dependencies in requirements.txt with 12 | ```bash 13 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 14 | ``` 15 | We use a new version of PyTorch (1.4.0) for the released code. 16 | The experiments in the paper use an older version of PyTorch, which could lead to slightly different results. 17 | Please read the [PyTorch documentation](https://pytorch.org/docs/stable/notes/randomness.html) for more information. 18 | 19 | ## Acknowledgements 20 | 21 | The code make uses of code from the following two projects: 22 | https://github.com/YuliaRubanova/latent_ode 23 | for the paper 24 | > Yulia Rubanova, Ricky Chen, David Duvenaud. "Latent ODEs for Irregularly-Sampled Time Series" (2019) 25 | [[arxiv]](https://arxiv.org/abs/1907.03907) 26 | 27 | https://github.com/rtqichen/ffjord 28 | for the paper 29 | > Will Grathwohl*, Ricky T. Q. Chen*, Jesse Bettencourt, Ilya Sutskever, David Duvenaud. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models." _International Conference on Learning Representations_ (2019). 30 | > [[arxiv]](https://arxiv.org/abs/1810.01367) [[bibtex]](http://www.cs.toronto.edu/~rtqichen/bibtex/ffjord.bib) 31 | 32 | We make use the following files from the code of *FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models*: `train_misc.py`, `lib/layers`, `lib/utils.py`, `lib/spectral_norm.py`. 33 | We make changes to the file train_misc.py, lib/layers/ode_func.py, and lib/utils.py. 34 | 35 | We use `lib/encoder_decoder.py`, `lib/ode_func.py`, `lib/diffeq_solver.py` from the code of *Latent ODEs for Irregularly-Sampled Time Series*. We make changes to the file `lib/encoder_decoder.py`. 36 | 37 | 38 | ## Command for training the model 39 | 40 | We train the models using $\lambda=2$ 41 | 42 | ### Training CTFP model on GBM Process 43 | ```bash 44 | python train_ctfp.py --batch_size 100 --test_batch_size 100 --num_blocks 1 --save ctfp_gbm --log_freq 1 --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 100 --data_path data/gbm_2.pkl 45 | ``` 46 | ### Training latent CTFP model on GBM Process 47 | ```bash 48 | python train_latent_ctfp.py --batch_size 50 --test_batch_size 5 --num_blocks 1 --save latent_ctfp_gbm --log_freq 1 --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 100 --data_path data/gbm_2.pkl 49 | ``` 50 | ### Training CTFP model on OU Process 51 | ```bash 52 | python train_ctfp.py --batch_size 100 --test_batch_size 100 --num_blocks 1 --save ctfp_ou --log_freq 1 --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 100 --data_path data/ou_2.pkl --activation identity 53 | ``` 54 | ### Training latent CTFP model on OU Process 55 | ```bash 56 | python train_latent_ctfp.py --batch_size 50 --test_batch_size 5 --num_blocks 1 --save latent_ctfp_ou --log_freq 1 --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 300 --data_path data/ou_2.pkl --activation identity --aggressive 57 | ``` 58 | 59 | 60 | ## Command for evaluating the code 61 | 62 | We evaluate the models on data sampled by a observation process with $\lambda=2$ 63 | 64 | ### Evaluating CTFP model on GBM Process 65 | ```bash 66 | python eval_ctfp.py --test_batch_size 100 --num_blocks 1 --save ctfp_gbm --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 100 --resume experiments/ctfp_gbm/pretrained.pth --data_path data/gbm_2.pkl 67 | ``` 68 | ### Evaluating latent CTFP model on GBM Process 69 | ```bash 70 | python eval_latent_ctfp.py --test_batch_size 5 --num_blocks 1 --save latent_ctfp_gbm --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --lr 5e-4 --num_epochs 100 --data_path data/gbm_2.pkl --resume experiments/latent_ctfp_gbm/pretrained.pth 71 | ``` 72 | ### Evaluating CTFP model on OU Process 73 | ```bash 74 | python eval_ctfp.py --test_batch_size 100 --num_blocks 1 --save ctfp_ou --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --data_path data/ou_2.pkl --activation identity --resume experiments/ctfp_ou/pretrained.pth 75 | ``` 76 | ### Evaluating latent CTFP model on OU Process 77 | ```bash 78 | python eval_latent_ctfp.py --test_batch_size 5 --num_blocks 1 --save latent_ctfp_ou --num_workers 2 --layer_type concat --dims 32,64,64,32 --nonlinearity tanh --data_path data/ou_2.pkl --activation identity --resume experiments/latent_ctfp_ou/pretrained.pth 79 | ``` 80 | 81 | ## Performance Summary 82 | | Model | GBM | OU | 83 | |---|---|---| 84 | | CTFP | 3.107 | 2.902 | 85 | | Latent CTFP | 3.106 | 2.902 | 86 | 87 | Download the data from [this link](https://drive.google.com/file/d/1ZyQ7VdL0Oe0DMMyfrB7jgoqqTcex3yVN/view?usp=sharing) for evaluating the models on GBM and OU data with $\lambda=20$ and training the models on mixture of OU data. 88 | -------------------------------------------------------------------------------- /bm_sequential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import pickle 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils import data 13 | 14 | TRAIN_SPLIT_PERCENTAGE = 0.7 15 | VAL_SPLIT_PERCENTAGE = 0.8 16 | 17 | 18 | def get_test_dataset(args, test_batch_size): 19 | """ 20 | Function for getting the dataset for testing 21 | 22 | Parameters: 23 | args: the arguments from parse_arguments in ctfp_tools 24 | test_batch_size: batch size used for data 25 | 26 | Returns: 27 | test_loader: the dataloader for testing 28 | """ 29 | test_set = BMSequence(data_path=args.data_path, split=args.test_split) 30 | test_loader = torch.utils.data.DataLoader( 31 | dataset=test_set, 32 | batch_size=test_batch_size, 33 | shuffle=False, 34 | drop_last=False, 35 | num_workers=args.num_workers, 36 | ) 37 | return test_loader 38 | 39 | 40 | def get_dataset(args): 41 | """ 42 | Function for getting the dataset for training and validation 43 | 44 | Parameters: 45 | args: the arguments from parse_arguments in ctfp_tools 46 | return the dataloader for training and validation 47 | 48 | Returns: 49 | train_loader: data loader of training data 50 | val_loader: data loader of validation data 51 | """ 52 | train_set = BMSequence(data_path=args.data_path, split="train") 53 | val_set = BMSequence(data_path=args.data_path, split="val") 54 | 55 | train_loader = torch.utils.data.DataLoader( 56 | dataset=train_set, 57 | batch_size=args.batch_size, 58 | shuffle=True, 59 | drop_last=False, 60 | num_workers=args.num_workers, 61 | ) 62 | val_loader = torch.utils.data.DataLoader( 63 | dataset=val_set, 64 | batch_size=args.test_batch_size, 65 | shuffle=False, 66 | drop_last=True, 67 | ) 68 | return train_loader, val_loader 69 | 70 | 71 | class BMSequence(data.dataset.Dataset): 72 | """ 73 | Dataset class for observations on irregular time grids of synthetic continuous 74 | time stochastic processes 75 | data_path: path to a pickle file storing the data 76 | split: split of the data, train, val, or test 77 | """ 78 | 79 | def __init__(self, data_path, split="train"): 80 | super(BMSequence, self).__init__() 81 | f = open(data_path, "rb") 82 | self.data = pickle.load(f) 83 | f.close() 84 | self.max_length = 0 85 | for item in self.data: 86 | self.max_length = max(len(item), self.max_length) 87 | total_length = len(self.data) 88 | train_split = int(total_length * TRAIN_SPLIT_PERCENTAGE) 89 | val_split = int(total_length * VAL_SPLIT_PERCENTAGE) 90 | if split == "train": 91 | self.data = self.data[:train_split] 92 | elif split == "val": 93 | self.data = self.data[train_split:val_split] 94 | elif split == "test": 95 | self.data = self.data[val_split:] 96 | 97 | def __len__(self): 98 | return len(self.data) 99 | 100 | def __getitem__(self, index): 101 | item = np.array(self.data[index]) 102 | item_len = item.shape[0] 103 | item_times = item[:, 0] 104 | item_times_shift = np.zeros_like(item_times) 105 | item_times_shift[1:] = item_times[:-1] 106 | item_values = item[:, 1] 107 | padded_times = torch.zeros(self.max_length) 108 | ## Pad all the sequences to the max length with value of 100 109 | ## Any value greater than zero can be used 110 | padded_values = torch.zeros(self.max_length) + 100 111 | masks = torch.ByteTensor(self.max_length).zero_() 112 | padded_times[:item_len] = torch.Tensor(item_times).type(torch.FloatTensor) 113 | padded_values[:item_len] = torch.Tensor(item_values).type(torch.FloatTensor) 114 | masks[:item_len] = 1 115 | padded_variance = torch.ones(self.max_length) 116 | padded_variance[:item_len] = torch.Tensor(item_times - item_times_shift).type( 117 | torch.FloatTensor 118 | ) 119 | return ( 120 | padded_values.unsqueeze(1), 121 | padded_times.unsqueeze(1), 122 | padded_variance, 123 | masks, 124 | ) 125 | -------------------------------------------------------------------------------- /data/gbm_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/data/gbm_2.pkl -------------------------------------------------------------------------------- /data/ou_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/data/ou_2.pkl -------------------------------------------------------------------------------- /eval_ctfp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import time 10 | 11 | import lib.utils as utils 12 | import numpy as np 13 | import torch 14 | 15 | from bm_sequential import get_test_dataset as get_dataset 16 | from ctfp_tools import build_augmented_model_tabular 17 | from ctfp_tools import parse_arguments 18 | from ctfp_tools import run_ctfp_model as run_model 19 | from train_misc import ( 20 | create_regularization_fns, 21 | ) 22 | from train_misc import set_cnf_options 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | if __name__ == "__main__": 27 | args = parse_arguments() 28 | logger = utils.get_logger( 29 | logpath=os.path.join(args.save, "logs_test"), filepath=os.path.abspath(__file__) 30 | ) 31 | 32 | if args.layer_type == "blend": 33 | logger.info( 34 | "!! Setting time_length from None to 1.0 due to use of Blend layers." 35 | ) 36 | args.time_length = 1.0 37 | 38 | logger.info(args) 39 | # get deivce 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | if args.use_cpu: 42 | device = torch.device("cpu") 43 | cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) 44 | 45 | # load dataset 46 | test_loader = get_dataset(args, args.test_batch_size) 47 | 48 | # build model 49 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 50 | aug_model = build_augmented_model_tabular( 51 | args, 52 | args.aug_size + args.effective_shape, 53 | regularization_fns=regularization_fns, 54 | ) 55 | set_cnf_options(args, aug_model) 56 | logger.info(aug_model) 57 | 58 | # restore parameters 59 | itr = 0 60 | if args.resume is not None: 61 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 62 | aug_model.load_state_dict(checkpt["state_dict"]) 63 | 64 | if torch.cuda.is_available() and not args.use_cpu: 65 | aug_model = torch.nn.DataParallel(aug_model).cuda() 66 | 67 | best_loss = float("inf") 68 | aug_model.eval() 69 | with torch.no_grad(): 70 | logger.info("Testing...") 71 | losses = [] 72 | num_observes = [] 73 | for _, x in enumerate(test_loader): 74 | ## x is a tuple of (values, times, stdv, masks) 75 | start = time.time() 76 | # cast data and move to device 77 | x = map(cvt, x) 78 | values, times, vars, masks = x 79 | loss = run_model(args, aug_model, values, times, vars, masks) 80 | losses.append(loss.data.cpu().numpy()) 81 | num_observes.append(torch.sum(masks).data.cpu().numpy()) 82 | loss = np.sum(np.array(losses) * np.array(num_observes)) / np.sum(num_observes) 83 | logger.info("Bit/dim {:.4f}".format(loss)) 84 | -------------------------------------------------------------------------------- /eval_latent_ctfp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | 10 | import lib.utils as utils 11 | import numpy as np 12 | import torch 13 | import torch.optim as optim 14 | 15 | from bm_sequential import get_test_dataset as get_dataset 16 | from ctfp_tools import build_augmented_model_tabular 17 | from ctfp_tools import parse_arguments 18 | from ctfp_tools import run_latent_ctfp_model as run_model 19 | from ode_rnn_encoder import create_ode_rnn_encoder 20 | from train_misc import ( 21 | create_regularization_fns, 22 | ) 23 | from train_misc import set_cnf_options, count_parameters 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | if __name__ == "__main__": 28 | args = parse_arguments() 29 | # logger 30 | logger = utils.get_logger( 31 | logpath=os.path.join(args.save, "logs_test"), filepath=os.path.abspath(__file__) 32 | ) 33 | 34 | if args.layer_type == "blend": 35 | logger.info( 36 | "!! Setting time_length from None to 1.0 due to use of Blend layers." 37 | ) 38 | args.time_length = 1.0 39 | 40 | logger.info(args) 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | if args.use_cpu: 43 | device = torch.device("cpu") 44 | cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) 45 | 46 | # load dataset 47 | test_loader = get_dataset(args, args.test_batch_size) 48 | 49 | # build model 50 | # Build the encoder 51 | if args.encoder == "ode_rnn": 52 | encoder = create_ode_rnn_encoder(args, device) 53 | else: 54 | raise NotImplementedError 55 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 56 | 57 | aug_model = build_augmented_model_tabular( 58 | args, 59 | args.aug_size + args.effective_shape + args.latent_size, 60 | regularization_fns=regularization_fns, 61 | ) 62 | 63 | set_cnf_options(args, aug_model) 64 | logger.info(aug_model) 65 | logger.info( 66 | "Number of trainable parameters: {}".format(count_parameters(aug_model)) 67 | ) 68 | 69 | # optimizer 70 | optimizer = optim.Adam( 71 | list(aug_model.parameters()) + list(encoder.parameters()), 72 | lr=args.lr, 73 | weight_decay=args.weight_decay, 74 | ) 75 | num_params = sum(p.numel() for p in aug_model.parameters() if p.requires_grad) 76 | 77 | if args.aggressive: 78 | encoder_optimizer = optim.Adam( 79 | encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay 80 | ) 81 | enc_num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) 82 | print( 83 | "Total Num of Parameters: %d, Encoder Num of Parameters: %d" 84 | % (num_params + enc_num_params, enc_num_params) 85 | ) 86 | 87 | # restore parameters 88 | itr = 0 89 | if args.resume is not None: 90 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 91 | aug_model.load_state_dict(checkpt["state_dict"]) 92 | encoder.load_state_dict(checkpt["encoder_state_dict"]) 93 | 94 | if torch.cuda.is_available() and not args.use_cpu: 95 | aug_model = torch.nn.DataParallel(aug_model).cuda() 96 | encoder = torch.nn.DataParallel(encoder).cuda() 97 | 98 | aug_model.eval() 99 | encoder.eval() 100 | with torch.no_grad(): 101 | logger.info("validating...") 102 | losses = [] 103 | num_observes = [] 104 | for _, x in enumerate(test_loader): 105 | ## x is a tuple of (values, times, stdv, masks) 106 | x = map(cvt, x) 107 | values, times, vars, masks = x 108 | loss = run_model( 109 | args, encoder, aug_model, values, times, vars, masks, evaluation=True 110 | ) 111 | losses.append(loss.data.cpu().numpy()) 112 | num_observes.append(torch.sum(masks).data.cpu().numpy()) 113 | loss = np.sum(np.array(losses) * np.array(num_observes)) / np.sum(num_observes) 114 | logger.info("Bit/dim {:.4f}".format(loss)) 115 | -------------------------------------------------------------------------------- /experiments/ctfp_gbm/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/ctfp_gbm/pretrained.pth -------------------------------------------------------------------------------- /experiments/ctfp_mix/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/ctfp_mix/pretrained.pth -------------------------------------------------------------------------------- /experiments/ctfp_ou/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/ctfp_ou/pretrained.pth -------------------------------------------------------------------------------- /experiments/latent_ctfp_gbm/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/latent_ctfp_gbm/pretrained.pth -------------------------------------------------------------------------------- /experiments/latent_ctfp_mix/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/latent_ctfp_mix/pretrained.pth -------------------------------------------------------------------------------- /experiments/latent_ctfp_ou/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/experiments/latent_ctfp_ou/pretrained.pth -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/figure.png -------------------------------------------------------------------------------- /lib/diffeq_solver.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Yulia Rubanova 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | ########################### 24 | # Latent ODEs for Irregularly-Sampled Time Series 25 | # Author: Yulia Rubanova 26 | # Link: https://github.com/YuliaRubanova/latent_ode 27 | ########################### 28 | 29 | import torch 30 | import torch.nn as nn 31 | 32 | # git clone https://github.com/rtqichen/torchdiffeq.git 33 | from torchdiffeq import odeint as odeint 34 | 35 | 36 | ##################################################################################################### 37 | 38 | 39 | class DiffeqSolver(nn.Module): 40 | def __init__( 41 | self, 42 | input_dim, 43 | ode_func, 44 | method, 45 | latents, 46 | odeint_rtol=1e-4, 47 | odeint_atol=1e-5, 48 | device=torch.device("cpu"), 49 | ): 50 | super(DiffeqSolver, self).__init__() 51 | 52 | self.ode_method = method 53 | self.latents = latents 54 | self.device = device 55 | self.ode_func = ode_func 56 | 57 | self.odeint_rtol = odeint_rtol 58 | self.odeint_atol = odeint_atol 59 | 60 | def forward(self, first_point, time_steps_to_predict, backwards=False): 61 | """ 62 | # Decode the trajectory through ODE Solver 63 | """ 64 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 65 | n_dims = first_point.size()[-1] 66 | 67 | pred_y = odeint( 68 | self.ode_func, 69 | first_point, 70 | time_steps_to_predict, 71 | rtol=self.odeint_rtol, 72 | atol=self.odeint_atol, 73 | method=self.ode_method, 74 | ) 75 | pred_y = pred_y.permute(1, 2, 0, 3) 76 | 77 | assert torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001 78 | assert pred_y.size()[0] == n_traj_samples 79 | assert pred_y.size()[1] == n_traj 80 | 81 | return pred_y 82 | 83 | def sample_traj_from_prior( 84 | self, starting_point_enc, time_steps_to_predict, n_traj_samples=1 85 | ): 86 | """ 87 | # Decode the trajectory through ODE Solver using samples from the prior 88 | 89 | time_steps_to_predict: time steps at which we want to sample the new trajectory 90 | """ 91 | func = self.ode_func.sample_next_point_from_prior 92 | 93 | pred_y = odeint( 94 | func, 95 | starting_point_enc, 96 | time_steps_to_predict, 97 | rtol=self.odeint_rtol, 98 | atol=self.odeint_atol, 99 | method=self.ode_method, 100 | ) 101 | # shape: [n_traj_samples, n_traj, n_tp, n_dim] 102 | pred_y = pred_y.permute(1, 2, 0, 3) 103 | return pred_y 104 | -------------------------------------------------------------------------------- /lib/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Yulia Rubanova 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | ########################### 24 | # Latent ODEs for Irregularly-Sampled Time Series 25 | # Author: Yulia Rubanova 26 | # Link: https://github.com/YuliaRubanova/latent_ode 27 | ########################### 28 | 29 | import torch 30 | import torch.nn as nn 31 | import lib.utils as utils 32 | from torch.nn.modules.rnn import GRU 33 | 34 | 35 | def get_device(tensor): 36 | device = torch.device("cpu") 37 | if tensor.is_cuda: 38 | device = tensor.get_device() 39 | return device 40 | 41 | 42 | # GRU description: 43 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/ 44 | class GRU_unit(nn.Module): 45 | def __init__( 46 | self, 47 | latent_dim, 48 | input_dim, 49 | update_gate=None, 50 | reset_gate=None, 51 | new_state_net=None, 52 | n_units=100, 53 | device=torch.device("cpu"), 54 | ): 55 | super(GRU_unit, self).__init__() 56 | 57 | if update_gate is None: 58 | self.update_gate = nn.Sequential( 59 | nn.Linear(latent_dim * 2 + input_dim, n_units), 60 | nn.Tanh(), 61 | nn.Linear(n_units, latent_dim), 62 | nn.Sigmoid(), 63 | ) 64 | utils.init_network_weights(self.update_gate) 65 | else: 66 | self.update_gate = update_gate 67 | 68 | if reset_gate is None: 69 | self.reset_gate = nn.Sequential( 70 | nn.Linear(latent_dim * 2 + input_dim, n_units), 71 | nn.Tanh(), 72 | nn.Linear(n_units, latent_dim), 73 | nn.Sigmoid(), 74 | ) 75 | utils.init_network_weights(self.reset_gate) 76 | else: 77 | self.reset_gate = reset_gate 78 | 79 | if new_state_net is None: 80 | self.new_state_net = nn.Sequential( 81 | nn.Linear(latent_dim * 2 + input_dim, n_units), 82 | nn.Tanh(), 83 | nn.Linear(n_units, latent_dim * 2), 84 | ) 85 | utils.init_network_weights(self.new_state_net) 86 | else: 87 | self.new_state_net = new_state_net 88 | 89 | def forward(self, y_mean, y_std, x, masked_update=True): 90 | y_concat = torch.cat([y_mean, y_std, x], -1) 91 | 92 | update_gate = self.update_gate(y_concat) 93 | reset_gate = self.reset_gate(y_concat) 94 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) 95 | 96 | new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat)) 97 | new_state_std = new_state_std.abs() 98 | 99 | new_y = (1 - update_gate) * new_state + update_gate * y_mean 100 | new_y_std = (1 - update_gate) * new_state_std + update_gate * y_std 101 | 102 | assert not torch.isnan(new_y).any() 103 | 104 | if masked_update: 105 | # IMPORTANT: assumes that x contains both data and mask 106 | # update only the hidden states for hidden state only if at least one feature is present for the current time point 107 | n_data_dims = x.size(-1) // 2 108 | mask = x[:, :, n_data_dims:] 109 | utils.check_mask(x[:, :, :n_data_dims], mask) 110 | 111 | mask = (torch.sum(mask, -1, keepdim=True) > 0).float() 112 | 113 | assert not torch.isnan(mask).any() 114 | 115 | new_y = mask * new_y + (1 - mask) * y_mean 116 | new_y_std = mask * new_y_std + (1 - mask) * y_std 117 | 118 | if torch.isnan(new_y).any(): 119 | print("new_y is nan!") 120 | print(mask) 121 | print(y_mean) 122 | print(prev_new_y) 123 | exit() 124 | 125 | new_y_std = new_y_std.abs() 126 | return new_y, new_y_std 127 | 128 | 129 | class Encoder_z0_RNN(nn.Module): 130 | def __init__( 131 | self, 132 | latent_dim, 133 | input_dim, 134 | lstm_output_size=20, 135 | use_delta_t=True, 136 | device=torch.device("cpu"), 137 | ): 138 | 139 | super(Encoder_z0_RNN, self).__init__() 140 | 141 | self.gru_rnn_output_size = lstm_output_size 142 | self.latent_dim = latent_dim 143 | self.input_dim = input_dim 144 | self.device = device 145 | self.use_delta_t = use_delta_t 146 | 147 | self.hiddens_to_z0 = nn.Sequential( 148 | nn.Linear(self.gru_rnn_output_size, 50), 149 | nn.Tanh(), 150 | nn.Linear(50, latent_dim * 2), 151 | ) 152 | 153 | utils.init_network_weights(self.hiddens_to_z0) 154 | 155 | input_dim = self.input_dim 156 | 157 | if use_delta_t: 158 | self.input_dim += 1 159 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device) 160 | 161 | def forward(self, data, time_steps, run_backwards=True): 162 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 163 | 164 | # data shape: [n_traj, n_tp, n_dims] 165 | # shape required for rnn: (seq_len, batch, input_size) 166 | # t0: not used here 167 | n_traj = data.size(0) 168 | 169 | assert not torch.isnan(data).any() 170 | assert not torch.isnan(time_steps).any() 171 | 172 | data = data.permute(1, 0, 2) 173 | 174 | if run_backwards: 175 | # Look at data in the reverse order: from later points to the first 176 | data = utils.reverse(data) 177 | 178 | if self.use_delta_t: 179 | delta_t = time_steps[1:] - time_steps[:-1] 180 | if run_backwards: 181 | # we are going backwards in time with 182 | delta_t = utils.reverse(delta_t) 183 | # append zero delta t in the end 184 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) 185 | delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1) 186 | data = torch.cat((delta_t, data), -1) 187 | 188 | outputs, _ = self.gru_rnn(data) 189 | 190 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size) 191 | last_output = outputs[-1] 192 | 193 | self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps} 194 | 195 | mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) 196 | std = std.abs() 197 | 198 | assert not torch.isnan(mean).any() 199 | assert not torch.isnan(std).any() 200 | 201 | return mean.unsqueeze(0), std.unsqueeze(0) 202 | 203 | 204 | class Encoder_z0_ODE_RNN(nn.Module): 205 | # Derive z0 by running ode backwards. 206 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i 207 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1 208 | # Continue until we get to z0 209 | def __init__( 210 | self, 211 | latent_dim, 212 | input_dim, 213 | z0_diffeq_solver=None, 214 | z0_dim=None, 215 | GRU_update=None, 216 | n_gru_units=100, 217 | device=torch.device("cpu"), 218 | ): 219 | 220 | super(Encoder_z0_ODE_RNN, self).__init__() 221 | 222 | if z0_dim is None: 223 | self.z0_dim = latent_dim 224 | else: 225 | self.z0_dim = z0_dim 226 | 227 | if GRU_update is None: 228 | self.GRU_update = GRU_unit( 229 | latent_dim, input_dim, n_units=n_gru_units, device=device 230 | ).to(device) 231 | else: 232 | self.GRU_update = GRU_update 233 | 234 | self.z0_diffeq_solver = z0_diffeq_solver 235 | self.latent_dim = latent_dim 236 | self.input_dim = input_dim 237 | self.device = device 238 | self.extra_info = None 239 | 240 | self.transform_z0 = nn.Sequential( 241 | nn.Linear(latent_dim * 2, 100), 242 | nn.Tanh(), 243 | nn.Linear(100, self.z0_dim * 2), 244 | ) 245 | utils.init_network_weights(self.transform_z0) 246 | 247 | def forward(self, data, time_steps, run_backwards=True, save_info=False): 248 | # data, time_steps -- observations and their time stamps 249 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 250 | assert not torch.isnan(data).any() 251 | assert not torch.isnan(time_steps).any() 252 | 253 | n_traj, n_tp, n_dims = data.size() 254 | if len(time_steps) == 1: 255 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 256 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 257 | 258 | xi = data[:, 0, :].unsqueeze(0) 259 | 260 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi) 261 | extra_info = None 262 | else: 263 | 264 | last_yi, last_yi_std, _, extra_info = self.run_odernn( 265 | data, time_steps, run_backwards=run_backwards, save_info=save_info 266 | ) 267 | 268 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim) 269 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim) 270 | 271 | mean_z0, std_z0 = utils.split_last_dim( 272 | self.transform_z0(torch.cat((means_z0, std_z0), -1)) 273 | ) 274 | std_z0 = std_z0.abs() 275 | if save_info: 276 | self.extra_info = extra_info 277 | 278 | return mean_z0, std_z0 279 | 280 | def run_odernn(self, data, time_steps, run_backwards=True, save_info=False): 281 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 282 | n_traj, n_tp, n_dims = data.size() 283 | extra_info = [] 284 | 285 | t0 = time_steps[-1] 286 | if run_backwards: 287 | t0 = time_steps[0] 288 | 289 | device = get_device(data) 290 | 291 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) 292 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) 293 | 294 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] 295 | 296 | interval_length = time_steps[-1] - time_steps[0] 297 | minimum_step = interval_length / 50 298 | 299 | # print("minimum step: {}".format(minimum_step)) 300 | 301 | assert not torch.isnan(data).any() 302 | assert not torch.isnan(time_steps).any() 303 | 304 | latent_ys = [] 305 | # Run ODE backwards and combine the y(t) estimates using gating 306 | time_points_iter = range(0, len(time_steps)) 307 | if run_backwards: 308 | time_points_iter = reversed(time_points_iter) 309 | 310 | for i in time_points_iter: 311 | if (prev_t - t_i) < minimum_step: 312 | time_points = torch.stack((prev_t, t_i)) 313 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) 314 | 315 | assert not torch.isnan(inc).any() 316 | 317 | ode_sol = prev_y + inc 318 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) 319 | 320 | assert not torch.isnan(ode_sol).any() 321 | else: 322 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) 323 | 324 | time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) 325 | ode_sol = self.z0_diffeq_solver(prev_y, time_points) 326 | 327 | assert not torch.isnan(ode_sol).any() 328 | 329 | if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: 330 | print("Error: first point of the ODE is not equal to initial value") 331 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) 332 | exit() 333 | # assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) 334 | 335 | yi_ode = ode_sol[:, :, -1, :] 336 | xi = data[:, i, :].unsqueeze(0) 337 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) 338 | 339 | prev_y, prev_std = yi, yi_std 340 | prev_t, t_i = time_steps[i], time_steps[i - 1] 341 | 342 | latent_ys.append(yi) 343 | 344 | if save_info: 345 | d = { 346 | "yi_ode": yi_ode.detach(), # "yi_from_data": yi_from_data, 347 | "yi": yi.detach(), 348 | "yi_std": yi_std.detach(), 349 | "time_points": time_points.detach(), 350 | "ode_sol": ode_sol.detach(), 351 | } 352 | extra_info.append(d) 353 | 354 | latent_ys = torch.stack(latent_ys, 1) 355 | 356 | assert not torch.isnan(yi).any() 357 | assert not torch.isnan(yi_std).any() 358 | 359 | return yi, yi_std, latent_ys, extra_info 360 | -------------------------------------------------------------------------------- /lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | from . import diffeq_layers 25 | from .cnf import * 26 | from .container import * 27 | from .coupling import * 28 | from .elemwise import * 29 | from .glow import * 30 | from .norm_flows import * 31 | from .normalization import * 32 | from .odefunc import * 33 | from .odefunc_aug import * 34 | from .squeeze import * 35 | -------------------------------------------------------------------------------- /lib/layers/cnf.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torchdiffeq import odeint_adjoint as odeint 27 | 28 | from .wrappers.cnf_regularization import RegularizedODEfunc 29 | 30 | __all__ = ["CNF"] 31 | 32 | 33 | class CNF(nn.Module): 34 | def __init__( 35 | self, 36 | odefunc, 37 | T=1.0, 38 | train_T=False, 39 | regularization_fns=None, 40 | solver="dopri5", 41 | atol=1e-5, 42 | rtol=1e-5, 43 | ): 44 | super(CNF, self).__init__() 45 | if train_T: 46 | self.register_parameter( 47 | "sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))) 48 | ) 49 | else: 50 | self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T))) 51 | 52 | nreg = 0 53 | if regularization_fns is not None: 54 | odefunc = RegularizedODEfunc(odefunc, regularization_fns) 55 | nreg = len(regularization_fns) 56 | self.odefunc = odefunc 57 | self.nreg = nreg 58 | self.regularization_states = None 59 | self.solver = solver 60 | self.atol = atol 61 | self.rtol = rtol 62 | self.test_solver = solver 63 | self.test_atol = atol 64 | self.test_rtol = rtol 65 | self.solver_options = {} 66 | 67 | def forward(self, z, logpz=None, integration_times=None, reverse=False): 68 | 69 | if logpz is None: 70 | _logpz = torch.zeros(z.shape[0], 1).to(z) 71 | else: 72 | _logpz = logpz 73 | if integration_times is None: 74 | integration_times = torch.tensor( 75 | [0.0, self.sqrt_end_time * self.sqrt_end_time] 76 | ).to(z) 77 | if reverse: 78 | integration_times = _flip(integration_times, 0) 79 | 80 | # Refresh the odefunc statistics. 81 | self.odefunc.before_odeint() 82 | 83 | # Add regularization states. 84 | reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg)) 85 | 86 | if self.training: 87 | state_t = odeint( 88 | self.odefunc, 89 | (z, _logpz) + reg_states, 90 | integration_times.to(z), 91 | atol=[self.atol, self.atol] + [1e20] * len(reg_states) 92 | if self.solver == "dopri5" 93 | else self.atol, 94 | rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) 95 | if self.solver == "dopri5" 96 | else self.rtol, 97 | method=self.solver, 98 | options=self.solver_options, 99 | ) 100 | else: 101 | state_t = odeint( 102 | self.odefunc, 103 | (z, _logpz), 104 | integration_times.to(z), 105 | atol=self.test_atol, 106 | rtol=self.test_rtol, 107 | method=self.test_solver, 108 | ) 109 | 110 | if len(integration_times) == 2: 111 | state_t = tuple(s[1] for s in state_t) 112 | 113 | z_t, logpz_t = state_t[:2] 114 | self.regularization_states = state_t[2:] 115 | 116 | if logpz is not None: 117 | return z_t, logpz_t 118 | else: 119 | return z_t 120 | 121 | def get_regularization_states(self): 122 | reg_states = self.regularization_states 123 | self.regularization_states = None 124 | return reg_states 125 | 126 | def num_evals(self): 127 | return self.odefunc._num_evals.item() 128 | 129 | 130 | def _flip(x, dim): 131 | indices = [slice(None)] * x.dim() 132 | indices[dim] = torch.arange( 133 | x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device 134 | ) 135 | return x[tuple(indices)] 136 | -------------------------------------------------------------------------------- /lib/layers/container.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch.nn as nn 25 | 26 | 27 | class SequentialFlow(nn.Module): 28 | """A generalized nn.Sequential container for normalizing flows.""" 29 | 30 | def __init__(self, layersList): 31 | super(SequentialFlow, self).__init__() 32 | self.chain = nn.ModuleList(layersList) 33 | 34 | def forward(self, x, logpx=None, reverse=False, inds=None): 35 | if inds is None: 36 | if reverse: 37 | inds = range(len(self.chain) - 1, -1, -1) 38 | else: 39 | inds = range(len(self.chain)) 40 | 41 | if logpx is None: 42 | for i in inds: 43 | x = self.chain[i](x, reverse=reverse) 44 | return x 45 | else: 46 | for i in inds: 47 | x, logpx = self.chain[i](x, logpx, reverse=reverse) 48 | return x, logpx 49 | -------------------------------------------------------------------------------- /lib/layers/coupling.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | __all__ = ["CouplingLayer", "MaskedCouplingLayer"] 28 | 29 | 30 | class CouplingLayer(nn.Module): 31 | """Used in 2D experiments.""" 32 | 33 | def __init__(self, d, intermediate_dim=64, swap=False): 34 | nn.Module.__init__(self) 35 | self.d = d - (d // 2) 36 | self.swap = swap 37 | self.net_s_t = nn.Sequential( 38 | nn.Linear(self.d, intermediate_dim), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(intermediate_dim, intermediate_dim), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(intermediate_dim, (d - self.d) * 2), 43 | ) 44 | 45 | def forward(self, x, logpx=None, reverse=False): 46 | 47 | if self.swap: 48 | x = torch.cat([x[:, self.d:], x[:, : self.d]], 1) 49 | 50 | in_dim = self.d 51 | out_dim = x.shape[1] - self.d 52 | 53 | s_t = self.net_s_t(x[:, :in_dim]) 54 | scale = torch.sigmoid(s_t[:, :out_dim] + 2.0) 55 | shift = s_t[:, out_dim:] 56 | 57 | logdetjac = torch.sum( 58 | torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True 59 | ) 60 | 61 | if not reverse: 62 | y1 = x[:, self.d:] * scale + shift 63 | delta_logp = -logdetjac 64 | else: 65 | y1 = (x[:, self.d:] - shift) / scale 66 | delta_logp = logdetjac 67 | 68 | y = ( 69 | torch.cat([x[:, : self.d], y1], 1) 70 | if not self.swap 71 | else torch.cat([y1, x[:, : self.d]], 1) 72 | ) 73 | 74 | if logpx is None: 75 | return y 76 | else: 77 | return y, logpx + delta_logp 78 | 79 | 80 | class MaskedCouplingLayer(nn.Module): 81 | """Used in the tabular experiments.""" 82 | 83 | def __init__(self, d, hidden_dims, mask_type="alternate", swap=False): 84 | nn.Module.__init__(self) 85 | self.d = d 86 | self.register_buffer("mask", sample_mask(d, mask_type, swap).view(1, d)) 87 | self.net_scale = build_net(d, hidden_dims, activation="tanh") 88 | self.net_shift = build_net(d, hidden_dims, activation="relu") 89 | 90 | def forward(self, x, logpx=None, reverse=False): 91 | 92 | scale = torch.exp(self.net_scale(x * self.mask)) 93 | shift = self.net_shift(x * self.mask) 94 | 95 | masked_scale = scale * (1 - self.mask) + torch.ones_like(scale) * self.mask 96 | masked_shift = shift * (1 - self.mask) 97 | 98 | logdetjac = torch.sum( 99 | torch.log(masked_scale).view(scale.shape[0], -1), 1, keepdim=True 100 | ) 101 | 102 | if not reverse: 103 | y = x * masked_scale + masked_shift 104 | delta_logp = -logdetjac 105 | else: 106 | y = (x - masked_shift) / masked_scale 107 | delta_logp = logdetjac 108 | 109 | if logpx is None: 110 | return y 111 | else: 112 | return y, logpx + delta_logp 113 | 114 | 115 | def sample_mask(dim, mask_type, swap): 116 | if mask_type == "alternate": 117 | # Index-based masking in MAF paper. 118 | mask = torch.zeros(dim) 119 | mask[::2] = 1 120 | if swap: 121 | mask = 1 - mask 122 | return mask 123 | elif mask_type == "channel": 124 | # Masking type used in Real NVP paper. 125 | mask = torch.zeros(dim) 126 | mask[: dim // 2] = 1 127 | if swap: 128 | mask = 1 - mask 129 | return mask 130 | else: 131 | raise ValueError("Unknown mask_type {}".format(mask_type)) 132 | 133 | 134 | def build_net(input_dim, hidden_dims, activation="relu"): 135 | dims = (input_dim,) + tuple(hidden_dims) + (input_dim,) 136 | activation_modules = {"relu": nn.ReLU(inplace=True), "tanh": nn.Tanh()} 137 | 138 | chain = [] 139 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 140 | chain.append(nn.Linear(in_dim, out_dim)) 141 | if i < len(hidden_dims): 142 | chain.append(activation_modules[activation]) 143 | return nn.Sequential(*chain) 144 | -------------------------------------------------------------------------------- /lib/layers/diffeq_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | from .basic import * 25 | from .container import * 26 | from .resnet import * 27 | from .wrappers import * 28 | -------------------------------------------------------------------------------- /lib/layers/diffeq_layers/basic.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | 29 | def weights_init(m): 30 | classname = m.__class__.__name__ 31 | if classname.find("Linear") != -1 or classname.find("Conv") != -1: 32 | nn.init.constant_(m.weight, 0) 33 | nn.init.normal_(m.bias, 0, 0.01) 34 | 35 | 36 | class HyperLinear(nn.Module): 37 | def __init__(self, dim_in, dim_out, hypernet_dim=8, n_hidden=1, activation=nn.Tanh): 38 | super(HyperLinear, self).__init__() 39 | self.dim_in = dim_in 40 | self.dim_out = dim_out 41 | self.params_dim = self.dim_in * self.dim_out + self.dim_out 42 | 43 | layers = [] 44 | dims = [1] + [hypernet_dim] * n_hidden + [self.params_dim] 45 | for i in range(1, len(dims)): 46 | layers.append(nn.Linear(dims[i - 1], dims[i])) 47 | if i < len(dims) - 1: 48 | layers.append(activation()) 49 | self._hypernet = nn.Sequential(*layers) 50 | self._hypernet.apply(weights_init) 51 | 52 | def forward(self, t, x): 53 | params = self._hypernet(t.view(1, 1)).view(-1) 54 | b = params[: self.dim_out].view(self.dim_out) 55 | w = params[self.dim_out:].view(self.dim_out, self.dim_in) 56 | return F.linear(x, w, b) 57 | 58 | 59 | class InverseHyperLinear(nn.Module): 60 | def __init__( 61 | self, 62 | dim_in, 63 | dim_out, 64 | hypernet_dim=8, 65 | n_hidden=1, 66 | activation=nn.Tanh, 67 | small_number=1e-6, 68 | ): 69 | super(HyperLinear, self).__init__() 70 | self.dim_in = dim_in 71 | self.dim_out = dim_out 72 | self.params_dim = self.dim_in * self.dim_out + self.dim_out 73 | self.small_number = small_number 74 | layers = [] 75 | dims = [1] + [hypernet_dim] * n_hidden + [self.params_dim] 76 | for i in range(1, len(dims)): 77 | layers.append(nn.Linear(dims[i - 1], dims[i])) 78 | if i < len(dims) - 1: 79 | layers.append(activation()) 80 | self._hypernet = nn.Sequential(*layers) 81 | self._hypernet.apply(weights_init) 82 | 83 | def forward(self, t, x): 84 | t[t < self.small_number] = self.small_number 85 | params = self._hypernet(1.0 / t.view(1, 1)).view(-1) 86 | b = params[: self.dim_out].view(self.dim_out) 87 | w = params[self.dim_out:].view(self.dim_out, self.dim_in) 88 | return F.linear(x, w, b) 89 | 90 | 91 | class IgnoreLinear(nn.Module): 92 | def __init__(self, dim_in, dim_out): 93 | super(IgnoreLinear, self).__init__() 94 | self._layer = nn.Linear(dim_in, dim_out) 95 | 96 | def forward(self, t, x): 97 | return self._layer(x) 98 | 99 | 100 | class ConcatLinear(nn.Module): 101 | def __init__(self, dim_in, dim_out): 102 | super(ConcatLinear, self).__init__() 103 | self._layer = nn.Linear(dim_in + 1, dim_out) 104 | 105 | def forward(self, t, x): 106 | tt = torch.ones_like(x[:, :1]) * t 107 | ttx = torch.cat([tt, x], 1) 108 | return self._layer(ttx) 109 | 110 | 111 | class ConcatLinear_v2(nn.Module): 112 | def __init__(self, dim_in, dim_out): 113 | super(ConcatLinear, self).__init__() 114 | self._layer = nn.Linear(dim_in, dim_out) 115 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 116 | 117 | def forward(self, t, x): 118 | return self._layer(x) + self._hyper_bias(t.view(1, 1)) 119 | 120 | 121 | class SquashLinear(nn.Module): 122 | def __init__(self, dim_in, dim_out): 123 | super(SquashLinear, self).__init__() 124 | self._layer = nn.Linear(dim_in, dim_out) 125 | self._hyper = nn.Linear(1, dim_out) 126 | 127 | def forward(self, t, x): 128 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1))) 129 | 130 | 131 | class ConcatSquashLinear(nn.Module): 132 | def __init__(self, dim_in, dim_out): 133 | super(ConcatSquashLinear, self).__init__() 134 | self._layer = nn.Linear(dim_in, dim_out) 135 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 136 | self._hyper_gate = nn.Linear(1, dim_out) 137 | 138 | def forward(self, t, x): 139 | return self._layer(x) * torch.sigmoid( 140 | self._hyper_gate(t.view(1, 1)) 141 | ) + self._hyper_bias(t.view(1, 1)) 142 | 143 | 144 | class HyperConv2d(nn.Module): 145 | def __init__( 146 | self, 147 | dim_in, 148 | dim_out, 149 | ksize=3, 150 | stride=1, 151 | padding=0, 152 | dilation=1, 153 | groups=1, 154 | bias=True, 155 | transpose=False, 156 | ): 157 | super(HyperConv2d, self).__init__() 158 | assert ( 159 | dim_in % groups == 0 and dim_out % groups == 0 160 | ), "dim_in and dim_out must both be divisible by groups." 161 | self.dim_in = dim_in 162 | self.dim_out = dim_out 163 | self.ksize = ksize 164 | self.stride = stride 165 | self.padding = padding 166 | self.dilation = dilation 167 | self.groups = groups 168 | self.bias = bias 169 | self.transpose = transpose 170 | 171 | self.params_dim = int(dim_in * dim_out * ksize * ksize / groups) 172 | if self.bias: 173 | self.params_dim += dim_out 174 | self._hypernet = nn.Linear(1, self.params_dim) 175 | self.conv_fn = F.conv_transpose2d if transpose else F.conv2d 176 | 177 | self._hypernet.apply(weights_init) 178 | 179 | def forward(self, t, x): 180 | params = self._hypernet(t.view(1, 1)).view(-1) 181 | weight_size = int( 182 | self.dim_in * self.dim_out * self.ksize * self.ksize / self.groups 183 | ) 184 | if self.transpose: 185 | weight = params[:weight_size].view( 186 | self.dim_in, self.dim_out // self.groups, self.ksize, self.ksize 187 | ) 188 | else: 189 | weight = params[:weight_size].view( 190 | self.dim_out, self.dim_in // self.groups, self.ksize, self.ksize 191 | ) 192 | bias = params[: self.dim_out].view(self.dim_out) if self.bias else None 193 | return self.conv_fn( 194 | x, 195 | weight=weight, 196 | bias=bias, 197 | stride=self.stride, 198 | padding=self.padding, 199 | groups=self.groups, 200 | dilation=self.dilation, 201 | ) 202 | 203 | 204 | class IgnoreConv2d(nn.Module): 205 | def __init__( 206 | self, 207 | dim_in, 208 | dim_out, 209 | ksize=3, 210 | stride=1, 211 | padding=0, 212 | dilation=1, 213 | groups=1, 214 | bias=True, 215 | transpose=False, 216 | ): 217 | super(IgnoreConv2d, self).__init__() 218 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 219 | self._layer = module( 220 | dim_in, 221 | dim_out, 222 | kernel_size=ksize, 223 | stride=stride, 224 | padding=padding, 225 | dilation=dilation, 226 | groups=groups, 227 | bias=bias, 228 | ) 229 | 230 | def forward(self, t, x): 231 | return self._layer(x) 232 | 233 | 234 | class SquashConv2d(nn.Module): 235 | def __init__( 236 | self, 237 | dim_in, 238 | dim_out, 239 | ksize=3, 240 | stride=1, 241 | padding=0, 242 | dilation=1, 243 | groups=1, 244 | bias=True, 245 | transpose=False, 246 | ): 247 | super(SquashConv2d, self).__init__() 248 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 249 | self._layer = module( 250 | dim_in + 1, 251 | dim_out, 252 | kernel_size=ksize, 253 | stride=stride, 254 | padding=padding, 255 | dilation=dilation, 256 | groups=groups, 257 | bias=bias, 258 | ) 259 | self._hyper = nn.Linear(1, dim_out) 260 | 261 | def forward(self, t, x): 262 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1))).view( 263 | 1, -1, 1, 1 264 | ) 265 | 266 | 267 | class ConcatConv2d(nn.Module): 268 | def __init__( 269 | self, 270 | dim_in, 271 | dim_out, 272 | ksize=3, 273 | stride=1, 274 | padding=0, 275 | dilation=1, 276 | groups=1, 277 | bias=True, 278 | transpose=False, 279 | ): 280 | super(ConcatConv2d, self).__init__() 281 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 282 | self._layer = module( 283 | dim_in + 1, 284 | dim_out, 285 | kernel_size=ksize, 286 | stride=stride, 287 | padding=padding, 288 | dilation=dilation, 289 | groups=groups, 290 | bias=bias, 291 | ) 292 | 293 | def forward(self, t, x): 294 | tt = torch.ones_like(x[:, :1, :, :]) * t 295 | ttx = torch.cat([tt, x], 1) 296 | return self._layer(ttx) 297 | 298 | 299 | class ConcatConv2d_v2(nn.Module): 300 | def __init__( 301 | self, 302 | dim_in, 303 | dim_out, 304 | ksize=3, 305 | stride=1, 306 | padding=0, 307 | dilation=1, 308 | groups=1, 309 | bias=True, 310 | transpose=False, 311 | ): 312 | super(ConcatConv2d, self).__init__() 313 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 314 | self._layer = module( 315 | dim_in, 316 | dim_out, 317 | kernel_size=ksize, 318 | stride=stride, 319 | padding=padding, 320 | dilation=dilation, 321 | groups=groups, 322 | bias=bias, 323 | ) 324 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 325 | 326 | def forward(self, t, x): 327 | return self._layer(x) + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1) 328 | 329 | 330 | class ConcatSquashConv2d(nn.Module): 331 | def __init__( 332 | self, 333 | dim_in, 334 | dim_out, 335 | ksize=3, 336 | stride=1, 337 | padding=0, 338 | dilation=1, 339 | groups=1, 340 | bias=True, 341 | transpose=False, 342 | ): 343 | super(ConcatSquashConv2d, self).__init__() 344 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 345 | self._layer = module( 346 | dim_in, 347 | dim_out, 348 | kernel_size=ksize, 349 | stride=stride, 350 | padding=padding, 351 | dilation=dilation, 352 | groups=groups, 353 | bias=bias, 354 | ) 355 | self._hyper_gate = nn.Linear(1, dim_out) 356 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 357 | 358 | def forward(self, t, x): 359 | return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))).view( 360 | 1, -1, 1, 1 361 | ) + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1) 362 | 363 | 364 | class ConcatCoordConv2d(nn.Module): 365 | def __init__( 366 | self, 367 | dim_in, 368 | dim_out, 369 | ksize=3, 370 | stride=1, 371 | padding=0, 372 | dilation=1, 373 | groups=1, 374 | bias=True, 375 | transpose=False, 376 | ): 377 | super(ConcatCoordConv2d, self).__init__() 378 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 379 | self._layer = module( 380 | dim_in + 3, 381 | dim_out, 382 | kernel_size=ksize, 383 | stride=stride, 384 | padding=padding, 385 | dilation=dilation, 386 | groups=groups, 387 | bias=bias, 388 | ) 389 | 390 | def forward(self, t, x): 391 | b, c, h, w = x.shape 392 | hh = torch.arange(h).to(x).view(1, 1, h, 1).expand(b, 1, h, w) 393 | ww = torch.arange(w).to(x).view(1, 1, 1, w).expand(b, 1, h, w) 394 | tt = t.to(x).view(1, 1, 1, 1).expand(b, 1, h, w) 395 | x_aug = torch.cat([x, tt, hh, ww], 1) 396 | return self._layer(x_aug) 397 | 398 | 399 | class GatedLinear(nn.Module): 400 | def __init__(self, in_features, out_features): 401 | super(GatedLinear, self).__init__() 402 | self.layer_f = nn.Linear(in_features, out_features) 403 | self.layer_g = nn.Linear(in_features, out_features) 404 | 405 | def forward(self, x): 406 | f = self.layer_f(x) 407 | g = torch.sigmoid(self.layer_g(x)) 408 | return f * g 409 | 410 | 411 | class GatedConv(nn.Module): 412 | def __init__( 413 | self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1 414 | ): 415 | super(GatedConv, self).__init__() 416 | self.layer_f = nn.Conv2d( 417 | in_channels, 418 | out_channels, 419 | kernel_size, 420 | stride=stride, 421 | padding=padding, 422 | dilation=1, 423 | groups=groups, 424 | ) 425 | self.layer_g = nn.Conv2d( 426 | in_channels, 427 | out_channels, 428 | kernel_size, 429 | stride=stride, 430 | padding=padding, 431 | dilation=1, 432 | groups=groups, 433 | ) 434 | 435 | def forward(self, x): 436 | f = self.layer_f(x) 437 | g = torch.sigmoid(self.layer_g(x)) 438 | return f * g 439 | 440 | 441 | class GatedConvTranspose(nn.Module): 442 | def __init__( 443 | self, 444 | in_channels, 445 | out_channels, 446 | kernel_size, 447 | stride=1, 448 | padding=0, 449 | output_padding=0, 450 | groups=1, 451 | ): 452 | super(GatedConvTranspose, self).__init__() 453 | self.layer_f = nn.ConvTranspose2d( 454 | in_channels, 455 | out_channels, 456 | kernel_size, 457 | stride=stride, 458 | padding=padding, 459 | output_padding=output_padding, 460 | groups=groups, 461 | ) 462 | self.layer_g = nn.ConvTranspose2d( 463 | in_channels, 464 | out_channels, 465 | kernel_size, 466 | stride=stride, 467 | padding=padding, 468 | output_padding=output_padding, 469 | groups=groups, 470 | ) 471 | 472 | def forward(self, x): 473 | f = self.layer_f(x) 474 | g = torch.sigmoid(self.layer_g(x)) 475 | return f * g 476 | 477 | 478 | class BlendLinear(nn.Module): 479 | def __init__(self, dim_in, dim_out, layer_type=nn.Linear, **unused_kwargs): 480 | super(BlendLinear, self).__init__() 481 | self._layer0 = layer_type(dim_in, dim_out) 482 | self._layer1 = layer_type(dim_in, dim_out) 483 | 484 | def forward(self, t, x): 485 | y0 = self._layer0(x) 486 | y1 = self._layer1(x) 487 | return y0 + (y1 - y0) * t 488 | 489 | 490 | class BlendConv2d(nn.Module): 491 | def __init__( 492 | self, 493 | dim_in, 494 | dim_out, 495 | ksize=3, 496 | stride=1, 497 | padding=0, 498 | dilation=1, 499 | groups=1, 500 | bias=True, 501 | transpose=False, 502 | **unused_kwargs 503 | ): 504 | super(BlendConv2d, self).__init__() 505 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 506 | self._layer0 = module( 507 | dim_in, 508 | dim_out, 509 | kernel_size=ksize, 510 | stride=stride, 511 | padding=padding, 512 | dilation=dilation, 513 | groups=groups, 514 | bias=bias, 515 | ) 516 | self._layer1 = module( 517 | dim_in, 518 | dim_out, 519 | kernel_size=ksize, 520 | stride=stride, 521 | padding=padding, 522 | dilation=dilation, 523 | groups=groups, 524 | bias=bias, 525 | ) 526 | 527 | def forward(self, t, x): 528 | y0 = self._layer0(x) 529 | y1 = self._layer1(x) 530 | return y0 + (y1 - y0) * t 531 | -------------------------------------------------------------------------------- /lib/layers/diffeq_layers/container.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | from .wrappers import diffeq_wrapper 28 | 29 | 30 | class SequentialDiffEq(nn.Module): 31 | """A container for a sequential chain of layers. Supports both regular and diffeq layers.""" 32 | 33 | def __init__(self, *layers): 34 | super(SequentialDiffEq, self).__init__() 35 | self.layers = nn.ModuleList([diffeq_wrapper(layer) for layer in layers]) 36 | 37 | def forward(self, t, x): 38 | for layer in self.layers: 39 | x = layer(t, x) 40 | return x 41 | 42 | 43 | class MixtureODELayer(nn.Module): 44 | """Produces a mixture of experts where output = sigma(t) * f(t, x). 45 | Time-dependent weights sigma(t) help learn to blend the experts without resorting to a highly stiff f. 46 | Supports both regular and diffeq experts. 47 | """ 48 | 49 | def __init__(self, experts): 50 | super(MixtureODELayer, self).__init__() 51 | assert len(experts) > 1 52 | wrapped_experts = [diffeq_wrapper(ex) for ex in experts] 53 | self.experts = nn.ModuleList(wrapped_experts) 54 | self.mixture_weights = nn.Linear(1, len(self.experts)) 55 | 56 | def forward(self, t, y): 57 | dys = [] 58 | for f in self.experts: 59 | dys.append(f(t, y)) 60 | dys = torch.stack(dys, 0) 61 | weights = self.mixture_weights(t).view(-1, *([1] * (dys.ndimension() - 1))) 62 | 63 | dy = torch.sum(dys * weights, dim=0, keepdim=False) 64 | return dy 65 | -------------------------------------------------------------------------------- /lib/layers/diffeq_layers/resnet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch.nn as nn 25 | 26 | from . import basic 27 | from . import container 28 | 29 | NGROUPS = 16 30 | 31 | 32 | class ResNet(container.SequentialDiffEq): 33 | def __init__(self, dim, intermediate_dim, n_resblocks, conv_block=None): 34 | super(ResNet, self).__init__() 35 | 36 | if conv_block is None: 37 | conv_block = basic.ConcatCoordConv2d 38 | 39 | self.dim = dim 40 | self.intermediate_dim = intermediate_dim 41 | self.n_resblocks = n_resblocks 42 | 43 | layers = [] 44 | layers.append( 45 | conv_block(dim, intermediate_dim, ksize=3, stride=1, padding=1, bias=False) 46 | ) 47 | for _ in range(n_resblocks): 48 | layers.append(BasicBlock(intermediate_dim, conv_block)) 49 | layers.append(nn.GroupNorm(NGROUPS, intermediate_dim, eps=1e-4)) 50 | layers.append(nn.ReLU(inplace=True)) 51 | layers.append(conv_block(intermediate_dim, dim, ksize=1, bias=False)) 52 | 53 | super(ResNet, self).__init__(*layers) 54 | 55 | def __repr__(self): 56 | return "{name}({dim}, intermediate_dim={intermediate_dim}, n_resblocks={n_resblocks})".format( 57 | name=self.__class__.__name__, **self.__dict__ 58 | ) 59 | 60 | 61 | class BasicBlock(nn.Module): 62 | expansion = 1 63 | 64 | def __init__(self, dim, conv_block=None): 65 | super(BasicBlock, self).__init__() 66 | 67 | if conv_block is None: 68 | conv_block = basic.ConcatCoordConv2d 69 | 70 | self.norm1 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 71 | self.relu1 = nn.ReLU(inplace=True) 72 | self.conv1 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 73 | self.norm2 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 74 | self.relu2 = nn.ReLU(inplace=True) 75 | self.conv2 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 76 | 77 | def forward(self, t, x): 78 | residual = x 79 | 80 | out = self.norm1(x) 81 | out = self.relu1(out) 82 | out = self.conv1(t, out) 83 | 84 | out = self.norm2(out) 85 | out = self.relu2(out) 86 | out = self.conv2(t, out) 87 | 88 | out += residual 89 | 90 | return out 91 | -------------------------------------------------------------------------------- /lib/layers/diffeq_layers/wrappers.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | from inspect import signature 25 | 26 | import torch.nn as nn 27 | 28 | __all__ = ["diffeq_wrapper", "reshape_wrapper"] 29 | 30 | 31 | class DiffEqWrapper(nn.Module): 32 | def __init__(self, module): 33 | super(DiffEqWrapper, self).__init__() 34 | self.module = module 35 | if len(signature(self.module.forward).parameters) == 1: 36 | self.diffeq = lambda t, y: self.module(y) 37 | elif len(signature(self.module.forward).parameters) == 2: 38 | self.diffeq = self.module 39 | else: 40 | raise ValueError( 41 | "Differential equation needs to either take (t, y) or (y,) as input." 42 | ) 43 | 44 | def forward(self, t, y): 45 | return self.diffeq(t, y) 46 | 47 | def __repr__(self): 48 | return self.diffeq.__repr__() 49 | 50 | 51 | def diffeq_wrapper(layer): 52 | return DiffEqWrapper(layer) 53 | 54 | 55 | class ReshapeDiffEq(nn.Module): 56 | def __init__(self, input_shape, net): 57 | super(ReshapeDiffEq, self).__init__() 58 | assert ( 59 | len(signature(net.forward).parameters) == 2 60 | ), "use diffeq_wrapper before reshape_wrapper." 61 | self.input_shape = input_shape 62 | self.net = net 63 | 64 | def forward(self, t, x): 65 | batchsize = x.shape[0] 66 | x = x.view(batchsize, *self.input_shape) 67 | return self.net(t, x).view(batchsize, -1) 68 | 69 | def __repr__(self): 70 | return self.diffeq.__repr__() 71 | 72 | 73 | def reshape_wrapper(input_shape, layer): 74 | return ReshapeDiffEq(input_shape, layer) 75 | -------------------------------------------------------------------------------- /lib/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import math 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | _DEFAULT_ALPHA = 1e-6 30 | 31 | 32 | class ZeroMeanTransform(nn.Module): 33 | def __init__(self): 34 | nn.Module.__init__(self) 35 | 36 | def forward(self, x, logpx=None, reverse=False): 37 | if reverse: 38 | x = x + 0.5 39 | if logpx is None: 40 | return x 41 | return x, logpx 42 | else: 43 | x = x - 0.5 44 | if logpx is None: 45 | return x 46 | return x, logpx 47 | 48 | 49 | class LogitTransform(nn.Module): 50 | """ 51 | The proprocessing step used in Real NVP: 52 | y = sigmoid(x) - a / (1 - 2a) 53 | x = logit(a + (1 - 2a)*y) 54 | """ 55 | 56 | def __init__(self, alpha=_DEFAULT_ALPHA, effective_shape=None): 57 | nn.Module.__init__(self) 58 | self.alpha = alpha 59 | self.effective_shape = effective_shape 60 | 61 | def forward(self, x, logpx=None, reverse=False): 62 | if reverse: 63 | return _sigmoid(x, logpx, self.alpha, self.effective_shape) 64 | else: 65 | return _logit(x, logpx, self.alpha, self.effective_shape) 66 | 67 | 68 | class SigmoidTransform(nn.Module): 69 | """Reverse of LogitTransform.""" 70 | 71 | def __init__(self, alpha=_DEFAULT_ALPHA, effective_shape=None): 72 | nn.Module.__init__(self) 73 | self.alpha = alpha 74 | self.effective_shape = effective_shape 75 | 76 | def forward(self, x, logpx=None, reverse=False): 77 | if reverse: 78 | return _logit(x, logpx, self.alpha, self.effective_shape) 79 | else: 80 | return _sigmoid(x, logpx, self.alpha, self.effective_shape) 81 | 82 | 83 | def _logit(x, logpx=None, alpha=_DEFAULT_ALPHA, effective_shape=None): 84 | s = alpha + (1 - 2 * alpha) * x 85 | y = torch.log(s) - torch.log(1 - s) 86 | if logpx is None: 87 | return y 88 | 89 | if self.effective_shape is None: 90 | return y, logpx - _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True) 91 | return ( 92 | y, 93 | logpx 94 | - _logdetgrad(x.view(x.size(0), -1)[:, :effective_shape], alpha) 95 | .view(x.size(0), -1) 96 | .sum(1, keepdim=True), 97 | ) 98 | 99 | 100 | def _sigmoid(y, logpy=None, alpha=_DEFAULT_ALPHA, effective_shape=None): 101 | x = (torch.sigmoid(y) - alpha) / (1 - 2 * alpha) 102 | if logpy is None: 103 | return x 104 | if self.effective_shape is None: 105 | return x, logpy + _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True) 106 | return ( 107 | x, 108 | logpy 109 | + _logdetgrad(x.view(x.size(0), -1)[:, :effective_shape], alpha) 110 | .view(x.size(0), -1) 111 | .sum(1, keepdim=True), 112 | ) 113 | 114 | 115 | def _logdetgrad(x, alpha): 116 | s = alpha + (1 - 2 * alpha) * x 117 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * alpha) 118 | return logdetgrad 119 | -------------------------------------------------------------------------------- /lib/layers/glow.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | 29 | class BruteForceLayer(nn.Module): 30 | def __init__(self, dim): 31 | super(BruteForceLayer, self).__init__() 32 | self.weight = nn.Parameter(torch.eye(dim)) 33 | 34 | def forward(self, x, logpx=None, reverse=False): 35 | 36 | if not reverse: 37 | y = F.linear(x, self.weight) 38 | if logpx is None: 39 | return y 40 | else: 41 | return y, logpx - self._logdetgrad.expand_as(logpx) 42 | 43 | else: 44 | y = F.linear(x, self.weight.double().inverse().float()) 45 | if logpx is None: 46 | return y 47 | else: 48 | return y, logpx + self._logdetgrad.expand_as(logpx) 49 | 50 | @property 51 | def _logdetgrad(self): 52 | return torch.log(torch.abs(torch.det(self.weight.double()))).float() 53 | -------------------------------------------------------------------------------- /lib/layers/norm_flows.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import math 25 | 26 | import torch 27 | import torch.nn as nn 28 | from torch.autograd import grad 29 | 30 | 31 | class PlanarFlow(nn.Module): 32 | def __init__(self, nd=1): 33 | super(PlanarFlow, self).__init__() 34 | self.nd = nd 35 | self.activation = torch.tanh 36 | 37 | self.register_parameter("u", nn.Parameter(torch.randn(self.nd))) 38 | self.register_parameter("w", nn.Parameter(torch.randn(self.nd))) 39 | self.register_parameter("b", nn.Parameter(torch.randn(1))) 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | stdv = 1.0 / math.sqrt(self.nd) 44 | self.u.data.uniform_(-stdv, stdv) 45 | self.w.data.uniform_(-stdv, stdv) 46 | self.b.data.fill_(0) 47 | self.make_invertible() 48 | 49 | def make_invertible(self): 50 | u = self.u.data 51 | w = self.w.data 52 | dot = torch.dot(u, w) 53 | m = -1 + math.log(1 + math.exp(dot)) 54 | du = (m - dot) / torch.norm(w) * w 55 | u = u + du 56 | self.u.data = u 57 | 58 | def forward(self, z, logp=None, reverse=False): 59 | """Computes f(z) and log q(f(z))""" 60 | 61 | assert not reverse, "Planar normalizing flow cannot be reversed." 62 | 63 | logp - torch.log(self._detgrad(z) + 1e-8) 64 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 65 | z = z + self.u.expand_as(z) * h 66 | 67 | f = self.sample(z) 68 | if logp is not None: 69 | qf = self.log_density(z, logp) 70 | return f, qf 71 | else: 72 | return f 73 | 74 | def sample(self, z): 75 | """Computes f(z)""" 76 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 77 | output = z + self.u.expand_as(z) * h 78 | return output 79 | 80 | def _detgrad(self, z): 81 | """Computes |det df/dz|""" 82 | with torch.enable_grad(): 83 | z = z.requires_grad_(True) 84 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 85 | psi = grad( 86 | h, 87 | z, 88 | grad_outputs=torch.ones_like(h), 89 | create_graph=True, 90 | only_inputs=True, 91 | )[0] 92 | u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1)) 93 | detgrad = 1 + u_dot_psi 94 | return detgrad 95 | 96 | def log_density(self, z, logqz): 97 | """Computes log density of the flow given the log density of z""" 98 | return logqz - torch.log(self._detgrad(z) + 1e-8) 99 | -------------------------------------------------------------------------------- /lib/layers/normalization.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torch.nn import Parameter 27 | 28 | __all__ = ["MovingBatchNorm1d", "MovingBatchNorm2d"] 29 | 30 | 31 | class MovingBatchNormNd(nn.Module): 32 | def __init__( 33 | self, 34 | num_features, 35 | eps=1e-4, 36 | decay=0.1, 37 | bn_lag=0.0, 38 | affine=True, 39 | effective_shape=None, 40 | ): 41 | super(MovingBatchNormNd, self).__init__() 42 | self.num_features = num_features 43 | self.affine = affine 44 | self.eps = eps 45 | self.decay = decay 46 | self.bn_lag = bn_lag 47 | self.effective_shape = effective_shape 48 | self.register_buffer("step", torch.zeros(1)) 49 | if self.affine: 50 | self.weight = Parameter(torch.Tensor(num_features)) 51 | self.bias = Parameter(torch.Tensor(num_features)) 52 | else: 53 | self.register_parameter("weight", None) 54 | self.register_parameter("bias", None) 55 | self.register_buffer("running_mean", torch.zeros(num_features)) 56 | self.register_buffer("running_var", torch.ones(num_features)) 57 | self.reset_parameters() 58 | 59 | @property 60 | def shape(self): 61 | raise NotImplementedError 62 | 63 | def reset_parameters(self): 64 | self.running_mean.zero_() 65 | self.running_var.fill_(1) 66 | if self.affine: 67 | self.weight.data.zero_() 68 | self.bias.data.zero_() 69 | 70 | def forward(self, x, logpx=None, reverse=False): 71 | if reverse: 72 | return self._reverse(x, logpx) 73 | else: 74 | return self._forward(x, logpx) 75 | 76 | def _forward(self, x, logpx=None): 77 | c = x.size(1) 78 | used_mean = self.running_mean.clone().detach() 79 | used_var = self.running_var.clone().detach() 80 | 81 | if self.training: 82 | # compute batch statistics 83 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 84 | batch_mean = torch.mean(x_t, dim=1) 85 | batch_var = torch.var(x_t, dim=1) 86 | 87 | # moving average 88 | if self.bn_lag > 0: 89 | used_mean = batch_mean - (1 - self.bn_lag) * ( 90 | batch_mean - used_mean.detach() 91 | ) 92 | used_mean /= 1.0 - self.bn_lag ** (self.step[0] + 1) 93 | used_var = batch_var - (1 - self.bn_lag) * ( 94 | batch_var - used_var.detach() 95 | ) 96 | used_var /= 1.0 - self.bn_lag ** (self.step[0] + 1) 97 | 98 | # update running estimates 99 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 100 | self.running_var -= self.decay * (self.running_var - batch_var.data) 101 | self.step += 1 102 | 103 | # perform normalization 104 | used_mean = used_mean.view(*self.shape).expand_as(x) 105 | used_var = used_var.view(*self.shape).expand_as(x) 106 | 107 | y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps)) 108 | 109 | if self.affine: 110 | weight = self.weight.view(*self.shape).expand_as(x) 111 | bias = self.bias.view(*self.shape).expand_as(x) 112 | y = y * torch.exp(weight) + bias 113 | 114 | if logpx is None: 115 | return y 116 | else: 117 | if self.effective_shape is None: 118 | return y, logpx - self._logdetgrad(x, used_var).view(x.size(0), -1).sum( 119 | 1, keepdim=True 120 | ) 121 | else: 122 | return ( 123 | y, 124 | logpx 125 | - self._logdetgrad( 126 | x.view(x.size(0), -1)[:, : self.effective_shape], used_var 127 | ) 128 | .view(x.size(0), -1) 129 | .sum(1, keepdim=True), 130 | ) 131 | 132 | def _reverse(self, y, logpy=None): 133 | used_mean = self.running_mean 134 | used_var = self.running_var 135 | 136 | if self.affine: 137 | weight = self.weight.view(*self.shape).expand_as(y) 138 | bias = self.bias.view(*self.shape).expand_as(y) 139 | y = (y - bias) * torch.exp(-weight) 140 | 141 | used_mean = used_mean.view(*self.shape).expand_as(y) 142 | used_var = used_var.view(*self.shape).expand_as(y) 143 | x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean 144 | 145 | if logpy is None: 146 | return x 147 | else: 148 | return ( 149 | x, 150 | logpy 151 | + self._logdetgrad( 152 | x.view(x.size(0), -1)[:, : self.effective_shape], used_var 153 | ) 154 | .view(x.size(0), -1) 155 | .sum(1, keepdim=True), 156 | ) 157 | 158 | def _logdetgrad(self, x, used_var): 159 | logdetgrad = -0.5 * torch.log(used_var + self.eps) 160 | if self.affine: 161 | weight = self.weight.view(*self.shape).expand(*x.size()) 162 | logdetgrad += weight 163 | return logdetgrad 164 | 165 | def __repr__(self): 166 | return ( 167 | "{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag}," 168 | " affine={affine})".format(name=self.__class__.__name__, **self.__dict__) 169 | ) 170 | 171 | 172 | def stable_var(x, mean=None, dim=1): 173 | if mean is None: 174 | mean = x.mean(dim, keepdim=True) 175 | mean = mean.view(-1, 1) 176 | res = torch.pow(x - mean, 2) 177 | max_sqr = torch.max(res, dim, keepdim=True)[0] 178 | var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr 179 | var = var.view(-1) 180 | # change nan to zero 181 | var[var != var] = 0 182 | return var 183 | 184 | 185 | class MovingBatchNorm1d(MovingBatchNormNd): 186 | @property 187 | def shape(self): 188 | return [1, -1] 189 | 190 | 191 | class MovingBatchNorm2d(MovingBatchNormNd): 192 | @property 193 | def shape(self): 194 | return [1, -1, 1, 1] 195 | -------------------------------------------------------------------------------- /lib/layers/odefunc.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import copy 25 | 26 | import numpy as np 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | 31 | from . import diffeq_layers 32 | from .squeeze import squeeze, unsqueeze 33 | 34 | __all__ = ["ODEnet", "AutoencoderDiffEqNet", "ODEfunc", "AutoencoderODEfunc"] 35 | 36 | 37 | def divergence_bf(dx, y, **unused_kwargs): 38 | sum_diag = 0.0 39 | for i in range(y.shape[1]): 40 | sum_diag += ( 41 | torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0] 42 | .contiguous()[:, i] 43 | .contiguous() 44 | ) 45 | return sum_diag.contiguous() 46 | 47 | 48 | def _get_minibatch_jacobian(y, x): 49 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 50 | 51 | Args: 52 | y: (N, ...) with a total of D_y elements in ... 53 | x: (N, ...) with a total of D_x elements in ... 54 | Returns: 55 | The minibatch Jacobian matrix of shape (N, D_y, D_x) 56 | """ 57 | assert y.shape[0] == x.shape[0] 58 | y = y.view(y.shape[0], -1) 59 | 60 | # Compute Jacobian row by row. 61 | jac = [] 62 | for j in range(y.shape[1]): 63 | dy_j_dx = torch.autograd.grad( 64 | y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True, create_graph=True 65 | )[0].view(x.shape[0], -1) 66 | jac.append(torch.unsqueeze(dy_j_dx, 1)) 67 | jac = torch.cat(jac, 1) 68 | return jac 69 | 70 | 71 | def divergence_approx(f, y, e=None): 72 | e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] 73 | e_dzdx_e = e_dzdx * e 74 | approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1) 75 | return approx_tr_dzdx 76 | 77 | 78 | def sample_rademacher_like(y): 79 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 80 | 81 | 82 | def sample_gaussian_like(y): 83 | return torch.randn_like(y) 84 | 85 | 86 | class Swish(nn.Module): 87 | def __init__(self): 88 | super(Swish, self).__init__() 89 | self.beta = nn.Parameter(torch.tensor(1.0)) 90 | 91 | def forward(self, x): 92 | return x * torch.sigmoid(self.beta * x) 93 | 94 | 95 | class Lambda(nn.Module): 96 | def __init__(self, f): 97 | super(Lambda, self).__init__() 98 | self.f = f 99 | 100 | def forward(self, x): 101 | return self.f(x) 102 | 103 | 104 | NONLINEARITIES = { 105 | "tanh": nn.Tanh(), 106 | "relu": nn.ReLU(), 107 | "softplus": nn.Softplus(), 108 | "elu": nn.ELU(), 109 | "swish": Swish(), 110 | "square": Lambda(lambda x: x ** 2), 111 | "identity": Lambda(lambda x: x), 112 | } 113 | 114 | 115 | class ODEnet(nn.Module): 116 | """ 117 | Helper class to make neural nets for use in continuous normalizing flows 118 | """ 119 | 120 | def __init__( 121 | self, 122 | hidden_dims, 123 | input_shape, 124 | strides, 125 | conv, 126 | layer_type="concat", 127 | nonlinearity="softplus", 128 | num_squeeze=0, 129 | ): 130 | super(ODEnet, self).__init__() 131 | self.num_squeeze = num_squeeze 132 | if conv: 133 | assert len(strides) == len(hidden_dims) + 1 134 | base_layer = { 135 | "ignore": diffeq_layers.IgnoreConv2d, 136 | "hyper": diffeq_layers.HyperConv2d, 137 | "squash": diffeq_layers.SquashConv2d, 138 | "concat": diffeq_layers.ConcatConv2d, 139 | "concat_v2": diffeq_layers.ConcatConv2d_v2, 140 | "concatsquash": diffeq_layers.ConcatSquashConv2d, 141 | "blend": diffeq_layers.BlendConv2d, 142 | "concatcoord": diffeq_layers.ConcatCoordConv2d, 143 | }[layer_type] 144 | else: 145 | strides = [None] * (len(hidden_dims) + 1) 146 | base_layer = { 147 | "ignore": diffeq_layers.IgnoreLinear, 148 | "hyper": diffeq_layers.HyperLinear, 149 | "squash": diffeq_layers.SquashLinear, 150 | "concat": diffeq_layers.ConcatLinear, 151 | "concat_v2": diffeq_layers.ConcatLinear_v2, 152 | "concatsquash": diffeq_layers.ConcatSquashLinear, 153 | "blend": diffeq_layers.BlendLinear, 154 | "concatcoord": diffeq_layers.ConcatLinear, 155 | }[layer_type] 156 | 157 | # build layers and add them 158 | layers = [] 159 | activation_fns = [] 160 | hidden_shape = input_shape 161 | 162 | for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides): 163 | if stride is None: 164 | layer_kwargs = {} 165 | elif stride == 1: 166 | layer_kwargs = { 167 | "ksize": 3, 168 | "stride": 1, 169 | "padding": 1, 170 | "transpose": False, 171 | } 172 | elif stride == 2: 173 | layer_kwargs = { 174 | "ksize": 4, 175 | "stride": 2, 176 | "padding": 1, 177 | "transpose": False, 178 | } 179 | elif stride == -2: 180 | layer_kwargs = { 181 | "ksize": 4, 182 | "stride": 2, 183 | "padding": 1, 184 | "transpose": True, 185 | } 186 | else: 187 | raise ValueError("Unsupported stride: {}".format(stride)) 188 | 189 | layer = base_layer(hidden_shape[0], dim_out, **layer_kwargs) 190 | layers.append(layer) 191 | activation_fns.append(NONLINEARITIES[nonlinearity]) 192 | 193 | hidden_shape = list(copy.copy(hidden_shape)) 194 | hidden_shape[0] = dim_out 195 | if stride == 2: 196 | hidden_shape[1], hidden_shape[2] = ( 197 | hidden_shape[1] // 2, 198 | hidden_shape[2] // 2, 199 | ) 200 | elif stride == -2: 201 | hidden_shape[1], hidden_shape[2] = ( 202 | hidden_shape[1] * 2, 203 | hidden_shape[2] * 2, 204 | ) 205 | 206 | self.layers = nn.ModuleList(layers) 207 | self.activation_fns = nn.ModuleList(activation_fns[:-1]) 208 | 209 | def forward(self, t, y): 210 | dx = y 211 | # squeeze 212 | for _ in range(self.num_squeeze): 213 | dx = squeeze(dx, 2) 214 | for l, layer in enumerate(self.layers): 215 | dx = layer(t, dx) 216 | # if not last layer, use nonlinearity 217 | if l < len(self.layers) - 1: 218 | dx = self.activation_fns[l](dx) 219 | # unsqueeze 220 | for _ in range(self.num_squeeze): 221 | dx = unsqueeze(dx, 2) 222 | return dx 223 | 224 | 225 | class AutoencoderDiffEqNet(nn.Module): 226 | """ 227 | Helper class to make neural nets for use in continuous normalizing flows 228 | """ 229 | 230 | def __init__( 231 | self, 232 | hidden_dims, 233 | input_shape, 234 | strides, 235 | conv, 236 | layer_type="concat", 237 | nonlinearity="softplus", 238 | ): 239 | super(AutoencoderDiffEqNet, self).__init__() 240 | assert layer_type in ("ignore", "hyper", "concat", "concatcoord", "blend") 241 | assert nonlinearity in ("tanh", "relu", "softplus", "elu") 242 | 243 | self.nonlinearity = { 244 | "tanh": F.tanh, 245 | "relu": F.relu, 246 | "softplus": F.softplus, 247 | "elu": F.elu, 248 | }[nonlinearity] 249 | if conv: 250 | assert len(strides) == len(hidden_dims) + 1 251 | base_layer = { 252 | "ignore": diffeq_layers.IgnoreConv2d, 253 | "hyper": diffeq_layers.HyperConv2d, 254 | "squash": diffeq_layers.SquashConv2d, 255 | "concat": diffeq_layers.ConcatConv2d, 256 | "blend": diffeq_layers.BlendConv2d, 257 | "concatcoord": diffeq_layers.ConcatCoordConv2d, 258 | }[layer_type] 259 | else: 260 | strides = [None] * (len(hidden_dims) + 1) 261 | base_layer = { 262 | "ignore": diffeq_layers.IgnoreLinear, 263 | "hyper": diffeq_layers.HyperLinear, 264 | "squash": diffeq_layers.SquashLinear, 265 | "concat": diffeq_layers.ConcatLinear, 266 | "blend": diffeq_layers.BlendLinear, 267 | "concatcoord": diffeq_layers.ConcatLinear, 268 | }[layer_type] 269 | 270 | # build layers and add them 271 | encoder_layers = [] 272 | decoder_layers = [] 273 | hidden_shape = input_shape 274 | for i, (dim_out, stride) in enumerate( 275 | zip(hidden_dims + (input_shape[0],), strides) 276 | ): 277 | if i <= len(hidden_dims) // 2: 278 | layers = encoder_layers 279 | else: 280 | layers = decoder_layers 281 | 282 | if stride is None: 283 | layer_kwargs = {} 284 | elif stride == 1: 285 | layer_kwargs = { 286 | "ksize": 3, 287 | "stride": 1, 288 | "padding": 1, 289 | "transpose": False, 290 | } 291 | elif stride == 2: 292 | layer_kwargs = { 293 | "ksize": 4, 294 | "stride": 2, 295 | "padding": 1, 296 | "transpose": False, 297 | } 298 | elif stride == -2: 299 | layer_kwargs = { 300 | "ksize": 4, 301 | "stride": 2, 302 | "padding": 1, 303 | "transpose": True, 304 | } 305 | else: 306 | raise ValueError("Unsupported stride: {}".format(stride)) 307 | 308 | layers.append(base_layer(hidden_shape[0], dim_out, **layer_kwargs)) 309 | 310 | hidden_shape = list(copy.copy(hidden_shape)) 311 | hidden_shape[0] = dim_out 312 | if stride == 2: 313 | hidden_shape[1], hidden_shape[2] = ( 314 | hidden_shape[1] // 2, 315 | hidden_shape[2] // 2, 316 | ) 317 | elif stride == -2: 318 | hidden_shape[1], hidden_shape[2] = ( 319 | hidden_shape[1] * 2, 320 | hidden_shape[2] * 2, 321 | ) 322 | 323 | self.encoder_layers = nn.ModuleList(encoder_layers) 324 | self.decoder_layers = nn.ModuleList(decoder_layers) 325 | 326 | def forward(self, t, y): 327 | h = y 328 | for layer in self.encoder_layers: 329 | h = self.nonlinearity(layer(t, h)) 330 | 331 | dx = h 332 | for i, layer in enumerate(self.decoder_layers): 333 | dx = layer(t, dx) 334 | # if not last layer, use nonlinearity 335 | if i < len(self.decoder_layers) - 1: 336 | dx = self.nonlinearity(dx) 337 | return h, dx 338 | 339 | 340 | class ODEfunc(nn.Module): 341 | def __init__( 342 | self, diffeq, divergence_fn="approximate", residual=False, rademacher=False 343 | ): 344 | super(ODEfunc, self).__init__() 345 | assert divergence_fn in ("brute_force", "approximate") 346 | 347 | # self.diffeq = diffeq_layers.wrappers.diffeq_wrapper(diffeq) 348 | self.diffeq = diffeq 349 | self.residual = residual 350 | self.rademacher = rademacher 351 | 352 | if divergence_fn == "brute_force": 353 | self.divergence_fn = divergence_bf 354 | elif divergence_fn == "approximate": 355 | self.divergence_fn = divergence_approx 356 | 357 | self.register_buffer("_num_evals", torch.tensor(0.0)) 358 | 359 | def before_odeint(self, e=None): 360 | self._e = e 361 | self._num_evals.fill_(0) 362 | 363 | def num_evals(self): 364 | return self._num_evals.item() 365 | 366 | def forward(self, t, states): 367 | assert len(states) >= 2 368 | y = states[0] 369 | 370 | # increment num evals 371 | self._num_evals += 1 372 | 373 | # convert to tensor 374 | t = torch.tensor(t).type_as(y) 375 | batchsize = y.shape[0] 376 | 377 | # Sample and fix the noise. 378 | if self._e is None: 379 | if self.rademacher: 380 | self._e = sample_rademacher_like(y) 381 | else: 382 | self._e = sample_gaussian_like(y) 383 | 384 | with torch.set_grad_enabled(True): 385 | y.requires_grad_(True) 386 | t.requires_grad_(True) 387 | for s_ in states[2:]: 388 | s_.requires_grad_(True) 389 | dy = self.diffeq(t, y, *states[2:]) 390 | # Hack for 2D data to use brute force divergence computation. 391 | if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2: 392 | divergence = divergence_bf(dy, y).view(batchsize, 1) 393 | else: 394 | divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1) 395 | if self.residual: 396 | dy = dy - y 397 | divergence -= torch.ones_like(divergence) * torch.tensor( 398 | np.prod(y.shape[1:]), dtype=torch.float32 399 | ).to(divergence) 400 | return tuple( 401 | [dy, -divergence] 402 | + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]] 403 | ) 404 | 405 | 406 | class AutoencoderODEfunc(nn.Module): 407 | def __init__( 408 | self, 409 | autoencoder_diffeq, 410 | divergence_fn="approximate", 411 | residual=False, 412 | rademacher=False, 413 | ): 414 | assert divergence_fn in ( 415 | "approximate" 416 | ), "Only approximate divergence supported at the moment. (TODO)" 417 | assert isinstance(autoencoder_diffeq, AutoencoderDiffEqNet) 418 | super(AutoencoderODEfunc, self).__init__() 419 | self.residual = residual 420 | self.autoencoder_diffeq = autoencoder_diffeq 421 | self.rademacher = rademacher 422 | 423 | self.register_buffer("_num_evals", torch.tensor(0.0)) 424 | 425 | def before_odeint(self, e=None): 426 | self._e = e 427 | self._num_evals.fill_(0) 428 | 429 | def forward(self, t, y_and_logpy): 430 | y, _ = y_and_logpy # remove logpy 431 | 432 | # increment num evals 433 | self._num_evals += 1 434 | 435 | # convert to tensor 436 | t = torch.tensor(t).type_as(y) 437 | batchsize = y.shape[0] 438 | 439 | with torch.set_grad_enabled(True): 440 | y.requires_grad_(True) 441 | t.requires_grad_(True) 442 | h, dy = self.autoencoder_diffeq(t, y) 443 | 444 | # Sample and fix the noise. 445 | if self._e is None: 446 | if self.rademacher: 447 | self._e = sample_rademacher_like(h) 448 | else: 449 | self._e = sample_gaussian_like(h) 450 | 451 | e_vjp_dhdy = torch.autograd.grad(h, y, self._e, create_graph=True)[0] 452 | e_vjp_dfdy = torch.autograd.grad(dy, h, e_vjp_dhdy, create_graph=True)[0] 453 | divergence = torch.sum( 454 | (e_vjp_dfdy * self._e).view(batchsize, -1), 1, keepdim=True 455 | ) 456 | 457 | if self.residual: 458 | dy = dy - y 459 | divergence -= torch.ones_like(divergence) * torch.tensor( 460 | np.prod(y.shape[1:]), dtype=torch.float32 461 | ).to(divergence) 462 | 463 | return dy, -divergence 464 | -------------------------------------------------------------------------------- /lib/layers/odefunc_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | # This code is based on ffjord project which can be found at https://github.com/rtqichen/ffjordimport copy 10 | 11 | import copy 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | from . import diffeq_layers 18 | from .odefunc import NONLINEARITIES, sample_gaussian_like, sample_rademacher_like 19 | from .squeeze import squeeze, unsqueeze 20 | 21 | __all__ = ["AugODEnet", "AugODEfunc"] 22 | 23 | 24 | def divergence_bf_aug(dx, y, effective_dim, **unused_kwargs): 25 | """ 26 | The function for computing the exact log determinant of jacobian for augmented ode 27 | 28 | Parameters 29 | dx: Output of the neural ODE function 30 | y: input to the neural ode function 31 | effective_dim (int): the first n dimension of the input being transformed 32 | by normalizing flows to compute log determinant 33 | Returns: 34 | sum_diag: determin 35 | """ 36 | sum_diag = 0.0 37 | assert effective_dim <= y.shape[1] 38 | for i in range(effective_dim): 39 | sum_diag += ( 40 | torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0] 41 | .contiguous()[:, i] 42 | .contiguous() 43 | ) 44 | return sum_diag.contiguous() 45 | 46 | 47 | def divergence_approx_aug(f, y, effective_dim, e=None): 48 | """ 49 | The function for estimating log determinant of jacobian 50 | for augmented ode using Hutchinson's Estimator 51 | 52 | Parameters 53 | f: Output of the neural ODE function 54 | y: input to the neural ode function 55 | effective_dim (int): the first n dimensions of the input being transformed 56 | by normalizing flows to compute log determinant 57 | 58 | Returns: 59 | sum_diag: estimate log determinant of the df/dy 60 | """ 61 | e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] 62 | e_dzdx_e = e_dzdx * e 63 | approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1) 64 | return approx_tr_dzdx 65 | 66 | 67 | class AugODEnet(nn.Module): 68 | """ 69 | Class to make neural nets for use in augmented continuous normalizing flows 70 | Only consider one-dimensional data yet 71 | 72 | Parameters: 73 | hidden_dims (list): the hidden dimensions of the neural ODE function 74 | aug_dim (int): dimension along which the input is augmented 75 | effective_shape (int): the size of input to be transformed 76 | aug_mapping (int): True or False determine whether the augmented dimension will be 77 | fed into a network 78 | aug_hidden_dims (list): list of hiddem dimensions for the network of the augmented input 79 | """ 80 | 81 | def __init__( 82 | self, 83 | hidden_dims, 84 | input_shape, 85 | effective_shape, 86 | strides, 87 | conv, 88 | layer_type="concat", 89 | nonlinearity="softplus", 90 | num_squeeze=0, 91 | aug_dim=0, 92 | aug_mapping=True, 93 | aug_hidden_dims=None, 94 | ): 95 | 96 | super(AugODEnet, self).__init__() 97 | self.aug_mapping = aug_mapping 98 | self.num_squeeze = num_squeeze 99 | self.effective_shape = effective_shape 100 | if conv: 101 | raise NotImplementedError 102 | else: 103 | strides = [None] * (len(hidden_dims) + 1) 104 | base_layer = { 105 | "ignore": diffeq_layers.IgnoreLinear, 106 | "hyper": diffeq_layers.HyperLinear, 107 | "squash": diffeq_layers.SquashLinear, 108 | "concat": diffeq_layers.ConcatLinear, 109 | "concat_v2": diffeq_layers.ConcatLinear_v2, 110 | "concatsquash": diffeq_layers.ConcatSquashLinear, 111 | "blend": diffeq_layers.BlendLinear, 112 | "concatcoord": diffeq_layers.ConcatLinear, 113 | }[layer_type] 114 | 115 | # build layers and add them 116 | layers = [] 117 | activation_fns = [] 118 | hidden_shape = input_shape 119 | if self.aug_mapping: 120 | aug_layers = [] 121 | aug_activation_fns = [] 122 | aug_hidden_shape = list(copy.copy(input_shape)) 123 | aug_hidden_shape[aug_dim] = input_shape[aug_dim] - effective_shape 124 | if aug_hidden_dims is None: 125 | aug_hidden_dims = copy.copy(hidden_dims) 126 | 127 | for dim_out, stride in zip(hidden_dims + (effective_shape,), strides): 128 | if stride is None: 129 | layer_kwargs = {} 130 | elif stride == 1: 131 | layer_kwargs = { 132 | "ksize": 3, 133 | "stride": 1, 134 | "padding": 1, 135 | "transpose": False, 136 | } 137 | elif stride == 2: 138 | layer_kwargs = { 139 | "ksize": 4, 140 | "stride": 2, 141 | "padding": 1, 142 | "transpose": False, 143 | } 144 | elif stride == -2: 145 | layer_kwargs = { 146 | "ksize": 4, 147 | "stride": 2, 148 | "padding": 1, 149 | "transpose": True, 150 | } 151 | else: 152 | raise ValueError("Unsupported stride: {}".format(stride)) 153 | 154 | layer = base_layer(hidden_shape[0], dim_out, **layer_kwargs) 155 | layers.append(layer) 156 | activation_fns.append(NONLINEARITIES[nonlinearity]) 157 | 158 | hidden_shape = list(copy.copy(hidden_shape)) 159 | hidden_shape[0] = dim_out 160 | if stride == 2: 161 | hidden_shape[1], hidden_shape[2] = ( 162 | hidden_shape[1] // 2, 163 | hidden_shape[2] // 2, 164 | ) 165 | elif stride == -2: 166 | hidden_shape[1], hidden_shape[2] = ( 167 | hidden_shape[1] * 2, 168 | hidden_shape[2] * 2, 169 | ) 170 | if self.aug_mapping: 171 | for dim_out, stride in zip( 172 | aug_hidden_dims + (input_shape[aug_dim] - effective_shape,), strides 173 | ): 174 | if stride is None: 175 | layer_kwargs = {} 176 | elif stride == 1: 177 | layer_kwargs = { 178 | "ksize": 3, 179 | "stride": 1, 180 | "padding": 1, 181 | "transpose": False, 182 | } 183 | elif stride == 2: 184 | layer_kwargs = { 185 | "ksize": 4, 186 | "stride": 2, 187 | "padding": 1, 188 | "transpose": False, 189 | } 190 | elif stride == -2: 191 | layer_kwargs = { 192 | "ksize": 4, 193 | "stride": 2, 194 | "padding": 1, 195 | "transpose": True, 196 | } 197 | else: 198 | raise ValueError("Unsupported stride: {}".format(stride)) 199 | 200 | layer = base_layer(aug_hidden_shape[0], dim_out, **layer_kwargs) 201 | aug_layers.append(layer) 202 | aug_activation_fns.append(NONLINEARITIES[nonlinearity]) 203 | 204 | aug_hidden_shape = list(copy.copy(aug_hidden_shape)) 205 | aug_hidden_shape[0] = dim_out 206 | if stride == 2: 207 | aug_hidden_shape[1], aug_hidden_shape[2] = ( 208 | aug_hidden_shape[1] // 2, 209 | aug_hidden_shape[2] // 2, 210 | ) 211 | elif stride == -2: 212 | aug_hidden_shape[1], aug_hidden_shape[2] = ( 213 | aug_hidden_shape[1] * 2, 214 | aug_hidden_shape[2] * 2, 215 | ) 216 | 217 | self.layers = nn.ModuleList(layers) 218 | self.activation_fns = nn.ModuleList(activation_fns[:-1]) 219 | if self.aug_mapping: 220 | self.aug_layers = nn.ModuleList(aug_layers) 221 | self.aug_activation_fns = nn.ModuleList(aug_activation_fns[:-1]) 222 | 223 | def forward(self, t, y): 224 | dx = y 225 | # squeeze 226 | aug = y[:, self.effective_shape:] 227 | # aug = y[:, self] 228 | for _ in range(self.num_squeeze): 229 | dx = squeeze(dx, 2) 230 | for l, layer in enumerate(self.layers): 231 | dx = layer(t, dx) 232 | # if not last layer, use nonlinearity 233 | if l < len(self.layers) - 1: 234 | dx = self.activation_fns[l](dx) 235 | # unsqueeze 236 | for _ in range(self.num_squeeze): 237 | dx = unsqueeze(dx, 2) 238 | 239 | if self.aug_mapping: 240 | for l, layer in enumerate(self.aug_layers): 241 | aug = layer(t, aug) 242 | if l < len(self.aug_layers) - 1: 243 | aug = self.aug_activation_fns[l](aug) 244 | else: 245 | aug = torch.zeros_like(aug) 246 | 247 | dx = torch.cat([dx, aug], dim=1) 248 | return dx 249 | 250 | 251 | class AugODEfunc(nn.Module): 252 | """ 253 | Wrapper to make neural nets for use in augmented continuous normalizing flows 254 | """ 255 | 256 | def __init__( 257 | self, 258 | diffeq, 259 | divergence_fn="approximate", 260 | residual=False, 261 | rademacher=False, 262 | effective_shape=None, 263 | ): 264 | super(AugODEfunc, self).__init__() 265 | ## effective_dim is the effective dimension for likelihood estimation 266 | ## It's either an integer or a list of integers 267 | assert divergence_fn in ("brute_force", "approximate") 268 | 269 | self.diffeq = diffeq 270 | self.residual = residual 271 | self.rademacher = rademacher 272 | 273 | if divergence_fn == "brute_force": 274 | self.divergence_fn = divergence_bf_aug 275 | elif divergence_fn == "approximate": 276 | self.divergence_fn = divergence_approx_aug 277 | 278 | self.register_buffer("_num_evals", torch.tensor(0.0)) 279 | assert effective_shape is not None 280 | self.effective_shape = effective_shape 281 | 282 | def before_odeint(self, e=None): 283 | self._e = e 284 | self._num_evals.fill_(0) 285 | 286 | def num_evals(self): 287 | return self._num_evals.item() 288 | 289 | def forward(self, t, states): 290 | assert len(states) >= 2 291 | y = states[0] 292 | # increment num evals 293 | self._num_evals += 1 294 | 295 | # convert to tensor 296 | t = torch.tensor(t).type_as(y) 297 | batchsize = y.shape[0] 298 | 299 | # Sample and fix the noise. 300 | if self._e is None: 301 | self._e = torch.zeros_like(y) 302 | if isinstance(self.effective_shape, int): 303 | sample_like = y[:, : self.effective_shape] 304 | else: 305 | sample_like = y 306 | for dim, size in enumerate(self.effective_shape): 307 | sample_like = sample_like.narrow(dim + 1, 0, size) 308 | 309 | if self.rademacher: 310 | sample = sample_rademacher_like(sample_like) 311 | else: 312 | sample = sample_gaussian_like(sample_like) 313 | if isinstance(self.effective_shape, int): 314 | self._e[:, : self.effective_shape] = sample 315 | else: 316 | pad_size = [] 317 | for idx in self.effective_shape: 318 | pad_size.append(0) 319 | pad_size.append(y.shape[-idx - 1] - self.effective_shape[-idx - 1]) 320 | pad_size = tuple(pad_size) 321 | self._e = torch.functional.padding(sample, pad_size, mode="constant") 322 | ## pad zeros 323 | 324 | with torch.set_grad_enabled(True): 325 | y.requires_grad_(True) 326 | t.requires_grad_(True) 327 | for s_ in states[2:]: 328 | s_.requires_grad_(True) 329 | dy = self.diffeq(t, y, *states[2:]) 330 | # Hack for 2D data to use brute force divergence computation. 331 | if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2: 332 | divergence = divergence_bf_aug(dy, y, self.effective_shape).view( 333 | batchsize, 1 334 | ) 335 | else: 336 | divergence = self.divergence_fn( 337 | dy, y, self.effective_shape, e=self._e 338 | ).view(batchsize, 1) 339 | if self.residual: 340 | dy = dy - y 341 | if isinstance(self.effective_dim, int): 342 | divergence -= ( 343 | torch.ones_like(divergence) 344 | * torch.tensor( 345 | np.prod(y.shape[1:]) * self.effective_shape / y.shape[1], 346 | dtype=torch.float32, 347 | ).to(divergence) 348 | ) 349 | else: 350 | divergence -= ( 351 | torch.ones_like(divergence) 352 | * torch.tensor( 353 | np.prod(self.effective_shape), 354 | dtype=torch.float32, 355 | ).to(divergence) 356 | ) 357 | return tuple( 358 | [dy, -divergence] 359 | + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]] 360 | ) 361 | -------------------------------------------------------------------------------- /lib/layers/resnet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, dim): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 34 | self.bn1 = nn.GroupNorm(2, dim, eps=1e-4) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 37 | self.bn2 = nn.GroupNorm(2, dim, eps=1e-4) 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class ResNeXtBottleneck(nn.Module): 56 | """ 57 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 58 | """ 59 | 60 | def __init__(self, dim, cardinality=4, base_depth=32): 61 | """Constructor 62 | Args: 63 | in_channels: input channel dimensionality 64 | out_channels: output channel dimensionality 65 | stride: conv stride. Replaces pooling layer. 66 | cardinality: num of convolution groups. 67 | base_width: base number of channels in each group. 68 | widen_factor: factor to reduce the input dimensionality before convolution. 69 | """ 70 | super(ResNeXtBottleneck, self).__init__() 71 | D = cardinality * base_depth 72 | self.conv_reduce = nn.Conv2d( 73 | dim, D, kernel_size=1, stride=1, padding=0, bias=False 74 | ) 75 | self.bn_reduce = nn.BatchNorm2d(D) 76 | self.conv_grp = nn.Conv2d( 77 | D, D, kernel_size=3, stride=1, padding=1, groups=cardinality, bias=False 78 | ) 79 | self.bn = nn.BatchNorm2d(D) 80 | self.conv_expand = nn.Conv2d( 81 | D, dim, kernel_size=1, stride=1, padding=0, bias=False 82 | ) 83 | self.bn_expand = nn.BatchNorm2d(dim) 84 | 85 | def forward(self, x): 86 | bottleneck = self.conv_reduce.forward(x) 87 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 88 | bottleneck = self.conv_grp.forward(bottleneck) 89 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 90 | bottleneck = self.conv_expand.forward(bottleneck) 91 | bottleneck = self.bn_expand.forward(bottleneck) 92 | return F.relu(x + bottleneck, inplace=True) 93 | -------------------------------------------------------------------------------- /lib/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch.nn as nn 25 | 26 | __all__ = ["SqueezeLayer"] 27 | 28 | 29 | class SqueezeLayer(nn.Module): 30 | def __init__(self, downscale_factor): 31 | super(SqueezeLayer, self).__init__() 32 | self.downscale_factor = downscale_factor 33 | 34 | def forward(self, x, logpx=None, reverse=False): 35 | if reverse: 36 | return self._upsample(x, logpx) 37 | else: 38 | return self._downsample(x, logpx) 39 | 40 | def _downsample(self, x, logpx=None): 41 | squeeze_x = squeeze(x, self.downscale_factor) 42 | if logpx is None: 43 | return squeeze_x 44 | else: 45 | return squeeze_x, logpx 46 | 47 | def _upsample(self, y, logpy=None): 48 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 49 | if logpy is None: 50 | return unsqueeze_y 51 | else: 52 | return unsqueeze_y, logpy 53 | 54 | 55 | def unsqueeze(input, upscale_factor=2): 56 | """ 57 | [:, C*r^2, H, W] -> [:, C, H*r, W*r] 58 | """ 59 | batch_size, in_channels, in_height, in_width = input.size() 60 | out_channels = in_channels // (upscale_factor ** 2) 61 | 62 | out_height = in_height * upscale_factor 63 | out_width = in_width * upscale_factor 64 | 65 | input_view = input.contiguous().view( 66 | batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width 67 | ) 68 | 69 | output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 70 | return output.view(batch_size, out_channels, out_height, out_width) 71 | 72 | 73 | def squeeze(input, downscale_factor=2): 74 | """ 75 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 76 | """ 77 | batch_size, in_channels, in_height, in_width = input.size() 78 | out_channels = in_channels * (downscale_factor ** 2) 79 | 80 | out_height = in_height // downscale_factor 81 | out_width = in_width // downscale_factor 82 | 83 | input_view = input.contiguous().view( 84 | batch_size, 85 | in_channels, 86 | out_height, 87 | downscale_factor, 88 | out_width, 89 | downscale_factor, 90 | ) 91 | 92 | output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 93 | return output.view(batch_size, out_channels, out_height, out_width) 94 | -------------------------------------------------------------------------------- /lib/layers/wrappers/cnf_regularization.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | 28 | class RegularizedODEfunc(nn.Module): 29 | def __init__(self, odefunc, regularization_fns): 30 | super(RegularizedODEfunc, self).__init__() 31 | self.odefunc = odefunc 32 | self.regularization_fns = regularization_fns 33 | 34 | def before_odeint(self, *args, **kwargs): 35 | self.odefunc.before_odeint(*args, **kwargs) 36 | 37 | def forward(self, t, state): 38 | class SharedContext(object): 39 | pass 40 | 41 | with torch.enable_grad(): 42 | x, logp = state[:2] 43 | x.requires_grad_(True) 44 | logp.requires_grad_(True) 45 | dstate = self.odefunc(t, (x, logp)) 46 | if len(state) > 2: 47 | dx, dlogp = dstate[:2] 48 | reg_states = tuple( 49 | reg_fn(x, logp, dx, dlogp, SharedContext) 50 | for reg_fn in self.regularization_fns 51 | ) 52 | return dstate + reg_states 53 | else: 54 | return dstate 55 | 56 | @property 57 | def _num_evals(self): 58 | return self.odefunc._num_evals 59 | 60 | 61 | def _batch_root_mean_squared(tensor): 62 | tensor = tensor.view(tensor.shape[0], -1) 63 | return torch.mean(torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5) 64 | 65 | 66 | def l1_regularzation_fn(x, logp, dx, dlogp, unused_context): 67 | del x, logp, dlogp 68 | return torch.mean(torch.abs(dx)) 69 | 70 | 71 | def l2_regularzation_fn(x, logp, dx, dlogp, unused_context): 72 | del x, logp, dlogp 73 | return _batch_root_mean_squared(dx) 74 | 75 | 76 | def directional_l2_regularization_fn(x, logp, dx, dlogp, unused_context): 77 | del logp, dlogp 78 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 79 | return _batch_root_mean_squared(directional_dx) 80 | 81 | 82 | def jacobian_frobenius_regularization_fn(x, logp, dx, dlogp, context): 83 | del logp, dlogp 84 | if hasattr(context, "jac"): 85 | jac = context.jac 86 | else: 87 | jac = _get_minibatch_jacobian(dx, x) 88 | context.jac = jac 89 | return _batch_root_mean_squared(jac) 90 | 91 | 92 | def jacobian_diag_frobenius_regularization_fn(x, logp, dx, dlogp, context): 93 | del logp, dlogp 94 | if hasattr(context, "jac"): 95 | jac = context.jac 96 | else: 97 | jac = _get_minibatch_jacobian(dx, x) 98 | context.jac = jac 99 | diagonal = jac.view(jac.shape[0], -1)[ 100 | :, :: jac.shape[1] 101 | ] # assumes jac is minibatch square, ie. (N, M, M). 102 | return _batch_root_mean_squared(diagonal) 103 | 104 | 105 | def jacobian_offdiag_frobenius_regularization_fn(x, logp, dx, dlogp, context): 106 | del logp, dlogp 107 | if hasattr(context, "jac"): 108 | jac = context.jac 109 | else: 110 | jac = _get_minibatch_jacobian(dx, x) 111 | context.jac = jac 112 | diagonal = jac.view(jac.shape[0], -1)[ 113 | :, :: jac.shape[1] 114 | ] # assumes jac is minibatch square, ie. (N, M, M). 115 | ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum( 116 | diagonal ** 2, dim=1 117 | ) 118 | ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1)) 119 | return torch.mean(ms_offdiag) 120 | 121 | 122 | def _get_minibatch_jacobian(y, x, create_graph=False): 123 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 124 | 125 | Args: 126 | y: (N, ...) with a total of D_y elements in ... 127 | x: (N, ...) with a total of D_x elements in ... 128 | Returns: 129 | The minibatch Jacobian matrix of shape (N, D_y, D_x) 130 | """ 131 | assert y.shape[0] == x.shape[0] 132 | y = y.view(y.shape[0], -1) 133 | 134 | # Compute Jacobian row by row. 135 | jac = [] 136 | for j in range(y.shape[1]): 137 | dy_j_dx = torch.autograd.grad( 138 | y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True, create_graph=True 139 | )[0].view(x.shape[0], -1) 140 | jac.append(torch.unsqueeze(dy_j_dx, 1)) 141 | jac = torch.cat(jac, 1) 142 | return jac 143 | -------------------------------------------------------------------------------- /lib/ode_func.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Yulia Rubanova 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | ########################### 24 | # Latent ODEs for Irregularly-Sampled Time Series 25 | # Author: Yulia Rubanova 26 | # Link: https://github.com/YuliaRubanova/latent_ode 27 | ########################### 28 | 29 | import lib.utils as utils 30 | import torch 31 | import torch.nn as nn 32 | 33 | 34 | ##################################################################################################### 35 | 36 | 37 | class ODEFunc(nn.Module): 38 | def __init__(self, input_dim, latent_dim, ode_func_net, device=torch.device("cpu")): 39 | """ 40 | input_dim: dimensionality of the input 41 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 42 | """ 43 | super(ODEFunc, self).__init__() 44 | 45 | self.input_dim = input_dim 46 | self.device = device 47 | 48 | utils.init_network_weights(ode_func_net) 49 | self.gradient_net = ode_func_net 50 | 51 | def forward(self, t_local, y, backwards=False): 52 | """ 53 | Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point 54 | 55 | t_local: current time point 56 | y: value at the current time point 57 | """ 58 | grad = self.get_ode_gradient_nn(t_local, y) 59 | if backwards: 60 | grad = -grad 61 | return grad 62 | 63 | def get_ode_gradient_nn(self, t_local, y): 64 | return self.gradient_net(y) 65 | 66 | def sample_next_point_from_prior(self, t_local, y): 67 | """ 68 | t_local: current time point 69 | y: value at the current time point 70 | """ 71 | return self.get_ode_gradient_nn(t_local, y) 72 | 73 | 74 | ##################################################################################################### 75 | 76 | 77 | class ODEFunc_w_Poisson(ODEFunc): 78 | def __init__( 79 | self, 80 | input_dim, 81 | latent_dim, 82 | ode_func_net, 83 | lambda_net, 84 | device=torch.device("cpu"), 85 | ): 86 | """ 87 | input_dim: dimensionality of the input 88 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 89 | """ 90 | super(ODEFunc_w_Poisson, self).__init__( 91 | input_dim, latent_dim, ode_func_net, device 92 | ) 93 | 94 | self.latent_ode = ODEFunc( 95 | input_dim=input_dim, 96 | latent_dim=latent_dim, 97 | ode_func_net=ode_func_net, 98 | device=device, 99 | ) 100 | 101 | self.latent_dim = latent_dim 102 | self.lambda_net = lambda_net 103 | # The computation of poisson likelihood can become numerically unstable. 104 | # The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T] 105 | # Exponent of lambda can also take large values 106 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant 107 | self.const_for_lambda = torch.Tensor([100.0]).to(device) 108 | 109 | def extract_poisson_rate(self, augmented, final_result=True): 110 | y, log_lambdas, int_lambda = None, None, None 111 | 112 | assert augmented.size(-1) == self.latent_dim + self.input_dim 113 | latent_lam_dim = self.latent_dim // 2 114 | 115 | if len(augmented.size()) == 3: 116 | int_lambda = augmented[:, :, -self.input_dim:] 117 | y_latent_lam = augmented[:, :, : -self.input_dim] 118 | 119 | log_lambdas = self.lambda_net(y_latent_lam[:, :, -latent_lam_dim:]) 120 | y = y_latent_lam[:, :, :-latent_lam_dim] 121 | 122 | elif len(augmented.size()) == 4: 123 | int_lambda = augmented[:, :, :, -self.input_dim:] 124 | y_latent_lam = augmented[:, :, :, : -self.input_dim] 125 | 126 | log_lambdas = self.lambda_net(y_latent_lam[:, :, :, -latent_lam_dim:]) 127 | y = y_latent_lam[:, :, :, :-latent_lam_dim] 128 | 129 | # Multiply the intergral over lambda by a constant 130 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn) 131 | if final_result: 132 | int_lambda = int_lambda * self.const_for_lambda 133 | 134 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas) 135 | assert y.size(-1) == latent_lam_dim 136 | 137 | return y, log_lambdas, int_lambda, y_latent_lam 138 | 139 | def get_ode_gradient_nn(self, t_local, augmented): 140 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate( 141 | augmented, final_result=False 142 | ) 143 | dydt_dldt = self.latent_ode(t_local, y_latent_lam) 144 | 145 | log_lam = log_lam - torch.log(self.const_for_lambda) 146 | return torch.cat((dydt_dldt, torch.exp(log_lam)), -1) 147 | -------------------------------------------------------------------------------- /lib/spectral_norm.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | """ 25 | Spectral Normalization from https://arxiv.org/abs/1802.05957 26 | """ 27 | import types 28 | 29 | import torch 30 | from torch.nn.functional import normalize 31 | 32 | POWER_ITERATION_FN = "spectral_norm_power_iteration" 33 | 34 | 35 | class SpectralNorm(object): 36 | def __init__(self, name="weight", dim=0, eps=1e-12): 37 | self.name = name 38 | self.dim = dim 39 | self.eps = eps 40 | 41 | def compute_weight(self, module, n_power_iterations): 42 | if n_power_iterations < 0: 43 | raise ValueError( 44 | "Expected n_power_iterations to be non-negative, but " 45 | "got n_power_iterations={}".format(n_power_iterations) 46 | ) 47 | 48 | weight = getattr(module, self.name + "_orig") 49 | u = getattr(module, self.name + "_u") 50 | v = getattr(module, self.name + "_v") 51 | weight_mat = weight 52 | if self.dim != 0: 53 | # permute dim to front 54 | weight_mat = weight_mat.permute( 55 | self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] 56 | ) 57 | height = weight_mat.size(0) 58 | weight_mat = weight_mat.reshape(height, -1) 59 | with torch.no_grad(): 60 | for _ in range(n_power_iterations): 61 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 62 | # are the first left and right singular vectors. 63 | # This power iteration produces approximations of `u` and `v`. 64 | v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 65 | u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 66 | setattr(module, self.name + "_u", u) 67 | setattr(module, self.name + "_v", v) 68 | 69 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 70 | weight = weight / sigma 71 | setattr(module, self.name, weight) 72 | 73 | def remove(self, module): 74 | weight = getattr(module, self.name) 75 | delattr(module, self.name) 76 | delattr(module, self.name + "_u") 77 | delattr(module, self.name + "_orig") 78 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 79 | 80 | def get_update_method(self, module): 81 | def update_fn(module, n_power_iterations): 82 | self.compute_weight(module, n_power_iterations) 83 | 84 | return update_fn 85 | 86 | def __call__(self, module, unused_inputs): 87 | del unused_inputs 88 | self.compute_weight(module, n_power_iterations=0) 89 | 90 | # requires_grad might be either True or False during inference. 91 | if not module.training: 92 | r_g = getattr(module, self.name + "_orig").requires_grad 93 | setattr( 94 | module, 95 | self.name, 96 | getattr(module, self.name).detach().requires_grad_(r_g), 97 | ) 98 | 99 | @staticmethod 100 | def apply(module, name, dim, eps): 101 | fn = SpectralNorm(name, dim, eps) 102 | weight = module._parameters[name] 103 | height = weight.size(dim) 104 | 105 | u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 106 | v = normalize( 107 | weight.new_empty(int(weight.numel() / height)).normal_(0, 1), 108 | dim=0, 109 | eps=fn.eps, 110 | ) 111 | delattr(module, fn.name) 112 | module.register_parameter(fn.name + "_orig", weight) 113 | # We still need to assign weight back as fn.name because all sorts of 114 | # things may assume that it exists, e.g., when initializing weights. 115 | # However, we can't directly assign as it could be an nn.Parameter and 116 | # gets added as a parameter. Instead, we register weight.data as a 117 | # buffer, which will cause weight to be included in the state dict 118 | # and also supports nn.init due to shared storage. 119 | module.register_buffer(fn.name, weight.data) 120 | module.register_buffer(fn.name + "_u", u) 121 | module.register_buffer(fn.name + "_v", v) 122 | 123 | setattr( 124 | module, 125 | POWER_ITERATION_FN, 126 | types.MethodType(fn.get_update_method(module), module), 127 | ) 128 | 129 | module.register_forward_pre_hook(fn) 130 | return fn 131 | 132 | 133 | def inplace_spectral_norm(module, name="weight", dim=None, eps=1e-12): 134 | r"""Applies spectral normalization to a parameter in the given module. 135 | 136 | .. math:: 137 | \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ 138 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 139 | 140 | Spectral normalization stabilizes the training of discriminators (critics) 141 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor 142 | with spectral norm :math:`\sigma` of the weight matrix calculated using 143 | power iteration method. If the dimension of the weight tensor is greater 144 | than 2, it is reshaped to 2D in power iteration method to get spectral 145 | norm. This is implemented via a hook that calculates spectral norm and 146 | rescales weight before every :meth:`~Module.forward` call. 147 | 148 | See `Spectral Normalization for Generative Adversarial Networks`_ . 149 | 150 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 151 | 152 | Args: 153 | module (nn.Module): containing module 154 | name (str, optional): name of weight parameter 155 | n_power_iterations (int, optional): number of power iterations to 156 | calculate spectal norm 157 | dim (int, optional): dimension corresponding to number of outputs, 158 | the default is 0, except for modules that are instances of 159 | ConvTranspose1/2/3d, when it is 1 160 | eps (float, optional): epsilon for numerical stability in 161 | calculating norms 162 | 163 | Returns: 164 | The original module with the spectal norm hook 165 | 166 | Example:: 167 | 168 | >>> m = spectral_norm(nn.Linear(20, 40)) 169 | Linear (20 -> 40) 170 | >>> m.weight_u.size() 171 | torch.Size([20]) 172 | 173 | """ 174 | if dim is None: 175 | if isinstance( 176 | module, 177 | ( 178 | torch.nn.ConvTranspose1d, 179 | torch.nn.ConvTranspose2d, 180 | torch.nn.ConvTranspose3d, 181 | ), 182 | ): 183 | dim = 1 184 | else: 185 | dim = 0 186 | SpectralNorm.apply(module, name, dim=dim, eps=eps) 187 | return module 188 | 189 | 190 | def remove_spectral_norm(module, name="weight"): 191 | r"""Removes the spectral normalization reparameterization from a module. 192 | 193 | Args: 194 | module (nn.Module): containing module 195 | name (str, optional): name of weight parameter 196 | 197 | Example: 198 | >>> m = spectral_norm(nn.Linear(40, 10)) 199 | >>> remove_spectral_norm(m) 200 | """ 201 | for k, hook in module._forward_pre_hooks.items(): 202 | if isinstance(hook, SpectralNorm) and hook.name == name: 203 | hook.remove(module) 204 | del module._forward_pre_hooks[k] 205 | return module 206 | 207 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 208 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Yulia Rubanova 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/YuliaRubanova/latent_ode 23 | 24 | import logging 25 | import math 26 | import os 27 | from numbers import Number 28 | 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import torch.optim as optim 33 | 34 | 35 | def optimizer_factory(args, parameter_list): 36 | if args.optimizer == "adam": 37 | optimizer = optim.Adam( 38 | parameter_list, 39 | lr=args.lr, 40 | weight_decay=args.weight_decay, 41 | amsgrad=args.amsgrad, 42 | ) 43 | elif args.optimizer == "sgd": 44 | optimizer = optim.SGD(parameter_list, lr=args.lr, momentum=args.momentum) 45 | else: 46 | raise NotImplementedError 47 | num_params = sum(p.numel() for p in parameter_list if p.requires_grad) 48 | return optimizer, num_params 49 | 50 | 51 | def linspace_vector(start, end, n_points): 52 | # start is either one value or a vector 53 | size = np.prod(start.size()) 54 | 55 | assert start.size() == end.size() 56 | if size == 1: 57 | # start and end are 1d-tensors 58 | res = torch.linspace(start, end, n_points) 59 | else: 60 | # start and end are vectors 61 | res = torch.Tensor() 62 | for i in range(0, start.size(0)): 63 | res = torch.cat((res, torch.linspace(start[i], end[i], n_points)), 0) 64 | res = torch.t(res.reshape(start.size(0), n_points)) 65 | return res 66 | 67 | 68 | def split_last_dim(data): 69 | last_dim = data.size()[-1] 70 | last_dim = last_dim // 2 71 | 72 | if len(data.size()) == 3: 73 | res = data[:, :, :last_dim], data[:, :, last_dim:] 74 | 75 | if len(data.size()) == 2: 76 | res = data[:, :last_dim], data[:, last_dim:] 77 | return res 78 | 79 | 80 | def check_mask(data, mask): 81 | # check that "mask" argument indeed contains a mask for data 82 | n_zeros = torch.sum(mask == 0.0).cpu().numpy() 83 | n_ones = torch.sum(mask == 1.0).cpu().numpy() 84 | 85 | # mask should contain only zeros and ones 86 | assert (n_zeros + n_ones) == np.prod(list(mask.size())) 87 | 88 | # all masked out elements should be zeros 89 | assert torch.sum(data[mask == 0.0] != 0.0) == 0 90 | 91 | 92 | def init_network_weights(net, std=0.1): 93 | for m in net.modules(): 94 | if isinstance(m, nn.Linear): 95 | nn.init.normal_(m.weight, mean=0, std=std) 96 | nn.init.constant_(m.bias, val=0) 97 | 98 | 99 | def create_net(n_inputs, n_outputs, n_layers=1, n_units=100, nonlinear=nn.Tanh): 100 | layers = [nn.Linear(n_inputs, n_units)] 101 | for i in range(n_layers): 102 | layers.append(nonlinear()) 103 | layers.append(nn.Linear(n_units, n_units)) 104 | 105 | layers.append(nonlinear()) 106 | layers.append(nn.Linear(n_units, n_outputs)) 107 | return nn.Sequential(*layers) 108 | 109 | 110 | def get_device(tensor): 111 | device = torch.device("cpu") 112 | if tensor.is_cuda: 113 | device = tensor.get_device() 114 | return device 115 | 116 | 117 | def sample_standard_gaussian(mu, sigma): 118 | device = get_device(mu) 119 | d = torch.distributions.normal.Normal( 120 | torch.Tensor([0.0]).to(device), torch.Tensor([1.0]).to(device) 121 | ) 122 | r = d.sample(mu.size()).squeeze(-1) 123 | return r * sigma.float() + mu.float() 124 | 125 | 126 | def makedirs(dirname): 127 | if not os.path.exists(dirname): 128 | os.makedirs(dirname) 129 | 130 | 131 | def get_logger( 132 | logpath, filepath, package_files=[], displaying=True, saving=True, debug=False 133 | ): 134 | logger = logging.getLogger() 135 | if debug: 136 | level = logging.DEBUG 137 | else: 138 | level = logging.INFO 139 | logger.setLevel(level) 140 | if saving: 141 | info_file_handler = logging.FileHandler(logpath, mode="a") 142 | info_file_handler.setLevel(level) 143 | logger.addHandler(info_file_handler) 144 | if displaying: 145 | console_handler = logging.StreamHandler() 146 | console_handler.setLevel(level) 147 | logger.addHandler(console_handler) 148 | logger.info(filepath) 149 | with open(filepath, "r") as f: 150 | logger.info(f.read()) 151 | 152 | for f in package_files: 153 | logger.info(f) 154 | with open(f, "r") as package_f: 155 | logger.info(package_f.read()) 156 | 157 | return logger 158 | 159 | 160 | class AverageMeter(object): 161 | """Computes and stores the average and current value""" 162 | 163 | def __init__(self): 164 | self.reset() 165 | 166 | def reset(self): 167 | self.val = 0 168 | self.avg = 0 169 | self.sum = 0 170 | self.count = 0 171 | 172 | def update(self, val, n=1): 173 | self.val = val 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count 177 | 178 | 179 | class RunningAverageMeter(object): 180 | """Computes and stores the average and current value""" 181 | 182 | def __init__(self, momentum=0.99): 183 | self.momentum = momentum 184 | self.reset() 185 | 186 | def reset(self): 187 | self.val = None 188 | self.avg = 0 189 | 190 | def update(self, val): 191 | if self.val is None: 192 | self.avg = val 193 | else: 194 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 195 | self.val = val 196 | 197 | 198 | def inf_generator(iterable): 199 | """Allows training with DataLoaders in a single infinite loop: 200 | for i, (x, y) in enumerate(inf_generator(train_loader)): 201 | """ 202 | iterator = iterable.__iter__() 203 | while True: 204 | try: 205 | yield iterator.__next__() 206 | except StopIteration: 207 | iterator = iterable.__iter__() 208 | 209 | 210 | def save_checkpoint(state, save, epoch): 211 | if not os.path.exists(save): 212 | os.makedirs(save) 213 | filename = os.path.join(save, "checkpt-%04d.pth" % epoch) 214 | torch.save(state, filename) 215 | 216 | 217 | def isnan(tensor): 218 | return tensor != tensor 219 | 220 | 221 | def logsumexp(value, dim=None, keepdim=False): 222 | """Numerically stable implementation of the operation 223 | value.exp().sum(dim, keepdim).log() 224 | """ 225 | if dim is not None: 226 | m, _ = torch.max(value, dim=dim, keepdim=True) 227 | value0 = value - m 228 | if keepdim is False: 229 | m = m.squeeze(dim) 230 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 231 | else: 232 | m = torch.max(value) 233 | sum_exp = torch.sum(torch.exp(value - m)) 234 | if isinstance(sum_exp, Number): 235 | return m + math.log(sum_exp) 236 | else: 237 | return m + torch.log(sum_exp) 238 | -------------------------------------------------------------------------------- /ode_rnn_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # Copyright (c) 2019 Yulia Rubanova 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import lib.utils as utils 10 | # This code is based on latent ODE project which can be found at: https://github.com/YuliaRubanova/latent_ode copy 11 | import torch.nn as nn 12 | from lib.diffeq_solver import DiffeqSolver 13 | from lib.encoder_decoder import Encoder_z0_ODE_RNN 14 | from lib.ode_func import ODEFunc 15 | 16 | 17 | def create_ode_rnn_encoder(args, device): 18 | """ 19 | This function create the ode-rnn model as an encoder 20 | args: the arguments from parse_arguments in ctfp_tools 21 | device: cpu or gpu to run the model 22 | return an ode-rnn model 23 | """ 24 | enc_input_dim = args.input_size * 2 ## concatenate the mask with input 25 | 26 | ode_func_net = utils.create_net( 27 | args.rec_size, 28 | args.rec_size, 29 | n_layers=args.rec_layers, 30 | n_units=args.units, 31 | nonlinear=nn.Tanh, 32 | ) 33 | 34 | rec_ode_func = ODEFunc( 35 | input_dim=enc_input_dim, 36 | latent_dim=args.rec_size, 37 | ode_func_net=ode_func_net, 38 | device=device, 39 | ).to(device) 40 | 41 | z0_diffeq_solver = DiffeqSolver( 42 | enc_input_dim, 43 | rec_ode_func, 44 | "euler", 45 | args.latent_size, 46 | odeint_rtol=1e-3, 47 | odeint_atol=1e-4, 48 | device=device, 49 | ) 50 | 51 | encoder_z0 = Encoder_z0_ODE_RNN( 52 | args.rec_size, 53 | enc_input_dim, 54 | z0_diffeq_solver, 55 | z0_dim=args.latent_size, 56 | n_gru_units=args.gru_units, 57 | device=device, 58 | ).to(device) 59 | return encoder_z0 60 | -------------------------------------------------------------------------------- /page1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/continuous-time-flow-process/d380a7984d408d1f6c849219028cc7baffdd1e1e/page1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.2.0 2 | cycler==0.10.0 3 | decorator==4.4.2 4 | ipdb==0.11 5 | ipython==7.16.1 6 | ipython-genutils==0.2.0 7 | jedi==0.17.2 8 | joblib==0.17.0 9 | kiwisolver==1.3.0 10 | matplotlib==3.1.3 11 | numpy==1.17.4 12 | parso==0.7.1 13 | pexpect==4.8.0 14 | pickleshare==0.7.5 15 | Pillow==8.0.1 16 | prompt-toolkit==3.0.8 17 | protobuf==3.13.0 18 | ptyprocess==0.6.0 19 | Pygments==2.7.2 20 | pyparsing==2.4.7 21 | python-dateutil==2.8.1 22 | scikit-learn==0.23.1 23 | scipy==1.5.3 24 | six==1.13.0 25 | tensorboardX==2.0 26 | threadpoolctl==2.1.0 27 | torch==1.4.0+cu92 28 | torchdiffeq==0.0.1 29 | torchvision==0.5.0+cu92 30 | traitlets==4.3.3 31 | wcwidth==0.2.5 32 | -------------------------------------------------------------------------------- /train_ctfp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import os.path as osp 10 | import time 11 | 12 | import lib.utils as utils 13 | import numpy as np 14 | import torch 15 | from lib.utils import optimizer_factory 16 | 17 | from bm_sequential import get_dataset 18 | from ctfp_tools import build_augmented_model_tabular 19 | from ctfp_tools import run_ctfp_model as run_model, parse_arguments 20 | from train_misc import ( 21 | create_regularization_fns, 22 | get_regularization, 23 | append_regularization_to_log, 24 | ) 25 | from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time 26 | 27 | RUNNINGAVE_PARAM = 0.7 28 | torch.backends.cudnn.benchmark = True 29 | 30 | 31 | def save_model(args, aug_model, optimizer, epoch, itr, save_path): 32 | """ 33 | save CTFP model's checkpoint during training 34 | 35 | Parameters: 36 | args: the arguments from parse_arguments in ctfp_tools 37 | aug_model: the CTFP Model 38 | optimizer: optimizer of CTFP model 39 | epoch: training epoch 40 | itr: training iteration 41 | save_path: path to save the model 42 | """ 43 | torch.save( 44 | { 45 | "args": args, 46 | "state_dict": aug_model.module.state_dict() 47 | if torch.cuda.is_available() and not args.use_cpu 48 | else aug_model.state_dict(), 49 | "optim_state_dict": optimizer.state_dict(), 50 | "last_epoch": epoch, 51 | "iter": itr, 52 | }, 53 | save_path, 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | args = parse_arguments() 59 | # logger 60 | utils.makedirs(args.save) 61 | logger = utils.get_logger( 62 | logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__) 63 | ) 64 | 65 | if args.layer_type == "blend": 66 | logger.info( 67 | "!! Setting time_length from None to 1.0 due to use of Blend layers." 68 | ) 69 | args.time_length = 1.0 70 | logger.info(args) 71 | if not args.no_tb_log: 72 | from tensorboardX import SummaryWriter 73 | 74 | writer = SummaryWriter(osp.join(args.save, "tb_logs")) 75 | writer.add_text("args", str(args)) 76 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 77 | # get deivce 78 | if args.use_cpu: 79 | device = torch.device("cpu") 80 | cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) 81 | 82 | # load dataset 83 | train_loader, val_loader = get_dataset(args) 84 | 85 | # build model 86 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 87 | 88 | aug_model = build_augmented_model_tabular( 89 | args, 90 | args.aug_size + args.effective_shape, 91 | regularization_fns=regularization_fns, 92 | ) 93 | 94 | set_cnf_options(args, aug_model) 95 | logger.info(aug_model) 96 | 97 | logger.info( 98 | "Number of trainable parameters: {}".format(count_parameters(aug_model)) 99 | ) 100 | 101 | # optimizer 102 | parameter_list = list(aug_model.parameters()) 103 | optimizer, num_params = optimizer_factory(args, parameter_list) 104 | print("Num of Parameters: %d" % num_params) 105 | 106 | # restore parameters 107 | itr = 0 108 | if args.resume is not None: 109 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 110 | aug_model.load_state_dict(checkpt["state_dict"]) 111 | if "optim_state_dict" in checkpt.keys(): 112 | optimizer.load_state_dict(checkpt["optim_state_dict"]) 113 | # Manually move optimizer state to device. 114 | for state in optimizer.state.values(): 115 | for k, v in state.items(): 116 | if torch.is_tensor(v): 117 | state[k] = cvt(v) 118 | if "iter" in checkpt.keys(): 119 | itr = checkpt["iter"] 120 | if "last_epoch" in checkpt.keys(): 121 | args.begin_epoch = checkpt["last_epoch"] + 1 122 | 123 | if torch.cuda.is_available() and not args.use_cpu: 124 | aug_model = torch.nn.DataParallel(aug_model).cuda() 125 | 126 | # For visualization. 127 | 128 | time_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 129 | loss_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 130 | steps_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 131 | grad_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 132 | tt_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 133 | 134 | best_loss = float("inf") 135 | for epoch in range(args.begin_epoch, args.num_epochs + 1): 136 | aug_model.train() 137 | for temp_idx, x in enumerate(train_loader): 138 | ## x is a tuple of (values, times, stdv, masks) 139 | start = time.time() 140 | optimizer.zero_grad() 141 | 142 | # cast data and move to device 143 | x = map(cvt, x) 144 | values, times, vars, masks = x 145 | # compute loss 146 | loss = run_model(args, aug_model, values, times, vars, masks) 147 | 148 | total_time = count_total_time(aug_model) 149 | ## Assume the base distribution be Brownian motion 150 | 151 | if regularization_coeffs: 152 | reg_states = get_regularization(aug_model, regularization_coeffs) 153 | reg_loss = sum( 154 | reg_state * coeff 155 | for reg_state, coeff in zip(reg_states, regularization_coeffs) 156 | if coeff != 0 157 | ) 158 | loss = loss + reg_loss 159 | 160 | loss.backward() 161 | grad_norm = torch.nn.utils.clip_grad_norm_( 162 | aug_model.parameters(), args.max_grad_norm 163 | ) 164 | optimizer.step() 165 | 166 | time_meter.update(time.time() - start) 167 | loss_meter.update(loss.item()) 168 | steps_meter.update(count_nfe(aug_model)) 169 | grad_meter.update(grad_norm) 170 | tt_meter.update(total_time) 171 | 172 | if not args.no_tb_log: 173 | writer.add_scalar("train/NLL", loss.cpu().data.item(), itr) 174 | 175 | if itr % args.log_freq == 0: 176 | log_message = ( 177 | "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | " 178 | "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time " 179 | "{:.2f}({:.2f})".format( 180 | itr, 181 | time_meter.val, 182 | time_meter.avg, 183 | loss_meter.val, 184 | loss_meter.avg, 185 | steps_meter.val, 186 | steps_meter.avg, 187 | grad_meter.val, 188 | grad_meter.avg, 189 | tt_meter.val, 190 | tt_meter.avg, 191 | ) 192 | ) 193 | if regularization_coeffs: 194 | log_message = append_regularization_to_log( 195 | log_message, regularization_fns, reg_states 196 | ) 197 | logger.info(log_message) 198 | 199 | itr += 1 200 | 201 | if epoch % args.val_freq == 0: 202 | with torch.no_grad(): 203 | start = time.time() 204 | logger.info("validating...") 205 | losses = [] 206 | num_observes = [] 207 | aug_model.eval() 208 | for temp_idx, x in enumerate(val_loader): 209 | ## x is a tuple of (values, times, stdv, masks) 210 | start = time.time() 211 | 212 | # cast data and move to device 213 | x = map(cvt, x) 214 | values, times, vars, masks = x 215 | loss = run_model(args, aug_model, values, times, vars, masks) 216 | # compute loss 217 | losses.append(loss.data.cpu().numpy()) 218 | num_observes.append(torch.sum(masks).data.cpu().numpy()) 219 | 220 | loss = np.sum(np.array(losses) * np.array(num_observes)) / np.sum( 221 | num_observes 222 | ) 223 | if not args.no_tb_log: 224 | writer.add_scalar("val/NLL", loss, epoch) 225 | logger.info( 226 | "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format( 227 | epoch, time.time() - start, loss 228 | ) 229 | ) 230 | 231 | save_model( 232 | args, 233 | aug_model, 234 | optimizer, 235 | epoch, 236 | itr, 237 | os.path.join(args.save, "checkpt_last.pth"), 238 | ) 239 | save_model( 240 | args, 241 | aug_model, 242 | optimizer, 243 | epoch, 244 | itr, 245 | os.path.join(args.save, "checkpt_%d.pth") % (epoch), 246 | ) 247 | 248 | if loss < best_loss: 249 | best_loss = loss 250 | save_model( 251 | args, 252 | aug_model, 253 | optimizer, 254 | epoch, 255 | itr, 256 | os.path.join(args.save, "checkpt_best.pth"), 257 | ) 258 | -------------------------------------------------------------------------------- /train_latent_ctfp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present Royal Bank of Canada 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import os.path as osp 9 | import time 10 | 11 | import lib.utils as utils 12 | import numpy as np 13 | import torch 14 | from lib.utils import optimizer_factory 15 | 16 | from bm_sequential import get_dataset 17 | from ctfp_tools import build_augmented_model_tabular 18 | from ctfp_tools import parse_arguments 19 | from ctfp_tools import run_latent_ctfp_model as run_model 20 | from ode_rnn_encoder import create_ode_rnn_encoder 21 | from train_misc import ( 22 | create_regularization_fns, 23 | get_regularization, 24 | append_regularization_to_log, 25 | ) 26 | from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time 27 | 28 | RUNNINGAVE_PARAM = 0.7 29 | torch.backends.cudnn.benchmark = True 30 | 31 | 32 | def save_model( 33 | args, aug_model, encoder, optimizer, epoch, itr, path, encoder_optimizer=None 34 | ): 35 | """ 36 | save latent CTFP model's checkpoint during training 37 | 38 | Parameters: 39 | args: the arguments from parse_arguments in ctfp_tools 40 | aug_model: the CTFP Model as the decoder 41 | encoder: the ode-rnn model as the encoder 42 | optimizer: optimizer of the latent CTFP model 43 | all the encoder and decoder parameters 44 | epoch: training epoch 45 | itr: training iteration 46 | path: path to save the model 47 | encoder_optimizer: the optimizer for the encoder, used in aggressive training 48 | """ 49 | torch.save( 50 | { 51 | "args": args, 52 | "state_dict": aug_model.module.state_dict() 53 | if torch.cuda.is_available() and not args.use_cpu 54 | else aug_model.state_dict(), 55 | "encoder_state_dict": encoder.module.state_dict() 56 | if torch.cuda.is_available() and not args.use_cpu 57 | else encoder.state_dict(), 58 | "optim_state_dict": optimizer.state_dict(), 59 | "enc_optim_state_dict": encoder_optimizer.state_dict() 60 | if args.aggressive 61 | else None, 62 | "last_epoch": epoch, 63 | "iter": itr, 64 | }, 65 | path, 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_arguments() 71 | # logger 72 | utils.makedirs(args.save) 73 | logger = utils.get_logger( 74 | logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__) 75 | ) 76 | 77 | if args.layer_type == "blend": 78 | logger.info( 79 | "!! Setting time_length from None to 1.0 due to use of Blend layers." 80 | ) 81 | args.time_length = 1.0 82 | 83 | logger.info(args) 84 | if not args.no_tb_log: 85 | from tensorboardX import SummaryWriter 86 | 87 | writer = SummaryWriter(osp.join(args.save, "tb_logs")) 88 | writer.add_text("args", str(args)) 89 | # get deivce 90 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 91 | if args.use_cpu: 92 | device = torch.device("cpu") 93 | cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) 94 | 95 | # load dataset 96 | train_loader, val_loader = get_dataset(args) 97 | 98 | # build model 99 | # Build the encoder 100 | if args.encoder == "ode_rnn": 101 | encoder = create_ode_rnn_encoder(args, device) 102 | else: 103 | raise NotImplementedError 104 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 105 | 106 | aug_model = build_augmented_model_tabular( 107 | args, 108 | args.aug_size + args.effective_shape + args.latent_size, 109 | regularization_fns=regularization_fns, 110 | ) 111 | 112 | set_cnf_options(args, aug_model) 113 | logger.info(aug_model) 114 | logger.info( 115 | "Number of trainable parameters: {}".format(count_parameters(aug_model)) 116 | ) 117 | 118 | # optimizer 119 | parameter_list = list(aug_model.parameters()) + list(encoder.parameters()) 120 | optimizer, num_params = optimizer_factory(args, parameter_list) 121 | 122 | if args.aggressive: 123 | encoder_optimizer, enc_num_params = optimizer_factory( 124 | args, encoder.parameters() 125 | ) 126 | else: 127 | encoder_optimizer = None 128 | enc_num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) 129 | print( 130 | "Total Num of Parameters: %d, Encoder Num of Parameters: %d" 131 | % (num_params + enc_num_params, enc_num_params) 132 | ) 133 | 134 | # restore parameters 135 | itr = 0 136 | if args.resume is not None: 137 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 138 | aug_model.load_state_dict(checkpt["state_dict"]) 139 | encoder.load_state_dict(checkpt["encoder_state_dict"]) 140 | if "optim_state_dict" in checkpt.keys(): 141 | optimizer.load_state_dict(checkpt["optim_state_dict"]) 142 | # Manually move optimizer state to device. 143 | for state in optimizer.state.values(): 144 | for k, v in state.items(): 145 | if torch.is_tensor(v): 146 | state[k] = cvt(v) 147 | if ( 148 | args.aggressive 149 | and "enc_optim_state_dict" in checkpt.keys() 150 | and checkpt["enc_optim_state_dict"] is not None 151 | ): 152 | encoder_optimizer.load_state_dict(checkpt["enc_optim_state_dict"]) 153 | for state in encoder_optimizer.state.values(): 154 | for k, v in state.items(): 155 | if torch.is_tensor(v): 156 | state[k] = cvt(v) 157 | 158 | if "iter" in checkpt.keys(): 159 | itr = checkpt["iter"] 160 | if "last_epoch" in checkpt.keys(): 161 | args.begin_epoch = checkpt["last_epoch"] + 1 162 | 163 | if torch.cuda.is_available() and not args.use_cpu: 164 | aug_model = torch.nn.DataParallel(aug_model).cuda() 165 | encoder = torch.nn.DataParallel(encoder).cuda() 166 | 167 | # For visualization. 168 | time_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 169 | loss_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 170 | steps_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 171 | grad_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 172 | tt_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM) 173 | 174 | best_loss = float("inf") 175 | for epoch in range(args.begin_epoch, args.num_epochs + 1): 176 | aug_model.train() 177 | encoder.train() 178 | for temp_idx, x in enumerate(train_loader): 179 | ## x is a tuple of (values, times, stdv, masks) 180 | start = time.time() 181 | optimizer.zero_grad() 182 | aug_model.zero_grad() 183 | encoder.zero_grad() 184 | 185 | # cast data and move to device 186 | x = map(cvt, x) 187 | values, times, vars, masks = x 188 | loss, loss_training = run_model( 189 | args, encoder, aug_model, values, times, vars, masks 190 | ) 191 | total_time = count_total_time(aug_model) 192 | 193 | ## Assume the base distribution be Brownian motion 194 | 195 | if regularization_coeffs: 196 | reg_states = get_regularization(aug_model, regularization_coeffs) 197 | reg_loss = sum( 198 | reg_state * coeff 199 | for reg_state, coeff in zip(reg_states, regularization_coeffs) 200 | if coeff != 0 201 | ) 202 | loss_training = loss_training + reg_loss 203 | 204 | loss_training.backward() 205 | 206 | grad_norm = torch.nn.utils.clip_grad_norm_( 207 | list(aug_model.parameters()) + list(encoder.parameters()), 208 | args.max_grad_norm, 209 | ) 210 | 211 | if args.aggressive and (itr == 0 or itr % args.decoder_frequency != 0): 212 | encoder_optimizer.step() 213 | else: 214 | optimizer.step() 215 | 216 | time_meter.update(time.time() - start) 217 | loss_meter.update(loss.item()) 218 | steps_meter.update(count_nfe(aug_model)) 219 | grad_meter.update(grad_norm) 220 | tt_meter.update(total_time) 221 | if not args.no_tb_log: 222 | writer.add_scalar("train/NLL", loss.cpu().data.item(), itr) 223 | # 224 | if itr % args.log_freq == 0: 225 | log_message = ( 226 | "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | " 227 | "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time " 228 | "{:.2f}({:.2f})".format( 229 | itr, 230 | time_meter.val, 231 | time_meter.avg, 232 | loss_meter.val, 233 | loss_meter.avg, 234 | steps_meter.val, 235 | steps_meter.avg, 236 | grad_meter.val, 237 | grad_meter.avg, 238 | tt_meter.val, 239 | tt_meter.avg, 240 | ) 241 | ) 242 | if regularization_coeffs: 243 | log_message = append_regularization_to_log( 244 | log_message, regularization_fns, reg_states 245 | ) 246 | logger.info(log_message) 247 | 248 | itr += 1 249 | 250 | # compute test loss 251 | if epoch % args.val_freq == 0: 252 | aug_model.eval() 253 | encoder.eval() 254 | with torch.no_grad(): 255 | start = time.time() 256 | logger.info("validating...") 257 | losses = [] 258 | num_observes = [] 259 | for temp_idx, x in enumerate(val_loader): 260 | ## x is a tuple of (values, times, stdv, masks) 261 | x = map(cvt, x) 262 | values, times, vars, masks = x 263 | loss = run_model( 264 | args, 265 | encoder, 266 | aug_model, 267 | values, 268 | times, 269 | vars, 270 | masks, 271 | evaluation=True, 272 | ) 273 | losses.append(loss.data.cpu().numpy()) 274 | num_observes.append(torch.sum(masks).data.cpu().numpy()) 275 | 276 | loss = np.sum(np.array(losses) * np.array(num_observes)) / np.sum( 277 | num_observes 278 | ) 279 | if not args.no_tb_log: 280 | writer.add_scalar("val/NLL", loss, epoch) 281 | logger.info( 282 | "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format( 283 | epoch, time.time() - start, loss 284 | ) 285 | ) 286 | 287 | save_model( 288 | args, 289 | aug_model, 290 | encoder, 291 | optimizer, 292 | epoch, 293 | itr, 294 | os.path.join(args.save, "checkpt_last.pth"), 295 | encoder_optimizer, 296 | ) 297 | save_model( 298 | args, 299 | aug_model, 300 | encoder, 301 | optimizer, 302 | epoch, 303 | itr, 304 | os.path.join(args.save, "checkpt_%d.pth") % (epoch), 305 | encoder_optimizer, 306 | ) 307 | 308 | if loss < best_loss: 309 | best_loss = loss 310 | save_model( 311 | args, 312 | aug_model, 313 | encoder, 314 | optimizer, 315 | epoch, 316 | itr, 317 | os.path.join(args.save, "checkpt_best.pth"), 318 | encoder_optimizer, 319 | ) 320 | -------------------------------------------------------------------------------- /train_misc.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # Link: https://github.com/rtqichen/ffjord 23 | 24 | import math 25 | 26 | import lib.layers as layers 27 | import lib.layers.wrappers.cnf_regularization as reg_lib 28 | import lib.spectral_norm as spectral_norm 29 | import six 30 | from lib.layers.odefunc import divergence_bf, divergence_approx 31 | 32 | 33 | def standard_normal_logprob(z): 34 | logZ = -0.5 * math.log(2 * math.pi) 35 | return logZ - z.pow(2) / 2 36 | 37 | 38 | def set_cnf_options(args, model): 39 | def _set(module): 40 | if isinstance(module, layers.CNF): 41 | # Set training settings 42 | module.solver = args.solver 43 | module.atol = args.atol 44 | module.rtol = args.rtol 45 | if args.step_size is not None: 46 | module.solver_options["step_size"] = args.step_size 47 | 48 | # If using fixed-grid adams, restrict order to not be too high. 49 | if args.solver in ["fixed_adams", "explicit_adams"]: 50 | module.solver_options["max_order"] = 4 51 | 52 | # Set the test settings 53 | module.test_solver = args.test_solver if args.test_solver else args.solver 54 | module.test_atol = args.test_atol if args.test_atol else args.atol 55 | module.test_rtol = args.test_rtol if args.test_rtol else args.rtol 56 | 57 | if isinstance(module, layers.ODEfunc): 58 | module.rademacher = args.rademacher 59 | module.residual = args.residual 60 | 61 | model.apply(_set) 62 | 63 | 64 | def override_divergence_fn(model, divergence_fn): 65 | def _set(module): 66 | if isinstance(module, layers.ODEfunc): 67 | if divergence_fn == "brute_force": 68 | module.divergence_fn = divergence_bf 69 | elif divergence_fn == "approximate": 70 | module.divergence_fn = divergence_approx 71 | 72 | model.apply(_set) 73 | 74 | 75 | def count_nfe(model): 76 | class AccNumEvals(object): 77 | def __init__(self): 78 | self.num_evals = 0 79 | 80 | def __call__(self, module): 81 | if isinstance(module, layers.ODEfunc): 82 | self.num_evals += module.num_evals() 83 | 84 | accumulator = AccNumEvals() 85 | model.apply(accumulator) 86 | return accumulator.num_evals 87 | 88 | 89 | def count_parameters(model): 90 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | 92 | 93 | def count_total_time(model): 94 | class Accumulator(object): 95 | def __init__(self): 96 | self.total_time = 0 97 | 98 | def __call__(self, module): 99 | if isinstance(module, layers.CNF): 100 | self.total_time = ( 101 | self.total_time + module.sqrt_end_time * module.sqrt_end_time 102 | ) 103 | 104 | accumulator = Accumulator() 105 | model.apply(accumulator) 106 | return accumulator.total_time 107 | 108 | 109 | def add_spectral_norm(model, logger=None): 110 | """Applies spectral norm to all modules within the scope of a CNF.""" 111 | 112 | def apply_spectral_norm(module): 113 | if "weight" in module._parameters: 114 | if logger: 115 | logger.info("Adding spectral norm to {}".format(module)) 116 | spectral_norm.inplace_spectral_norm(module, "weight") 117 | 118 | def find_cnf(module): 119 | if isinstance(module, layers.CNF): 120 | module.apply(apply_spectral_norm) 121 | else: 122 | for child in module.children(): 123 | find_cnf(child) 124 | 125 | find_cnf(model) 126 | 127 | 128 | def spectral_norm_power_iteration(model, n_power_iterations=1): 129 | def recursive_power_iteration(module): 130 | if hasattr(module, spectral_norm.POWER_ITERATION_FN): 131 | getattr(module, spectral_norm.POWER_ITERATION_FN)(n_power_iterations) 132 | 133 | model.apply(recursive_power_iteration) 134 | 135 | 136 | REGULARIZATION_FNS = { 137 | "l1int": reg_lib.l1_regularzation_fn, 138 | "l2int": reg_lib.l2_regularzation_fn, 139 | "dl2int": reg_lib.directional_l2_regularization_fn, 140 | "JFrobint": reg_lib.jacobian_frobenius_regularization_fn, 141 | "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn, 142 | "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn, 143 | } 144 | 145 | INV_REGULARIZATION_FNS = {v: k for k, v in six.iteritems(REGULARIZATION_FNS)} 146 | 147 | 148 | def append_regularization_to_log(log_message, regularization_fns, reg_states): 149 | for i, reg_fn in enumerate(regularization_fns): 150 | log_message = ( 151 | log_message 152 | + " | " 153 | + INV_REGULARIZATION_FNS[reg_fn] 154 | + ": {:.8f}".format(reg_states[i].item()) 155 | ) 156 | return log_message 157 | 158 | 159 | def create_regularization_fns(args): 160 | regularization_fns = [] 161 | regularization_coeffs = [] 162 | 163 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): 164 | if getattr(args, arg_key) is not None: 165 | regularization_fns.append(reg_fn) 166 | regularization_coeffs.append(eval("args." + arg_key)) 167 | 168 | regularization_fns = tuple(regularization_fns) 169 | regularization_coeffs = tuple(regularization_coeffs) 170 | return regularization_fns, regularization_coeffs 171 | 172 | 173 | def get_regularization(model, regularization_coeffs): 174 | if len(regularization_coeffs) == 0: 175 | return None 176 | 177 | acc_reg_states = tuple([0.0] * len(regularization_coeffs)) 178 | for module in model.modules(): 179 | if isinstance(module, layers.CNF): 180 | acc_reg_states = tuple( 181 | acc + reg 182 | for acc, reg in zip(acc_reg_states, module.get_regularization_states()) 183 | ) 184 | return acc_reg_states 185 | 186 | 187 | def build_model_tabular(args, dims, regularization_fns=None): 188 | hidden_dims = tuple(map(int, args.dims.split("-"))) 189 | 190 | def build_cnf(): 191 | diffeq = layers.ODEnet( 192 | hidden_dims=hidden_dims, 193 | input_shape=(dims,), 194 | strides=None, 195 | conv=False, 196 | layer_type=args.layer_type, 197 | nonlinearity=args.nonlinearity, 198 | ) 199 | odefunc = layers.ODEfunc( 200 | diffeq=diffeq, 201 | divergence_fn=args.divergence_fn, 202 | residual=args.residual, 203 | rademacher=args.rademacher, 204 | ) 205 | cnf = layers.CNF( 206 | odefunc=odefunc, 207 | T=args.time_length, 208 | train_T=args.train_T, 209 | regularization_fns=regularization_fns, 210 | solver=args.solver, 211 | ) 212 | return cnf 213 | 214 | chain = [build_cnf() for _ in range(args.num_blocks)] 215 | if args.batch_norm: 216 | bn_layers = [ 217 | layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) 218 | for _ in range(args.num_blocks) 219 | ] 220 | bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] 221 | for a, b in zip(chain, bn_layers): 222 | bn_chain.append(a) 223 | bn_chain.append(b) 224 | chain = bn_chain 225 | model = layers.SequentialFlow(chain) 226 | 227 | set_cnf_options(args, model) 228 | 229 | return model 230 | --------------------------------------------------------------------------------