├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.demo.yaml ├── config.yaml ├── densities.py ├── flows.py ├── images ├── demo.png ├── discrete-c-concave.png ├── fig3.png └── table2.png ├── main.py ├── manifolds.py ├── plot-components.py ├── plot-demo.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | exp 2 | exp_local 3 | t -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Riemannian Convex Potential Maps 2 | 3 | This repository is by 4 | [Brandon Amos](http://bamos.github.io), 5 | [Samuel Cohen](https://samcohen16.github.io/) 6 | and 7 | [Yaron Lipman](http://www.wisdom.weizmann.ac.il/~ylipman/) 8 | and contains the [JAX](https://jax.readthedocs.io/en/latest/) 9 | source code to reproduce the 10 | experiments in our ICML 2021 paper on 11 | [Riemannian Convex Potential Maps](https://arxiv.org/abs/2106.10272). 12 | 13 | 14 | > Modeling distributions on Riemannian manifolds is a crucial 15 | > component in understanding non-Euclidean data that arises, e.g., in 16 | > physics and geology. The budding approaches in this space are 17 | > limited by representational and computational tradeoffs. We propose 18 | > and study a class of flows that uses convex potentials from 19 | > Riemannian optimal transport. These are universal and can model 20 | > distributions on any compact Riemannian manifold without requiring 21 | > domain knowledge of the manifold to be integrated into the 22 | > architecture. We demonstrate that these flows can model standard 23 | > distributions on spheres, and tori, on synthetic and geological 24 | > data. 25 | 26 | ![](images/demo.png) 27 | 28 | 29 | # Reproducing our experiments 30 | 31 | [config.yaml](config.yaml) contains the basic config for 32 | setting up our experiments. 33 | We currently use hydra 1.0.3. 34 | By default it contains the options to 35 | reproduce the multimodal sphere flow: 36 | 37 | 38 | 39 | 40 | This can be run with: 41 | 42 | ``` 43 | $ ./main.py 44 | workspace: /private/home/bda/repos/rcpm/exp_local/2021.06.21/053411 45 | Iter 1000 | Loss -10.906 | KL 0.017 | ESS 96.74% | 9.54e-02s/it 46 | Iter 2000 | Loss -10.908 | KL 0.013 | ESS 97.43% | 1.90e-02s/it 47 | Iter 3000 | Loss -10.911 | KL 0.012 | ESS 97.71% | 1.75e-02s/it 48 | Iter 4000 | Loss -10.912 | KL 0.010 | ESS 98.02% | 1.63e-02s/it 49 | Iter 5000 | Loss -10.912 | KL 0.009 | ESS 98.19% | 1.46e-02s/it 50 | ... 51 | Iter 30000 | Loss -10.915 | KL 0.006 | ESS 98.75% | 1.78e-02s/it 52 | ``` 53 | 54 | This will create a work directory in `exp_local` with 55 | the models and debugging information. 56 | You can use 57 | [plot-components.py](plot-components.py) 58 | to further analyze the components of the learned flow, 59 | and 60 | [plot-demo.py](plot-demo.py) 61 | to produce the grid visualization from Figure 2 62 | of our paper. 63 | 64 | # Training with the likelihood 65 | 66 | This can be done for the checkerboard dataset with: 67 | 68 | ``` 69 | $ ./main.py loss=likelihood base=SphereUniform target=SphereCheckerboard 70 | ``` 71 | 72 | # Other JAX sphere flow library 73 | [katalinic/sdflows](https://github.com/katalinic/sdflows) 74 | provides a great JAX re-implementation of 75 | [Normalizing Flows on Tori and Spheres](https://arxiv.org/abs/2002.02428). 76 | 77 | # Citations 78 | If you find this repository helpful for your publications, 79 | please consider citing our paper: 80 | 81 | ``` 82 | @inproceedings{cohen2021riemannian, 83 | title = {{Riemannian Convex Potential Maps}}, 84 | author = {Cohen, Samuel and Amos, Brandon and Lipman, Yaron}, 85 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 86 | pages = {2028--2038}, 87 | year = {2021}, 88 | editor = {Meila, Marina and Zhang, Tong}, 89 | volume = {139}, 90 | series = {Proceedings of Machine Learning Research}, 91 | month = {18--24 Jul}, 92 | publisher = {PMLR}, 93 | pdf = {http://proceedings.mlr.press/v139/cohen21a/cohen21a.pdf}, 94 | url = {https://proceedings.mlr.press/v139/cohen21a.html} 95 | } 96 | ``` 97 | 98 | # Licensing 99 | This repository is licensed under the 100 | [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). 101 | -------------------------------------------------------------------------------- /config.demo.yaml: -------------------------------------------------------------------------------- 1 | # defaults: 2 | # - hydra/sweeper: nevergrad 3 | # - hydra/launcher: submitit_slurm 4 | 5 | seed: 0 6 | exp: t 7 | batch_size: 256 8 | iterations: 10000 9 | eval_samples: 20000 10 | log_frequency: 1000 11 | disable_init_plots: True 12 | disable_evol_plots: True 13 | 14 | sphere: 15 | _target_: manifolds.Sphere 16 | D: 3 17 | jitter: 1e-2 18 | 19 | manifold: ${sphere} 20 | loss: likelihood 21 | base: SphereBaseWrappedNormal 22 | target: SphereDemo 23 | 24 | optim: 25 | _target_: flax.optim.Adam 26 | learning_rate: 1e-3 27 | beta1: 0.9 28 | beta2: 0.999 29 | 30 | flow: 31 | _target_: flows.SequentialFlow 32 | n_transforms: 1 33 | single_transform_cfg: 34 | _target_: flows.ExpMapFlow 35 | potential_cfg: ${infaff_potential} 36 | 37 | 38 | infaff_potential: 39 | _target_: flows.InfAffine 40 | n_components: 68 41 | init_alpha_mode: uniform 42 | init_alpha_linear_scale: 1. 43 | init_alpha_minval: 0.4 44 | init_alpha_range: 0.01 45 | cost_gamma: 0.1 46 | min_zero_gamma: null 47 | 48 | hydra: 49 | run: 50 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 51 | # sweep: 52 | # dir: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${exp} 53 | # subdir: ${hydra.job.num} 54 | # launcher: 55 | # max_num_timeout: 100000 56 | # timeout_min: 4319 57 | # partition: learnfair 58 | # mem_gb: 64 59 | # gpus_per_node: 1 60 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # defaults: 2 | # - hydra/sweeper: nevergrad 3 | # - hydra/launcher: submitit_slurm 4 | 5 | seed: 0 6 | exp: t 7 | batch_size: 256 8 | iterations: 1e6 9 | eval_samples: 20000 10 | log_frequency: 1000 11 | disable_init_plots: True 12 | disable_evol_plots: True 13 | 14 | sphere: 15 | _target_: manifolds.Sphere 16 | D: 3 17 | jitter: 1e-2 18 | 19 | product: 20 | _target_: manifolds.Product 21 | manifolds_str: S1,S1 22 | D: 4 23 | 24 | torus: 25 | _target_: manifolds.Torus 26 | D: 4 27 | 28 | manifold: ${sphere} 29 | loss: kl 30 | # loss: likelihood 31 | base: SphereUniform 32 | # base: SphereBaseWrappedNormal 33 | target: RezendeSphereFourMode 34 | # target: LouSphereSingleMode 35 | # target: LouSphereFourModes 36 | # target: SphereCheckerboard 37 | 38 | # manifold: ${torus} 39 | # loss: likelihood 40 | # base: ProductUniformComponents 41 | # target: RezendeTorusUnimodal 42 | 43 | optim: 44 | _target_: flax.optim.Adam 45 | learning_rate: 1e-3 46 | beta1: 0.9 47 | beta2: 0.999 48 | 49 | flow: 50 | _target_: flows.SequentialFlow 51 | n_transforms: 5 52 | single_transform_cfg: 53 | _target_: flows.ExpMapFlow 54 | # potential_cfg: ${radial_potential} 55 | potential_cfg: ${infaff_potential} 56 | 57 | 58 | infaff_potential: 59 | _target_: flows.InfAffine 60 | n_components: 68 61 | init_alpha_mode: uniform 62 | init_alpha_linear_scale: 1. 63 | init_alpha_minval: 0.4 64 | init_alpha_range: 0.01 65 | cost_gamma: 0.1 66 | min_zero_gamma: null 67 | 68 | 69 | multi_infaff_potential: 70 | _target_: flows.MultiInfAffine 71 | n__components: 200 72 | init_alpha_minval: 0.4 73 | init_alpha_range: 1 74 | n_layers: 3 75 | cost_gamma: 0.05 76 | min_zero_gamma: 0.05 77 | 78 | 79 | radial_potential: 80 | _target_: flows.RadialPotential 81 | n_radial_components: 12 82 | init_beta_minval: 1. 83 | init_beta_range: 2. 84 | 85 | 86 | hydra: 87 | run: 88 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 89 | # sweep: 90 | # dir: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${exp} 91 | # subdir: ${hydra.job.num} 92 | # launcher: 93 | # max_num_timeout: 100000 94 | # timeout_min: 4319 95 | # partition: priority 96 | # comment: ICML 97 | # mem_gb: 64 98 | # gpus_per_node: 1 99 | # sweeper: 100 | # optim: 101 | # optimizer: RandomSearch 102 | # budget: 200 103 | # num_workers: 200 104 | # parametrization: 105 | # optim.learning_rate: 106 | # lower: 5e-6 107 | # upper: 1e-1 108 | # log: True 109 | # step: 10 110 | # optim.beta1: [0.1, 0.3, 0.5, 0.7, 0.9] 111 | # optim.beta2: [0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.999] 112 | # # infaff_potential.n_components: 113 | # # lower: 50 114 | # # upper: 1000 115 | # # integer: True 116 | # # log: True 117 | # # step: 10 118 | # # infaff_potential.init_alpha_minval: 119 | # # lower: 1e-5 120 | # # upper: 10. 121 | # # log: True 122 | # # step: 10 123 | # # infaff_potential.init_alpha_range: 124 | # # lower: 1e-3 125 | # # upper: 1. 126 | # # log: True 127 | # # infaff_potential.cost_gamma: [0.01, 0.05, 0.1, 0.5] 128 | # # infaff_potential.min_zero_gamma: [null, 0.01, 0.05, 0.1, 0.5] 129 | 130 | # multi_infaff_potential.n_components: 131 | # lower: 50 132 | # upper: 1000 133 | # integer: True 134 | # log: True 135 | # step: 10 136 | # multi_infaff_potential.init_alpha_minval: 137 | # lower: 1e-5 138 | # upper: 10. 139 | # log: True 140 | # step: 10 141 | # multi_infaff_potential.init_alpha_range: 142 | # lower: 1e-3 143 | # upper: 1. 144 | # log: True 145 | # step: 10 146 | # multi_infaff_potential.cost_gamma: [0.01, 0.05, 0.1, 0.5] 147 | # multi_infaff_potential.min_zero_gamma: [null, 0.01, 0.05, 0.1, 0.5] 148 | # multi_infaff_potential.n_layers: 149 | # lower: 1 150 | # upper: 5 151 | # integer: True 152 | # step: 1 153 | -------------------------------------------------------------------------------- /densities.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import random 6 | from jax.scipy.stats import norm 7 | 8 | import numpy as np 9 | 10 | from functools import partial 11 | import sys 12 | 13 | from dataclasses import dataclass 14 | from abc import ABC, abstractmethod 15 | from manifolds import Manifold, Sphere 16 | 17 | import utils 18 | import pickle 19 | 20 | from scipy.stats import gaussian_kde 21 | from jax.numpy import newaxis 22 | 23 | def get(manifold, name): 24 | if name == 'SphereBaseWrappedNormal': 25 | assert isinstance(manifold, Sphere) 26 | loc = manifold.zero() 27 | scale = jnp.full(manifold.D-1, .3) 28 | return WrappedNormal(manifold=manifold, loc=loc, scale=scale) 29 | elif name == 'LouSphereSingleMode': 30 | assert isinstance(manifold, Sphere) 31 | loc = manifold.projx(-jnp.ones(manifold.D)) 32 | scale = jnp.full(manifold.D-1, .3) 33 | return WrappedNormal(manifold=manifold, loc=loc, scale=scale) 34 | elif 'Earth' in name: 35 | try: 36 | name, year = name.split('_') 37 | return getattr(sys.modules[__name__], name)(manifold=manifold, year = year) 38 | except: 39 | print(f"Error loading data class {name}") 40 | raise 41 | else: 42 | try: 43 | return getattr(sys.modules[__name__], name)(manifold=manifold) 44 | except: 45 | print(f"Error loading data class {name}") 46 | raise 47 | 48 | def get_uniform(manifold): 49 | if isinstance(manifold, Sphere): 50 | return get(manifold, 'SphereUniform') 51 | else: 52 | assert False 53 | 54 | @dataclass 55 | class Density(ABC): 56 | manifold: Manifold 57 | 58 | @abstractmethod 59 | def log_prob(self, x): 60 | pass 61 | 62 | @abstractmethod 63 | def sample(self, key, n_samples): 64 | pass 65 | 66 | def __hash__(self): return 0 # For jitting 67 | 68 | 69 | @dataclass 70 | class Earth(Density): 71 | year: int 72 | def __post_init__(self): 73 | self.data = pickle.load(open('../../../data/earth_data_sphere_' + self.year + '.pkl','rb')) 74 | self.data = self.data[self.data[:,-1]>-0.8] 75 | self.data = jnp.array(self.data)[::5] 76 | self.data = jax.ops.index_update(self.data, jax.ops.index[:, -1], -self.data[:, -1]) 77 | self.dens = dens = gaussian_kde(self.data.T, 0.1) 78 | L = jnp.linalg.cholesky(self.dens.covariance*2*jnp.pi) 79 | self.log_det = 2*jnp.log(jnp.diag(L)).sum() 80 | self.inv_cov = jnp.array(self.dens.inv_cov) 81 | self.weights = jnp.array(self.dens.weights) 82 | 83 | #Can't pickle kde object 84 | self.dens = 0. 85 | def log_prob(self,xs): 86 | def fun(point): 87 | diff = self.data.T - point 88 | tdiff = jnp.dot(self.inv_cov, diff) 89 | energy = jnp.sum(diff * tdiff, axis=0) 90 | log_to_sum = 2.0 * jnp.log(self.weights) - self.log_det - energy 91 | result = jax.scipy.special.logsumexp(0.5 * log_to_sum) 92 | return result 93 | fun_map = jax.vmap(fun) 94 | return fun_map(xs[:,:,newaxis]) 95 | 96 | def sample(self, key, n_samples): 97 | key, k1 = jax.random.split(key, 2) 98 | indexes = jax.random.randint(k1, [n_samples], 0, self.data.shape[0]-1) 99 | return self.data[indexes] 100 | 101 | 102 | 103 | 104 | class SphereUniform(Density): 105 | def log_prob(self, xs): 106 | # TODO, support other spheres 107 | assert xs.ndim == 2 108 | n_batch, D = xs.shape 109 | assert D == self.manifold.D 110 | 111 | if self.manifold.D == 2: 112 | SA = 2.*jnp.pi 113 | elif self.manifold.D == 3: 114 | SA = 4.*jnp.pi 115 | else: 116 | raise NotImplementedError() 117 | 118 | return jnp.full([n_batch], jnp.log(1. / SA)) 119 | 120 | def sample(self, key, n_samples): 121 | xs = random.normal(key, shape=[n_samples, self.manifold.D]) 122 | return self.manifold.projx(xs) 123 | 124 | 125 | @dataclass 126 | class WrappedNormal(Density): 127 | loc: jnp.ndarray 128 | scale: jnp.ndarray 129 | 130 | def log_prob(self, z): 131 | u = self.manifold.log(self.loc, z) 132 | y = self.manifold.zero_like(self.loc) 133 | v = self.manifold.transp(self.loc, y, u) 134 | v = self.manifold.squeeze_tangent(v) 135 | n_logprob = norm.logpdf(v, scale=self.scale).sum(axis=-1) 136 | logdet = self.manifold.logdetexp(self.loc, u) 137 | assert n_logprob.shape == logdet.shape 138 | log_prob = n_logprob - logdet 139 | return log_prob 140 | 141 | def sample(self, key, n_samples): 142 | v = self.scale * random.normal(key, [n_samples, self.manifold.D-1]) 143 | v = self.manifold.unsqueeze_tangent(v) 144 | x = self.manifold.zero_like(self.loc) 145 | u = self.manifold.transp(x, self.loc, v) 146 | z = self.manifold.exponential_map(self.loc, u) 147 | return z 148 | 149 | def __hash__(self): return 0 # For jitting 150 | 151 | @dataclass 152 | class SphereDemo(Density): 153 | def __post_init__(self): 154 | self.modes = [] 155 | locs = [ 156 | jnp.array([0.3, 1., 1.]), 157 | jnp.array([0.3, -1., 1.]), 158 | jnp.array([0.3, 1., -1.]), 159 | jnp.array([0.3, -1., -1.]), 160 | ] 161 | locs = [self.manifold.projx(loc) for loc in locs] 162 | scale = jnp.full(self.manifold.D-1, .3) 163 | self.dists = [ 164 | WrappedNormal(manifold=self.manifold, loc=loc, scale=scale) 165 | for loc in locs 166 | ] 167 | 168 | def log_prob(self, z): 169 | raise NotImplementedError() 170 | 171 | def sample(self, key, n_samples): 172 | keys = random.split(key, len(self.dists)) 173 | n = int(np.ceil(n_samples/len(self.dists))) 174 | samples = jnp.concatenate([ 175 | d.sample(key, n) for key, d in zip(keys, self.dists) 176 | ], axis=0) 177 | samples = random.permutation(key, samples) 178 | return samples[:n_samples] 179 | 180 | def __hash__(self): return 0 # For jitting 181 | 182 | @dataclass 183 | class LouSphereFourModes(Density): 184 | def __post_init__(self): 185 | self.modes = [] 186 | one = jnp.ones(3) 187 | oned = jnp.ones(3) 188 | oned = jax.ops.index_update(oned, jax.ops.index[2], -1.) 189 | locs = [one, -one, oned, -oned] 190 | locs = [self.manifold.projx(loc) for loc in locs] 191 | scale = jnp.full(self.manifold.D-1, .3) 192 | self.dists = [ 193 | WrappedNormal(manifold=self.manifold, loc=loc, scale=scale) 194 | for loc in locs 195 | ] 196 | 197 | def log_prob(self, z): 198 | raise NotImplementedError() 199 | 200 | def sample(self, key, n_samples): 201 | keys = random.split(key, len(self.dists)) 202 | n = int(np.ceil(n_samples/len(self.dists))) 203 | samples = jnp.concatenate([ 204 | d.sample(key, n) for key, d in zip(keys, self.dists) 205 | ], axis=0) 206 | samples = random.permutation(key, samples) 207 | return samples[:n_samples] 208 | 209 | def __hash__(self): return 0 # For jitting 210 | 211 | 212 | class RezendeSphereFourMode(Density): 213 | # https://github.com/katalinic/sdflows/blob/master/optimisation.py#L12 214 | target_mu = utils.spherical_to_euclidean(jnp.array([ 215 | [1.5, 0.7 + jnp.pi / 2], 216 | [1., -1. + jnp.pi / 2], 217 | [5., 0.6 + jnp.pi / 2], 218 | [4., -0.7 + jnp.pi / 2] 219 | ])) 220 | 221 | def log_prob(self, x): 222 | # TODO: This is unnormalized 223 | assert x.ndim == 2 224 | return jnp.log(jnp.sum(jnp.exp(10. * x.dot(self.target_mu.T)), axis=1)) 225 | 226 | def sample(self, key, n_samples): 227 | raise NotImplementedError() 228 | 229 | class RezendeTorusUnimodal(Density): 230 | psi = [4.18, 5.96] 231 | 232 | def log_prob(self, x): 233 | assert x.ndim == 2 234 | 235 | theta1, theta2 = utils.S1euclideantospherical(x[:,:2]), utils.S1euclideantospherical(x[:,2:]) 236 | 237 | return jnp.log(jnp.exp(jnp.cos(theta1-self.psi[0]) + jnp.cos(theta2-self.psi[1]))) 238 | 239 | def sample(self, key, n_samples): 240 | raise NotImplementedError() 241 | 242 | class RezendeCorrelated(Density): 243 | psi = 1.94 244 | 245 | def log_prob(self, x): 246 | assert x.ndim == 2 247 | 248 | theta1, theta2 = utils.S1euclideantospherical(x[:,:2]), utils.S1euclideantospherical(x[:,2:]) 249 | 250 | return jnp.log(jnp.exp(jnp.cos(theta1 + theta2 - self.psi))) 251 | 252 | def sample(self, key, n_samples): 253 | raise NotImplementedError() 254 | 255 | class SphereCheckerboard(Density): 256 | def log_prob(self, x): 257 | # TODO: Could be optimized 258 | # TODO: Assumes x is uniformly distributed 259 | 260 | lonlat = utils.euclidean_to_spherical(x) 261 | s = jnp.pi/2-.2 # long side length 262 | 263 | def in_board(z, s): 264 | # z is lonlat 265 | lon = z[0] 266 | lat = z[1] 267 | 268 | if np.pi <= lon < np.pi+s or np.pi-2*s <= lon < np.pi-s: 269 | v = np.pi/2 <= lat < np.pi/2+s/2 or \ 270 | np.pi/2-s <= lat < np.pi/2-s/2 271 | elif np.pi-2*s <= lon < np.pi+2*s: 272 | v = np.pi/2+s/2 <= lat < np.pi/2+s or \ 273 | np.pi/2-s/2 <= lat < np.pi/2 274 | else: 275 | v = 0. 276 | 277 | v = float(v) 278 | return v 279 | 280 | probs = [] 281 | for i in range(lonlat.shape[0]): 282 | probs.append(in_board(lonlat[i,:], s)) 283 | probs = jnp.stack(probs) 284 | probs /= jnp.sum(probs) 285 | probs = jnp.log(probs) 286 | return probs 287 | 288 | def sample(self, key, n_samples): 289 | s = jnp.pi/2.-.2 # long side length 290 | offsets = jnp.array([ 291 | (0,0), (s, s/2), (s, -s/2), (0, -s), (-s, s/2), 292 | (-s, -s/2), (-2*s, 0), (-2*s, -s)]) 293 | 294 | # (x,y) ~ uniform([pi,pi + s] times [pi/2, pi/2 + s/2]) 295 | k1, k2, k3 = jax.random.split(key, 3) 296 | x1 = random.uniform(k1, [n_samples]) * s + jnp.pi 297 | x2 = random.uniform(k1, [n_samples]) * s + jnp.pi 298 | x2 = random.uniform(k2, [n_samples]) * s/2. + jnp.pi/2. 299 | 300 | samples = jnp.stack([x1, x2], axis=1) 301 | off = offsets[random.randint( 302 | k3, [n_samples], minval=0, maxval=len(offsets))] 303 | 304 | samples += off 305 | 306 | samples = utils.spherical_to_euclidean(samples) 307 | return samples 308 | 309 | 310 | @dataclass 311 | class ProductUniformComponents(Density): 312 | def __post_init__(self): 313 | self.base_dists = [] 314 | for man in self.manifold.manifolds: 315 | self.base_dists.append(get_uniform(man)) 316 | 317 | def log_prob(self, xs): 318 | #Note this is not necessarily uniform 319 | assert xs.ndim == 2 320 | n_batch = xs.shape[0] 321 | log_probas = jnp.zeros([n_batch]) 322 | d = 0 323 | for i, base_dist in enumerate(self.base_dists): 324 | D = self.manifold.manifolds[i].D 325 | log_probas += base_dist.log_prob(xs[:, d:d+D]) 326 | d = d + D 327 | return log_probas 328 | 329 | def sample(self, key, n_samples): 330 | #Note this is not necessarily uniform 331 | xs = [] 332 | keys = jax.random.split(key, len(self.base_dists)) 333 | for key, base_dist in zip(keys, self.base_dists): 334 | samples_man = base_dist.sample(key = key, n_samples = n_samples) 335 | xs.append(samples_man) 336 | xs = jnp.concatenate(xs, 1) 337 | return xs 338 | 339 | def __hash__(self): return 0 # For jitting 340 | -------------------------------------------------------------------------------- /flows.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import random 6 | from jax.nn import initializers as init 7 | from jax.scipy.special import logsumexp 8 | 9 | from flax import linen as nn 10 | 11 | import hydra 12 | import omegaconf 13 | 14 | from manifolds import Manifold, Sphere, Product 15 | import densities 16 | import utils 17 | 18 | 19 | def init_uniform(minval, maxval, dtype=jnp.float32): 20 | def init(key, shape, dtype=dtype): 21 | return random.uniform(key, shape, dtype, minval=minval, maxval=maxval) 22 | return init 23 | 24 | def init_manifold_samples(dist, dtype=jnp.float32): 25 | def init(key, shape, dtype=dtype): 26 | D, N = shape 27 | samples = dist.sample(key, N).T 28 | assert samples.shape == (D, N) 29 | return samples 30 | return init 31 | 32 | def init_full(val, dtype=jnp.float32): 33 | def init(key, shape, dtype=dtype): 34 | return jnp.full(shape, val) 35 | return init 36 | 37 | 38 | class RadialPotential(nn.Module): 39 | n_radial_components: int 40 | init_beta_minval: float 41 | init_beta_range: float 42 | manifold: Manifold 43 | 44 | def setup(self): 45 | assert isinstance(self.manifold, Sphere) 46 | mu_init = densities.get_uniform(self.manifold) 47 | self.betas = self.param( 48 | 'betas', init_uniform( 49 | minval=self.init_beta_minval, 50 | maxval=self.init_beta_minval+self.init_beta_range 51 | ), [self.n_radial_components]) 52 | self.mus = self.param( 53 | 'mus', init_manifold_samples(mu_init), 54 | [self.manifold.D, self.n_radial_components]) 55 | self.alphas = self.param( 56 | 'alphas', init_full(1./self.n_radial_components), 57 | [self.n_radial_components]) 58 | 59 | 60 | def __call__(self, xs): 61 | single = xs.ndim == 1 62 | if single: 63 | xs = jnp.expand_dims(xs, 0) 64 | 65 | assert xs.ndim == 2 66 | assert xs.shape[1] == self.manifold.D 67 | n_batch = xs.shape[0] 68 | 69 | betas = nn.softplus(self.betas) 70 | mus = self.mus / jnp.linalg.norm(self.mus, axis=0, keepdims=True) 71 | alphas = nn.softmax(self.alphas) 72 | 73 | F = jnp.sum( 74 | (alphas/betas)*jnp.exp(betas * (jnp.matmul(xs, mus) - 1)), 75 | axis=-1 76 | ) 77 | if single: 78 | F = jnp.squeeze(F, 0) 79 | 80 | return F 81 | 82 | class InfAffine(nn.Module): 83 | n_components: int 84 | init_alpha_mode: str 85 | init_alpha_linear_scale: float 86 | init_alpha_minval: float 87 | init_alpha_range: float 88 | manifold: Manifold 89 | cost_gamma: float 90 | min_zero_gamma: float 91 | 92 | def setup(self): 93 | if self.cost_gamma == 'None': self.cost_gamma = None 94 | if self.min_zero_gamma == 'None': self.min_zero_gamma = None 95 | 96 | if isinstance(self.min_zero_gamma, str): 97 | self.min_zero_gamma = float(self.min_zero_gamma) 98 | 99 | if isinstance(self.manifold, Product): 100 | mu_init = densities.get(self.manifold, 'ProductUniformComponents') 101 | else: 102 | mu_init = densities.get_uniform(self.manifold) 103 | 104 | self.mus = self.param( 105 | 'mus', init_manifold_samples(mu_init), 106 | [self.manifold.D, self.n_components]) 107 | if self.init_alpha_mode == 'linear': 108 | alphas = self.init_alpha_linear_scale*self.mus[:,0].dot(self.mus) 109 | self.alphas = self.param( 110 | 'alphas', lambda key, shape: alphas, 111 | [self.n_components]) 112 | elif self.init_alpha_mode == 'uniform': 113 | self.alphas = self.param( 114 | 'alphas', init_uniform( 115 | minval=self.init_alpha_minval, 116 | maxval=self.init_alpha_minval+self.init_alpha_range), 117 | [self.n_components]) 118 | else: 119 | assert False 120 | 121 | def __call__(self, xs): 122 | single = xs.ndim == 1 123 | if single: 124 | xs = jnp.expand_dims(xs, 0) 125 | 126 | assert xs.ndim == 2 127 | assert xs.shape[1] == self.manifold.D 128 | n_batch = xs.shape[0] 129 | 130 | mus = self.manifold.projx(self.mus.T) 131 | mus = mus.T 132 | 133 | costs = self.manifold.cost(xs, mus) + self.alphas 134 | 135 | if self.cost_gamma is not None and self.cost_gamma > 0.: 136 | F = self.cost_gamma * logsumexp( 137 | -costs/self.cost_gamma, axis = 1) 138 | else: 139 | F = - jnp.min(costs, 1) 140 | 141 | if self.min_zero_gamma is not None and self.min_zero_gamma > 0.: 142 | Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1) 143 | F = self.min_zero_gamma * logsumexp( 144 | -Fz/self.min_zero_gamma, axis=-1) 145 | 146 | if single: 147 | F = jnp.squeeze(F, 0) 148 | return F 149 | 150 | 151 | class MultiInfAffine(nn.Module): 152 | n_layers: int 153 | n_components: int 154 | init_alpha_minval: float 155 | init_alpha_range: float 156 | manifold: Manifold 157 | cost_gamma: float 158 | min_zero_gamma: float 159 | 160 | def setup(self): 161 | if self.cost_gamma == 'None': self.cost_gamma = None 162 | if self.min_zero_gamma == 'None': self.min_zero_gamma = None 163 | 164 | mu_init = densities.get_uniform(self.manifold) 165 | 166 | self.mus = [] 167 | self.alphas = [] 168 | self.ws = [] 169 | input_sz = self.manifold.D 170 | for i in range(self.n_layers): 171 | 172 | key = f'mu{i:02d}' 173 | mu = self.param( 174 | key, init_manifold_samples(mu_init), 175 | [self.manifold.D, self.n_components]) 176 | setattr(self, key, mu) 177 | 178 | key = f'alpha{i:02d}' 179 | alpha = self.param( 180 | key, init_uniform( 181 | minval=self.init_alpha_minval, 182 | maxval=self.init_alpha_minval+self.init_alpha_range), 183 | [self.n_components]) 184 | setattr(self, key, alpha) 185 | 186 | key = f'w{i:02d}' 187 | w = self.param( 188 | key, init_uniform(minval=0., maxval=1.), [1]) 189 | setattr(self, key, w) 190 | 191 | self.mus.append(mu) 192 | self.alphas.append(alpha) 193 | self.ws.append(w) 194 | 195 | 196 | def __call__(self, xs): 197 | single = xs.ndim == 1 198 | if single: 199 | xs = jnp.expand_dims(xs, 0) 200 | 201 | assert xs.ndim == 2 202 | assert xs.shape[1] == self.manifold.D 203 | 204 | F = 0. 205 | for i, (mu, alpha, w) in enumerate( 206 | zip(self.mus, self.alphas, self.ws)): 207 | 208 | mu = self.manifold.projx(mu.T) 209 | mu = mu.T 210 | 211 | costs = self.manifold.cost(xs, mu) + alpha 212 | 213 | w = jnp.exp(-w**2)[0] 214 | 215 | if self.cost_gamma is not None and self.cost_gamma > 0.: 216 | mincosts = self.cost_gamma * logsumexp( 217 | -costs/self.cost_gamma, axis = 1) 218 | else: 219 | mincosts = - jnp.min(costs, 1) 220 | 221 | F = w * nn.relu(F) + (1-w) * mincosts 222 | 223 | 224 | if self.min_zero_gamma is not None and self.min_zero_gamma > 0.: 225 | Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1) 226 | F = self.min_zero_gamma * logsumexp( 227 | -Fz/self.min_zero_gamma, axis=-1) 228 | 229 | if single: 230 | F = jnp.squeeze(F, 0) 231 | return F 232 | 233 | 234 | class ExpMapFlow(nn.Module): 235 | potential_cfg: omegaconf.dictconfig.DictConfig 236 | manifold: Manifold 237 | 238 | def setup(self): 239 | self.potential_mod = hydra.utils.instantiate( 240 | dict(self.potential_cfg), manifold=self.manifold, 241 | _recursive_=False, _convert_='object', 242 | ) 243 | 244 | 245 | def __call__(self, xs, t = 1): 246 | assert xs.ndim == 2 247 | n_batch = xs.shape[0] 248 | 249 | def dF_riemannian(xs): 250 | assert xs.ndim == 1 251 | dF = jax.jacfwd(self.potential)(xs) 252 | dF = self.manifold.tangent_projection(xs, dF) 253 | return dF 254 | 255 | def flow(xs): 256 | assert xs.ndim == 1 257 | dF = dF_riemannian(xs) 258 | z = self.manifold.exponential_map(xs, t * dF) 259 | return z 260 | 261 | def flow_jacobian(xs): 262 | assert xs.ndim == 1 263 | J = jax.jacfwd(flow)(xs) 264 | return J 265 | 266 | def flow_and_jac(xs): 267 | z = flow(xs) 268 | dF = dF_riemannian(xs) 269 | J = flow_jacobian(xs) 270 | return z, dF, J 271 | 272 | z, dF, J = jax.vmap(flow_and_jac)(xs) 273 | 274 | E = self.manifold.tangent_orthonormal_basis(xs, dF) 275 | JE = jnp.matmul(J, E) 276 | JETJE = jnp.einsum('nji,njk->nik', JE, JE) 277 | 278 | sign, logdet = jnp.linalg.slogdet(JETJE) 279 | logdet *= 0.5 280 | 281 | return z, logdet, sign 282 | 283 | def potential(self, xs): 284 | F = self.potential_mod(xs) 285 | return F 286 | 287 | 288 | class SequentialFlow(nn.Module): 289 | n_transforms: int 290 | manifold: Manifold 291 | single_transform_cfg: omegaconf.dictconfig.DictConfig 292 | 293 | def setup(self): 294 | transforms = [] 295 | for i in range(self.n_transforms): 296 | mod = hydra.utils.instantiate( 297 | dict(self.single_transform_cfg), 298 | manifold=self.manifold, 299 | _recursive_=False, _convert_='object', 300 | ) 301 | transforms.append(mod) 302 | 303 | # hack for https://github.com/google/flax/issues/524 304 | key = f'transform{i:02d}' 305 | setattr(self, key, mod) 306 | self.transforms = transforms 307 | 308 | def __call__(self, orig_xs, debug=False, t = 1): 309 | ldjs = 0. 310 | all_xs = [] 311 | all_ldjs = [] 312 | all_ldj_signs = [] 313 | Fs = [] 314 | 315 | xs = orig_xs 316 | for transform in self.transforms: 317 | xs, ldj, ldj_sign = transform(xs, t = t) 318 | if debug: 319 | F = transform.potential(orig_xs) 320 | all_xs.append(xs) 321 | all_ldjs.append(ldj) 322 | all_ldj_signs.append(ldj_sign) 323 | Fs.append(F) 324 | ldjs += ldj 325 | 326 | if not debug: 327 | return xs, ldjs 328 | else: 329 | all_xs = jnp.stack(all_xs) 330 | all_ldjs = jnp.stack(all_ldjs) 331 | all_ldj_signs = jnp.stack(all_ldj_signs) 332 | return all_xs, all_ldjs, all_ldj_signs, Fs, ldjs 333 | -------------------------------------------------------------------------------- /images/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rcpm/HEAD/images/demo.png -------------------------------------------------------------------------------- /images/discrete-c-concave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rcpm/HEAD/images/discrete-c-concave.png -------------------------------------------------------------------------------- /images/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rcpm/HEAD/images/fig3.png -------------------------------------------------------------------------------- /images/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/rcpm/HEAD/images/table2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import sys 5 | from IPython.core import ultratb 6 | sys.excepthook = ultratb.FormattedTB( 7 | mode='Plain', color_scheme='Neutral', call_pdb=1) 8 | 9 | import numpy as np 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | from jax.config import config; config.update("jax_enable_x64", True) 16 | 17 | import pickle as pkl 18 | 19 | from flax import linen as nn 20 | 21 | import time 22 | 23 | import hydra 24 | 25 | import csv 26 | import os 27 | 28 | import functools 29 | 30 | import flows 31 | import utils 32 | import densities 33 | 34 | from setproctitle import setproctitle 35 | setproctitle('iccnn') 36 | 37 | 38 | def kl_ess(log_model_prob, log_target_prob): 39 | weights = jnp.exp(log_target_prob) / jnp.exp(log_model_prob) 40 | Z = jnp.mean(weights) 41 | KL = jnp.mean(log_model_prob - log_target_prob) + jnp.log(Z) 42 | ESS = jnp.sum(weights) ** 2 / jnp.sum(weights ** 2) 43 | return Z, KL, ESS 44 | 45 | 46 | class Workspace: 47 | def __init__(self, cfg): 48 | self.cfg = cfg 49 | 50 | self.work_dir = os.getcwd() 51 | print(f'workspace: {self.work_dir}') 52 | 53 | self.manifold = hydra.utils.instantiate(self.cfg.manifold) 54 | self.base = densities.get(self.manifold, self.cfg.base) 55 | self.target = densities.get(self.manifold, self.cfg.target) 56 | 57 | self.key = jax.random.PRNGKey(self.cfg.seed) 58 | 59 | self.flow = hydra.utils.instantiate( 60 | self.cfg.flow, manifold=self.manifold, 61 | _recursive_=False, _convert_='object', 62 | ) 63 | self.key, k1, k2, k3, k4, k5 = jax.random.split(self.key, 6) 64 | batch = self.base.sample(k1, self.cfg.batch_size) 65 | init_params = self.flow.init(k2, batch) 66 | 67 | self.base_samples = self.base.sample(k3, self.cfg.eval_samples) 68 | self.base_log_probs = self.base.log_prob(self.base_samples) 69 | if self.cfg.loss == 'likelihood': 70 | self.eval_target_samples = self.target.sample( 71 | k5, self.cfg.eval_samples) 72 | 73 | optimizer_def = hydra.utils.instantiate(self.cfg.optim) 74 | self.optimizer = optimizer_def.create(init_params) 75 | 76 | self.iter = 0 77 | 78 | def run(self): 79 | if self.cfg.loss == 'kl': 80 | self.train_kl() 81 | elif self.cfg.loss == 'likelihood': 82 | self.train_likelihood() 83 | else: 84 | assert False 85 | 86 | def train_kl(self): 87 | @jax.jit 88 | def loss(params, base_samples, base_log_probs): 89 | z, ldjs = self.flow.apply(params, base_samples) 90 | loss = (base_log_probs - ldjs - 91 | self.target.log_prob(z)).mean() 92 | return loss 93 | 94 | @jax.jit 95 | def update(optimizer, base_samples, base_log_probs): 96 | l, grads = jax.value_and_grad(loss)( 97 | optimizer.target, base_samples, base_log_probs) 98 | optimizer = optimizer.apply_gradient(grads) 99 | return l, optimizer 100 | 101 | logf, writer = self._init_logging() 102 | 103 | times = [] 104 | if self.iter == 0: 105 | model_samples, ldjs = self.flow.apply( 106 | self.optimizer.target, self.base_samples) 107 | self.manifold.plot_samples( 108 | model_samples, save=f'{self.iter:06d}.png') 109 | 110 | self.manifold.plot_density(self.target.log_prob, 'target.png') 111 | 112 | while self.iter < self.cfg.iterations: 113 | start = time.time() 114 | self.key, subkey = jax.random.split(self.key) 115 | base_samples = self.base.sample(subkey, self.cfg.batch_size) 116 | base_log_probs = self.base.log_prob(base_samples) 117 | l, self.optimizer = update( 118 | self.optimizer, base_samples, base_log_probs) 119 | 120 | times.append(time.time() - start) 121 | self.iter += 1 122 | if self.iter % self.cfg.log_frequency == 0: 123 | l = loss(self.optimizer.target, 124 | self.base_samples, self.base_log_probs) 125 | 126 | model_samples, ldjs = self.flow.apply( 127 | self.optimizer.target, self.base_samples) 128 | self.manifold.plot_samples( 129 | model_samples, save=f'{self.iter:06d}.png') 130 | if not self.cfg.disable_evol_plots: 131 | for i, t in enumerate(jnp.linspace(0.1,1,11)): 132 | model_samples, ldjs = self.flow.apply( 133 | self.optimizer.target, self.base_samples, t = t) 134 | self.manifold.plot_samples( 135 | model_samples, 136 | save=f'{self.iter:06d}_{i}.png') 137 | 138 | 139 | log_prob = self.base_log_probs - ldjs 140 | _, kl, ess = kl_ess( 141 | log_prob, self.target.log_prob(model_samples)) 142 | ess = ess / self.cfg.eval_samples * 100 143 | msg = "Iter {} | Loss {:.3f} | KL {:.3f} | ESS {:.2f}% | {:.2e}s/it" 144 | print(msg.format( 145 | self.iter, l, kl, ess, jnp.mean(jnp.array(times)))) 146 | writer.writerow({ 147 | 'iter': self.iter, 'loss': l, 'kl': kl, 'ess': ess 148 | }) 149 | logf.flush() 150 | self.save('latest') 151 | 152 | times = [] 153 | 154 | 155 | def train_likelihood(self): 156 | @jax.jit 157 | def logprob(params, target_samples, t = 1): 158 | zs, ldjs = self.flow.apply(params, target_samples, t = t) 159 | log_prob = ldjs + self.base.log_prob(zs) 160 | return log_prob 161 | 162 | @jax.jit 163 | def loss(params, target_samples): 164 | return -logprob(params, target_samples).mean() 165 | 166 | @jax.jit 167 | def update(optimizer, target_samples): 168 | l, grads = jax.value_and_grad(loss)( 169 | optimizer.target, target_samples) 170 | optimizer = optimizer.apply_gradient(grads) 171 | return l, optimizer 172 | 173 | target_sample_jit = jax.jit(self.target.sample, static_argnums=(1,)) 174 | base_sample_jit = jax.jit(self.base.sample, static_argnums=(1,)) 175 | 176 | logf, writer = self._init_logging() 177 | 178 | times = [] 179 | 180 | if self.iter == 0 and not self.cfg.disable_init_plots: 181 | model_samples, ldjs = self.flow.apply( 182 | self.optimizer.target, self.eval_target_samples) 183 | try: 184 | self.manifold.plot_density( 185 | self.target.log_prob, save=f'target_density.png') 186 | except: 187 | pass 188 | self.manifold.plot_samples( 189 | self.eval_target_samples, save=f'target_samples.png') 190 | self.manifold.plot_samples( 191 | base_sample_jit(self.key, self.cfg.eval_samples), 192 | save=f'base_samples.png') 193 | self.manifold.plot_density( 194 | self.base.log_prob, save=f'base_density.png') 195 | self.manifold.plot_samples( 196 | model_samples, save=f'samples_{self.iter:06d}.png') 197 | self.manifold.plot_density( 198 | functools.partial(logprob, self.optimizer.target), 199 | save=f'density_{self.iter:06d}.png') 200 | if not self.cfg.disable_evol_plots: 201 | for i, t in enumerate(jnp.linspace(0.1,1,11)): 202 | self.manifold.plot_density( 203 | functools.partial(logprob, self.optimizer.target, t = t), 204 | save=f'density_{self.iter:06d}_{i}.png') 205 | 206 | 207 | 208 | while self.iter < self.cfg.iterations: 209 | start = time.time() 210 | self.key, subkey = jax.random.split(self.key) 211 | target_samples = target_sample_jit(subkey, self.cfg.batch_size) 212 | l, self.optimizer = update(self.optimizer, target_samples) 213 | 214 | times.append(time.time() - start) 215 | self.iter += 1 216 | if self.iter % self.cfg.log_frequency == 0: 217 | l = loss(self.optimizer.target, self.eval_target_samples) 218 | model_samples, ldjs = self.flow.apply( 219 | self.optimizer.target, self.eval_target_samples) 220 | self.manifold.plot_samples( 221 | model_samples, save=f'samples_{self.iter:06d}.png') 222 | self.manifold.plot_density( 223 | functools.partial(logprob, self.optimizer.target), 224 | save=f'density_{self.iter:06d}.png') 225 | if not self.cfg.disable_evol_plots: 226 | for i, t in enumerate(jnp.linspace(0.1,1,10)): 227 | self.manifold.plot_density( 228 | functools.partial(logprob, self.optimizer.target, t = t), 229 | save=f'density_{self.iter:06d}_{i}.png') 230 | 231 | 232 | 233 | msg = "Iter {} | Loss {:.3f} | {:.2e}s/it" 234 | print(msg.format( 235 | self.iter, l, jnp.mean(jnp.array(times)))) 236 | writer.writerow({ 237 | 'iter': self.iter, 'loss': l, 238 | }) 239 | logf.flush() 240 | self.save('latest') 241 | times = [] 242 | 243 | 244 | 245 | def save(self, tag='latest'): 246 | path = os.path.join(self.work_dir, f'{tag}.pkl') 247 | with open(path, 'wb') as f: 248 | pkl.dump(self, f) 249 | 250 | 251 | def _init_logging(self): 252 | logf = open('log.csv', 'a') 253 | fieldnames = ['iter', 'loss', 'kl', 'ess'] 254 | writer = csv.DictWriter(logf, fieldnames=fieldnames) 255 | if os.stat('log.csv').st_size == 0: 256 | writer.writeheader() 257 | logf.flush() 258 | return logf, writer 259 | 260 | 261 | # Import like this for pickling 262 | from main import Workspace as W 263 | 264 | @hydra.main(config_path=".", config_name="config.yaml", version_base="1.1") 265 | def main(cfg): 266 | fname = os.getcwd() + '/latest.pt' 267 | if os.path.exists(fname): 268 | print(f'Resuming fom {fname}') 269 | with open(fname, 'rb') as f: 270 | workspace = pkl.load(f) 271 | else: 272 | workspace = W(cfg) 273 | 274 | workspace.run() 275 | 276 | if __name__ == '__main__': 277 | main() 278 | -------------------------------------------------------------------------------- /manifolds.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import numpy as np 4 | import jax.numpy as jnp 5 | from jax import random 6 | import jax 7 | from jax.scipy.linalg import block_diag 8 | 9 | from spherical_kde import SphericalKDE 10 | import matplotlib.pyplot as plt 11 | from mpl_toolkits.mplot3d import Axes3D 12 | from scipy.stats import gaussian_kde 13 | import os 14 | from dataclasses import dataclass 15 | from abc import ABC, abstractmethod 16 | 17 | import utils 18 | import cartopy.crs as ccrs 19 | 20 | import matplotlib 21 | 22 | @dataclass 23 | class Manifold(ABC): 24 | D: int # Dimension of the ambient Euclidean space 25 | 26 | @abstractmethod 27 | def exponential_map(self, x, v): 28 | pass 29 | 30 | @abstractmethod 31 | def tangent_projection(self, x, v): 32 | pass 33 | 34 | @abstractmethod 35 | def projx(self, x): 36 | pass 37 | 38 | # @abstractmethod 39 | # def dist(self, x, y): 40 | # pass 41 | 42 | @abstractmethod 43 | def cost(self, x, y): 44 | pass 45 | 46 | @abstractmethod 47 | def tangent_orthonormal_basis(self, x, dF): 48 | pass 49 | 50 | 51 | eps = 1e-5 # TODO: Other stabilization? 52 | divsin = lambda x: x / jnp.sin(x) 53 | sindiv = lambda x: jnp.sin(x) / (x + eps) 54 | divsinh = lambda x: x / jnp.sinh(x) 55 | sinhdiv = lambda x: jnp.sinh(x) / (x + eps) 56 | 57 | def lorentz_cross(x, y): 58 | z = jnp.cross(x, y) 59 | z = z.at[...,0].set(-z[...,0]) 60 | return z 61 | 62 | @dataclass 63 | class Sphere(Manifold): 64 | jitter: float = 1e-2 65 | 66 | NUM_POINTS = 100 67 | 68 | theta = jnp.linspace(0, 2 * np.pi, 2 * NUM_POINTS) 69 | phi = jnp.linspace(0, np.pi, NUM_POINTS) 70 | tp = jnp.array(np.meshgrid(theta, phi, indexing='ij')) 71 | tp = tp.transpose([1, 2, 0]).reshape(-1, 2) 72 | 73 | def exponential_map(self, x, v): 74 | v_norm = jnp.linalg.norm(v, axis=-1, keepdims=True) 75 | return x * jnp.cos(v_norm) + v * sindiv(v_norm) 76 | 77 | def log(self, x, y): 78 | xy = (x * y).sum(axis=-1, keepdims=True) 79 | xy = jnp.clip(xy, a_min=-1 + 1e-6, a_max=1 - 1e-6) 80 | val = jnp.arccos(xy) 81 | return divsin(val) * (y - xy * x) 82 | 83 | def tangent_projection(self, x, u): 84 | proj_u = u - x*x.dot(u) 85 | return proj_u 86 | 87 | def tangent_orthonormal_basis(self, x, dF): 88 | assert x.ndim == 2 89 | 90 | if x.shape[1] == 2: 91 | E = x[:, jnp.array([1,0])] * jnp.array([-1., 1.]) 92 | E = E.reshape(*E.shape, 1) 93 | elif x.shape[1] == 3: 94 | # The potential's Riemannian derivative dF is on the 95 | # tangent space, so on S2 we normalize this and 96 | # find the only remaining orthogonal direction. 97 | norm_v = dF / jnp.linalg.norm(dF, axis=-1, keepdims=True) 98 | E = jnp.dstack([norm_v, jnp.cross(x, norm_v)]) 99 | else: 100 | raise NotImplementedError() 101 | 102 | return E 103 | 104 | def dist(self, x, y): 105 | inner = jnp.matmul(x, y) 106 | inner = inner/(1 + self.jitter) 107 | return jnp.arccos(inner) 108 | 109 | def cost(self, x, y): 110 | return self.dist(x, y)**2 / 2. 111 | 112 | def projx(self, x): 113 | x /= jnp.linalg.norm(x, axis=-1, keepdims=True) 114 | return x 115 | 116 | def transp(self, x, y, u): 117 | yu = jnp.sum(y * u, axis=-1, keepdims=True) 118 | xy = jnp.sum(x * y, axis=-1, keepdims=True) 119 | return u - yu/(1 + xy) * (x + y) 120 | 121 | def logdetexp(self, x, u): 122 | norm_u = jnp.linalg.norm(u, axis=-1) 123 | val = jnp.log(jnp.abs(sindiv(norm_u))) 124 | return (u.shape[-1]-2) * val 125 | 126 | 127 | def zero(self): 128 | y = jnp.zeros(self.D) 129 | y = y.at[...,0].set(-1.) 130 | return y 131 | 132 | def zero_like(self, x): 133 | y = jnp.zeros_like(x) 134 | y = y.at[...,0].set(-1.) 135 | return y 136 | 137 | def squeeze_tangent(self, x): 138 | return x[..., 1:] 139 | 140 | def unsqueeze_tangent(self, x): 141 | return jnp.concatenate((jnp.zeros_like(x[..., :1]), x), axis=-1) 142 | 143 | def plot_samples(self, model_samples, kde_factor=0.1, save='t.png'): 144 | spherical_samples = utils.euclidean_to_spherical(model_samples) 145 | kde = SphericalKDE( 146 | spherical_samples[:,0], spherical_samples[:,1], bandwidth=kde_factor) 147 | heatmap = np.exp(kde(self.tp[:,0], self.tp[:,1]).reshape( 148 | 2 * self.NUM_POINTS, self.NUM_POINTS)) 149 | self.plot_mollweide(heatmap, save=save) 150 | 151 | def plot_density(self, log_prob_fn, save='t.png'): 152 | density = log_prob_fn(utils.spherical_to_euclidean(self.tp)) 153 | density = jnp.exp(density) 154 | heatmap = density.reshape(2 * self.NUM_POINTS, self.NUM_POINTS) 155 | self.plot_mollweide(heatmap, save=save) 156 | 157 | def plot_mollweide(self, heatmap, save): 158 | tt, pp = np.meshgrid( 159 | self.theta - np.pi, self.phi - np.pi / 2, indexing='ij') 160 | 161 | proj = ccrs.Mollweide() 162 | fig = plt.figure(figsize=(3,2), dpi=200) 163 | ax = fig.add_subplot(111, projection='mollweide') 164 | norm = matplotlib.colors.Normalize() 165 | ax.pcolormesh(tt, pp, heatmap, cmap='magma', norm = norm) 166 | ax.set_axis_off() 167 | plt.savefig(save) 168 | os.system(f"convert {save} -trim {save} &") 169 | plt.close(fig) 170 | 171 | 172 | 173 | class Euclidean(Manifold): 174 | def exponential_map(self, x, v): 175 | return x + v 176 | 177 | def tangent_projection(self, x, u): 178 | return u 179 | 180 | def cost(self, x, y): 181 | return 0.5 * self.dist(x,y)**2 182 | 183 | def dist(self, x, y): 184 | return - jnp.matmul(x, y) 185 | 186 | def tangent_orthonormal_basis(self, x, dF): 187 | tang_vecs = [jnp.eye(x.shape[1]) for i in range(x.shape[0])] 188 | return jnp.stack(tang_vecs, 0) 189 | 190 | 191 | 192 | def get(manifold): 193 | if manifold == 'S1': 194 | return Sphere(D = 2) 195 | elif manifold == 'S2': 196 | return Sphere(D = 3) 197 | elif manifolds == 'R': 198 | return Euclidean(D = 1) 199 | else: 200 | assert False 201 | 202 | @dataclass 203 | class Product(Manifold): 204 | manifolds_str: str = 'S1,S1' 205 | 206 | def __post_init__(self): 207 | self.manifolds = [] 208 | for man in self.manifolds_str.split(','): 209 | self.manifolds.append(get(man)) 210 | 211 | def exponential_map(self, x, v): 212 | exp_prod = [] 213 | d = 0 214 | for man in self.manifolds: 215 | exp_man = man.exponential_map(x[d:d+man.D], v[d:d+man.D]) 216 | exp_prod.append(exp_man) 217 | d = d + man.D 218 | exp_prod = jnp.concatenate(exp_prod) 219 | return exp_prod 220 | 221 | def tangent_projection(self, x, u): 222 | proj_prod = [] 223 | d = 0 224 | for man in self.manifolds: 225 | proj_man = man.tangent_projection(x[d:d+man.D], u[d:d+man.D]) 226 | proj_prod.append(proj_man) 227 | d = d + man.D 228 | proj_prod = jnp.concatenate(proj_prod) 229 | return proj_prod 230 | 231 | def cost(self, x, y): 232 | cost_prod = jnp.zeros([x.shape[0], y.T.shape[0]]) 233 | d = 0 234 | for man in self.manifolds: 235 | cost_prod += man.cost(x[:,d:d+man.D], y[d:d+man.D,:]) 236 | d = d + man.D 237 | return cost_prod 238 | 239 | def dist(self, x, y): 240 | pass 241 | 242 | def tangent_orthonormal_basis(self, x, dF): 243 | d = 0 244 | map_block_diag = jax.vmap(block_diag) 245 | blocks = [] 246 | for man in self.manifolds: 247 | onb_man = man.tangent_orthonormal_basis(x[:,d:d+man.D], dF[:,d:d+man.D]) 248 | blocks.append(onb_man) 249 | d = d + man.D 250 | onb = map_block_diag(*(blocks)) 251 | return onb 252 | 253 | def projx(self, x): 254 | x_proj = [] 255 | d = 0 256 | for man in self.manifolds: 257 | x_proj_man = man.projx(x[:,d:d+man.D]) 258 | d = d + man.D 259 | x_proj.append(x_proj_man) 260 | x_proj = jnp.concatenate(x_proj, 1) 261 | return x_proj 262 | 263 | def plot_samples(self, model_samples, save='t.png'): 264 | pass 265 | 266 | def plot_density(self, log_prob_fn, save='t.png'): 267 | pass 268 | 269 | 270 | 271 | @dataclass 272 | class Torus(Product): 273 | manifolds: str = 'S1,S1' 274 | 275 | NUM_POINTS = 160 276 | 277 | theta = jnp.linspace(0, 2 * np.pi, 2 * NUM_POINTS) 278 | phi = jnp.linspace(0, 2 * np.pi, NUM_POINTS) 279 | tp = jnp.array(np.meshgrid(theta, phi, indexing='ij')) 280 | tp = tp.transpose([1, 2, 0]).reshape(-1, 2) 281 | 282 | def plot_samples(self, model_samples, save='t.png'): 283 | theta1 = utils.S1euclideantospherical(model_samples[:,:2]) 284 | theta2 = utils.S1euclideantospherical(model_samples[:,2:]) 285 | 286 | x, y, z = utils.productS1toTorus(theta1, theta2) 287 | data = jnp.stack((x, y, z), 1) 288 | estimated_density = gaussian_kde( 289 | data.T, 0.2) 290 | 291 | x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1]) 292 | grid = jnp.stack((x_grid, y_grid, z_grid), 1) 293 | probas_grid = estimated_density(grid.T) 294 | 295 | fig = plt.figure() 296 | ax = Axes3D(fig) 297 | #TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density 298 | ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = probas_grid) 299 | ax.set_xlim(-1,1) 300 | ax.set_ylim(-1,1) 301 | ax.set_zlim(-1,1) 302 | plt.axis('off') 303 | plt.savefig(save) 304 | 305 | 306 | def plot_density(self, log_prob_fn, save='t.png'): 307 | euc1 = jnp.stack((jnp.cos(self.tp[:,0]), jnp.sin(self.tp[:,0])),1) 308 | euc2 = jnp.stack((jnp.cos(self.tp[:,1]), jnp.sin(self.tp[:,1])),1) 309 | prod_euc = jnp.concatenate((euc1,euc2),1) 310 | 311 | density = log_prob_fn(prod_euc) 312 | density = jnp.exp(density) 313 | 314 | x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1]) 315 | grid = jnp.stack((x_grid, y_grid, z_grid), 1) 316 | 317 | fig = plt.figure() 318 | plt.savefig(save) 319 | ax = Axes3D(fig) 320 | #TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density 321 | ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = density) 322 | ax.set_xlim(-1,1) 323 | ax.set_ylim(-1,1) 324 | ax.set_zlim(-1,1) 325 | plt.axis('off') 326 | 327 | plt.savefig(save) 328 | 329 | 330 | @dataclass 331 | class InfCylinder(Product): 332 | manifolds: str = 'S1,R' 333 | -------------------------------------------------------------------------------- /plot-components.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import pickle as pkl 11 | import shutil 12 | from omegaconf import OmegaConf 13 | from collections import namedtuple 14 | 15 | from scipy.stats import gaussian_kde 16 | 17 | import matplotlib.pyplot as plt 18 | plt.style.use('bmh') 19 | from matplotlib import cm 20 | 21 | import utils 22 | 23 | import sys 24 | from IPython.core import ultratb 25 | sys.excepthook = ultratb.FormattedTB( 26 | mode='Verbose', color_scheme='Linux', call_pdb=1) 27 | 28 | NUM_POINTS = 150 29 | 30 | theta = jnp.linspace(0, 2 * jnp.pi, 2 * NUM_POINTS) 31 | phi = jnp.linspace(0, jnp.pi, NUM_POINTS) 32 | tp = jnp.array(np.meshgrid(theta, phi, indexing='ij')) 33 | tp = tp.transpose([1, 2, 0]).reshape(-1, 2) 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('exp_root', type=str) 38 | args = parser.parse_args() 39 | 40 | fname = f"{args.exp_root}/latest.pkl" 41 | with open(fname, 'rb') as f: 42 | W = pkl.load(f) 43 | 44 | n_transforms = W.cfg.flow.n_transforms 45 | nrows, ncols = n_transforms+1, 3 46 | fig, axs = plt.subplots( 47 | nrows, ncols, figsize=(6*ncols, 4*nrows), 48 | subplot_kw={'projection': 'mollweide'} 49 | ) 50 | if nrows == 1: 51 | axs = np.expand_dims(axs, 0) 52 | # if ncols == 1: 53 | # axs = np.expand_dims(axs, -1) 54 | 55 | axs[0,0].set_title('Potential', fontsize=20) 56 | axs[0,1].set_title('LDJ', fontsize=20) 57 | axs[0,2].set_title('Distribution', fontsize=20) 58 | 59 | all_xs, all_ldjs, all_ldj_signs, Fs, ldjs = W.flow.apply( 60 | W.optimizer.target, utils.spherical_to_euclidean(tp), debug=True) 61 | all_ldjs = jnp.stack(all_ldjs) 62 | Fs = jnp.stack(Fs) 63 | ldj_bounds = (jnp.min(all_ldjs), jnp.max(all_ldjs)) 64 | F_bounds = (jnp.min(Fs), jnp.max(Fs)) 65 | 66 | for t in range(n_transforms): 67 | plot_heatmap(Fs[t].reshape(2*NUM_POINTS, NUM_POINTS), axs[t,0], 68 | vbounds=F_bounds) 69 | plot_heatmap(all_ldjs[t].reshape(2*NUM_POINTS, NUM_POINTS), 70 | axs[t,1], vbounds=ldj_bounds) 71 | plot_density(all_xs[t], axs[t,2]) 72 | 73 | axs[-1,0].set_axis_off() 74 | axs[-1,2].set_axis_off() 75 | axs[-1,1].set_title('Cumulative LDJ', fontsize=20) 76 | plot_heatmap(ldjs.reshape(2*NUM_POINTS, NUM_POINTS), axs[-1,1]) 77 | 78 | fname = f"{args.exp_root}/components.png" 79 | print(f'Saving to {fname}') 80 | fig.tight_layout() 81 | fig.subplots_adjust(wspace=0, hspace=0) 82 | fig.savefig(fname) 83 | os.system(f"convert {fname} -trim {fname}") 84 | 85 | fig, ax = plt.subplots(1, 1, figsize=(6, 4), 86 | subplot_kw={'projection': 'mollweide'}) 87 | 88 | plot_heatmap(ldjs.reshape(2*NUM_POINTS, NUM_POINTS), ax) 89 | fname = f"{args.exp_root}/ldj.png" 90 | print(f'Saving to {fname}') 91 | fig.tight_layout() 92 | fig.subplots_adjust(wspace=0, hspace=0) 93 | fig.savefig(fname) 94 | os.system(f"convert {fname} -trim {fname}") 95 | 96 | 97 | def plot_density(xs, ax): 98 | estimated_density = gaussian_kde( 99 | utils.euclidean_to_spherical(xs).T, 0.2) 100 | heatmap = estimated_density(tp.T).reshape(2 * NUM_POINTS, NUM_POINTS) 101 | plot_heatmap(heatmap, ax) 102 | 103 | 104 | def plot_heatmap(fs, ax, cmap=plt.cm.magma, vbounds=None): 105 | tt, pp = jnp.meshgrid(theta - jnp.pi, phi - jnp.pi / 2, indexing='ij') 106 | vmin = vmax = None 107 | if vbounds is not None: 108 | vmin, vmax = vbounds 109 | ax.pcolormesh(tt, pp, fs, cmap=cmap, vmin=vmin, vmax=vmax) 110 | ax.set_axis_off() 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /plot-demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import pickle as pkl 11 | import shutil 12 | from omegaconf import OmegaConf 13 | from collections import namedtuple 14 | 15 | from scipy.stats import gaussian_kde 16 | 17 | import matplotlib.pyplot as plt 18 | plt.style.use('bmh') 19 | from matplotlib import cm 20 | from matplotlib.collections import LineCollection 21 | 22 | 23 | import utils 24 | 25 | import sys 26 | from IPython.core import ultratb 27 | sys.excepthook = ultratb.FormattedTB( 28 | mode='Verbose', color_scheme='Linux', call_pdb=1) 29 | 30 | NUM_POINTS = 100 31 | 32 | theta = jnp.linspace(0, 2 * jnp.pi, 2 * NUM_POINTS) 33 | phi = jnp.linspace(0, jnp.pi, NUM_POINTS) 34 | tp = jnp.array(np.meshgrid(theta, phi, indexing='ij')) 35 | tp = tp.transpose([1, 2, 0]).reshape(-1, 2) 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('exp_root', type=str) 40 | args = parser.parse_args() 41 | 42 | fname = f"{args.exp_root}/latest.pkl" 43 | with open(fname, 'rb') as f: 44 | W = pkl.load(f) 45 | 46 | nrows, ncols = 1, 1 47 | fig, ax = plt.subplots( 48 | nrows, ncols, figsize=(6*ncols, 4*nrows), 49 | subplot_kw={'projection': 'mollweide'} 50 | ) 51 | 52 | all_xs, _, _, Fs, _ = W.flow.apply( 53 | W.optimizer.target, utils.spherical_to_euclidean(tp), debug=True) 54 | plot_heatmap(Fs[0].reshape(2*NUM_POINTS, NUM_POINTS), ax) 55 | 56 | fname = f"{args.exp_root}/potential.png" 57 | print(f'Saving to {fname}') 58 | fig.tight_layout() 59 | fig.subplots_adjust(wspace=0, hspace=0) 60 | fig.savefig(fname) 61 | os.system(f"convert {fname} -trim {fname}") 62 | 63 | 64 | nrows, ncols = 1, 1 65 | fig, ax = plt.subplots( 66 | nrows, ncols, figsize=(6*ncols, 4*nrows), 67 | subplot_kw={'projection': 'mollweide'} 68 | ) 69 | 70 | def plot_grid(x,y, ax=None, **kwargs): 71 | ax = ax or plt.gca() 72 | segs1 = np.stack((x,y), axis=2) 73 | segs2 = segs1.transpose(1,0,2) 74 | ax.add_collection(LineCollection(segs1, **kwargs)) 75 | ax.add_collection(LineCollection(segs2, **kwargs)) 76 | 77 | b = 0.2 78 | lw = 0.5 79 | grid_x, grid_y = np.meshgrid( 80 | np.linspace(-np.pi+b, np.pi-b, 50), 81 | np.linspace(-np.pi+b, np.pi-b, 50)) 82 | plot_grid(grid_x, grid_y, ax, color='lightgrey', lw=lw) 83 | 84 | grid_sphere = utils.spherical_to_euclidean( 85 | jnp.stack((grid_x+np.pi, (grid_y+np.pi)/2.)).reshape(2, -1).T 86 | ) 87 | F_grid_sphere, _ = W.flow.apply(W.optimizer.target, grid_sphere) 88 | F_grid = utils.euclidean_to_spherical(F_grid_sphere) 89 | F_grid_x = F_grid[:,0].reshape(grid_x.shape) - np.pi 90 | F_grid_y = F_grid[:,1].reshape(grid_x.shape)*2. - np.pi 91 | plot_grid(F_grid_x, F_grid_y, color='C0', lw=lw) 92 | 93 | ax.set_axis_off() 94 | fname = f"{args.exp_root}/grid.png" 95 | print(f'Saving to {fname}') 96 | fig.tight_layout() 97 | fig.subplots_adjust(wspace=0, hspace=0) 98 | fig.savefig(fname) 99 | os.system(f"convert {fname} -trim {fname}") 100 | 101 | def plot_density(xs, ax): 102 | estimated_density = gaussian_kde( 103 | utils.euclidean_to_spherical(xs).T, 0.2) 104 | heatmap = estimated_density(tp.T).reshape(2 * NUM_POINTS, NUM_POINTS) 105 | plot_heatmap(heatmap, ax) 106 | 107 | def plot_heatmap(fs, ax): 108 | tt, pp = jnp.meshgrid(theta - jnp.pi, phi - jnp.pi / 2, indexing='ij') 109 | ax.pcolormesh(tt, pp, fs, cmap=plt.cm.magma) 110 | ax.set_axis_off() 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cartopy==0.21.0 2 | flax==0.3.4 3 | hydra-core==1.3.2 4 | ipython==8.12.0 5 | jax==0.2.14 6 | jaxlib==0.1.67 7 | matplotlib==3.1 8 | numpy==1.22 9 | omegaconf==2.3.0 10 | scipy==1.10.0 11 | setproctitle==1.1.10 12 | spherical_kde==0.1.2 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import jax.numpy as np 4 | from math import pi 5 | import jax 6 | 7 | def spherical_to_euclidean(theta_phi): 8 | single= theta_phi.ndim == 1 9 | if single: 10 | theta_phi = np.expand_dims(theta_phi, 0) 11 | theta, phi = np.split(theta_phi, 2, 1) 12 | return np.concatenate(( 13 | np.sin(phi) * np.cos(theta), 14 | np.sin(phi) * np.sin(theta), 15 | np.cos(phi) 16 | ), 1) 17 | 18 | 19 | def euclidean_to_spherical(xyz): 20 | single = xyz.ndim == 1 21 | if single: 22 | xyz = np.expand_dims(xyz, 0) 23 | x, y, z = np.split(xyz, 3, 1) 24 | return np.concatenate(( 25 | np.arctan2(y, x), 26 | np.arccos(z) 27 | ), 1) 28 | 29 | def S1euclideantospherical(euc_coords): 30 | return np.arctan2(euc_coords[:,1], euc_coords[:,0]) 31 | 32 | def productS1toTorus(theta1, theta2): 33 | R = 1 34 | r = 0.3 35 | 36 | x = (R + r * np.cos(theta1))*np.cos(theta2) 37 | y = (R + r * np.cos(theta1))*np.sin(theta2) 38 | z = r * np.sin(theta1) 39 | return x,y,z 40 | --------------------------------------------------------------------------------