├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── afhq_preprocess.py ├── afhq_sample.py ├── assets ├── afhq.png ├── crowd-nav.gif └── lidar.png ├── configs ├── experiment │ ├── afhq.yaml │ ├── gmm.yaml │ ├── lidar.yaml │ ├── opinion.yaml │ ├── opinion_2d.yaml │ ├── stunnel.yaml │ ├── stunnel_eam.yaml │ └── vneck.yaml └── train.yaml ├── environment.yml ├── gsbm ├── dataset.py ├── ema.py ├── evaluator.py ├── experimental.py ├── gaussian_path.py ├── interp1d.py ├── match_loss.py ├── network.py ├── nn.py ├── opinion.py ├── path_integral.py ├── pl_model.py ├── plotting.py ├── sde.py ├── state_cost.py ├── unet.py ├── utils.py └── vae.py ├── notebooks ├── afhq_sample.ipynb └── example_CondSOC.ipynb ├── scripts └── train.sh ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # 163 | data/* 164 | outputs/* -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to genralized-schrodinger-bridge-matching 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to genralized-schrodinger-bridge-matching, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Generalized Schrödinger Bridge Matching (GSBM)

2 |
3 | Guan-Horng Liu1·  4 | Yaron Lipman2,3·  5 | Maximilian Nickel3 6 |
7 | Brian Karrer3·  8 | Evangelos A. Theodorou1·  9 | Ricky T. Q. Chen3
10 | 1Georgia Tech   2Weizmann Institute of Science   3FAIR, Meta
11 |
12 | 13 |
14 | 15 |
16 | 17 | [![arxiv](https://img.shields.io/badge/ICLR2024-blue)](https://arxiv.org/abs/2310.02233) 18 | [![twitter](https://img.shields.io/badge/twitter-thread-green)](https://x.com/guanhorng_liu/status/1709983646558404913) 19 | 20 |
21 | 22 | [Generalized Schrödinger Bridge Matching](https://arxiv.org/abs/2310.02233) (**GSBM**) is a new matching algorithm 23 | for learning diffusion models between two distributions with task-specific optimality structures. 24 | Examples of task-specific structures include mean-field interaction in population propagation (_1st, 2nd figures_), geometric prior given LiDAR manifold (_3rd figure_), or latent-guided unpaired image translation (_right figure_). 25 | 26 | 27 |
28 | 29 |
30 | 31 | 32 | ## Installation 33 | ```bash 34 | conda env create -f environment.yml 35 | pip install -e . 36 | ``` 37 | 38 | ## Crowd navigation & opinion depolarization 39 | ```python 40 | python train.py experiment=$EXP seed=0,1,2,3,4 -m 41 | ``` 42 | where `EXP` is one of the settings in `configs/experiment/*.yaml`. The commands to generate similar results shown in our paper can be found in [`scripts/train.sh`](https://github.com/facebookresearch/gsbm/blob/main/scripts/train.sh). By default, checkpoints and figures are saved under the folder `outputs`. 43 | 44 | ## Unsupervised image translation 45 | 46 | ### Download dataset 47 | Download the official AFHQ dataset from [stargan-v2](https://github.com/clovaai/stargan-v2#animal-faces-hq-dataset-afhq), then preprocess images with 48 | ```python 49 | python afhq_preprocess.py --dir $DIR_AFHQ 50 | ``` 51 | where `DIR_AFHQ` is the path to AFHQ dataset (_e.g._, `../stargan-v2/data/afhq`). 52 | 53 | Download [lidar data](https://rtqichen.com/data/rainier2-thin.las) and place it in `data` folder. 54 | 55 | All downloaded files will be stored under the folder `data`. 56 | 57 | ### Sampling from trained model 58 | 59 | See 60 | [`notebooks/afhq_sample.ipynb`](https://github.com/facebookresearch/gsbm/blob/main/notebooks/afhq_sample.ipynb). 61 | 62 | ### Training from scratch 63 | We train GSBM with 4 nodes, each with 8 32GB V100 GPUs. 64 | ```python 65 | python train.py experiment=afhq nnodes=4 -m 66 | ``` 67 | 68 | To sample from a checkpoint $CKPT saved under `outputs/multiruns/afhq/$CKPT`, run 69 | ```bash 70 | python afhq_sample.py --ckpt $CKPT --transfer $TRNSF \ 71 | [--nfe $NFE] [--batch-size $BATCH] 72 | ``` 73 | where `TRNSF` can be either `cat2dog` or `dog2cat`. By default, we set `NFE=1000` and `BATCH=512`. To optionally parallelize the sampling across multiple devices, add `--partition 0_4` so that the dataset is partitioned into 4 subsets (indices 0,1,2,3) and only run the first partition, i.e. index 0. Similarly, `--partition 1_4` run the second partition, and so on. The reconstruction images will be saved under the parent of `outputs/multiruns/afhq/$CKPT`, in the folders named `samples` and `trajs`. 74 | 75 | ## Implementation 76 | 77 | GSBM alternatively solves the Conditional Stochastic Optimal Control (**CondSOC**) problem and the resulting marginal **Matching** problem. We implement GSBM on PyTorch Lightning with the [following configurations](https://github.com/facebookresearch/gsbm/blob/main/train.py#L120-L123): 78 | 79 | - We solve **CondSOC** and **Matching** respectively in the validation and training epochs. `pl.Trainer` is instantiated with `num_sanity_val_steps=-1` and `check_val_every_n_epoch=1` so that the validation epoch is executed before the initial training epoch and after each subsequent training epoch. 80 | - The results of **CondSOC** are gathered in [`validation_epoch_end`](https://github.com/facebookresearch/gsbm/blob/main/gsbm/pl_model.py#L353) and stored as `train_data`, which is then used to initialize [`train_dataloader`](https://github.com/facebookresearch/gsbm/blob/main/gsbm/pl_model.py#L399-L407). We set `reload_dataloaders_every_n_epochs=1` to refreash `train_dataloader` with latest **CondSOC** results. 81 | - For multi-GPU training, we distribute **CondSOC** optimization across each device by setting `replace_sampler_ddp=False` and then instantiating [`val_dataloader`](https://github.com/facebookresearch/gsbm/blob/main/gsbm/pl_model.py#L259-L280) on each device with a different seed. 82 | - The training direction (forward or backward) is altered in [`training_epoch_end`](https://github.com/facebookresearch/gsbm/blob/main/gsbm/pl_model.py#L254), which is called _after_ the validation epoch. 83 | 84 | The overall procedure follows 85 | 86 | > [validate epoch (sanity)] **CondSOC** with random coupling \ 87 | → [training epoch #0] **Matching** forward drift \ 88 | → [validate epoch #0] **CondSOC** given forward model coupling \ 89 | → [training epoch #1] **Matching** backward drift \ 90 | → [validate epoch #1] **CondSOC** given backward model coupling \ 91 | → [training epoch #2] **Matching** forward drift \ 92 | → ... 93 | 94 | If you wish to implement GSBM for your own distribution matching tasks, we recommand fine-tuning the **CondSOC** optimization independently as in [`notebooks/example_CondSOC.ipynb`](https://github.com/facebookresearch/gsbm/blob/main/notebooks/example_CondSOC.ipynb). Once you are happy with the **CondSOC** results, you can seamlessly integrate it into the main GSBM algorithm. 95 | 96 | ## Citation 97 | If you find this repository helpful for your publications, 98 | please consider citing our paper: 99 | ``` 100 | @inproceedings{liu2024gsbm, 101 | title={{Generalized Schr{\"o}dinger bridge matching}}, 102 | author={Liu, Guan-Horng and Lipman, Yaron and Nickel, Maximilian and Karrer, Brian and Theodorou, Evangelos A and Chen, Ricky TQ}, 103 | booktitle={International Conference on Learning Representations}, 104 | year={2024} 105 | } 106 | ``` 107 | 108 | ## License 109 | The majority of `generalized-schrodinger-bridge-matching` is licensed under [CC BY-NC](LICENSE.md), however portions of the project are adapted from other sources and are under separate license terms: files from https://github.com/ghliu/deepgsb is licensed under the Apache 2.0 license, and files from https://github.com/openai/guided-diffusion are licensed under the MIT license. 110 | -------------------------------------------------------------------------------- /afhq_preprocess.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import os 4 | import math 5 | import glob 6 | from PIL import Image 7 | import numpy as np 8 | import argparse 9 | from pathlib import Path 10 | import logging 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torchvision.transforms as transforms 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | log = logging.getLogger("[afhq preprocessing]") 18 | 19 | 20 | class NumpyFolder(Dataset): 21 | def __init__(self, folder, resize): 22 | files_grabbed = [] 23 | for tt in ("*.png", "*.jpg"): 24 | files_grabbed.extend(glob.glob(os.path.join(folder, tt))) 25 | files_grabbed = sorted(files_grabbed) 26 | 27 | self.img_paths = files_grabbed 28 | 29 | self.transform = transforms.Compose( 30 | [ 31 | transforms.Resize(resize), 32 | ] 33 | ) 34 | 35 | def __len__(self): 36 | return len(self.img_paths) 37 | 38 | def __getitem__(self, idx): 39 | image_set = Image.open(self.img_paths[idx]) 40 | image_set = self.transform(image_set) 41 | image_tensor = np.asarray(image_set) 42 | return image_tensor 43 | 44 | 45 | def main(opt): 46 | for split in ("train", "val"): 47 | for ani in ("cat", "dog", "wild"): 48 | log.info(f"Extracting {split=}, {ani=}, resolution={opt.resolution} ...") 49 | 50 | dataset = NumpyFolder(opt.dir / f"{split}/{ani}", resize=opt.resolution) 51 | np_imgs = np.stack([dataset[ii] for ii in range(len(dataset))], axis=0) 52 | 53 | os.makedirs(opt.save, exist_ok=True) 54 | fn = opt.save / f"afhq{opt.resolution}-{split}-{ani}.npz" 55 | np.savez(fn, data=np_imgs) 56 | log.info(f"Saved in {fn=}!") 57 | 58 | 59 | if __name__ == "__main__": 60 | """ 61 | # PATH_TO_AFHQ="../stargan-v2/data/afhq" 62 | python afhq_preprocess.py --dir $PATH_TO_AFHQ 63 | """ 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--dir", type=Path, default=None) 66 | parser.add_argument("--save", type=Path, default="data/") 67 | parser.add_argument("--resolution", type=int, default=64) 68 | 69 | opt = parser.parse_args() 70 | main(opt) 71 | -------------------------------------------------------------------------------- /afhq_sample.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import os 4 | import copy 5 | import argparse 6 | import random 7 | from pathlib import Path 8 | from easydict import EasyDict as edict 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from torch.multiprocessing import Process 15 | from torch.utils.data import DataLoader, Subset 16 | import torchvision.utils as tu 17 | 18 | from gsbm.utils import restore_model 19 | from gsbm.dataset import get_dist_boundary 20 | 21 | import PIL.Image 22 | 23 | import colored_traceback.always 24 | from ipdb import set_trace as debug 25 | 26 | import pytorch_lightning as pl 27 | 28 | BASE_DIR = Path("outputs/multiruns/afhq") 29 | 30 | T = 20 # log_steps 31 | 32 | 33 | def main(opt, log): 34 | 35 | log(opt) 36 | 37 | ## Load model 38 | ckpt = BASE_DIR / opt.ckpt 39 | model, cfg = restore_model(ckpt, device=opt.device) 40 | model.eval() 41 | 42 | ## Build dataset 43 | dataset, start_idx, end_idx = build_data(opt, cfg) 44 | log(f"[Dataset] {opt.transfer=}, {opt.split=}, total size={len(dataset)}!") 45 | log( 46 | f"[Dataset] Built partition={opt.partition}, {start_idx=}, {end_idx=}! Now size={end_idx-start_idx}!" 47 | ) 48 | 49 | ## Sample Setup 50 | pl.utilities.seed.seed_everything(opt.seed) 51 | 52 | exp_cfg = f"{ckpt.name[:-5]}_nfe{opt.nfe}_{opt.transfer}_{opt.split}" # remove ".ckpt" at the end 53 | sample_dir = BASE_DIR / ckpt.parent.parent / "samples" / exp_cfg 54 | traj_dir = BASE_DIR / ckpt.parent.parent / "trajs" / exp_cfg 55 | os.makedirs(sample_dir, exist_ok=True) 56 | os.makedirs(traj_dir, exist_ok=True) 57 | log(f"[Sample] Samples will be store in {sample_dir}!") 58 | 59 | ## Sample 60 | trajs = [] 61 | num = 0 62 | sample_idx = torch.arange(start_idx, end_idx) 63 | for idx, batch_idx in enumerate(sample_idx.split(opt.batch)): 64 | image = torch.stack([dataset[i] for i in batch_idx], dim=0) 65 | xinit = image.reshape(len(batch_idx), -1).to(opt.device) 66 | B, D = xinit.shape 67 | 68 | direction = get_direction(opt) 69 | log(f"[Sample] Sampling ....") 70 | output = model.sample( 71 | xinit, 72 | log_steps=T, 73 | direction=direction, 74 | nfe=opt.nfe, 75 | verbose=opt.verbose and idx == 0, 76 | ) 77 | xs = output["xs"].detach().cpu() 78 | assert xs.shape == (B, T, D) 79 | 80 | ## Save images 81 | gen_image = xs[:, (-1 if direction == "fwd" else 0)].reshape_as(image) 82 | gen_image_np = ( 83 | (gen_image * 127.5 + 128) 84 | .clip(0, 255) 85 | .to(torch.uint8) 86 | .permute(0, 2, 3, 1) 87 | .cpu() 88 | .numpy() 89 | ) 90 | for idx, image_np in zip(batch_idx, gen_image_np): 91 | image_path = sample_dir / f"{idx:04d}.png" 92 | PIL.Image.fromarray(image_np, "RGB").save(image_path) 93 | 94 | trajs.append(xs) 95 | num += B 96 | log(f"Collected {num} images!") 97 | 98 | ## Save trajs 99 | all_trajs = torch.cat(trajs, axis=0) 100 | traj_path = traj_dir / f"p{opt.partition}.pt" 101 | torch.save(all_trajs, traj_path) 102 | log("Done!") 103 | 104 | 105 | def get_direction(opt): 106 | if opt.transfer == "cat2dog": 107 | return "fwd" 108 | elif opt.transfer == "dog2cat": 109 | return "bwd" 110 | else: 111 | raise ValueError() 112 | 113 | 114 | def get_init_p(opt, cfg): 115 | p0, p1, p0_val, p1_val = get_dist_boundary(cfg) 116 | if opt.transfer[:3] == "cat": 117 | p = p0 if opt.split == "train" else p0_val 118 | elif opt.transfer[:3] == "dog": 119 | p = p1 if opt.split == "train" else p1_val 120 | else: 121 | raise ValueError() 122 | return p 123 | 124 | 125 | def build_partition(opt, full_dataset): 126 | n_samples = len(full_dataset) 127 | 128 | part_idx, n_part = [int(s) for s in opt.partition.split("_")] 129 | assert part_idx < n_part and part_idx >= 0 130 | # assert n_samples % n_part == 0 131 | 132 | n_samples_per_part = n_samples // n_part 133 | start_idx = part_idx * n_samples_per_part 134 | end_idx = (part_idx + 1) * n_samples_per_part 135 | 136 | if part_idx == (n_part - 1): 137 | end_idx = n_samples 138 | 139 | return start_idx, end_idx 140 | 141 | 142 | def build_data(opt, cfg): 143 | pinit = get_init_p(opt, cfg) 144 | dataset = pinit.dataset 145 | start_idx, end_idx = build_partition(opt, dataset) 146 | return dataset, start_idx, end_idx 147 | 148 | 149 | if __name__ == "__main__": 150 | """ 151 | python afhq_sample.py --ckpt 2023.09.09/163416/0/checkpoints/epoch-015_step-400000.ckpt \ 152 | --transfer cat2dog --partition 0_4 --nfe 1000 153 | """ 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("--seed", type=int, default=0) 156 | parser.add_argument("--device", type=str, default="cuda") 157 | parser.add_argument("--nfe", type=int, default=1000) 158 | parser.add_argument("--ckpt", type=Path, default=None) 159 | parser.add_argument( 160 | "--transfer", type=str, default=None, choices=["cat2dog", "dog2cat"] 161 | ) 162 | parser.add_argument("--split", type=str, default="train", choices=["train", "val"]) 163 | parser.add_argument("--partition", type=str, default="0_1") 164 | parser.add_argument("--batch", type=int, default=512) 165 | parser.add_argument("--verbose", action="store_true") 166 | 167 | opt = parser.parse_args() 168 | 169 | assert opt.nfe > T 170 | 171 | def do_nothing(*arg): 172 | return 173 | 174 | log = print if opt.verbose else do_nothing 175 | main(opt, log) 176 | -------------------------------------------------------------------------------- /assets/afhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/generalized-schrodinger-bridge-matching/a8ab5b500dcea8b1c0df84822188f69758a08442/assets/afhq.png -------------------------------------------------------------------------------- /assets/crowd-nav.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/generalized-schrodinger-bridge-matching/a8ab5b500dcea8b1c0df84822188f69758a08442/assets/crowd-nav.gif -------------------------------------------------------------------------------- /assets/lidar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/generalized-schrodinger-bridge-matching/a8ab5b500dcea8b1c0df84822188f69758a08442/assets/lidar.png -------------------------------------------------------------------------------- /configs/experiment/afhq.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: afhq 4 | dim: 12288 5 | image_size: [3,64,64] 6 | 7 | lim: 15 8 | 9 | optim: 10 | max_epochs: 100 11 | batch_size: 16 12 | lr: 1e-4 13 | 14 | ### problem 15 | prob: 16 | name: ${name} 17 | sigma: 0.5 18 | p0: 19 | name: afhq 20 | resize: ${image_size[2]} 21 | animals: [cat] 22 | p1: 23 | name: afhq 24 | resize: ${image_size[2]} 25 | animals: [dog] 26 | 27 | ### network 28 | field: vector 29 | unet: 30 | num_head_channels: 64 31 | image_size: 64 32 | in_channels: 3 33 | model_channels: 128 # optim.batch_size=64 34 | # model_channels: 160 # optim.batch_size=64 35 | # model_channels: 192 # optim.batch_size=32 36 | out_channels: ${unet.in_channels} 37 | num_res_blocks: 4 # 3 38 | resblock_updown: true 39 | use_new_attention_order: true 40 | use_scale_shift_norm: true 41 | attention_resolutions: 42 | - 32 43 | - 16 44 | - 8 45 | dropout: 0.1 46 | channel_mult: 47 | - 1 48 | - 2 49 | - 2 #3 50 | - 2 #4 51 | num_heads: 4 52 | use_checkpoint: false #true 53 | with_fourier_features: false 54 | 55 | ### gsbm matching (Alg 1) 56 | matching: 57 | loss: bm 58 | 59 | ### gsbm conditional SOC (Alg 3 & 4) 60 | csoc: 61 | name: ${name} 62 | 63 | ## train dataloader (B * epd_fct = data size) 64 | B: 25600 # number of couplings 65 | epd_fct: 500 # times each coupling appears in each epoch 66 | 67 | ## spline param 68 | T_mean: 8 # number of knots mean spline 69 | T_gamma: 8 # number of knots gamma spline 70 | 71 | ## spline optim 72 | optim: adam # optimizer {sgd, adam} 73 | S: 30 # number of timesteps 74 | N: 4 # number of trajs per couplings 75 | lr_mean: 0.01 # lr of mean spline 76 | lr_gamma: 0.03 # lr of gamma spline 77 | nitr: 100 # optim steps 78 | mB: 64 # micro batch size FIXME tune base on spline 79 | scale_by_sigma: true # scale control by sigma (equiv to KL) 80 | weight_c: 0.2 81 | weight_s: 5.0 82 | 83 | ## impt weight 84 | IW: false 85 | 86 | vae: 87 | ckpt: ./data/vae.ckpt 88 | image_size: ${image_size} 89 | 90 | plot: 91 | name: ${name} 92 | 93 | hydra: 94 | launcher: 95 | gpus_per_node: 8 96 | constraint: volta32gb 97 | mem_per_gpu: 64gb 98 | -------------------------------------------------------------------------------- /configs/experiment/gmm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: gmm 4 | dim: 2 5 | 6 | ### problem 7 | prob: 8 | name: ${name} 9 | sigma: 1 10 | p0: 11 | name: gmm 12 | # radius: 4 13 | # num: 4 14 | mu: [[4., 0.], [0., 4.], [-4., 0.,], [0., -4.]] 15 | var: 1.0 16 | p1: 17 | name: gmm 18 | # radius: 16 19 | # num: 8 20 | mu: [[16, 0], [11.31, 11.31], [0, 16,], [-11.31, 11.31], [-16, 0], [-11.31, -11.31], [0, -16], [11.31, -11.31]] 21 | var: 1.0 22 | 23 | ### network 24 | field: vector 25 | net: toy 26 | 27 | ### gsbm matching (Alg 1) 28 | matching: 29 | loss: bm 30 | 31 | ### gsbm conditional SOC (Alg 3 & 4) 32 | csoc: 33 | name: ${name} 34 | 35 | ## train dataloader (B * epd_fct = data size) 36 | B: 5120 # number of couplings 37 | epd_fct: 500 # times each coupling appears in each epoch 38 | 39 | ## spline param 40 | T_mean: 15 # number of knots mean spline 41 | T_gamma: 30 # number of knots gamma spline 42 | 43 | ## spline optim 44 | optim: sgd # optimizer {sgd, adam} 45 | S: 100 # number of timesteps 46 | N: 4 # number of trajs per couplings 47 | lr_mean: 0.4 # lr of mean spline 48 | lr_gamma: 0.2 # lr of gamma spline 49 | nitr: 2000 # optim steps 50 | mB: 256 # micro batch size 51 | momentum: 0.0 # mSGD 52 | scale_by_sigma: true # scale control by sigma (equiv to KL) 53 | 54 | ## impt weight 55 | IW: false 56 | IW_N: ${csoc.epd_fct} 57 | IW_S: 300 58 | 59 | state_cost: 60 | type: [obs, ent] 61 | obs: 1500. # obstacle cost 62 | ent: 5. # entropy interaction cost 63 | cgst: 0. # congestion interaction cost 64 | 65 | plot: 66 | name: ${name} 67 | lim: 18 68 | -------------------------------------------------------------------------------- /configs/experiment/lidar.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: lidar 4 | dim: 3 5 | 6 | optim: 7 | max_epochs: 10 8 | 9 | ### problem 10 | prob: 11 | name: ${name} 12 | sigma: 1.0 13 | p0: 14 | name: lidarproj 15 | mu: [[-4, -2, 0.5], [-3.75, -1.125, 0.5], [-3.5, -0.25, 0.5], [-3.25, 0.675, 0.5], [-3, 1.5, 0.5]] 16 | var: 0.02 17 | lcfg: ${lidar} 18 | p1: 19 | name: lidarproj 20 | mu: [[2, -2, 0.5], [2.6, -1.25, 0.5], [3.2, -0.5, 0.5]] 21 | var: 0.03 22 | lcfg: ${lidar} 23 | 24 | ### network 25 | field: vector 26 | net: toy 27 | 28 | ### gsbm matching (Alg 1) 29 | matching: 30 | loss: bm 31 | 32 | ### gsbm conditional SOC (Alg 3 & 4) 33 | csoc: 34 | name: ${name} 35 | 36 | ## train dataloader (B * epd_fct = data size) 37 | B: 2560 # number of couplings 38 | epd_fct: 1000 # times each coupling appears in each epoch 39 | 40 | ## spline param 41 | T_mean: 30 # number of knots mean spline 42 | T_gamma: 30 # number of knots gamma spline 43 | 44 | ## spline optim 45 | optim: sgd # optimizer {sgd, adam} 46 | S: 100 # number of timesteps 47 | N: 4 # number of trajs per couplings 48 | lr_mean: 0.03 # lr of mean spline 49 | lr_gamma: 0.03 # lr of gamma spline 50 | nitr: 200 # optim steps 51 | mB: 256 # micro batch size 52 | momentum: 0.5 # mSGD 53 | scale_by_sigma: true # scale control by sigma (equiv to KL) 54 | 55 | ## impt weight 56 | IW: false 57 | IW_N: ${csoc.epd_fct} 58 | IW_S: 300 59 | 60 | ### LiDAR state cost 61 | lidar: 62 | lim: 5 63 | filename: ./data/rainier2-thin.las 64 | k: 20 65 | closeness_weight: 5000 66 | boundary_weight: 5000 67 | height_weight: 5000 68 | 69 | plot: 70 | name: ${name} 71 | lim: ${lidar.lim} 72 | -------------------------------------------------------------------------------- /configs/experiment/opinion.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: opinion 4 | dim: 1000 5 | 6 | optim: 7 | batch_size: 256 8 | val_batch_size: 256 9 | 10 | ### problem 11 | prob: 12 | name: ${name} 13 | sigma: 0.5 14 | p0: 15 | name: opinion 16 | dim: ${dim} 17 | mu: 0.0 18 | var: 0.25 19 | var_1st_dim: 4.0 20 | p1: 21 | name: opinion 22 | dim: ${dim} 23 | mu: 0.0 24 | var: 4.0 25 | 26 | ### polarize drift 27 | pdrift: 28 | D: ${dim} 29 | S: 500 # interval in deepgsb 30 | strength: 6.0 31 | m_coeff: 8.0 32 | 33 | ### network 34 | field: vector 35 | net: opinion 36 | 37 | ### gsbm matching (Alg 1) 38 | matching: 39 | loss: bm 40 | 41 | ### gsbm conditional SOC (Alg 3 & 4) 42 | csoc: 43 | name: ${name} 44 | 45 | ## train dataloader (B * epd_fct = data size) 46 | B: 10240 # number of couplings 47 | epd_fct: 1000 # times each coupling appears in each epoch 48 | 49 | ## spline param 50 | T_mean: 30 # number of knots mean spline 51 | T_gamma: 30 # number of knots gamma spline 52 | 53 | ## spline optim 54 | optim: sgd # optimizer {sgd, adam} 55 | S: 100 # number of timesteps 56 | N: 4 # number of trajs per couplings 57 | lr_mean: 0.03 # lr of mean spline 58 | lr_gamma: 0.02 # lr of gamma spline 59 | nitr: 700 # optim steps 60 | mB: 256 # micro batch size 61 | momentum: 0.0 # mSGD 62 | scale_by_sigma: true # scale control by sigma (equiv to KL) 63 | 64 | ## impt weight 65 | IW: false 66 | 67 | state_cost: 68 | type: [cgst] 69 | ent: 0. 70 | cgst: 10. 71 | 72 | nfe: 300 73 | 74 | plot: 75 | name: ${name} 76 | lim: 10 77 | 78 | hydra: 79 | launcher: 80 | gpus_per_node: 8 81 | constraint: volta32gb 82 | -------------------------------------------------------------------------------- /configs/experiment/opinion_2d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: opinion 4 | dim: 2 5 | 6 | optim: 7 | max_epochs: 4 8 | batch_size: 1024 9 | val_batch_size: 1024 10 | 11 | ### problem 12 | prob: 13 | name: ${name} 14 | sigma: 0.5 15 | p0: 16 | name: opinion 17 | dim: ${dim} 18 | mu: 0.0 19 | var: 0.25 20 | var_1st_dim: 0.5 21 | p1: 22 | name: opinion 23 | dim: ${dim} 24 | mu: 0.0 25 | var: 3.0 26 | 27 | ### polarize drift 28 | pdrift: 29 | D: ${dim} 30 | S: 500 # interval in deepgsb 31 | strength: 3.0 32 | m_coeff: 8.0 33 | 34 | ### network 35 | field: vector 36 | net: opinion 37 | 38 | ### gsbm matching (Alg 1) 39 | matching: 40 | loss: bm 41 | 42 | ### gsbm conditional SOC (Alg 3 & 4) 43 | csoc: 44 | name: ${name} 45 | 46 | ## train dataloader (B * epd_fct = data size) 47 | B: 5120 # number of couplings 48 | epd_fct: 100 # times each coupling appears in each epoch 49 | 50 | ## spline param 51 | T_mean: 30 # number of knots mean spline 52 | T_gamma: 30 # number of knots gamma spline 53 | 54 | ## spline optim 55 | optim: sgd # optimizer {sgd, adam} 56 | S: 100 # number of timesteps 57 | N: 4 # number of trajs per couplings 58 | lr_mean: 0.03 # lr of mean spline 59 | lr_gamma: 0.02 # lr of gamma spline 60 | nitr: 1000 # optim steps 61 | mB: 1024 # micro batch size 62 | momentum: 0.0 # mSGD 63 | scale_by_sigma: true # scale control by sigma (equiv to KL) 64 | 65 | ## impt weight 66 | IW: false 67 | 68 | state_cost: 69 | type: [cgst] 70 | ent: 0. 71 | cgst: 10. 72 | 73 | nfe: 300 74 | 75 | plot: 76 | name: ${name} 77 | lim: 10 78 | -------------------------------------------------------------------------------- /configs/experiment/stunnel.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: stunnel 4 | dim: 2 5 | 6 | ### problem 7 | prob: 8 | name: ${name} 9 | sigma: 1.0 10 | p0: 11 | name: gaussian 12 | mu: [-11, -1] 13 | var: 0.5 14 | p1: 15 | name: gaussian 16 | mu: [11, 1] 17 | var: 0.5 18 | 19 | ### network 20 | field: vector 21 | net: toy 22 | 23 | ### gsbm matching (Alg 1) 24 | matching: 25 | loss: bm 26 | 27 | ### gsbm conditional SOC (Alg 3 & 4) 28 | csoc: 29 | name: ${name} 30 | 31 | ## train dataloader (B * epd_fct = data size) 32 | B: 2560 # number of couplings 33 | epd_fct: 100 # times each coupling appears in each epoch 34 | 35 | ## spline param 36 | T_mean: 15 # number of knots mean spline 37 | T_gamma: 30 # number of knots gamma spline 38 | 39 | ## spline optim 40 | optim: sgd # optimizer {sgd, adam} 41 | S: 100 # number of timesteps 42 | N: 4 # number of trajs per couplings 43 | lr_mean: 0.2 # lr of mean spline 44 | lr_gamma: 0.1 # lr of gamma spline 45 | nitr: 1000 # optim steps 46 | mB: 1024 # micro batch size 47 | momentum: 0.0 # mSGD 48 | scale_by_sigma: true # scale control by sigma (equiv to KL) 49 | 50 | ## impt weight 51 | IW: false 52 | IW_N: ${csoc.epd_fct} 53 | IW_S: 300 54 | 55 | state_cost: 56 | type: [obs, cgst] 57 | obs: 1500. # obstacle cost 58 | ent: 0. # entropy interaction cost 59 | cgst: 50. # congestion interaction cost 60 | 61 | plot: 62 | name: ${name} 63 | lim: 15 64 | -------------------------------------------------------------------------------- /configs/experiment/stunnel_eam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: stunnel 4 | dim: 2 5 | 6 | optim: 7 | max_epochs: 20 8 | 9 | ### problem 10 | prob: 11 | name: ${name} 12 | sigma: 1.0 13 | p0: 14 | name: gaussian 15 | mu: [-11, -1] 16 | var: 0.5 17 | p1: 18 | name: gaussian 19 | mu: [11, 1] 20 | var: 0.5 21 | 22 | ### network 23 | field: potential 24 | net: toy 25 | 26 | ### gsbm matching (Alg 2) 27 | matching: 28 | loss: eam 29 | lap: approx # {exact, approx} 30 | batch_t: 100 31 | 32 | ### gsbm conditional SOC (Alg 3 & 4) 33 | csoc: 34 | name: ${name} 35 | 36 | ## train dataloader (B * epd_fct = data size) 37 | B: 2560 # number of couplings 38 | epd_fct: 100 # times each coupling appears in each epoch 39 | 40 | ## spline param 41 | T_mean: 15 # number of knots mean spline 42 | T_gamma: 30 # number of knots gamma spline 43 | 44 | ## spline optim 45 | optim: sgd # optimizer {sgd, adam} 46 | S: 100 # number of timesteps 47 | N: 4 # number of trajs per couplings 48 | lr_mean: 0.2 # lr of mean spline 49 | lr_gamma: 0.1 # lr of gamma spline 50 | nitr: 1000 # optim steps 51 | mB: 1024 # micro batch size 52 | momentum: 0.0 # mSGD 53 | scale_by_sigma: true # scale control by sigma (equiv to KL) 54 | 55 | ## impt weight 56 | IW: false 57 | IW_N: ${csoc.epd_fct} 58 | IW_S: 300 59 | 60 | 61 | state_cost: 62 | type: [obs, cgst] 63 | obs: 1500. # obstacle cost 64 | ent: 0. # entropy interaction cost 65 | cgst: 50. # congestion interaction cost 66 | 67 | plot: 68 | name: ${name} 69 | lim: 15 70 | -------------------------------------------------------------------------------- /configs/experiment/vneck.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: vneck 4 | dim: 2 5 | 6 | ### problem 7 | prob: 8 | name: ${name} 9 | sigma: 1.0 10 | p0: 11 | name: gaussian 12 | mu: [-7, 0] 13 | var: 0.2 14 | p1: 15 | name: gaussian 16 | mu: [7, 0] 17 | var: 0.2 18 | 19 | ### network 20 | field: vector 21 | net: toy 22 | 23 | ### gsbm matching (Alg 1) 24 | matching: 25 | loss: bm 26 | 27 | ### gsbm conditional SOC (Alg 3 & 4) 28 | csoc: 29 | name: ${name} 30 | 31 | ## train dataloader (B * epd_fct = data size) 32 | B: 2560 # number of couplings 33 | epd_fct: 100 # times each coupling appears in each epoch 34 | 35 | ## spline param 36 | T_mean: 15 # number of knots mean spline 37 | T_gamma: 30 # number of knots gamma spline 38 | 39 | ## spline optim 40 | optim: sgd # optimizer {sgd, adam} 41 | S: 100 # number of timesteps 42 | N: 4 # number of trajs per couplings 43 | lr_mean: 0.3 # lr of mean spline 44 | lr_gamma: 0.2 # lr of gamma spline 45 | nitr: 3000 # optim steps 46 | mB: 256 # micro batch size 47 | momentum: 0.0 # mSGD 48 | scale_by_sigma: true # scale control by sigma (equiv to KL) 49 | 50 | ## impt weight 51 | IW: false 52 | IW_N: ${csoc.epd_fct} 53 | IW_S: 300 54 | 55 | state_cost: 56 | type: [obs, ent] 57 | obs: 3000. # obstacle cost 58 | ent: 8. # entropy interaction cost 59 | cgst: 0. # congestion interaction cost 60 | 61 | plot: 62 | name: ${name} 63 | lim: 10 64 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - experiment: null 4 | - override hydra/launcher: submitit_slurm 5 | - override hydra/job_logging: colorlog 6 | - override hydra/hydra_logging: colorlog 7 | 8 | optim: 9 | max_epochs: 30 10 | num_workers: 2 11 | batch_size: 64 12 | val_batch_size: 1024 13 | num_iterations: 2000000 14 | 15 | scheduler: cosine 16 | grad_clip: 0.0 17 | lr: 1e-3 18 | wd: 0.0 19 | eps: 1e-8 20 | ema_decay: 0.999 21 | 22 | eval_coupling: false 23 | nfe: 1000 24 | use_wandb: false 25 | resume: null 26 | seed: 0 27 | nnodes: 1 28 | 29 | hydra: 30 | job: 31 | chdir: True 32 | run: 33 | dir: ./outputs/runs/${name}/${now:%Y.%m.%d}/${now:%H%M%S} 34 | sweep: 35 | dir: ./outputs/multiruns/${name}/${now:%Y.%m.%d}/${now:%H%M%S} 36 | subdir: ${hydra.job.num} 37 | launcher: 38 | max_num_timeout: 100000 39 | timeout_min: 4319 40 | cpus_per_task: 10 41 | gpus_per_node: 1 42 | tasks_per_node: ${hydra.launcher.gpus_per_node} 43 | nodes: ${nnodes} 44 | partition: learnlab 45 | constraint: volta16gb 46 | mem_per_gpu: 30gb 47 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gsbm 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | dependencies: 7 | - python=3.9 8 | - pytorch 9 | - pytorch-cuda=11.7 10 | - matplotlib 11 | - jupyter 12 | - numpy 13 | - pip 14 | - tqdm 15 | - torchmetrics 16 | - ipdb 17 | - pip: 18 | - submitit 19 | - pre-commit 20 | - black==22.6.0 21 | - ipykernel 22 | - torchdiffeq 23 | - scikit-learn 24 | - pytorch-lightning==1.8.5.post0 25 | - hydra-core==1.2.0 26 | - hydra-submitit-launcher==1.2.0 27 | - hydra_colorlog==1.2.0 28 | - click 29 | - wandb 30 | - biopython 31 | - pyevtk 32 | - ipympl 33 | - geomloss 34 | - colored-traceback 35 | - POT 36 | - termcolor 37 | - torchvision 38 | - laspy 39 | - gdown 40 | -------------------------------------------------------------------------------- /gsbm/dataset.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import os 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | 11 | from .utils import get_repo_path 12 | from .state_cost import LIDARStateCost 13 | from ipdb import set_trace as debug 14 | 15 | 16 | # make transforms.Lambda(lambda x: x * 2 - 1) picklable 17 | class Normalize(object): 18 | def __call__(self, img): 19 | return img * 2 - 1 20 | 21 | def __repr__(self): 22 | return self.__class__.__name__ 23 | 24 | 25 | class AFHQ(Dataset): 26 | animals = ["cat", "dog", "wild"] 27 | 28 | def __init__(self, resize, animals, split="train"): 29 | assert split in ("train", "val") 30 | 31 | np_imgs = [] 32 | for ani in animals: 33 | assert ani in self.animals 34 | out = np.load(get_repo_path() / "data" / f"afhq{resize}-{split}-{ani}.npz") 35 | np_imgs.append(out["data"]) 36 | 37 | np_imgs = np.concatenate(np_imgs, axis=0) 38 | th_imgs = torch.from_numpy(np_imgs) / 255.0 # [0, 1] 39 | 40 | self.th_imgs = th_imgs.permute(0, 3, 1, 2) 41 | transform_list = ( 42 | [ 43 | transforms.RandomHorizontalFlip(), 44 | ] 45 | if split == "train" 46 | else [] 47 | ) 48 | transform_list.append(Normalize()) 49 | self.transform = transforms.Compose(transform_list) 50 | 51 | def __len__(self): 52 | return self.th_imgs.shape[0] 53 | 54 | def __getitem__(self, idx): 55 | image_tensor = self.transform(self.th_imgs[idx]) 56 | return image_tensor 57 | 58 | 59 | class ImageSampler: 60 | def __init__(self, dataset, generator=None): 61 | self.dataset = dataset 62 | self.generator = generator 63 | 64 | def set_generator(self, generator): 65 | self.generator = generator 66 | 67 | def __call__(self, n): 68 | ii = torch.randint(0, len(self.dataset), (n,), generator=self.generator) 69 | out = torch.stack([self.dataset[i] for i in ii], dim=0) 70 | return out.reshape(n, -1) 71 | 72 | 73 | def normal_logprob(z, mean, log_std): 74 | mean = mean + torch.tensor(0.0) 75 | log_std = log_std + torch.tensor(0.0) 76 | c = torch.tensor([math.log(2 * math.pi)]).to(z) 77 | inv_sigma = torch.exp(-log_std) 78 | tmp = (z - mean) * inv_sigma 79 | return -0.5 * (tmp * tmp + 2 * log_std + c) 80 | 81 | 82 | class Gaussian: 83 | def __init__(self, mu, var, generator): 84 | self.mean = torch.tensor(mu).float() 85 | self.std = torch.tensor(var).sqrt() 86 | self.generator = generator 87 | 88 | def set_generator(self, generator): 89 | self.generator = generator 90 | 91 | def __call__(self, n): 92 | noise_shape = (n,) + self.mean.shape 93 | return ( 94 | torch.randn(*noise_shape, generator=self.generator).to(self.mean) * self.std 95 | + self.mean 96 | ) 97 | 98 | 99 | class Opinion(Gaussian): 100 | def __init__(self, dim, mu, var, var_1st_dim, generator=None): 101 | mu = mu * torch.ones(dim) 102 | var = var * torch.ones(dim) 103 | if var_1st_dim is not None: 104 | var[0] = var_1st_dim 105 | super(Opinion, self).__init__(mu, var, generator) 106 | 107 | 108 | class GaussianMM: 109 | def __init__(self, mu, var, generator=None): 110 | super().__init__() 111 | self.centers = torch.tensor(mu) 112 | self.logstd = torch.tensor(var).log() / 2.0 113 | self.K = self.centers.shape[0] 114 | self.generator = generator 115 | 116 | def set_generator(self, generator): 117 | self.generator = generator 118 | 119 | def logprob(self, x): 120 | """Computes the log probability.""" 121 | logprobs = normal_logprob( 122 | x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd 123 | ) 124 | logprobs = torch.sum(logprobs, dim=2) 125 | return torch.logsumexp(logprobs, dim=1) - math.log(self.K) 126 | 127 | def __call__(self, n_samples): 128 | idx = torch.randint(self.K, (n_samples,)).to(self.centers.device) 129 | mean = self.centers[idx] 130 | return ( 131 | torch.randn(*mean.shape, generator=self.generator).to(mean) 132 | * torch.exp(self.logstd) 133 | + mean 134 | ) 135 | 136 | 137 | class LiDARProjector: 138 | """Takes an existing dataset and projects all points onto the manifold.""" 139 | 140 | def __init__(self, dataset, lcfg): 141 | self.manifold = LIDARStateCost(lcfg) 142 | self.dataset = dataset 143 | 144 | def set_generator(self, generator): 145 | self.dataset.set_generator(generator) 146 | 147 | def __call__(self, n_samples): 148 | samples = self.dataset(n_samples) 149 | projx = self.manifold.get_tangent_proj(samples) 150 | samples = projx(samples) 151 | return samples 152 | 153 | 154 | class PairDataset(Dataset): 155 | def __init__(self, x0, x1, expand_factor=1): 156 | assert len(x0) == len(x1) 157 | self.x0 = x0 158 | self.x1 = x1 159 | self.expand_factor = expand_factor 160 | 161 | def __len__(self): 162 | return len(self.x0) * self.expand_factor 163 | 164 | def __getitem__(self, idx): 165 | return {"x0": self.x0[idx % len(self.x0)], "x1": self.x1[idx % len(self.x0)]} 166 | 167 | 168 | class SplineDataset(Dataset): 169 | def __init__(self, mean_t, mean_xt, gamma_s, gamma_xs, expand_factor=1): 170 | """ 171 | mean_t: (T,) 172 | mean_xt: (B, T, D) 173 | gamma_t: (S,) 174 | gamma_xt: (B, S, 1) 175 | """ 176 | (B, T, D), (S,) = mean_xt.shape, gamma_s.shape 177 | assert T > 3 and S > 3 178 | assert mean_t.shape == (T,) 179 | assert gamma_xs.shape == (B, S, 1) 180 | 181 | self.mean_t = mean_t.detach().cpu().clone() 182 | self.mean_xt = mean_xt.detach().cpu().clone() 183 | self.gamma_s = gamma_s.detach().cpu().clone() 184 | self.gamma_xs = gamma_xs.detach().cpu().clone() 185 | 186 | self.expand_factor = expand_factor 187 | 188 | def __len__(self): 189 | return self.mean_xt.shape[0] * self.expand_factor 190 | 191 | def __getitem__(self, idx): 192 | _idx = idx % self.mean_xt.shape[0] 193 | 194 | x0 = self.mean_xt[_idx, 0] 195 | x1 = self.mean_xt[_idx, -1] 196 | mean_xt = self.mean_xt[_idx] 197 | gamma_xs = self.gamma_xs[_idx] 198 | 199 | return { 200 | "x0": x0, 201 | "x1": x1, 202 | "mean_t": self.mean_t, 203 | "mean_xt": mean_xt, 204 | "gamma_s": self.gamma_s, 205 | "gamma_xs": gamma_xs, 206 | } 207 | 208 | 209 | class SplineIWDataset(Dataset): 210 | def __init__(self, spline_ds, IW_t, IW_xs, weights): 211 | """ 212 | IW_t: (TT,) 213 | IW_xs: (B, N, TT, D) 214 | weights: (B, N) 215 | """ 216 | B, N, TT, D = IW_xs.shape 217 | assert weights.shape == ( 218 | B, 219 | N, 220 | ) 221 | assert IW_t.shape == (TT,) 222 | 223 | self.IW_t = IW_t.detach().cpu().clone() 224 | self.IW_xs = IW_xs.detach().cpu().clone() 225 | self.weights = weights.detach().cpu().clone() 226 | 227 | self.expand_factor = spline_ds.expand_factor 228 | spline_ds.expand_factor = 1 229 | self.spline_ds = spline_ds 230 | assert len(self.spline_ds) == B 231 | assert spline_ds.mean_xt.shape[2] == D 232 | 233 | def __len__(self): 234 | return self.IW_xs.shape[0] * self.expand_factor 235 | 236 | def __getitem__(self, idx): 237 | _idx = idx % self.IW_xs.shape[0] 238 | 239 | out = self.spline_ds.__getitem__(_idx) 240 | out["IW_t"] = self.IW_t 241 | out["IW_xs"] = self.IW_xs[_idx] 242 | out["weights"] = self.weights[_idx] 243 | return out 244 | 245 | 246 | def get_sampler(p, gen=None, **kwargs): 247 | name = p.name 248 | 249 | if name == "gaussian": 250 | return Gaussian(p.mu, p.var, generator=gen) 251 | elif name == "gmm": 252 | return GaussianMM(p.mu, p.var, generator=gen) 253 | elif name == "opinion": 254 | return Opinion(p.dim, p.mu, p.var, p.get("var_1st_dim", None), generator=gen) 255 | elif name == "lidarproj": 256 | dataset = GaussianMM(p.mu, p.var, generator=gen) 257 | return LiDARProjector(dataset, p.lcfg) 258 | elif name == "afhq": 259 | dataset = AFHQ(resize=p.resize, animals=p.animals, **kwargs) 260 | return ImageSampler(dataset) 261 | else: 262 | raise ValueError(f"Unknown distribution option: {name}") 263 | 264 | 265 | def get_dist_boundary(cfg): 266 | p0 = get_sampler(cfg.prob.p0) 267 | p1 = get_sampler(cfg.prob.p1) 268 | p0_val = get_sampler(cfg.prob.p0, split="val") 269 | p1_val = get_sampler(cfg.prob.p1, split="val") 270 | return p0, p1, p0_val, p1_val 271 | -------------------------------------------------------------------------------- /gsbm/ema.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import torch 4 | 5 | 6 | class EMA(torch.nn.Module): 7 | def __init__(self, model: torch.nn.Module, decay: float = 0.999): 8 | super().__init__() 9 | self.model = model 10 | self.decay = decay 11 | 12 | # Put this in a buffer so that it gets included in the state dict 13 | self.register_buffer("num_updates", torch.tensor(0)) 14 | 15 | self.shadow_params = torch.nn.ParameterList( 16 | [ 17 | torch.nn.Parameter(p.clone().detach(), requires_grad=False) 18 | for p in model.parameters() 19 | if p.requires_grad 20 | ] 21 | ) 22 | self.backup_params = [] 23 | 24 | def train(self, mode: bool): 25 | if self.training and mode == False: 26 | # Switching from train mode to eval mode. Backup the model parameters and 27 | # overwrite with shadow params 28 | self.backup() 29 | self.copy_to_model() 30 | elif not self.training and mode == True: 31 | # Switching from eval to train mode. Restore the `backup_params` 32 | self.restore_to_model() 33 | 34 | super().train(mode) 35 | 36 | def update_ema(self): 37 | self.num_updates += 1 38 | num_updates = self.num_updates.item() 39 | decay = min(self.decay, (1 + num_updates) / (10 + num_updates)) 40 | with torch.no_grad(): 41 | params = [p for p in self.model.parameters() if p.requires_grad] 42 | for shadow, param in zip(self.shadow_params, params): 43 | shadow.sub_((1 - decay) * (shadow - param)) 44 | 45 | def forward(self, *args, **kwargs): 46 | return self.model(*args, **kwargs) 47 | 48 | def copy_to_model(self): 49 | # copy the shadow (ema) parameters to the model 50 | params = [p for p in self.model.parameters() if p.requires_grad] 51 | for shaddow, param in zip(self.shadow_params, params): 52 | param.data.copy_(shaddow.data) 53 | 54 | def backup(self): 55 | # Backup the current model parameters 56 | if len(self.backup_params) > 0: 57 | for p, b in zip(self.model.parameters(), self.backup_params): 58 | b.data.copy_(p.data) 59 | else: 60 | self.backup_params = [param.clone() for param in self.model.parameters()] 61 | 62 | def restore_to_model(self): 63 | # Restores the backed up parameters to the model. 64 | for param, backup in zip(self.model.parameters(), self.backup_params): 65 | param.data.copy_(backup.data) 66 | -------------------------------------------------------------------------------- /gsbm/evaluator.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributions as td 9 | 10 | from geomloss import SamplesLoss 11 | from ot.sliced import sliced_wasserstein_distance 12 | from .state_cost import build_obstacle_cost, congestion_cost, zero_cost_fn 13 | from .utils import get_repo_path 14 | 15 | from ipdb import set_trace as debug 16 | 17 | 18 | def build_evaluator(cfg): 19 | if cfg.prob.name in ["opinion", "afhq", "lidar"]: 20 | return DumpEvaluator() 21 | else: 22 | return CrowdNavEvaluator(cfg) 23 | 24 | 25 | def cpu_everything(*args): 26 | return [a.cpu() for a in args] if len(args) > 1 else args.cpu() 27 | 28 | 29 | def shuffle(t): 30 | """ 31 | t: (B, *) --> (B, *) 32 | """ 33 | return t[torch.randperm(t.shape[0])] 34 | 35 | 36 | class MMD_loss(nn.Module): 37 | def __init__(self, kernel_mul=2.0, kernel_num=5): 38 | super(MMD_loss, self).__init__() 39 | self.kernel_num = kernel_num 40 | self.kernel_mul = kernel_mul 41 | self.fix_sigma = None 42 | return 43 | 44 | def guassian_kernel( 45 | self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None 46 | ): 47 | n_samples = int(source.size()[0]) + int(target.size()[0]) 48 | total = torch.cat([source, target], dim=0) 49 | 50 | total0 = total.unsqueeze(0).expand( 51 | int(total.size(0)), int(total.size(0)), int(total.size(1)) 52 | ) 53 | total1 = total.unsqueeze(1).expand( 54 | int(total.size(0)), int(total.size(0)), int(total.size(1)) 55 | ) 56 | L2_distance = ((total0 - total1) ** 2).sum(2) 57 | if fix_sigma: 58 | bandwidth = fix_sigma 59 | else: 60 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples) 61 | bandwidth /= kernel_mul ** (kernel_num // 2) 62 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 63 | kernel_val = [ 64 | torch.exp(-L2_distance / bandwidth_temp) 65 | for bandwidth_temp in bandwidth_list 66 | ] 67 | return sum(kernel_val) 68 | 69 | def forward(self, source, target): 70 | batch_size = int(source.size()[0]) 71 | kernels = self.guassian_kernel( 72 | source, 73 | target, 74 | kernel_mul=self.kernel_mul, 75 | kernel_num=self.kernel_num, 76 | fix_sigma=self.fix_sigma, 77 | ) 78 | XX = kernels[:batch_size, :batch_size] 79 | YY = kernels[batch_size:, batch_size:] 80 | XY = kernels[:batch_size, batch_size:] 81 | YX = kernels[batch_size:, :batch_size] 82 | loss = torch.mean(XX + YY - XY - YX) 83 | return loss 84 | 85 | 86 | @torch.no_grad() 87 | def est_entropy_cost(xt, std=0.2): 88 | """ 89 | xt: (B, T, D) --> (B, T) 90 | """ 91 | B, T, D = xt.shape 92 | 93 | ### build B*T indep Gaussians with given std / bandwidth 94 | normals = td.Normal( 95 | xt.reshape(B * T, D), 96 | std * torch.ones(B * T, D).to(xt), 97 | ) 98 | indep_normals = td.Independent(normals, 1) 99 | 100 | ### evaluate log-prob of all `B` samples at each timestamp 101 | ### w.r.t. `B` Gaussians 102 | xxt = xt.unsqueeze(1).expand(-1, B, -1, -1) 103 | assert xxt.shape == (B, B, T, D) 104 | 105 | log_pt_01 = indep_normals.log_prob(xxt.reshape(B, B * T, D)).reshape(B, B, T) 106 | pt = log_pt_01.exp().mean(dim=1) # (B, T) 107 | 108 | log_pt = pt.log() 109 | assert not torch.isnan(log_pt).any() 110 | assert log_pt.shape == (B, T) 111 | return log_pt 112 | 113 | 114 | ########################################################## 115 | 116 | 117 | class DumpEvaluator: 118 | def __call__(self, samples): 119 | return {} 120 | 121 | 122 | class CrowdNavEvaluator: 123 | B = 1000 124 | D = 2 125 | 126 | def __init__(self, cfg) -> None: 127 | self.ccfg = cfg.csoc 128 | self.scfg = cfg.state_cost 129 | self.sigma = cfg.prob.sigma 130 | 131 | self.obstacle_cost = build_obstacle_cost(cfg.prob.name) 132 | self.sinkhorn_cfg = {"p": 2, "blur": 0.05, "scaling": 0.95} 133 | self.ref_x0, self.ref_x1 = self.build_ref_x(cfg) 134 | 135 | def build_ref_x(self, cfg): 136 | ref_fn = get_repo_path() / "data" / f"{cfg.prob.name}.pt" 137 | 138 | if not ref_fn.exists(): 139 | from .dataset import get_sampler 140 | 141 | ref_x0 = get_sampler(cfg.prob.p0)(self.B) 142 | ref_x1 = get_sampler(cfg.prob.p1)(self.B) 143 | torch.save({"ref_x0": ref_x0, "ref_x1": ref_x1}, ref_fn) 144 | print(f"Saved new reference file to {ref_fn}!") 145 | return ref_x0, ref_x1 146 | else: 147 | ref_pt = torch.load(ref_fn, map_location="cpu") 148 | return ref_pt["ref_x0"], ref_pt["ref_x1"] 149 | 150 | def boundary_metrics(self, xs): 151 | ## Resmple batch dimension if needed 152 | B, T, D = xs.shape 153 | if B < self.B: 154 | rand_idx = torch.randint(0, B, (self.B,)) 155 | xs = xs[rand_idx] 156 | elif B > self.B: 157 | rand_idx = torch.randperm(B)[: self.B] 158 | xs = xs[rand_idx] 159 | assert xs.shape == (self.B, T, D) 160 | 161 | ## Build x0, x1 162 | x0, x1 = shuffle(xs[:, 0]), shuffle(xs[:, -1]) 163 | assert x0.shape == self.ref_x0.shape == x1.shape == self.ref_x1.shape 164 | 165 | ## Compute metrics 166 | metrics = dict() 167 | metrics["SWD_0"] = sliced_wasserstein_distance(x0, self.ref_x0) 168 | metrics["SWD_1"] = sliced_wasserstein_distance(x1, self.ref_x1) 169 | 170 | sinkhorn = SamplesLoss("sinkhorn", **self.sinkhorn_cfg) 171 | metrics["Sinkhorn_0"] = sinkhorn(x0, self.ref_x0) 172 | metrics["Sinkhorn_1"] = sinkhorn(x1, self.ref_x1) 173 | 174 | mmd = MMD_loss() 175 | metrics["MMD_0"] = mmd(x0, self.ref_x0) 176 | metrics["MMD_1"] = mmd(x1, self.ref_x1) 177 | return metrics 178 | 179 | def state_costs(self, xs): 180 | (B, T, D), scfg = xs.shape, self.scfg 181 | assert "obs" in scfg.type and scfg.obs > 0 182 | 183 | cost_s = scfg.obs * self.obstacle_cost(xs) 184 | if "ent" in scfg.type and scfg.ent > 0: 185 | cost_s = cost_s + scfg.ent * est_entropy_cost(xs) 186 | elif "cgst" in scfg.type and scfg.cgst > 0: 187 | cost_s = cost_s + scfg.cgst * congestion_cost(xs) 188 | 189 | assert cost_s.shape == (B, T) 190 | return cost_s 191 | 192 | def cost_metrics(self, xs, us): 193 | B, T, D = xs.shape 194 | assert us.shape == (B, T, D) 195 | 196 | scale = (0.5 / (self.sigma**2)) if self.ccfg.scale_by_sigma else 0.5 197 | cost_c = scale * (us**2).sum(dim=-1) 198 | cost_s = self.state_costs(xs) 199 | assert cost_c.shape == cost_s.shape == (B, T) 200 | 201 | metrics = dict() 202 | metrics["control_cost"] = cost_c.mean() 203 | metrics["state_cost"] = cost_s.mean() 204 | metrics["total_cost"] = metrics["control_cost"] + metrics["state_cost"] 205 | return metrics 206 | 207 | def __call__(self, samples): 208 | """ 209 | xs: (B, T, D) 210 | us: (B, T, D) 211 | """ 212 | xs, us = cpu_everything(samples["xs"], samples["us"]) 213 | B, T, D = xs.shape 214 | assert us.shape == (B, T, D) 215 | 216 | metrics = {} 217 | metrics.update(self.boundary_metrics(xs)) 218 | metrics.update(self.cost_metrics(xs, us)) 219 | for k, v in metrics.items(): 220 | metrics[k] = v.item() 221 | return metrics 222 | -------------------------------------------------------------------------------- /gsbm/experimental.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | 5 | import torch 6 | from torchdiffeq import odeint 7 | 8 | from .sde import DIRECTIONS 9 | from .gaussian_path import EndPointGaussianPath 10 | 11 | ################################################################################################ 12 | ############################### Brownian bridge + quadratic cost ############################### 13 | ################################################################################################ 14 | 15 | 16 | # marginal sample 17 | def qt_quad_cost(t, x0, x1, sigma, alpha): 18 | 19 | B, D = x0.shape 20 | assert x1.shape == (B, D) and t.shape == (B,) 21 | 22 | eta = torch.tensor(sigma * math.sqrt(2 * alpha)) 23 | bar_eta = eta * (1 - t) 24 | 25 | ct = torch.sinh(bar_eta) / torch.sinh(eta) 26 | et = torch.sinh(bar_eta) * (1.0 / torch.tanh(bar_eta) - 1.0 / torch.tanh(eta)) 27 | 28 | mean_t = ct[..., None] * x0 + et[..., None] * x1 # (B, D) 29 | cov_t = sigma**2 * et * torch.sinh(bar_eta) / eta 30 | std_t = (cov_t).sqrt()[..., None] # (B,1) 31 | xt = mean_t + std_t * torch.randn_like(mean_t) 32 | 33 | assert xt.shape == x0.shape == x1.shape 34 | return xt 35 | 36 | 37 | # optimal drift 38 | def ut_quad_cost(t, x0, x1, xt, direction, sigma, alpha): 39 | assert direction in DIRECTIONS 40 | 41 | eta = torch.tensor(sigma * math.sqrt(2 * alpha)) 42 | 43 | if direction == "fwd": 44 | c1 = eta / torch.sinh(eta * (1 - t)) 45 | c2 = eta / torch.tanh(eta * (1 - t)) 46 | ut = c1[..., None] * x1 - c2[..., None] * xt 47 | else: 48 | c1 = eta / torch.sinh(eta * t) 49 | c2 = eta / torch.tanh(eta * t) 50 | ut = c1[..., None] * x0 - c2[..., None] * xt 51 | assert ut.shape == x0.shape == x1.shape 52 | return ut 53 | 54 | 55 | ################################################################################################ 56 | ############################### Simfree Gaussain path trajectory ############################### 57 | ################################################################################################ 58 | 59 | 60 | class EndPointGaussianPathv2(EndPointGaussianPath): 61 | odeint_kwargs = { 62 | "method": "scipy_solver", 63 | "atol": 1e-4, 64 | "rtol": 1e-7, 65 | } 66 | cov_eps = 1e-7 67 | cov_decom = "cholesky" 68 | 69 | def cov_mtx(self, t): 70 | """ 71 | t: (T, ) --> cov_mtx: (B, T, T) 72 | """ 73 | B, (T,) = self.B, t.shape 74 | 75 | Var_t = self.gamma(t).squeeze(-1) ** 2 76 | gt = odeint_gt(B, t, self.gamma, self.sigma, "cpu", self.odeint_kwargs) 77 | assert Var_t.shape == gt.shape == (B, T) 78 | 79 | Tidx, Sidx = torch.meshgrid(torch.arange(T), torch.arange(T)) 80 | min_st = torch.min(Tidx, Sidx) 81 | max_st = torch.max(Tidx, Sidx) 82 | C = (gt[:, max_st] - gt[:, min_st]).exp() * Var_t[:, min_st] 83 | assert C.shape == (B, T, T) 84 | return C + self.cov_eps 85 | 86 | def sample_xs(self, T, N, eps=0.01): 87 | """joint 88 | t: (T) 89 | xs: (B, N, T, D) 90 | """ 91 | B, D, device = self.B, self.D, self.device 92 | 93 | t = torch.linspace(eps, 1 - eps, T, device=device) 94 | 95 | ## Compute covariance & its decomposition 96 | C = self.cov_mtx(t) 97 | if self.cov_decom == "cholesky": 98 | A = torch.linalg.cholesky(C) 99 | elif self.cov_decom == "eigh": 100 | L, Q = torch.linalg.eigh(C) 101 | A = Q @ torch.diag_embed(L).sqrt() 102 | assert C.shape == A.shape == (B, T, T) 103 | 104 | ## Compute std_t and mean_t 105 | noise_t = torch.randn(N, B, T, D, device=device) 106 | std_t = (A @ noise_t).transpose(0, 1) # <-- this took most memory 107 | assert std_t.shape == (B, N, T, D) 108 | 109 | torch.cuda.empty_cache() # clear out memory 110 | 111 | mean_t = self.mean(t) # (B, T, D) 112 | assert mean_t.shape == (B, T, D) 113 | 114 | ## sim-free xs 115 | xs = mean_t.unsqueeze(1) + std_t 116 | assert xs.shape == (B, N, T, D) 117 | return t, xs 118 | 119 | 120 | def odeint_gt(B, ts, gamma, sigma, device, ocfg): 121 | """Implementation of Eq 32 122 | ts: (T,) 123 | gamma: (S,) -> (B, S) 124 | === 125 | gt: (B, T) 126 | """ 127 | (T,) = ts.shape 128 | orig_device = gamma.device 129 | 130 | def f(t, _gt): 131 | t = t.reshape(1) 132 | std, dstd = torch.autograd.functional.jvp( 133 | gamma, t, torch.ones_like(t), create_graph=False 134 | ) 135 | B, T, D = std.shape 136 | assert dstd.shape == (B, T, D) and _gt.shape == (B, T) 137 | return (dstd - sigma**2 / (2 * std)) / std 138 | 139 | ts = ts.to(device) 140 | gamma = gamma.to(device) 141 | g0 = torch.zeros(B, 1, device=device) 142 | gt = odeint(f, g0, ts, **ocfg) 143 | gt = gt.transpose(0, 1).squeeze(-1) 144 | assert gt.shape == (B, T) 145 | 146 | gamma = gamma.to(orig_device) 147 | return gt.to(orig_device) 148 | -------------------------------------------------------------------------------- /gsbm/gaussian_path.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | import copy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import trange 9 | 10 | from . import interp1d 11 | from .sde import DIRECTIONS 12 | 13 | from ipdb import set_trace as debug 14 | 15 | ################################################################################################ 16 | 17 | 18 | class EndPointSpline(torch.nn.Module): 19 | def __init__(self, t, xt, spline_type="linear"): 20 | """ 21 | t: (T,) 22 | xt: (B, T, D) 23 | """ 24 | super(EndPointSpline, self).__init__() 25 | B, T, D = xt.shape 26 | assert t.shape == (T,) and T > 2, "Need at least 3 points" 27 | assert t.device == xt.device 28 | 29 | t = t.detach().clone() 30 | xt = xt.permute(1, 0, 2).detach().clone() 31 | 32 | # fix 33 | self.B = B # number of (x0,x1) pairs 34 | self.T = T # number controlled points / time steps 35 | self.D = D # dimension 36 | self.spline_type = spline_type 37 | 38 | self.register_buffer("t", t) 39 | self.register_buffer("t_epd", t.reshape(-1, 1).expand(-1, B)) 40 | self.register_buffer("x0", xt[0].reshape(1, B, D)) 41 | self.register_buffer("x1", xt[-1].reshape(1, B, D)) 42 | self.register_parameter("knots", torch.nn.Parameter(xt[1:-1])) 43 | 44 | @property 45 | def device(self): 46 | return self.parameters().__next__().device 47 | 48 | @property 49 | def xt(self): # (B, T, D) 50 | return torch.cat([self.x0, self.knots, self.x1], dim=0).permute(1, 0, 2) 51 | 52 | def interp(self, query_t): 53 | """ 54 | query_t: (S,) --> yt: (B, S, D) 55 | """ 56 | 57 | (S,) = query_t.shape 58 | query_t = query_t.reshape(-1, 1).expand(-1, self.B) 59 | assert query_t.shape == (S, self.B) 60 | 61 | mask = None 62 | xt = torch.cat([self.x0, self.knots, self.x1], dim=0) # (T, B, D) 63 | if self.spline_type == "linear": 64 | yt = interp1d.linear_interp1d(self.t_epd, xt, mask, query_t) 65 | elif self.spline_type == "cubic": 66 | yt = interp1d.cubic_interp1d(self.t_epd, xt, mask, query_t) 67 | yt = yt.permute(1, 0, 2) 68 | assert yt.shape == (self.B, S, self.D), yt.shape 69 | return yt 70 | 71 | def forward(self, t): 72 | """ 73 | t: (S,) --> yt: (B, S, D) 74 | """ 75 | return self.interp(t) 76 | 77 | 78 | class StdSpline(EndPointSpline): 79 | def __init__(self, t, xt, sigma, spline_type="linear"): 80 | """ 81 | t: (T,) 82 | xt: (B, T, 1) 83 | """ 84 | super(StdSpline, self).__init__(t, xt, spline_type=spline_type) 85 | assert self.D == 1 86 | self.sigma = sigma 87 | self.softplus = torch.nn.Softplus() 88 | 89 | def forward(self, t): 90 | """ 91 | t: (S,) --> yt: (B, S, 1) 92 | """ 93 | base = self.sigma * (t * (1 - t)).sqrt() 94 | xt = self.interp(t) 95 | return base.reshape(1, -1, 1) * self.softplus(xt) 96 | 97 | 98 | ################################################################################################ 99 | 100 | 101 | class EndPointGaussianPath(torch.nn.Module): 102 | def __init__(self, t, xt, s, ys, sigma, basedrift): 103 | super(EndPointGaussianPath, self).__init__() 104 | 105 | (B, T, D), (S,) = xt.shape, s.shape 106 | assert t.shape == (T,) and ys.shape == (B, S, 1) 107 | 108 | self.B = B # number of (x0,x1) pairs 109 | self.T = T # number controlled points for mean spline 110 | self.S = S # number controlled points for std spline 111 | self.D = D # dimension 112 | 113 | self.sigma = sigma 114 | self.mean = EndPointSpline(t, xt) 115 | self.gamma = StdSpline(s, ys, sigma) 116 | self.basedrift = basedrift 117 | 118 | @property 119 | def device(self): 120 | return self.parameters().__next__().device 121 | 122 | @property 123 | def mean_ctl_pts(self): 124 | return self.mean.xt.detach().cpu() 125 | 126 | @property 127 | def std_ctl_pts(self): 128 | return self.gamma(self.gamma.t).detach().cpu() 129 | 130 | def sample_xt(self, t, N): 131 | """ 132 | N: number of xt for each (x0,x1) 133 | t: (T,) --> xt: (B, N, T, D) 134 | """ 135 | 136 | mean_t = self.mean(t) # (B, T, D) 137 | B, T, D = mean_t.shape 138 | 139 | assert t.shape == (T,) 140 | std_t = self.gamma(t).reshape(B, 1, T, 1) # (B, 1, T, 1) 141 | 142 | noise = torch.randn(B, N, T, D, device=t.device) # (B, N, T, D) 143 | 144 | xt = mean_t.unsqueeze(1) + std_t * noise 145 | assert xt.shape == noise.shape 146 | return xt 147 | 148 | def ft(self, t, xt, direction): 149 | """ 150 | t: (T,) 151 | xt: (B, N, T, D) 152 | === 153 | ft: (B, N, T, D) 154 | """ 155 | B, N, T, D = xt.shape 156 | assert t.shape == (T,) 157 | 158 | sign = 1.0 if direction == "fwd" else -1 159 | 160 | ft = self.basedrift( 161 | xt.reshape(B * N, T, D), 162 | t, 163 | ).reshape(B, N, T, D) 164 | return sign * ft 165 | 166 | def drift(self, t, xt, direction): 167 | """Implementation of the drift of Gaussian path in Eq 8 168 | t: (T,) 169 | xt: (B, N, T, D) 170 | === 171 | drift: (B, N, T, D) 172 | """ 173 | assert (t > 0).all() and (t < 1).all() 174 | 175 | B, N, T, D = xt.shape 176 | assert t.shape == (T,) 177 | 178 | mean, dmean = torch.autograd.functional.jvp( 179 | self.mean, t, torch.ones_like(t), create_graph=self.training 180 | ) 181 | assert mean.shape == dmean.shape == (B, T, D) 182 | 183 | dmean = dmean.reshape(B, 1, T, D) 184 | mean = mean.reshape(B, 1, T, D) 185 | 186 | std, dstd = torch.autograd.functional.jvp( 187 | self.gamma, t, torch.ones_like(t), create_graph=self.training 188 | ) 189 | assert std.shape == dstd.shape == (B, T, 1) 190 | 191 | if direction == "fwd": 192 | # u = ∂m + a (x - m), 193 | # a = (\dot γ - σ^2 / 2γ) / γ 194 | # = -1 / (1-t), if γ is the std of brownian bridge 195 | a = (dstd - self.sigma**2 / (2 * std)) / std 196 | if self.sigma == 0: 197 | a = torch.zeros_like(a) # handle deterministic cases 198 | drift = dmean + a.reshape(B, 1, T, 1) * (xt - mean) 199 | else: 200 | # u = -∂m + a (x - m), 201 | # a = (-\dot γ - σ^2 / 2γ) / γ 202 | # = -1 / t, if γ is the std of brownian bridge 203 | a = (-dstd - self.sigma**2 / (2 * std)) / std 204 | if self.sigma == 0: 205 | a = torch.zeros_like(a) # handle deterministic cases 206 | drift = -dmean + a.reshape(B, 1, T, 1) * (xt - mean) 207 | 208 | assert drift.shape == xt.shape 209 | return drift 210 | 211 | def ut(self, t, xt, direction): 212 | """ 213 | t: (T,) 214 | xt: (B, N, T, D) 215 | === 216 | ut: (B, N, T, D) 217 | """ 218 | ft = self.ft(t, xt, direction) 219 | drift = self.drift(t, xt, direction) 220 | assert drift.shape == ft.shape == xt.shape 221 | return drift - ft 222 | 223 | def forward(self, t, N, direction): 224 | """ 225 | t: (T,) 226 | === 227 | xt: (B, N, T, D) 228 | ut: (B, N, T, D) 229 | """ 230 | xt = self.sample_xt(t, N) 231 | 232 | B, N, T, D = xt.shape 233 | assert t.shape == (T,) 234 | 235 | ut = self.ut(t, xt, direction) 236 | assert ut.shape == xt.shape 237 | 238 | return xt, ut 239 | 240 | 241 | ################################################################################################ 242 | 243 | 244 | def build_loss_fn(gpath, sigma, V, ccfg): 245 | def loss_fn(t, xt, ut): 246 | B, N, T, D = xt.shape 247 | assert t.shape == (T,) and ut.shape == (B, N, T, D) 248 | 249 | cost_s = V(xt, t, gpath).reshape(B, N, T) 250 | scale = (0.5 / (sigma**2)) if ccfg.scale_by_sigma else 0.5 251 | cost_c = scale * (ut**2).sum(dim=-1) 252 | assert cost_s.shape == cost_c.shape == (B, N, T) 253 | return (cost_s + cost_c).mean() 254 | 255 | return loss_fn 256 | 257 | 258 | def build_img_loss_fn(gpath, sigma, V, ccfg): 259 | 260 | ### define inputs for VAE 261 | x0 = gpath.mean.xt[:, 0].detach() 262 | x1 = gpath.mean.xt[:, -1].detach() 263 | recon_xt, _ = V.latent_interp(x0, x1, ccfg.S) 264 | 265 | def loss_fn(t, xt, ut): 266 | B, N, T, D = xt.shape 267 | assert T == ccfg.S 268 | assert t.shape == (T,) and ut.shape == (B, N, T, D) 269 | 270 | cost_s = V(xt, t, recon_xt.detach()).reshape(B, N, T) 271 | scale = (0.5 / (sigma**2)) if ccfg.scale_by_sigma else 0.5 272 | cost_c = scale * (ut**2).mean(dim=-1) 273 | assert cost_s.shape == cost_c.shape == (B, N, T) 274 | return (ccfg.weight_s * cost_s + ccfg.weight_c * cost_c).mean() 275 | 276 | return loss_fn 277 | 278 | 279 | def build_optim(gpath, ccfg): 280 | if ccfg.optim == "sgd": 281 | return torch.optim.SGD( 282 | [ 283 | {"params": gpath.mean.parameters(), "lr": ccfg.lr_mean}, 284 | {"params": gpath.gamma.parameters(), "lr": ccfg.lr_gamma}, 285 | ], 286 | momentum=ccfg.momentum, 287 | ) 288 | elif ccfg.optim == "adam": 289 | return torch.optim.Adam( 290 | [ 291 | {"params": gpath.mean.parameters(), "lr": ccfg.lr_mean}, 292 | {"params": gpath.gamma.parameters(), "lr": ccfg.lr_gamma}, 293 | ], 294 | ) 295 | else: 296 | raise ValueError(f"Unsupported Spline optimizer {ccfg.optim}!") 297 | 298 | 299 | def fit(ccfg, gpath, direction, loss_fn, eps=0.001, verbose=False): 300 | """ 301 | V: xt: (*, T, D), t: (T,), gpath --> (*, T) 302 | """ 303 | assert direction in DIRECTIONS 304 | 305 | results = {"name": ccfg.name} 306 | results["init_mean"] = gpath.mean_ctl_pts 307 | results["init_gamma"] = gpath.std_ctl_pts 308 | 309 | ### setup 310 | B, D, N, T, device = gpath.B, gpath.D, ccfg.N, ccfg.S, gpath.device 311 | optim = build_optim(gpath, ccfg) 312 | 313 | ### optimize spline 314 | gpath.train() 315 | losses = np.zeros(ccfg.nitr) 316 | bar = trange(ccfg.nitr) if verbose else range(ccfg.nitr) 317 | for itr in bar: 318 | optim.zero_grad() 319 | 320 | t = torch.linspace(eps, 1 - eps, T, device=device) 321 | xt, ut = gpath(t, N, direction) 322 | assert xt.shape == ut.shape == (B, N, T, D) 323 | 324 | loss = loss_fn(t, xt, ut) 325 | 326 | loss.backward() 327 | optim.step() 328 | losses[itr] = loss.cpu().item() 329 | if verbose: 330 | bar.set_description(f"loss={losses[itr]}") 331 | 332 | gpath.eval() 333 | 334 | results["final_mean"] = gpath.mean_ctl_pts 335 | results["final_gamma"] = gpath.std_ctl_pts 336 | results["gpath"] = copy.deepcopy(gpath).cpu() 337 | results["losses"] = losses 338 | 339 | return results 340 | -------------------------------------------------------------------------------- /gsbm/interp1d.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def linear_interp1d(t, xt, mask, s): 8 | """Linear splines. 9 | B: batch, T: timestep, D: dim, S: query timestep 10 | Inputs: 11 | t: (T, B) 12 | xt: (T, B, D) 13 | mask: (T, B) 14 | s: (S, B) 15 | Outputs: 16 | xs: (S, B, D) 17 | """ 18 | T, N, D = xt.shape 19 | S = s.shape[0] 20 | 21 | if mask is None: 22 | mask = torch.ones_like(t).bool() 23 | 24 | m = (xt[1:] - xt[:-1]) / (t[1:] - t[:-1] + 1e-10).unsqueeze(-1) 25 | 26 | left = torch.searchsorted(t[1:].T.contiguous(), s.T.contiguous(), side="left").T 27 | mask_l = F.one_hot(left, T).permute(0, 2, 1).reshape(S, T, N, 1) 28 | 29 | t = t.reshape(1, T, N, 1) 30 | xt = xt.reshape(1, T, N, D) 31 | m = m.reshape(1, T - 1, N, D) 32 | s = s.reshape(S, N, 1) 33 | 34 | x0 = torch.sum(t * mask_l, dim=1) 35 | p0 = torch.sum(xt * mask_l, dim=1) 36 | m0 = torch.sum(m * mask_l[:, :-1], dim=1) 37 | 38 | t = s - x0 39 | 40 | return t * m0 + p0 41 | 42 | 43 | def cubic_interp1d(t, xt, mask, s): 44 | """ 45 | Inputs: 46 | t: (T, N) 47 | xt: (T, N, D) 48 | mask: (T, N) 49 | s: (S, N) 50 | """ 51 | T, N, D = xt.shape 52 | S = s.shape[0] 53 | 54 | if t.shape == s.shape: 55 | if torch.linalg.norm(t - s) == 0: 56 | return xt 57 | 58 | if mask is None: 59 | mask = torch.ones_like(t).bool() 60 | 61 | mask = mask.unsqueeze(-1) 62 | 63 | fd = (xt[1:] - xt[:-1]) / (t[1:] - t[:-1] + 1e-10).unsqueeze(-1) 64 | # Set tangents for the interior points. 65 | m = torch.cat([(fd[1:] + fd[:-1]) / 2, torch.zeros_like(fd[0:1])], dim=0) 66 | # Set tangent for the right end point. 67 | m = torch.where(torch.cat([mask[2:], torch.zeros_like(mask[0:1])]), m, fd) 68 | # Set tangent for the left end point. 69 | m = torch.cat([fd[[0]], m], dim=0) 70 | 71 | mask = mask.squeeze(-1) 72 | 73 | left = torch.searchsorted(t[1:].T.contiguous(), s.T.contiguous(), side="left").T 74 | right = (left + 1) % mask.sum(0).long() 75 | mask_l = F.one_hot(left, T).permute(0, 2, 1).reshape(S, T, N, 1) 76 | mask_r = F.one_hot(right, T).permute(0, 2, 1).reshape(S, T, N, 1) 77 | 78 | t = t.reshape(1, T, N, 1) 79 | xt = xt.reshape(1, T, N, D) 80 | m = m.reshape(1, T, N, D) 81 | s = s.reshape(S, N, 1) 82 | 83 | x0 = torch.sum(t * mask_l, dim=1) 84 | x1 = torch.sum(t * mask_r, dim=1) 85 | p0 = torch.sum(xt * mask_l, dim=1) 86 | p1 = torch.sum(xt * mask_r, dim=1) 87 | m0 = torch.sum(m * mask_l, dim=1) 88 | m1 = torch.sum(m * mask_r, dim=1) 89 | 90 | dx = x1 - x0 91 | t = (s - x0) / (dx + 1e-10) 92 | 93 | return ( 94 | t**3 * (2 * p0 + m0 - 2 * p1 + m1) 95 | + t**2 * (-3 * p0 + 3 * p1 - 2 * m0 - m1) 96 | + t * m0 97 | + p0 98 | ) 99 | -------------------------------------------------------------------------------- /gsbm/match_loss.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | import copy 5 | 6 | import numpy as np 7 | import torch 8 | from torch.func import vmap, grad, jacrev 9 | 10 | from .sde import DIRECTIONS 11 | 12 | from ipdb import set_trace as debug 13 | 14 | 15 | def laplacian(s): 16 | """Accepts a function s:R^D -> R.""" 17 | H = jacrev(grad(s)) 18 | return lambda x, t: torch.trace(H(x, t)) 19 | 20 | 21 | # entropic action matching loss + analytic laplacian 22 | def _eam_inter_exact(net, xt, t, sigma): 23 | 24 | st = net(xt, t) 25 | dsdx, dsdt = torch.autograd.grad(st.sum(), (xt, t), create_graph=True) 26 | 27 | control = 0.5 * (dsdx**2).sum(dim=1, keepdim=True) 28 | dsdt = dsdt.reshape(-1, 1) 29 | 30 | lap = vmap(laplacian(lambda x, t: net(x, t).sum()), in_dims=(0, 0))(xt, t) 31 | ent_term = sigma**2 / 2 * lap.reshape(-1, 1) 32 | 33 | return control, dsdt, ent_term 34 | 35 | 36 | def _eam_inter_approx(net, xt, t, sigma): 37 | 38 | dsdt_fn = grad(lambda x, t: net(x, t).sum(), argnums=1) 39 | dsdx_fn = grad(lambda x, t: net(x, t).sum(), argnums=0) 40 | 41 | eps = torch.randint(low=0, high=2, size=xt.shape).to(xt.device).float() * 2 - 1 42 | dsdx, jvp_val = torch.autograd.functional.jvp( 43 | lambda x: dsdx_fn(x, t), (xt,), (eps,), create_graph=True 44 | ) 45 | lap = (jvp_val * eps).sum(1, keepdims=True) 46 | 47 | dsdt = dsdt_fn(xt, t) 48 | 49 | control = 0.5 * (dsdx**2).sum(dim=1, keepdim=True) 50 | dsdt = dsdt.reshape(-1, 1) 51 | ent_term = 0.5 * sigma**2 * lap 52 | 53 | return control, dsdt, ent_term 54 | 55 | 56 | # entropic action matching loss with additional traj dimension 57 | def eam_loss_trajs(net, xs, t, x0, x1, sigma, direction, lap="approx"): 58 | assert direction in DIRECTIONS 59 | 60 | B, T, D = xs.shape 61 | assert x0.shape == x1.shape == (B, D) 62 | assert t.shape == (T,) 63 | 64 | # boundary terms 65 | s0 = net(x0, torch.zeros(B, device=x0.device)) 66 | s1 = net(x1, torch.ones(B, device=x1.device)) 67 | 68 | # intermidiate terms 69 | xs.requires_grad_(True) 70 | t.requires_grad_(True) 71 | 72 | xs_N = xs.reshape(-1, D) 73 | t_N = t.reshape(1, T).expand(B, -1).reshape(-1) 74 | 75 | if lap == "approx": 76 | control, dsdt, ent_term = _eam_inter_approx(net, xs_N, t_N, sigma) 77 | elif lap == "exact": 78 | control, dsdt, ent_term = _eam_inter_exact(net, xs_N, t_N, sigma) 79 | else: 80 | raise ValueError(f"Unsupported analytic qt: {lap}!") 81 | 82 | control = control.reshape(B, T) 83 | dsdt = dsdt.reshape(B, T) 84 | ent_term = ent_term.reshape(B, T) 85 | 86 | ### reweight loss 87 | # if direction == "fwd": 88 | # w_t_fn = lambda t: t 89 | # dwdt_fn = lambda t: 1 90 | # return (w_t_fn(0) * s0 - w_t_fn(1) * s1 + w_t_fn(t) * inter + dwdt_fn(t) * net(xt, t)).mean() 91 | # else: 92 | # w_t_fn = lambda t: -t 93 | # dwdt_fn = lambda t: -1 94 | # return (-w_t_fn(0) * s0 + w_t_fn(1) * s1 + w_t_fn(t) * inter + dwdt_fn(t) * net(xt, t)).mean() 95 | 96 | if direction == "fwd": 97 | return (s0 - s1 + (control + dsdt + ent_term).mean(dim=1)).mean() 98 | elif direction == "bwd": 99 | return (-s0 + s1 + (control - dsdt + ent_term).mean(dim=1)).mean() 100 | else: 101 | raise ValueError(f"Unsupported direction option: {direction}!") 102 | 103 | 104 | # bridge matching loss 105 | def bm_loss(drift, xt, t, vt): 106 | pred_vt = drift(xt, t) 107 | assert pred_vt.shape == vt.shape == xt.shape 108 | return torch.square(pred_vt - vt).mean() 109 | -------------------------------------------------------------------------------- /gsbm/network.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .ema import EMA 7 | from .nn import ( 8 | timestep_embedding, 9 | Unbatch, 10 | SiLU, 11 | ResNet_FC, 12 | ) 13 | from ipdb import set_trace as debug 14 | 15 | 16 | def build_net(cfg): 17 | if hasattr(cfg, "unet"): 18 | field = UNetVectorField(cfg.unet) 19 | else: 20 | field = { 21 | "toy-potential": ToyPotentialField, 22 | "toy-vector": ToyVectorField, 23 | "opinion-vector": OpinionVectorField, 24 | }.get(f"{cfg.net}-{cfg.field}")(cfg.dim) 25 | return EMA(Unbatch(field), cfg.optim.ema_decay) 26 | 27 | 28 | class ToyPotentialField(nn.Module): 29 | def __init__(self, data_dim: int = 2, hidden_dim: int = 128): 30 | super(ToyPotentialField, self).__init__() 31 | 32 | self.xt_module = ResNet_FC(data_dim + 1, hidden_dim, num_res_blocks=3) 33 | 34 | self.out_module = nn.Sequential( 35 | nn.Linear(hidden_dim, hidden_dim), 36 | SiLU(), 37 | nn.Linear(hidden_dim, 1), 38 | ) 39 | 40 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 41 | """ 42 | x: (b,nx) 43 | t: (b,) 44 | """ 45 | h = torch.hstack([t.reshape(-1, 1), x]) 46 | h = self.xt_module(h) 47 | out = self.out_module(h) 48 | return out 49 | 50 | 51 | class ToyVectorField(nn.Module): 52 | def __init__( 53 | self, 54 | data_dim: int = 2, 55 | hidden_dim: int = 128, 56 | time_embed_dim: int = 128, 57 | step_scale: int = 1000, 58 | ): 59 | super(ToyVectorField, self).__init__() 60 | 61 | self.step_scale = step_scale 62 | self.time_embed_dim = time_embed_dim 63 | hid = hidden_dim 64 | 65 | self.t_module = nn.Sequential( 66 | nn.Linear(self.time_embed_dim, hid), 67 | SiLU(), 68 | nn.Linear(hid, hid), 69 | ) 70 | 71 | self.x_module = nn.Sequential( 72 | nn.Linear(data_dim, hid), 73 | SiLU(), 74 | nn.Linear(hid, hid), 75 | SiLU(), 76 | nn.Linear(hid, hid), 77 | ) 78 | 79 | self.out_module = nn.Sequential( 80 | nn.Linear(hid, hid), 81 | SiLU(), 82 | nn.Linear(hid, data_dim), 83 | ) 84 | 85 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 86 | """ 87 | x: (b,nx) 88 | t: (b,) 89 | """ 90 | 91 | steps = t * self.step_scale 92 | t_emb = timestep_embedding(steps, self.time_embed_dim) 93 | t_out = self.t_module(t_emb) 94 | x_out = self.x_module(x) 95 | out = self.out_module(x_out + t_out) 96 | 97 | return out 98 | 99 | 100 | class OpinionVectorField(nn.Module): 101 | def __init__( 102 | self, data_dim=1000, hidden_dim=256, time_embed_dim=128, step_scale=1000 103 | ): 104 | super(OpinionVectorField, self).__init__() 105 | 106 | self.step_scale = step_scale 107 | self.time_embed_dim = time_embed_dim 108 | hid = hidden_dim 109 | 110 | self.t_module = nn.Sequential( 111 | nn.Linear(time_embed_dim, hid), 112 | SiLU(), 113 | nn.Linear(hid, hid), 114 | ) 115 | self.x_module = ResNet_FC(data_dim, hid, num_res_blocks=5) 116 | 117 | self.out_module = nn.Sequential( 118 | nn.Linear(hid, hid), 119 | SiLU(), 120 | nn.Linear(hid, data_dim), 121 | ) 122 | 123 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 124 | """ 125 | Apply the model to an input batch. 126 | :param x: an [N x C x ...] Tensor of inputs. 127 | :param t: a 1-D batch of timesteps. 128 | """ 129 | 130 | t = t * self.step_scale 131 | t_emb = timestep_embedding(t, self.time_embed_dim) 132 | t_out = self.t_module(t_emb) 133 | x_out = self.x_module(x) 134 | out = self.out_module(x_out + t_out) 135 | 136 | return out 137 | 138 | 139 | class UNetVectorField(nn.Module): 140 | def __init__(self, cfg, timesteps=1000): 141 | super(UNetVectorField, self).__init__() 142 | 143 | from .unet import UNetModel 144 | 145 | self.net = UNetModel(**cfg) 146 | self.timesteps = timesteps 147 | 148 | def forward(self, x, t) -> torch.Tensor: 149 | """ 150 | x: (b,nx) range: [-1,1] 151 | t: (b,) timesteps 152 | """ 153 | B, D = x.shape 154 | assert t.shape == (B,) 155 | assert D == 3 * 64 * 64 156 | 157 | batch = {} 158 | batch["noisy_x"] = x.reshape(B, 3, 64, 64) 159 | timestep = t * self.timesteps 160 | 161 | return self.net(batch, timestep).reshape(B, D) 162 | -------------------------------------------------------------------------------- /gsbm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | Taken from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 13 | class SiLU(nn.Module): 14 | def forward(self, x): 15 | return x * th.sigmoid(x) 16 | 17 | 18 | class GroupNorm32(nn.GroupNorm): 19 | def forward(self, x): 20 | return super().forward(x.float()).type(x.dtype) 21 | 22 | 23 | class MaskMixin: 24 | pass 25 | 26 | 27 | class MaskedConv1d(nn.Conv1d, MaskMixin): 28 | def forward(self, x, mask=None, **kwargs): 29 | if mask is not None: 30 | x = x * mask 31 | return super().forward(x) 32 | 33 | 34 | class MaskedConv2d(nn.Conv2d, MaskMixin): 35 | def forward(self, x, mask=None, **kwargs): 36 | if mask is not None: 37 | x = x * mask 38 | return super().forward(x) 39 | 40 | 41 | class MaskedConv3d(nn.Conv3d, MaskMixin): 42 | def forward(self, x, mask=None, **kwargs): 43 | if mask is not None: 44 | x = x * mask 45 | return super().forward(x) 46 | 47 | 48 | def conv_nd(dims, *args, **kwargs): 49 | """ 50 | Create a 1D, 2D, or 3D convolution module. 51 | """ 52 | if dims == 1: 53 | return MaskedConv1d(*args, **kwargs) 54 | elif dims == 2: 55 | return MaskedConv2d(*args, **kwargs) 56 | elif dims == 3: 57 | return MaskedConv3d(*args, **kwargs) 58 | raise ValueError(f"unsupported dimensions: {dims}") 59 | 60 | 61 | def linear(*args, **kwargs): 62 | """ 63 | Create a linear module. 64 | """ 65 | return nn.Linear(*args, **kwargs) 66 | 67 | 68 | def avg_pool_nd(dims, *args, **kwargs): 69 | """ 70 | Create a 1D, 2D, or 3D average pooling module. 71 | """ 72 | if dims == 1: 73 | return nn.AvgPool1d(*args, **kwargs) 74 | elif dims == 2: 75 | return nn.AvgPool2d(*args, **kwargs) 76 | elif dims == 3: 77 | return nn.AvgPool3d(*args, **kwargs) 78 | raise ValueError(f"unsupported dimensions: {dims}") 79 | 80 | 81 | def update_ema(target_params, source_params, rate=0.99): 82 | """ 83 | Update target parameters to be closer to those of source parameters using 84 | an exponential moving average. 85 | :param target_params: the target parameter sequence. 86 | :param source_params: the source parameter sequence. 87 | :param rate: the EMA rate (closer to 1 means slower). 88 | """ 89 | for targ, src in zip(target_params, source_params): 90 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 91 | 92 | 93 | def zero_module(module): 94 | """ 95 | Zero out the parameters of a module and return it. 96 | """ 97 | for p in module.parameters(): 98 | p.detach().zero_() 99 | return module 100 | 101 | 102 | def scale_module(module, scale): 103 | """ 104 | Scale the parameters of a module and return it. 105 | """ 106 | for p in module.parameters(): 107 | p.detach().mul_(scale) 108 | return module 109 | 110 | 111 | def mean_flat(tensor): 112 | """ 113 | Take the mean over all non-batch dimensions. 114 | """ 115 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 116 | 117 | 118 | class LayerNorm(th.nn.LayerNorm): 119 | def forward(self, x): 120 | d = x.ndim - 1 121 | return super().forward(x.transpose(1, d)).transpose(1, d) 122 | 123 | 124 | def normalization(channels, normalization_type: str = "group_norm"): 125 | """ 126 | Make a standard normalization layer. 127 | :param channels: number of input channels. 128 | :return: an nn.Module for normalization. 129 | """ 130 | if normalization_type == "group_norm": 131 | return GroupNorm32(32, channels) 132 | elif normalization_type == "layer_norm": 133 | return LayerNorm(channels) 134 | else: 135 | raise ValueError("Unknown normalization type!") 136 | 137 | 138 | def timestep_embedding(timesteps, dim, max_period=10000): 139 | """ 140 | Create sinusoidal timestep embeddings. 141 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 142 | These may be fractional. 143 | :param dim: the dimension of the output. 144 | :param max_period: controls the minimum frequency of the embeddings. 145 | :return: an [N x dim] Tensor of positional embeddings. 146 | """ 147 | half = dim // 2 148 | freqs = th.exp( 149 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 150 | ).to(device=timesteps.device) 151 | args = timesteps[:, None].float() * freqs[None] 152 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 153 | if dim % 2: 154 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 155 | return embedding 156 | 157 | 158 | def checkpoint(func, inputs, kwargs, params, flag): 159 | """ 160 | Evaluate a function without caching intermediate activations, allowing for 161 | reduced memory at the expense of extra compute in the backward pass. 162 | :param func: the function to evaluate. 163 | :param inputs: the argument sequence to pass to `func`. 164 | :param params: a sequence of parameters `func` depends on but does not 165 | explicitly take as arguments. 166 | :param flag: if False, disable gradient checkpointing. 167 | """ 168 | if flag: 169 | # Use pytorch's activation checkpointing. This has support for fp16 autocast 170 | return th.utils.checkpoint.checkpoint(func, *inputs, **kwargs) 171 | # args = tuple(inputs) + tuple(params) 172 | # return CheckpointFunction.apply(func, len(inputs), *args) 173 | else: 174 | return func(*inputs, **kwargs) 175 | 176 | 177 | class CheckpointFunction(th.autograd.Function): 178 | @staticmethod 179 | def forward(ctx, run_function, length, *args): 180 | ctx.run_function = run_function 181 | ctx.input_tensors = list(args[:length]) 182 | ctx.input_params = list(args[length:]) 183 | with th.no_grad(): 184 | output_tensors = ctx.run_function(*ctx.input_tensors) 185 | return output_tensors 186 | 187 | @staticmethod 188 | def backward(ctx, *output_grads): 189 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 190 | with th.enable_grad(): 191 | # Fixes a bug where the first op in run_function modifies the 192 | # Tensor storage in place, which is not allowed for detach()'d 193 | # Tensors. 194 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 195 | output_tensors = ctx.run_function(*shallow_copies) 196 | input_grads = th.autograd.grad( 197 | output_tensors, 198 | ctx.input_tensors + ctx.input_params, 199 | output_grads, 200 | allow_unused=True, 201 | ) 202 | del ctx.input_tensors 203 | del ctx.input_params 204 | del output_tensors 205 | return (None, None) + input_grads 206 | 207 | 208 | class Unbatch(nn.Module): 209 | def __init__(self, net): 210 | super().__init__() 211 | self.net = net 212 | 213 | def forward(self, x, t, *args, **kwargs): 214 | has_batch = x.ndim > 1 215 | if not has_batch: 216 | x = x.reshape(1, -1) # (1,nx) 217 | t = t.reshape(-1) # (1,) 218 | v = self.net(x, t, *args, **kwargs) 219 | if not has_batch: 220 | v = v[0] 221 | return v 222 | 223 | 224 | class ResNet_FC(nn.Module): 225 | def __init__(self, data_dim, hidden_dim, num_res_blocks): 226 | super().__init__() 227 | self.hidden_dim = hidden_dim 228 | self.map = nn.Linear(data_dim, hidden_dim) 229 | self.res_blocks = nn.ModuleList( 230 | [self.build_res_block() for _ in range(num_res_blocks)] 231 | ) 232 | 233 | def build_linear(self, in_features, out_features): 234 | linear = nn.Linear(in_features, out_features) 235 | return linear 236 | 237 | def build_res_block(self): 238 | hid = self.hidden_dim 239 | layers = [] 240 | widths = [hid] * 4 241 | for i in range(len(widths) - 1): 242 | layers.append(self.build_linear(widths[i], widths[i + 1])) 243 | layers.append(SiLU()) 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | h = self.map(x) 248 | for res_block in self.res_blocks: 249 | h = (h + res_block(h)) / 2 250 | return h 251 | -------------------------------------------------------------------------------- /gsbm/opinion.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from ipdb import set_trace as debug 8 | 9 | 10 | def t_to_idx(t: torch.Tensor, T: int) -> torch.Tensor: 11 | return (t * (T - 1)).round().long() 12 | 13 | 14 | @torch.no_grad() 15 | def est_directional_similarity(xs: torch.Tensor, n_est: int = 1000) -> torch.Tensor: 16 | """xs: (batch, nx). Returns (n_est, ) between 0 and 1.""" 17 | # xs: (batch, nx) 18 | batch, nx = xs.shape 19 | 20 | # Center first. 21 | xs = xs - torch.mean(xs, dim=0, keepdim=True) 22 | 23 | rand_idxs1 = torch.randint(batch, [n_est], dtype=torch.long) 24 | rand_idxs2 = torch.randint(batch, [n_est], dtype=torch.long) 25 | 26 | # (n_est, nx) 27 | xs1 = xs[rand_idxs1] 28 | # (n_est, nx) 29 | xs2 = xs[rand_idxs2] 30 | 31 | # Normalize to unit vector. 32 | xs1 /= torch.linalg.norm(xs1, dim=1, keepdim=True) 33 | xs2 /= torch.linalg.norm(xs2, dim=1, keepdim=True) 34 | 35 | # (n_est, ) 36 | cos_angle = torch.sum(xs1 * xs2, dim=1).clip(-1.0, 1.0) 37 | assert cos_angle.shape == (n_est,) 38 | 39 | # Should be in [0, pi). 40 | angle = torch.acos(cos_angle) 41 | assert (0 <= angle).all() 42 | assert (angle <= torch.pi).all() 43 | 44 | D_ij = 1.0 - angle / torch.pi 45 | assert D_ij.shape == (n_est,) 46 | 47 | return D_ij 48 | 49 | 50 | def opinion_thresh(inner: torch.Tensor) -> torch.Tensor: 51 | return 2.0 * (inner > 0) - 1.0 52 | 53 | 54 | def compute_mean_drift_term(mf_x: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: 55 | """Decompose the polarize dynamic Eq (18) in paper into 2 parts for faster computation: 56 | f_polarize(x,p,ξ) 57 | = E_{y~p}[a(x,y,ξ) * bar_y], where a(x,y,ξ) = sign()*sign() 58 | and bar_y = y / |y|^{0.5} 59 | = sign() * E_{y~p}[sign() * bar_y], since sign() is independent of y 60 | = A(x,ξ) * B(p,ξ) 61 | Hence, bar_f_polarize = bar_A(x,ξ) * bar_B(p,ξ) 62 | This function computes only bar_B(p,ξ). 63 | """ 64 | # mf_x: (B, *, D), xi: (*, D) 65 | # output: (*, D) 66 | 67 | B, Ts, D = mf_x.shape[0], mf_x.shape[1:-1], mf_x.shape[-1] 68 | assert xi.shape == (*Ts, D) 69 | 70 | mf_x_norm = torch.linalg.norm(mf_x, dim=-1, keepdim=True) 71 | assert torch.all(mf_x_norm > 0.0) 72 | 73 | normalized_mf_x = mf_x / torch.sqrt(mf_x_norm) 74 | assert normalized_mf_x.shape == (B, *Ts, D) 75 | 76 | # Compute the mean drift term: 1/J sum_j a(y_j) y_j / sqrt(| y_j |). 77 | mf_agree_j = opinion_thresh(torch.sum(mf_x * xi, dim=-1, keepdim=True)) 78 | assert mf_agree_j.shape == (B, *Ts, 1) 79 | 80 | mean_drift_term = torch.mean(mf_agree_j * normalized_mf_x, dim=0) 81 | assert mean_drift_term.shape == (*Ts, D) 82 | 83 | mean_drift_term_norm = torch.linalg.norm(mean_drift_term, dim=-1, keepdim=True) 84 | mean_drift_term = mean_drift_term / torch.sqrt(mean_drift_term_norm) 85 | assert mean_drift_term.shape == (*Ts, D) 86 | 87 | return mean_drift_term 88 | 89 | 90 | def opinion_f( 91 | x: torch.Tensor, mf_drift: torch.Tensor, xi: torch.Tensor 92 | ) -> torch.Tensor: 93 | """This function computes the polarize dynamic in Eq (18) by 94 | bar_f_polarize(x,p,ξ) = bar_A(x,ξ) * bar_B(p,ξ) 95 | where bar_B(p,ξ) is pre-computed in func compute_mean_drift_term and passed in as mf_drift. 96 | """ 97 | # x: (b, T, nx), mf_drift: (T, nx), xi: (T, nx) 98 | # out: (b, T, nx) 99 | 100 | b, T, nx = x.shape 101 | assert xi.shape == mf_drift.shape == (T, nx) 102 | 103 | agree_i = opinion_thresh(torch.sum(x * xi, dim=-1, keepdim=True)) 104 | # Make sure we are not dividing by 0. 105 | agree_i[agree_i == 0] = 1.0 106 | 107 | abs_sqrt_agree_i = torch.sqrt(torch.abs(agree_i)) 108 | assert torch.all(abs_sqrt_agree_i > 0.0) 109 | 110 | norm_agree_i = agree_i / abs_sqrt_agree_i 111 | assert norm_agree_i.shape == (b, T, 1) 112 | 113 | f = norm_agree_i * mf_drift 114 | assert f.shape == (b, T, nx) 115 | 116 | return f 117 | 118 | 119 | def build_f_mul(T, coeff=8.0) -> torch.Tensor: 120 | # set f_mul with some heuristic so that it doesn't diverge exponentially fast 121 | # and yield bad normalization, since the more polarized the opinion is the faster it will grow 122 | ts = torch.linspace(0.0, 1.0, T) 123 | f_mul = torch.clip(1.0 - torch.exp(coeff * (ts - 1.0)) + 1e-5, min=1e-4, max=1.0) 124 | f_mul = f_mul**5.0 125 | return f_mul 126 | 127 | 128 | def build_xis(T, D) -> torch.Tensor: 129 | # Generate random unit vectors. 130 | rng = np.random.default_rng(seed=4078213) 131 | xis = rng.standard_normal([T, D]) 132 | 133 | # Construct a xis that has some degree of "continuous" over time, as a brownian motion. 134 | xi = xis[0] 135 | bm_xis = [xi] 136 | std = 0.4 137 | dt = 1.0 / T 138 | for t in range(1, T): 139 | xi = xi - (2.0 * xi) * dt + std * math.sqrt(dt) * xis[t] 140 | bm_xis.append(xi) 141 | assert len(bm_xis) == xis.shape[0] 142 | 143 | xis = torch.Tensor(np.stack(bm_xis)) 144 | xis /= torch.linalg.norm(xis, dim=-1, keepdim=True) 145 | 146 | # Just safeguard if the self.xis becomes different. 147 | print("USING BM XI! xis.sum(): {}".format(torch.sum(xis))) 148 | assert xis.shape == (T, D) 149 | return xis 150 | 151 | 152 | @torch.no_grad() 153 | def proj_pca(xs_f: torch.Tensor): 154 | """ 155 | xs_f: (B, T, D) 156 | === 157 | proj_xs_f: (B, T, 2) 158 | V: (D, 2) s.t. proj_xs = xs @ V 159 | """ 160 | # xs_f: (batch, T, nx) 161 | # Only use final timestep of xs_f for PCA. 162 | batch, T, nx = xs_f.shape 163 | 164 | # (batch * T, nx) 165 | flat_xsf = xs_f.reshape(-1, *xs_f.shape[2:]) 166 | 167 | # Center by subtract mean. 168 | # (batch, nx) 169 | final_xs_f = xs_f[:, -1, :] 170 | 171 | mean_pca_xs = torch.mean(final_xs_f, dim=0, keepdim=True) 172 | final_xs_f -= mean_pca_xs 173 | 174 | # if batch is too large, it will run out of memory. 175 | if batch > 200: 176 | rand_idxs = torch.randperm(batch)[:200] 177 | final_xs_f = final_xs_f[rand_idxs] 178 | 179 | # U: (batch, k) 180 | # S: (k, k) 181 | # VT: (k, nx) 182 | U, S, VT = torch.linalg.svd(final_xs_f) 183 | 184 | # log.info("Singular values of xs_f at final timestep:") 185 | # log.info(S) 186 | 187 | # Keep the first and last directions. 188 | VT = VT[:2, :] 189 | # VT = VT[[0, -1], :] 190 | 191 | assert VT.shape == (2, nx) 192 | V = VT.T 193 | 194 | # Project both xs_f and xs_b onto V. 195 | flat_xsf -= mean_pca_xs 196 | 197 | proj_xs_f = flat_xsf @ V 198 | proj_xs_f = proj_xs_f.reshape(batch, T, *proj_xs_f.shape[1:]) 199 | 200 | return proj_xs_f, V 201 | 202 | 203 | class PolarizeDyn(torch.nn.Module): 204 | def __init__(self, pcfg) -> None: 205 | super(PolarizeDyn, self).__init__() 206 | 207 | self.S = pcfg.S 208 | self.D = pcfg.D 209 | self.polarize_strength = pcfg.strength 210 | self.register_buffer("xis", build_xis(pcfg.S, pcfg.D)) 211 | self.register_buffer("f_muls", build_f_mul(pcfg.S, coeff=pcfg.m_coeff)) 212 | self.register_buffer("mf_drift", torch.zeros(pcfg.S, pcfg.D)) 213 | self.register_buffer("is_mf_drift_set", torch.tensor(False)) 214 | 215 | def set_mf_drift(self, mf_xs): 216 | """run on cpu to prevent OOM 217 | mf_xs: (B, S, D) 218 | """ 219 | assert mf_xs.shape[1:] == (self.S, self.D) 220 | 221 | xis = self.xis.detach().cpu() 222 | mf_drift = compute_mean_drift_term(mf_xs, xis) 223 | assert mf_drift.shape == (self.S, self.D) 224 | 225 | self.mf_drift = mf_drift.to(self.xis) 226 | self.is_mf_drift_set = torch.tensor(True).to(self.xis) 227 | 228 | def forward(self, xs, t): 229 | """ 230 | xs: (*, T, D) 231 | t: (T,) 232 | === 233 | out: (*, T, D) 234 | """ 235 | Bs, (T, D) = xs.shape[:-2], xs.shape[-2:] 236 | assert t.shape == (T,) 237 | 238 | xs = xs.reshape(-1, T, D) 239 | 240 | t_idx = t_to_idx(t, self.S) 241 | fmul = self.f_muls[t_idx].to(xs) 242 | assert t_idx.shape == fmul.shape == (T,) 243 | 244 | xi = self.xis[t_idx].to(xs) 245 | if self.is_mf_drift_set: 246 | mf_drift = self.mf_drift[t_idx].to(xs) 247 | else: 248 | mf_drift = compute_mean_drift_term(xs, xi) 249 | assert xi.shape == mf_drift.shape == (T, D) 250 | 251 | f = self.polarize_strength * opinion_f(xs, mf_drift, xi) 252 | assert f.shape == xs.shape 253 | 254 | f = fmul.reshape(1, -1, 1) * f 255 | assert f.shape == xs.shape 256 | 257 | return f 258 | -------------------------------------------------------------------------------- /gsbm/path_integral.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import torch 4 | 5 | from .sde import DIRECTIONS, sdeint 6 | 7 | from ipdb import set_trace as debug 8 | 9 | 10 | def t_to_idx(t: torch.Tensor, T: int) -> torch.Tensor: 11 | return (t * (T - 1)).round().long() 12 | 13 | 14 | def impt_weight_fn(s, xs, us, ws, V, sigma): 15 | """ 16 | s: (S,) 17 | xs: (B, N, S, D) 18 | us: (B, N, S, D) 19 | ws: (B, N, S, D) 20 | === 21 | weights: (B, N) 22 | """ 23 | 24 | B, N, S, D = xs.shape 25 | assert s.shape == (S,) 26 | assert (s < 1).all() and (s > 0).all() 27 | assert us.shape == ws.shape == (B, N, S, D) 28 | assert sigma > 0 29 | 30 | dt = torch.cat([s[[1]] - s[[0]], s[1:] - s[:-1]]).reshape(1, 1, -1) 31 | assert dt.shape == (1, 1, S) and (dt > 0).all() 32 | 33 | state_cost = V(xs, s) * dt 34 | control_cost = (0.5 / sigma**2) * (us**2).sum(dim=-1) * dt 35 | girsanov_cost = (1.0 / sigma) * (us * ws).sum(dim=-1) 36 | assert state_cost.shape == control_cost.shape == girsanov_cost.shape == (B, N, S) 37 | 38 | total_cost = (state_cost + control_cost + girsanov_cost).mean(dim=-1) # (B, N) 39 | total_cost = total_cost - total_cost.min(dim=-1, keepdim=True)[0] # (B, 1) 40 | assert total_cost.shape == (B, N) 41 | # print(state_cost.abs().max(), control_cost.abs().max(), girsanov_cost.abs().max()) 42 | 43 | weights = torch.exp(-total_cost) 44 | weights = weights / weights.sum(dim=-1, keepdim=True) 45 | assert weights.shape == (B, N) 46 | return weights 47 | 48 | 49 | def impt_sample_xs(ccfg, gpath, sigma, direction, V, eps=0.001): 50 | assert direction in DIRECTIONS and sigma > 0 51 | 52 | B, D, N, S = gpath.B, gpath.D, ccfg.IW_N, ccfg.IW_S 53 | device = gpath.device 54 | 55 | tinit = eps if direction == "fwd" else (1.0 - eps) 56 | xinit = gpath.sample_xt(tinit * torch.ones(1, device=device), N=N) # (B, N, 1, D) 57 | xinit = xinit.reshape(B * N, D) 58 | 59 | def drift(xt, t): 60 | """ 61 | xt: (B*N, D) 62 | t: (B*N,) 63 | === 64 | ut: (B*N, D) 65 | """ 66 | assert torch.allclose(t, t[0] * torch.ones_like(t)) 67 | _t = t[0].reshape(1) 68 | _xt = xt.reshape(B, N, 1, D) 69 | ut = gpath.ut(_t, _xt, direction) # drift - ft 70 | ft = gpath.ft(_t, _xt, direction) 71 | assert _xt.shape == ut.shape == ft.shape 72 | 73 | return (ut + ft).reshape(B * N, D) 74 | 75 | diffusion = lambda x, t: sigma 76 | out = sdeint( 77 | xinit, 78 | drift, 79 | diffusion, 80 | direction, 81 | nfe=S - 1, 82 | log_steps=S, 83 | eps=eps, 84 | return_ws=True, 85 | ) 86 | s, xs, us, ws = out["t"], out["xs"], out["us"], out["ws"] 87 | 88 | xs = xs.reshape(B, N, S, D) 89 | us = us.reshape(B, N, S, D) 90 | ws = ws.reshape(B, N, S, D) 91 | 92 | VV = lambda x, t: V(x, t, gpath) 93 | weights = impt_weight_fn(s, xs, us, ws, VV, sigma) 94 | return {"IW_t": s, "IW_xs": xs, "weights": weights} 95 | 96 | 97 | def impt_weighted(t, xs, weights): 98 | """ 99 | t: (T,) 100 | xs: (B, N, S, D) 101 | weights: (B, N) 102 | === 103 | out: (B, T, D) 104 | """ 105 | (T,), (B, N, S, D) = t.shape, xs.shape 106 | assert weights.shape == (B, N) 107 | 108 | permvector = torch.multinomial(weights, T, replacement=True) 109 | permvector = permvector.unsqueeze(-1).expand(-1, -1, D) 110 | assert permvector.shape == (B, T, D) 111 | 112 | ## subsample time grid 113 | ys = xs[:, :, t_to_idx(t, S)] 114 | assert ys.shape == (B, N, T, D) 115 | 116 | # https://github.com/pytorch/pytorch/issues/30574#issuecomment-1199665661 117 | yt = ys.gather(1, permvector.unsqueeze(1)).squeeze(1) 118 | assert yt.shape == (B, T, D) 119 | return yt 120 | -------------------------------------------------------------------------------- /gsbm/pl_model.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | from typing import Any, List 4 | import os 5 | import numpy as np 6 | import math 7 | from datetime import datetime 8 | from rich.console import Console 9 | from easydict import EasyDict as edict 10 | import copy 11 | 12 | import torch 13 | import pytorch_lightning as pl 14 | from torch.utils.data import DataLoader 15 | import torchvision.utils as tu 16 | 17 | from .network import build_net 18 | from .state_cost import build_state_cost 19 | from .evaluator import build_evaluator 20 | from .sde import build_basedrift, sdeint 21 | 22 | from .dataset import PairDataset, SplineDataset, SplineIWDataset 23 | from . import gaussian_path as gpath_lib 24 | from . import path_integral as pi_lib 25 | from . import match_loss as match_lib 26 | from . import utils 27 | 28 | from .plotting import ( 29 | save_fig, 30 | save_xs, 31 | plot_gpath, 32 | plot_iw, 33 | plot_boundaries, 34 | plot_xs_opinion, 35 | ) 36 | 37 | # put gc.collect after io writing to prevent c10::CUDAError in multi-threading 38 | # https://github.com/pytorch/pytorch/issues/67978#issuecomment-1661986812 39 | import gc 40 | from ipdb import set_trace as debug 41 | 42 | console = Console() 43 | 44 | 45 | class GSBMLitModule(pl.LightningModule): 46 | def __init__(self, cfg, p0, p1, p0_val, p1_val): 47 | super().__init__() 48 | 49 | os.makedirs("figs", exist_ok=True) 50 | 51 | self.cfg = cfg 52 | self.p0 = p0 53 | self.p1 = p1 54 | self.p0_val = p0_val 55 | self.p1_val = p1_val 56 | 57 | ### Problem 58 | self.sigma = cfg.prob.sigma 59 | self.V = build_state_cost(cfg) 60 | self.basedrift = build_basedrift(cfg) 61 | self.evaluator = build_evaluator(cfg) 62 | 63 | ### SB Model 64 | self.direction = None 65 | self.fwd_net = build_net(cfg) 66 | self.bwd_net = build_net(cfg) 67 | 68 | def print(self, content, prefix=True): 69 | if self.trainer.is_global_zero: 70 | if prefix: 71 | now = f"[[cyan]{datetime.now():%Y-%m-%d %H:%M:%S}[/cyan]]" 72 | if self.direction is None: 73 | base = f"[[blue]Init[/blue]] " 74 | else: 75 | base = f"[[blue]Ep {self.current_epoch} ({self.direction})[/blue]] " 76 | console.print(now, highlight=False, end="") 77 | console.print(base, end="") 78 | console.print(f"{content}") 79 | 80 | @property 81 | def wandb_logger(self): 82 | ## assume wandb is added to the end of loggers 83 | return self.loggers[-1] 84 | 85 | @property 86 | def is_img_prob(self): 87 | return self.cfg.prob.name in [ 88 | "afhq", 89 | ] 90 | 91 | @property 92 | def logging_batch_idxs(self): 93 | return np.linspace(0, self.trainer.num_training_batches - 1, 10).astype(int) 94 | 95 | @property 96 | def ocfg(self): 97 | return self.cfg.optim 98 | 99 | @property 100 | def ccfg(self): 101 | return self.cfg.csoc 102 | 103 | @property 104 | def mcfg(self): 105 | return self.cfg.matching 106 | 107 | @property 108 | def pcfg(self): 109 | if self.cfg.prob.name == "lidar": 110 | pcfg = edict(self.cfg.plot) 111 | pcfg.dataset = self.V.dataset 112 | return pcfg 113 | return self.cfg.plot 114 | 115 | @property 116 | def device(self): 117 | return self.fwd_net.parameters().__next__().device 118 | 119 | @property 120 | def net(self): 121 | return self.fwd_net if self.direction == "fwd" else self.bwd_net 122 | 123 | @property 124 | def direction_r(self): 125 | return "bwd" if self.direction == "fwd" else "fwd" 126 | 127 | def build_ft(self, direction): 128 | def ft(x, t): 129 | """ 130 | x: (B, D) 131 | t: (B,) 132 | === 133 | out: (B, D) 134 | """ 135 | B, D = x.shape 136 | sign = 1.0 if direction == "fwd" else -1.0 137 | assert t.shape == (B,) and torch.allclose(t, t[0] * torch.ones_like(t)) 138 | return sign * self.basedrift(x.unsqueeze(1), t[0].reshape(1)).squeeze(1) 139 | 140 | return ft 141 | 142 | def build_ut(self, direction, backprop_snet=False): 143 | """ 144 | ut: x: (B, D), t: (B,) --> (B, D) 145 | """ 146 | net = self.fwd_net if direction == "fwd" else self.bwd_net 147 | if self.cfg.field == "vector": 148 | ut = net 149 | elif self.cfg.field == "potential": 150 | 151 | def ut(x, t): 152 | with torch.enable_grad(): 153 | x = x.detach().clone() 154 | x.requires_grad_(True) 155 | out = net(x, t) 156 | return torch.autograd.grad( 157 | out.sum(), x, create_graph=backprop_snet 158 | )[0] 159 | 160 | else: 161 | ValueError(f"Unsupportted field: {self.cfg.field}!") 162 | return ut 163 | 164 | def build_drift(self, direction, backprop_snet=False): 165 | ft = self.build_ft(direction) 166 | ut = self.build_ut(direction, backprop_snet=backprop_snet) 167 | drift = lambda x, t: ut(x, t) + ft(x, t) 168 | return drift 169 | 170 | @torch.no_grad() 171 | def sample(self, xinit, log_steps, direction, drift=None, nfe=None, verbose=False): 172 | drift = self.build_drift(direction) if drift is None else drift 173 | diffusion = lambda x, t: self.sigma 174 | nfe = nfe or self.cfg.nfe 175 | output = sdeint( 176 | xinit, 177 | drift, 178 | diffusion, 179 | direction, 180 | nfe=nfe, 181 | log_steps=log_steps, 182 | verbose=verbose, 183 | ) 184 | return output 185 | 186 | def sample_t(self, batch): 187 | if self.mcfg.loss == "eam": 188 | t0 = torch.rand(1) 189 | t = (t0 + math.sqrt(2) * torch.arange(batch)) % 1 190 | t.clamp_(min=0.001, max=0.999) 191 | elif self.mcfg.loss == "bm": 192 | eps = 1e-4 193 | t = torch.rand(batch).reshape(-1) * (1 - 2 * eps) + eps 194 | else: 195 | raise ValueError(f"Unsupported matching loss option: {self.mcfg.loss}!") 196 | 197 | assert t.shape == (batch,) 198 | return t 199 | 200 | def sample_gpath(self, batch): 201 | ### Setup 202 | gpath = gpath_lib.EndPointGaussianPath( 203 | batch["mean_t"][0], 204 | batch["mean_xt"], 205 | batch["gamma_s"][0], 206 | batch["gamma_xs"], 207 | self.sigma, 208 | self.basedrift, 209 | ) 210 | x0, x1 = batch["x0"], batch["x1"] 211 | B, D = x0.shape 212 | 213 | ### Sample t and xt 214 | T = B if self.mcfg.loss == "bm" else self.mcfg.batch_t 215 | if not self.ccfg.IW: 216 | t = self.sample_t(T).to(x0) 217 | with torch.no_grad(): 218 | xt = gpath.sample_xt(t, N=1) 219 | else: 220 | IW_t, IW_xs, weights = batch["IW_t"][0], batch["IW_xs"], batch["weights"] 221 | 222 | # weights should be positive and self-normalized 223 | assert (weights > 0).all() 224 | assert torch.allclose(weights.sum(dim=1), torch.ones(B).to(weights)) 225 | 226 | rand_idx = torch.randint(low=0, high=len(IW_t), size=(T,)) 227 | t = IW_t[rand_idx] 228 | xt = pi_lib.impt_weighted(t, IW_xs, weights).unsqueeze(1) 229 | assert t.shape == (T,) and xt.shape == (B, 1, T, D) 230 | 231 | ### Sample vt and build output 232 | if self.mcfg.loss == "bm": 233 | assert B == T 234 | vt = gpath.ut(t, xt, self.direction) 235 | xt = xt[torch.arange(B), 0, torch.arange(B)] 236 | vt = vt[torch.arange(B), 0, torch.arange(B)] 237 | assert xt.shape == vt.shape == (B, D) 238 | 239 | elif self.mcfg.loss == "eam": 240 | vt = None 241 | xt = xt.squeeze(1) 242 | assert xt.shape == (B, T, D) 243 | 244 | return x0, x1, t, xt, vt 245 | 246 | def training_step(self, batch: Any, batch_idx: int): 247 | ### Sample from Gaussian path 248 | x0, x1, t, xt, vt = self.sample_gpath(batch) 249 | 250 | ### Apply bridge matching orr entropic action matching 251 | if self.mcfg.loss == "bm": 252 | ut = self.build_ut(self.direction, backprop_snet=True) 253 | loss = match_lib.bm_loss(ut, xt, t, vt) 254 | elif self.mcfg.loss == "eam": 255 | loss = match_lib.eam_loss_trajs( 256 | self.net, 257 | xt, 258 | t, 259 | x0, 260 | x1, 261 | self.sigma, 262 | self.direction, 263 | lap=self.mcfg.lap, 264 | ) 265 | else: 266 | raise ValueError(f"Unsupported match_loss option: {self.mcfg.loss}!") 267 | 268 | if torch.isfinite(loss): 269 | self.log("train/loss", loss, on_step=True, on_epoch=True) 270 | else: 271 | ### Skip step if loss is NaN. 272 | self.print(f"Skipping iteration because loss is {loss.item()}.") 273 | return None 274 | 275 | if batch_idx in self.logging_batch_idxs: 276 | self.print( 277 | f"[M-step] batch idx: {batch_idx+1}/{self.trainer.num_training_batches} ..." 278 | ) 279 | 280 | return {"loss": loss} 281 | 282 | def training_epoch_end(self, outputs: List[Any]): 283 | ### ** The only place where we modify the direction!! ** 284 | self.direction = self.direction_r 285 | self.print("", prefix=False) # change line 286 | 287 | def localize(self, p): 288 | g = torch.Generator() 289 | g.manual_seed(g.seed() + self.global_rank) 290 | local_p = copy.deepcopy(p) 291 | local_p.set_generator(g) 292 | return local_p 293 | 294 | def val_dataloader(self): 295 | totalB, n_device = self.ccfg.B, utils.n_device() 296 | B = totalB // n_device 297 | self.print(f"[Data] Building {totalB} train_data ...") 298 | self.print( 299 | f"[Data] Found {n_device} devices, each will generate {B} samples ..." 300 | ) 301 | 302 | x0 = self.localize(self.p0)(B) 303 | x1 = self.localize(self.p1)(B) 304 | return DataLoader( 305 | PairDataset(x0, x1), 306 | num_workers=self.ocfg.num_workers, 307 | batch_size=self.ccfg.mB, 308 | persistent_workers=self.ocfg.num_workers > 0, 309 | shuffle=False, 310 | pin_memory=True, 311 | ) 312 | 313 | def compute_coupling(self, batch, direction, eval_coupling): 314 | x0, x1, T = batch["x0"], batch["x1"], self.ccfg.T_mean 315 | if direction is None: 316 | t = torch.linspace(0, 1, T).to(x0) 317 | xt = (1 - t[None, :, None]) * x0.unsqueeze(1) + t[ 318 | None, :, None 319 | ] * x1.unsqueeze(1) 320 | else: 321 | xinit = x0 if direction == "fwd" else x1 322 | output = self.sample(xinit, log_steps=T, direction=direction) 323 | t, xt = output["t"], output["xs"] 324 | 325 | if eval_coupling: 326 | metrics = self.evaluator(output) 327 | for k, v in metrics.items(): 328 | self.log(f"metrics/{k}", v, on_epoch=True) 329 | 330 | return t, xt 331 | 332 | def validation_step(self, batch: Any, batch_idx: int): 333 | log_step = batch_idx == 0 334 | ccfg, direction = self.ccfg, self.direction 335 | postfix = f"{self.current_epoch:03d}" if direction is not None else "init" 336 | 337 | (B, D), T, S, sigma = batch["x0"].shape, ccfg.T_mean, ccfg.T_gamma, self.sigma 338 | 339 | ### Initialize mean spline (with copuling) 340 | eval_coupling = log_step and self.cfg.eval_coupling 341 | self.print(f"[R-step] Simulating {direction or 'init'} coupling ...") 342 | t, xt = self.compute_coupling(batch, direction, eval_coupling) 343 | self.print(f"[R-step] Simulated {xt.shape=}!") 344 | if log_step: 345 | self.log_coupling(t, xt, direction, f"coupling-{postfix}") 346 | assert xt.shape == (B, T, D) and t.shape == (T,) 347 | 348 | ### Initialize std spline 349 | s = torch.linspace(0, 1, S).to(t) 350 | ys = torch.zeros(B, S, 1).to(xt) 351 | 352 | ### Fit Gaussian paths (update xt, ys) 353 | gpath = gpath_lib.EndPointGaussianPath(t, xt, s, ys, sigma, self.basedrift) 354 | if self.is_img_prob: 355 | loss_fn = gpath_lib.build_img_loss_fn(gpath, sigma, self.V, ccfg) 356 | else: 357 | loss_fn = gpath_lib.build_loss_fn(gpath, sigma, self.V, ccfg) 358 | with torch.enable_grad(): 359 | verbose = log_step and self.trainer.is_global_zero 360 | result = gpath_lib.fit( 361 | ccfg, gpath, direction or "fwd", loss_fn, verbose=verbose 362 | ) 363 | self.print(f"[R-step] Fit {B} gaussian paths!") 364 | if log_step: 365 | self.log_gpath(result, f"gpath-{postfix}") 366 | 367 | ### Built output 368 | xt = gpath.mean.xt.detach().clone() 369 | ys = gpath.gamma.xt.detach().clone() 370 | assert xt.shape == (B, T, D) and ys.shape == (B, S, 1) 371 | output = {"mean_t": t, "mean_xt": xt, "gamma_s": s, "gamma_xs": ys} 372 | 373 | ### (optional) Handle important weighting 374 | if ccfg.IW: 375 | with torch.no_grad(): 376 | iw_output = pi_lib.impt_sample_xs( 377 | ccfg, gpath, sigma, direction or "fwd", V=self.V 378 | ) 379 | output.update(iw_output) 380 | self.print(f"[R-step] Compute IW.shape={iw_output['weights'].shape}!") 381 | if log_step: 382 | self.log_IW(iw_output, f"iw-{postfix}") 383 | 384 | ### (optional) Handle opinion drift 385 | if ccfg.name == "opinion": 386 | tt = torch.linspace(0, 1, self.cfg.pdrift.S).to(t) 387 | mf_x = gpath.sample_xt(tt, N=1).squeeze(1) 388 | assert mf_x.shape == (B, len(tt), D) 389 | output["mf_x"] = mf_x.detach().cpu() 390 | 391 | return output 392 | 393 | def validation_epoch_end(self, outputs: List[Any]): 394 | ### Handle opinion drift 395 | if self.cfg.prob.name == "opinion": 396 | mf_xs = utils.gather(outputs, "mf_x") 397 | # if utils.is_distributed(): 398 | # mf_xs = utils.all_gather(mf_xs) 399 | self.basedrift.set_mf_drift(mf_xs) 400 | self.print(f"[Opinion] Set MF drift shape={mf_xs.shape}!") 401 | 402 | ccfg = self.ccfg 403 | T, S, D = ccfg.T_mean, ccfg.T_gamma, self.cfg.dim 404 | 405 | ## gather mean_t, gamma_s 406 | mean_t = outputs[0]["mean_t"].detach().cpu() 407 | gamma_s = outputs[0]["gamma_s"].detach().cpu() 408 | assert mean_t.shape == (T,) and gamma_s.shape == (S,) 409 | 410 | ## gather mean_xt, gamma_xs 411 | mean_xt = utils.gather(outputs, "mean_xt") 412 | gamma_xs = utils.gather(outputs, "gamma_xs") 413 | B = mean_xt.shape[0] 414 | assert mean_xt.shape == (B, T, D) 415 | assert gamma_xs.shape == (B, S, 1) 416 | 417 | self.train_data = SplineDataset( 418 | mean_t, mean_xt, gamma_s, gamma_xs, expand_factor=ccfg.epd_fct 419 | ) 420 | self.print(f"[Data] Fit total {B} gaussian paths as train_data!") 421 | 422 | ### (optional) IW 423 | if ccfg.IW: 424 | iN, iS = ccfg.IW_N, ccfg.IW_S 425 | IW_t = outputs[0]["IW_t"].detach().cpu() 426 | IW_xs = utils.gather(outputs, "IW_xs") 427 | weights = utils.gather(outputs, "weights") 428 | assert IW_t.shape == (iS,) and IW_xs.shape == (B, iN, iS, D) 429 | assert weights.shape == (B, iN) 430 | self.print(f"[Data] Computed important {weights.shape=} as train_data!") 431 | self.train_data = SplineIWDataset(self.train_data, IW_t, IW_xs, weights) 432 | 433 | if self.direction is None: 434 | self.direction = "fwd" 435 | self.print("", prefix=False) # change line 436 | 437 | torch.cuda.empty_cache() 438 | 439 | def train_dataloader(self): 440 | dataloader = DataLoader( 441 | self.train_data, 442 | num_workers=self.ocfg.num_workers, 443 | batch_size=self.ocfg.batch_size, 444 | persistent_workers=self.ocfg.num_workers > 0, 445 | shuffle=True, 446 | pin_memory=True, 447 | drop_last=True, 448 | ) 449 | return dataloader 450 | 451 | def configure_optimizers(self): 452 | optimizer = torch.optim.AdamW( 453 | self.parameters(), 454 | lr=self.ocfg.lr, 455 | weight_decay=self.ocfg.wd, 456 | eps=self.ocfg.eps, 457 | ) 458 | 459 | if self.ocfg.get("scheduler", "cosine") == "cosine": 460 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 461 | optimizer, 462 | T_max=self.ocfg.num_iterations, 463 | ) 464 | return { 465 | "optimizer": optimizer, 466 | "lr_scheduler": { 467 | "scheduler": scheduler, 468 | "interval": "step", 469 | }, 470 | } 471 | else: 472 | return { 473 | "optimizer": optimizer, 474 | } 475 | 476 | def optimizer_step(self, *args, **kwargs): 477 | super().optimizer_step(*args, **kwargs) 478 | self.net.update_ema() 479 | 480 | def log_coupling(self, t, xs, direction, fn, log_steps=5): 481 | """ 482 | t: (T,) xs: (B, T, D) 483 | """ 484 | B, T, D = xs.shape 485 | assert t.shape == (T,) 486 | 487 | if self.is_img_prob: 488 | # subsample B & T to 10 & S 489 | mB = 10 # mini-batch 490 | xs = xs[:mB][:, np.linspace(0, T - 1, log_steps).astype(int)] 491 | xs = xs.reshape(mB * log_steps, D) 492 | self.log_images(xs, log_steps, fn, "viz") 493 | else: 494 | save_xs(t, xs, log_steps, direction, self.pcfg, fn) 495 | 496 | gc.collect() 497 | 498 | def log_gpath(self, result, fn): 499 | plot_gpath(result, self.pcfg) 500 | save_fig(fn) 501 | gc.collect() 502 | 503 | if self.is_img_prob: 504 | B, T, D = result["init_mean"].shape 505 | assert T < 10 506 | 507 | self.log_images(result["init_mean"][:10], T, f"{fn}-init_mean", "init_mean") 508 | self.log_images( 509 | result["final_mean"][:10], T, f"{fn}-final_mean", "final_mean" 510 | ) 511 | 512 | def log_IW(self, result, fn): 513 | IW_t, IW_xs, ws = result["IW_t"], result["IW_xs"], result["weights"] 514 | result["IW_xt"] = pi_lib.impt_weighted(IW_t, IW_xs, ws) 515 | plot_iw(result, self.pcfg) 516 | save_fig(fn) 517 | gc.collect() 518 | 519 | def log_images(self, x, T, fn, key=None): 520 | """ 521 | x: (B, 3*D*D) --> grid images: (3, (B//T)*D, T*D) 522 | """ 523 | images = x.reshape(-1, *self.cfg.image_size) 524 | images = torch.clamp((images + 1) / 2, 0.0, 1.0).cpu() 525 | tu.save_image(images, f"figs/{fn}.png", nrow=T, pad_value=1.0) 526 | 527 | if key is not None and self.cfg.use_wandb: 528 | grid_image = tu.make_grid(images, nrow=T, pad_value=1.0) 529 | self.wandb_logger.log_image(key=f"images/{key}", images=[grid_image]) 530 | 531 | gc.collect() 532 | 533 | def log_boundary(self, p0, p1, p0_val, p1_val): 534 | if self.is_img_prob: 535 | self.log_images(p0(64), 8, "init-p0") 536 | self.log_images(p1(64), 8, "init-p1") 537 | self.log_images(p0_val(64), 8, "init-p0-val") 538 | self.log_images(p1_val(64), 8, "init-p1-val") 539 | else: 540 | plot_boundaries(p0, p1, self.pcfg) 541 | save_fig("train_dist") 542 | plot_boundaries(p0_val, p1_val, self.pcfg) 543 | save_fig("val_dist") 544 | gc.collect() 545 | 546 | def log_basedrift(self, p0): 547 | ft = self.build_ft("fwd") 548 | result = self.sample(p0(512), log_steps=5, direction="fwd", drift=ft, nfe=500) 549 | plot_xs_opinion(result["t"], result["xs"], 5, "Init", self.pcfg) 550 | save_fig("ft") 551 | -------------------------------------------------------------------------------- /gsbm/plotting.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import numpy as np 4 | 5 | import matplotlib 6 | from matplotlib import pyplot as plt 7 | from matplotlib import cm 8 | from matplotlib.patches import Circle, Ellipse, Rectangle 9 | 10 | import torch 11 | 12 | from .state_cost import ( 13 | obstacle_cfg_gmm, 14 | obstacle_cfg_vneck, 15 | obstacle_cfg_stunnel, 16 | obstacle_cfg_drunken_spider, 17 | ) 18 | from .opinion import est_directional_similarity, proj_pca 19 | 20 | from ipdb import set_trace as debug 21 | 22 | cmap = "Greens" 23 | fontsize = 10 24 | 25 | plt.rcParams.update({"font.size": fontsize}) 26 | 27 | cpuize = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t 28 | 29 | ################################################################################################ 30 | 31 | 32 | def get_fig_axes(ncol, nrow=1, ax_length_in=2.0, lim=None): 33 | figsize = (ncol * ax_length_in, nrow * ax_length_in) 34 | fig = plt.figure(figsize=figsize) 35 | axes = fig.subplots(nrow, ncol) 36 | 37 | if lim is not None: 38 | axs = [axes] if nrow == 1 and ncol == 1 else axes.reshape(-1) 39 | for ax in axs: 40 | ax.set(xlim=[-lim, lim], ylim=[-lim, lim]) 41 | 42 | return fig, axes 43 | 44 | 45 | def save_fig(fn, pdf=False): 46 | plt.tight_layout() 47 | if pdf: 48 | plt.savefig(f"figs/{fn}.pdf") 49 | else: 50 | plt.savefig(f"figs/{fn}.png", dpi=300) 51 | plt.close() 52 | 53 | 54 | def get_colors(n_snapshot, cmap=cmap): 55 | cm1 = cm.get_cmap(cmap) 56 | colors = cm1(np.linspace(0.2, 0.8, n_snapshot)) 57 | return colors 58 | 59 | 60 | @torch.no_grad() 61 | def plot_scatter(ax, x, s=2, c=None, zorder=0, marker=None, title=None, alpha=1.0): 62 | """ 63 | x: (B, 2) 64 | """ 65 | x = cpuize(x) 66 | ax.scatter(x[:, 0], x[:, 1], s=s, c=c, marker=marker, zorder=zorder, alpha=alpha) 67 | if title: 68 | ax.set_title(title) 69 | 70 | 71 | @torch.no_grad() 72 | def plot_traj(ax, xs, title=None, **kwargs): 73 | """ 74 | xs: (B, T, D) 75 | """ 76 | for x in xs: 77 | x = cpuize(x) 78 | ax.plot(x[:, 0], x[:, 1], **kwargs) 79 | if title: 80 | ax.set_title(title) 81 | 82 | 83 | @torch.no_grad() 84 | def plot_boundaries(p0, p1, pcfg): 85 | fig, axs = get_fig_axes(ncol=2, lim=pcfg.lim) 86 | 87 | plot_scatter(axs[0], p0(512)) 88 | plot_scatter(axs[1], p1(512)) 89 | 90 | axs[0].set_title(r"$\mu$ at $t$=0", fontsize=fontsize) 91 | axs[1].set_title(r"$\nu$ at $t$=1", fontsize=fontsize) 92 | 93 | plot_obstacles(axs[0], pcfg.name) 94 | plot_obstacles(axs[1], pcfg.name) 95 | 96 | 97 | def plot_obstacles(ax, name, zorder=0): 98 | if name == "gmm": 99 | centers, radius = obstacle_cfg_gmm() 100 | for c in centers: 101 | circle = Circle(xy=np.array(c), radius=radius, zorder=zorder) 102 | 103 | ax.add_artist(circle) 104 | circle.set_clip_box(ax.bbox) 105 | circle.set_facecolor("darkgray") 106 | circle.set_edgecolor(None) 107 | 108 | elif name == "vneck": 109 | c_sq, coef = obstacle_cfg_vneck() 110 | x = np.linspace(-6, 6, 100) 111 | y1 = np.sqrt(c_sq + coef * np.square(x)) 112 | y2 = np.ones_like(x) * y1[0] 113 | 114 | ax.fill_between(x, y1, y2, color="darkgray", edgecolor=None, zorder=zorder) 115 | ax.fill_between(x, -y1, -y2, color="darkgray", edgecolor=None, zorder=zorder) 116 | 117 | elif name == "stunnel": 118 | a, b, cc, centers = obstacle_cfg_stunnel() 119 | for c in centers: 120 | elp = Ellipse( 121 | xy=np.array(c), 122 | width=2 * np.sqrt(cc / a), 123 | height=2 * np.sqrt(cc / b), 124 | zorder=zorder, 125 | ) 126 | 127 | ax.add_artist(elp) 128 | elp.set_clip_box(ax.bbox) 129 | elp.set_facecolor("darkgray") 130 | elp.set_edgecolor(None) 131 | 132 | elif name == "drunken_spider": 133 | xys, widths, heights = obstacle_cfg_drunken_spider() 134 | 135 | for xy, width, height in zip(xys, widths, heights): 136 | rec = Rectangle(xy=xy, width=width, height=height, zorder=0) 137 | ax.add_artist(rec) 138 | rec.set_clip_box(ax.bbox) 139 | rec.set_facecolor("darkgray") 140 | rec.set_edgecolor(None) 141 | 142 | 143 | @torch.no_grad() 144 | def plot_directional_sim(ax, xt): 145 | """ 146 | xt: (B, 2) 147 | """ 148 | 149 | B, D = xt.shape 150 | assert D == 2 151 | 152 | n_est = 5000 153 | directional_sim = est_directional_similarity(xt, n_est) 154 | assert directional_sim.shape == (n_est,) 155 | 156 | directional_sim = directional_sim.detach().cpu().numpy() 157 | 158 | bins = 15 159 | _, _, patches = ax.hist( 160 | directional_sim, 161 | bins=bins, 162 | ) 163 | 164 | colors = plt.cm.coolwarm(np.linspace(1.0, 0.0, bins)) 165 | 166 | for c, p in zip(colors, patches): 167 | plt.setp(p, "facecolor", c) 168 | 169 | ymax = 1000 if xt.shape[1] == 2 else 2000 170 | ax.relim() 171 | ax.autoscale() 172 | ax.set_ylim(0, ymax) 173 | ax.set_xlim(0, 1) 174 | ax.set_xticks([]) 175 | ax.set_yticks([]) 176 | ax.set_xticks([], minor=True) 177 | ax.set_yticks([], minor=True) 178 | 179 | 180 | def show_image(images, ncol=10): 181 | 182 | images = torch.clamp((images + 1) / 2, 0.0, 1.0) 183 | 184 | n = len(images) 185 | nrow = n // ncol + (1 if n % ncol > 0 else 0) 186 | assert ncol * nrow >= n 187 | 188 | fig, axs = get_fig_axes(nrow=nrow, ncol=ncol, ax_length_in=1) 189 | for ax in axs.reshape(-1): 190 | ax.set_xticks([]) 191 | ax.set_yticks([]) 192 | ax.spines["top"].set_visible(False) 193 | ax.spines["right"].set_visible(False) 194 | ax.spines["bottom"].set_visible(False) 195 | ax.spines["left"].set_visible(False) 196 | 197 | images = cpuize(images) 198 | for ax, image in zip(axs.reshape(-1), images): 199 | ax.imshow(image.permute(1, 2, 0).numpy()) 200 | plt.tight_layout() 201 | 202 | 203 | ################################################################################################ 204 | ################################## Plot Gaussain path (Alg 3) ################################## 205 | ################################################################################################ 206 | 207 | 208 | @torch.no_grad() 209 | def plot_gpath(result, pcfg): 210 | if result["gpath"].D == 2: 211 | # crowd nav, opinion 2D 212 | plot_gpath_2d(result, pcfg) 213 | elif pcfg.name == "lidar": 214 | plot_gpath_lidar(result, pcfg) 215 | else: 216 | plot_gpath_nd(result) 217 | 218 | 219 | @torch.no_grad() 220 | def plot_lidar(ax, dataset, xs, S=5): 221 | B, T, D = xs.shape 222 | 223 | # Plot the surface. 224 | ax.scatter( 225 | dataset[:, 0], 226 | dataset[:, 1], 227 | dataset[:, 2], 228 | s=0.3, 229 | c=dataset[:, 2], 230 | cmap="viridis_r", 231 | alpha=1.0, 232 | ) 233 | ax.axes.set_xlim3d(left=-4.8, right=4.8) 234 | ax.axes.set_ylim3d(bottom=-4.8, top=4.8) 235 | ax.axes.set_zlim3d(bottom=0.0, top=2.0) 236 | ax.set_zticks([0, 1.0, 2.0]) 237 | 238 | # Plot marginal samples. 239 | cmap = matplotlib.cm.get_cmap("Spectral") 240 | steps_to_log = np.linspace(0, T - 1, S).astype(int) 241 | xs = xs.cpu().detach().clone() 242 | for idx, step in enumerate(steps_to_log): 243 | ax.scatter( 244 | xs[:512, step, 0], 245 | xs[:512, step, 1], 246 | xs[:512, step, 2], 247 | s=10.0, 248 | c=cmap(idx / (len(steps_to_log) - 1)), 249 | ) 250 | 251 | 252 | @torch.no_grad() 253 | def plot_gpath_lidar(result, pcfg): 254 | ax_length_in = 5 255 | fig = plt.figure(figsize=(3 * ax_length_in, 2 * ax_length_in)) 256 | 257 | ## Plot gamma (only the first pair) 258 | ax = fig.add_subplot(231) 259 | ax.plot(result["init_gamma"][0], "-x") 260 | ax.plot(result["final_gamma"][0], "-x") 261 | ax.set_title(r"$\gamma(t)$ Init vs Optimized") 262 | 263 | ## Plot loss 264 | ax = fig.add_subplot(234) 265 | losses = result["losses"] 266 | losses = result["losses"] 267 | ax.plot(losses) 268 | ax.set_title(f"Loss, last={losses[-1]:.1f}") 269 | 270 | ## Plot init mean 271 | xs = result["init_mean"] 272 | ax = fig.add_subplot(232, projection="3d", computed_zorder=False) 273 | ax.view_init(elev=50, azim=-115, roll=0) 274 | plot_lidar(ax, pcfg.dataset, xs, S=5) 275 | ax.set_title(f"init_mean") 276 | 277 | ax = fig.add_subplot(233, projection="3d", computed_zorder=False) 278 | ax.view_init(elev=90, azim=0, roll=0) 279 | plot_lidar(ax, pcfg.dataset, xs, S=5) 280 | ax.set_title(f"init_mean") 281 | 282 | ## Plot final mean 283 | xs = result["final_mean"] 284 | ax = fig.add_subplot(235, projection="3d", computed_zorder=False) 285 | ax.view_init(elev=50, azim=-115, roll=0) 286 | plot_lidar(ax, pcfg.dataset, xs, S=5) 287 | ax.set_title(f"final_mean") 288 | 289 | ax = fig.add_subplot(236, projection="3d", computed_zorder=False) 290 | ax.view_init(elev=90, azim=0, roll=0) 291 | plot_lidar(ax, pcfg.dataset, xs, S=5) 292 | ax.set_title(f"final_mean") 293 | 294 | 295 | @torch.no_grad() 296 | def plot_gpath_nd(result): 297 | fig, axs = get_fig_axes(ncol=2, ax_length_in=2.5) 298 | 299 | ## Plot gamma (only the first pair) 300 | axs[0].plot(result["init_gamma"][0], "-x") 301 | axs[0].plot(result["final_gamma"][0], "-x") 302 | axs[0].set_title(r"$\gamma(t)$ Init vs Optimized") 303 | 304 | ## Plot loss 305 | losses = result["losses"] 306 | axs[1].plot(losses) 307 | axs[1].set_title(f"Loss, last={losses[-1]:.1f}") 308 | 309 | 310 | @torch.no_grad() 311 | def plot_gpath_2d(result, pcfg): 312 | 313 | fig, axs = get_fig_axes(ncol=5, ax_length_in=2.5, lim=pcfg.lim) 314 | for ax in [axs[0], axs[1], axs[4]]: 315 | plot_obstacles(ax, result["name"]) 316 | for ax in [axs[2], axs[3]]: 317 | ax.relim() 318 | ax.autoscale() 319 | 320 | B, T, D = result["init_mean"].shape 321 | 322 | ## Plot mean & std (only the first pair) 323 | colors = get_colors(T) 324 | plot_scatter(axs[0], result["init_mean"][0], c=colors, title="Init Mean") 325 | plot_scatter(axs[1], result["final_mean"][0], c=colors, title="Optimized Mean") 326 | axs[2].plot(result["init_gamma"][0], "-x") 327 | axs[2].plot(result["final_gamma"][0], "-x") 328 | axs[2].set_title(r"$\gamma(t)$ Init vs Optimized") 329 | 330 | ## Plot loss 331 | losses = result["losses"] 332 | axs[3].plot(losses) 333 | axs[3].set_title(f"Loss, last={losses[-1]:.1f}") 334 | 335 | ## Plot marginal xt 336 | mB, S, N = 512, 5, 64 337 | with torch.no_grad(): 338 | xt = result["gpath"].sample_xt(torch.linspace(0, 1, S), N=N) # (B, N, S, D) 339 | xt = xt.permute(2, 0, 1, 3).reshape(S, B * N, D) 340 | for i, x in enumerate(xt): 341 | rand_idx = torch.randperm(B * N)[:mB] 342 | plot_scatter(axs[4], x[rand_idx], c=f"C{i}") 343 | axs[4].set_title(f"Optimized Xt") 344 | 345 | 346 | ################################################################################################ 347 | #################################### Plot important sampling ################################### 348 | ################################################################################################ 349 | 350 | 351 | def plot_iw(results, pcfg): 352 | t = results["IW_t"].detach().cpu() 353 | xt = results["IW_xt"].detach().cpu() 354 | xs = results["IW_xs"].detach().cpu() 355 | (B, N, S, D), (T,) = xs.shape, t.shape 356 | assert xt.shape == (B, T, D) 357 | 358 | _, sort_idx = torch.sort(t) 359 | xt = xt[:, sort_idx] 360 | colors = get_colors(T) 361 | 362 | fig, axs = get_fig_axes(ncol=4, nrow=2, ax_length_in=2, lim=pcfg.lim) 363 | for ax in axs.reshape(-1): 364 | plot_obstacles(ax, pcfg.name) 365 | 366 | for b in range(4): 367 | axs[0, b].set_title(f"Sample #{b}") 368 | 369 | plot_traj(axs[0, b], xs[b]) 370 | plot_scatter(axs[1, b], xt[b], c=colors) 371 | plot_scatter(axs[1, b], xt[[b], 0], s=20, c="C0") 372 | plot_scatter(axs[1, b], xt[[b], -1], s=20, c="C1") 373 | 374 | axs[0, 0].set_ylabel("SDEs") 375 | axs[1, 0].set_ylabel("IW Xt") 376 | 377 | 378 | ################################################################################################ 379 | ##################################### Plot simulated trajs ##################################### 380 | ################################################################################################ 381 | 382 | 383 | @torch.no_grad() 384 | def save_xs(t, xs, log_steps, direction, pcfg, fn): 385 | if pcfg.name == "lidar": 386 | plot_xs_lidar(xs, log_steps, pcfg) 387 | elif pcfg.name == "opinion": 388 | plot_xs_opinion(t, xs, log_steps, direction, pcfg) 389 | else: 390 | plot_xs_crowd_nav(t, xs, log_steps, direction, pcfg) 391 | save_fig(fn) 392 | 393 | 394 | @torch.no_grad() 395 | def plot_xs_opinion(t, xs, log_steps, direction, pcfg): 396 | """ 397 | t: (T,) 398 | xs: (B, T, D) 399 | """ 400 | B, T, D = xs.shape 401 | assert t.shape == (T,) and log_steps <= T 402 | 403 | ### PCA projection 404 | xs = xs.detach().clone() 405 | if D > 2: 406 | xs, _ = proj_pca(xs) 407 | assert xs.shape == (B, T, 2) 408 | 409 | ### Plot 410 | fig, axs = get_fig_axes(nrow=2, ncol=log_steps, lim=pcfg.lim) 411 | steps_to_log = np.linspace(0, T - 1, log_steps).astype(int) 412 | for i, step in enumerate(steps_to_log): 413 | plot_scatter(axs[0, i], xs[:, step], s=2, zorder=0) 414 | axs[0, i].set_title(r"$t$=" + f"{t[step]:.2f}") 415 | plot_directional_sim(axs[1, i], xs[:, step]) 416 | 417 | axs[0, 0].set_ylabel(direction) 418 | if D > 2: 419 | axs[1, 0].set_ylabel("PCA 1 vs 2") 420 | 421 | 422 | @torch.no_grad() 423 | def plot_xs_crowd_nav(t, xs, log_steps, direction, pcfg): 424 | B, T, D = xs.shape 425 | assert t.shape == (T,) and log_steps <= T 426 | 427 | fig, axs = get_fig_axes(ncol=log_steps, lim=pcfg.lim) 428 | steps_to_log = np.linspace(0, T - 1, log_steps).astype(int) 429 | 430 | for ax, step in zip(axs, steps_to_log): 431 | plot_obstacles(ax, pcfg.name) 432 | plot_scatter(ax, xs[:512, step]) 433 | ax.set_title(r"$t$=" + f"{t[step]:.2f}") 434 | axs[0].set_ylabel(direction) 435 | 436 | 437 | @torch.no_grad() 438 | def plot_xs_lidar(xs, log_steps, pcfg): 439 | fig = plt.figure() 440 | 441 | ax = fig.add_subplot(121, projection="3d", computed_zorder=False) 442 | ax.view_init(elev=50, azim=-115, roll=0) 443 | plot_lidar(ax, pcfg.dataset, xs, S=log_steps) 444 | 445 | ax = fig.add_subplot(122, projection="3d", computed_zorder=False) 446 | ax.view_init(elev=90, azim=0, roll=0) 447 | plot_lidar(ax, pcfg.dataset, xs, S=log_steps) 448 | -------------------------------------------------------------------------------- /gsbm/sde.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | from tqdm import trange 7 | from ipdb import set_trace as debug 8 | 9 | DIRECTIONS = ["fwd", "bwd"] 10 | 11 | 12 | def build_basedrift(cfg): 13 | if cfg.prob.name == "opinion": 14 | from .opinion import PolarizeDyn 15 | 16 | basedrift = PolarizeDyn(cfg.pdrift) 17 | else: 18 | basedrift = ZeroBaseDrift() 19 | return basedrift 20 | 21 | 22 | class ZeroBaseDrift(torch.nn.Module): 23 | def __init__(self): 24 | super(ZeroBaseDrift, self).__init__() 25 | 26 | def forward(self, xt, t): 27 | return torch.zeros_like(xt) 28 | 29 | 30 | def _assert_increasing(ts: torch.Tensor) -> None: 31 | assert (ts[1:] > ts[:-1]).all(), "time must be strictly increasing" 32 | 33 | 34 | def sdeint( 35 | xinit, 36 | drift, 37 | diffusion, 38 | direction, 39 | nfe, 40 | log_steps=5, 41 | eps=0, 42 | verbose=False, 43 | return_ws=False, 44 | ): 45 | """ 46 | xinit: (B, D) 47 | drift: (B, D) + (B,) --> (B, D) 48 | diffusion: (B, D) + (B,) --> (B, D) 49 | nfe = T - 1 50 | === 51 | t: (S,) 52 | xs: (B, S, D) 53 | us: (B, S, D) 54 | """ 55 | assert direction in DIRECTIONS 56 | 57 | T, (B, D), S, device = nfe + 1, xinit.shape, log_steps, xinit.device 58 | 59 | # build ts 60 | timesteps = ( 61 | torch.linspace(eps, 1 - eps, T, device=device) 62 | if direction == "fwd" 63 | else torch.linspace(1 - eps, eps, T, device=device) 64 | ) 65 | 66 | # logging 67 | steps_to_log = np.linspace(0, T - 1, log_steps).astype(int) 68 | xs = [] 69 | us = [] 70 | if return_ws: 71 | ws = [] 72 | 73 | x = xinit.detach() 74 | bar = trange(nfe) if verbose else range(nfe) 75 | for idx in bar: 76 | t, tnxt = timesteps[idx], timesteps[idx + 1] 77 | dt = (tnxt - t).abs() 78 | dw = math.sqrt(dt) * torch.randn(*x.shape, device=x.device) 79 | 80 | tt = t.repeat(x.shape[0]) 81 | u = drift(x, tt) 82 | g = diffusion(x, tt) 83 | 84 | if idx in steps_to_log: 85 | xs.append(x) 86 | us.append(u) 87 | if return_ws: 88 | ws.append(dw) 89 | 90 | x = x + u * dt + g * dw 91 | 92 | assert len(xs) == len(us) == S - 1 93 | 94 | # log last step, u = 0 95 | xs.append(x) 96 | us.append(torch.zeros_like(x)) 97 | if return_ws: 98 | ws.append(torch.zeros_like(x)) 99 | ts = timesteps[steps_to_log] 100 | 101 | if direction == "bwd": 102 | xs = xs[::-1] 103 | us = us[::-1] 104 | ts = ts.flip( 105 | dims=[ 106 | 0, 107 | ] 108 | ) 109 | if return_ws: 110 | ws = ws[::-1] 111 | 112 | xs = torch.stack(xs, dim=1).detach() 113 | us = torch.stack(us, dim=1).detach() 114 | assert xs.shape == us.shape == (B, S, D) and ts.shape == (S,) 115 | _assert_increasing(ts) 116 | out = {"t": ts, "xs": xs, "us": us} 117 | if return_ws: 118 | ws = torch.stack(ws, dim=1).detach() 119 | assert ws.shape == us.shape 120 | out["ws"] = ws 121 | return out 122 | -------------------------------------------------------------------------------- /gsbm/state_cost.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | from functools import partial 4 | from scipy.spatial import cKDTree 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.distributions as td 9 | 10 | from .vae import load_model, lerp, slerp 11 | from .utils import get_repo_path 12 | 13 | from ipdb import set_trace as debug 14 | 15 | 16 | def build_state_cost(cfg): 17 | if cfg.prob.name == "afhq": 18 | return VAEStateCost(cfg.vae) 19 | elif cfg.prob.name == "lidar": 20 | return LIDARStateCost(cfg.lidar) 21 | elif cfg.prob.name == "opinion": 22 | return OpinionStateCost(cfg.state_cost) 23 | else: 24 | return CrowdNavStateCost(cfg.prob.name, cfg.state_cost) 25 | 26 | 27 | ########################################################## 28 | ################### state cost functions ################# 29 | ########################################################## 30 | 31 | 32 | class VAEStateCost(torch.nn.Module): 33 | def __init__(self, vcfg): 34 | super().__init__() 35 | self.vcfg = vcfg 36 | self.vae = load_model(get_repo_path() / vcfg.ckpt).eval() 37 | 38 | @torch.no_grad() 39 | def recon(self, images): 40 | zs = self.vae.encoder(images)[0] 41 | return self.vae.decoder(zs) 42 | 43 | @torch.no_grad() 44 | def latent_interp(self, x0, x1, S, type="slerp"): 45 | """ 46 | x0, x1: (B, D) --> xt: (B, S, D), zt: (B, S, Z) 47 | """ 48 | img_size = self.vcfg.image_size 49 | B, D = x0.shape 50 | self.vae.to(x0) 51 | 52 | x0_img = x0.reshape(B, *img_size).to(x0) 53 | x1_img = x1.reshape(B, *img_size).to(x0) 54 | 55 | z0 = self.vae.encoder(x0_img)[0].reshape(1, B, -1) 56 | z1 = self.vae.encoder(x1_img)[0].reshape(1, B, -1) 57 | Zdim = z0.shape[-1] 58 | 59 | t = torch.linspace(0, 1, S).to(x0).reshape(-1, 1, 1) 60 | if type == "lerp": 61 | zt = lerp(z0, z1, t).reshape(-1, Zdim) 62 | elif type == "slerp": 63 | zt = slerp(z0, z1, t).reshape(-1, Zdim) 64 | xt = self.vae.decoder(zt) 65 | 66 | recon_xt = xt.reshape(S, B, D).permute(1, 0, 2) 67 | recon_zt = zt.reshape(S, B, Zdim).permute(1, 0, 2) 68 | return recon_xt, recon_zt 69 | 70 | def forward(self, xt, t, recon_xt): 71 | """ 72 | xt: (B, N, T, D) 73 | t: (T,) 74 | recon_xt: (B, T, D) 75 | === 76 | (B, N, T) 77 | """ 78 | B, N, T, D = xt.shape 79 | assert t.shape == (T,) 80 | assert recon_xt.shape == (B, T, D) 81 | 82 | recon_xt = recon_xt.reshape(B, 1, T, D).expand(-1, N, -1, -1) 83 | 84 | # loss = ((xt - recon_xt)**2).mean(dim=-1) # L2 85 | loss = ((xt - recon_xt).abs()).mean(dim=-1) # L1 slightly better 86 | return loss.reshape(B, N, T) 87 | 88 | 89 | class LIDARStateCost(torch.nn.Module): 90 | 91 | def __init__(self, lcfg): 92 | super().__init__() 93 | import laspy 94 | 95 | las = laspy.read(get_repo_path() / lcfg.filename) 96 | self.k = lcfg.k 97 | self.closeness_weight = lcfg.closeness_weight 98 | self.height_weight = lcfg.height_weight 99 | self.boundary_weight = lcfg.boundary_weight 100 | self.lim = lcfg.lim 101 | 102 | # Extract only "ground" points. 103 | mask = las.classification == 2 104 | 105 | # Extract points. 106 | x_offset, x_scale = las.header.offsets[0], las.header.scales[0] 107 | y_offset, y_scale = las.header.offsets[1], las.header.scales[1] 108 | z_offset, z_scale = las.header.offsets[2], las.header.scales[2] 109 | dataset = np.vstack( 110 | ( 111 | las.X[mask] * x_scale + x_offset, 112 | las.Y[mask] * y_scale + y_offset, 113 | las.Z[mask] * z_scale + z_offset, 114 | ) 115 | ).transpose() 116 | 117 | # Scale to [-5, 5]. 118 | mi = dataset.min(axis=0, keepdims=True) 119 | ma = dataset.max(axis=0, keepdims=True) 120 | dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] 121 | 122 | self.dataset = dataset 123 | 124 | # Build K-D tree for approximate nearest neighbor searches. 125 | self.tree = cKDTree(dataset) 126 | 127 | def get_tangent_plane(self, points, temp=1e-3): 128 | """ 129 | Estimates a tangent plane by taking the k nearest points. 130 | 131 | Then returns the projection operator to this tangent plane. 132 | 133 | Args: 134 | points: PyTorch tensor of shape (..., 3). 135 | 136 | Returns: 137 | Function that is the projection operator. Takes same size as the input (..., 3) --> (..., 3). 138 | """ 139 | 140 | # Query the nearest k points. 141 | # Note: this goes through CPU. 142 | points_np = points.detach().cpu().numpy() 143 | _, idx = self.tree.query(points_np, k=self.k) 144 | nearest_pts = self.dataset[idx] 145 | nearest_pts = torch.tensor(nearest_pts).to(points) 146 | 147 | dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) 148 | weights = torch.exp(-dists / temp) 149 | 150 | # Fits plane with least vertical distance. 151 | w = LIDARStateCost.fit_plane(nearest_pts, weights) 152 | return w 153 | 154 | def get_tangent_proj(self, points): 155 | w = self.get_tangent_plane(points) 156 | return partial(LIDARStateCost.projection_op, w=w) 157 | 158 | def boundary_penalty_1d(self, x, lim=5.0): 159 | cost = torch.sigmoid((x - lim) / 0.1) 160 | cost = cost + 1 - torch.sigmoid((x + lim) / 0.1) 161 | return cost 162 | 163 | def forward(self, xt, *args, **kwargs): 164 | shape = xt.shape[:-1] 165 | assert xt.shape[-1] == 3 166 | N = np.prod(shape) 167 | xt = xt.reshape(N, 3) 168 | 169 | projx = self.get_tangent_proj(xt) 170 | xt_projected = projx(xt) 171 | 172 | # Distance to the manifold. 173 | closeness = (xt_projected - xt).pow(2).sum(-1).reshape(*shape) 174 | 175 | # Don't leave the [-5, 5] boundary. 176 | boundary = self.boundary_penalty_1d(xt[:, 0]) + self.boundary_penalty_1d( 177 | xt[:, 1] 178 | ) 179 | boundary = boundary.reshape(*shape) 180 | 181 | # State cost is the height of the current point. 182 | # Note: we project x. This ensures the gradient is projected onto the tangent plane. 183 | height = torch.exp(xt_projected[:, 2]).reshape(*shape) 184 | 185 | return ( 186 | self.closeness_weight * closeness 187 | + self.height_weight * height 188 | + self.boundary_weight * boundary 189 | ) 190 | 191 | @staticmethod 192 | def fit_plane(points, weights=None): 193 | """Expects points to be of shape (..., 3). 194 | Returns [a, b, c] such that the plane is defined as 195 | ax + by + c = z 196 | """ 197 | D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) 198 | z = points[..., 2] 199 | if weights is not None: 200 | Dtrans = D.transpose(-1, -2) 201 | else: 202 | DW = D * weights 203 | Dtrans = DW.transpose(-1, -2) 204 | w = torch.linalg.solve( 205 | torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) 206 | ).squeeze(-1) 207 | return w 208 | 209 | @staticmethod 210 | def projection_op(x, w): 211 | """Projects points to a plane defined by w.""" 212 | # Normal vector to the tangent plane. 213 | n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) 214 | 215 | pn = torch.sum(x * n, dim=-1, keepdim=True) 216 | nn = torch.sum(n * n, dim=-1, keepdim=True) 217 | 218 | # Offset. 219 | d = w[..., 2:3] 220 | 221 | # Projection of x onto n. 222 | projn_x = ((pn + d) / nn) * n 223 | 224 | # Remove component in the normal direction. 225 | return x - projn_x 226 | 227 | 228 | class CrowdNavStateCost(torch.nn.Module): 229 | def __init__(self, name, scfg): 230 | super().__init__() 231 | self.scfg = scfg 232 | self.obstacle_cost = build_obstacle_cost(name) 233 | 234 | def forward(self, xt, t, gpath): 235 | """ 236 | xt: (*, T, D) 237 | t: (T,) 238 | === 239 | cost: (*, T) 240 | """ 241 | (T, D), scfg = xt.shape[-2:], self.scfg 242 | assert t.shape == (T,) and D == 2 243 | assert "obs" in scfg.type and scfg.obs > 0 244 | 245 | V = scfg.obs * self.obstacle_cost(xt) 246 | 247 | if "ent" in scfg.type and scfg.ent > 0: 248 | V = V + scfg.ent * entropy_cost(xt, t, gpath) 249 | elif "cgst" in scfg.type and scfg.cgst > 0: 250 | V = V + scfg.cgst * congestion_cost(xt) 251 | 252 | assert V.shape == xt.shape[:-1] 253 | return V 254 | 255 | 256 | class OpinionStateCost(torch.nn.Module): 257 | def __init__(self, scfg): 258 | super().__init__() 259 | self.scfg = scfg 260 | 261 | def forward(self, xt, t, gpath): 262 | """ 263 | xt: (*, T, D) 264 | t: (T,) 265 | === 266 | cost: (*, T) 267 | """ 268 | (T, D), scfg = xt.shape[-2:], self.scfg 269 | assert t.shape == (T,) 270 | 271 | V = zero_cost_fn(xt) 272 | if "ent" in scfg.type and scfg.ent > 0: 273 | V = V + scfg.ent * entropy_cost(xt, t, gpath) 274 | elif "cgst" in scfg.type and scfg.cgst > 0: 275 | V = V + scfg.cgst * congestion_cost(xt) 276 | 277 | assert V.shape == xt.shape[:-1] 278 | return V 279 | 280 | 281 | def zero_cost_fn(x: torch.Tensor, *args) -> torch.Tensor: 282 | return torch.zeros(*x.shape[:-1], device=x.device) 283 | 284 | 285 | ########################################################## 286 | ################## obstacle cost functions ############### 287 | ########################################################## 288 | 289 | 290 | def build_obstacle_cost(name): 291 | return { 292 | "gmm": obstacle_cost_gmm, 293 | "stunnel": obstacle_cost_stunnel, 294 | "vneck": obstacle_cost_vneck, 295 | "drunken_spider": obstacle_cost_drunken_spider, 296 | }.get(name) 297 | 298 | 299 | def obstacle_cfg_drunken_spider(): 300 | xys = [[-7, 0.5], [-7, -7.5]] 301 | widths = [14, 14] 302 | heights = [7, 7] 303 | return xys, widths, heights 304 | 305 | 306 | def obstacle_cost_drunken_spider(xt): 307 | """ 308 | xt: (*, 2) -> (*,) 309 | """ 310 | assert xt.shape[-1] == 2 311 | 312 | x, y = xt[..., 0], xt[..., 1] 313 | 314 | def cost_fn(xy, width, height): 315 | 316 | xbound = xy[0], xy[0] + width 317 | ybound = xy[1], xy[1] + height 318 | 319 | a = -5 * (x - xbound[0]) * (x - xbound[1]) 320 | b = -5 * (y - ybound[0]) * (y - ybound[1]) 321 | 322 | cost = F.softplus(a, beta=20, threshold=1) * F.softplus(b, beta=20, threshold=1) 323 | assert cost.shape == xt.shape[:-1] 324 | return cost 325 | 326 | return sum( 327 | cost_fn(xy, width, height) 328 | for xy, width, height in zip(*obstacle_cfg_drunken_spider()) 329 | ) 330 | 331 | 332 | def obstacle_cfg_gmm(): 333 | centers = [[6, 6], [6, -6], [-6, -6]] 334 | radius = 1.5 335 | return centers, radius 336 | 337 | 338 | def obstacle_cfg_stunnel(): 339 | a, b, c = 20, 1, 90 340 | centers = [[5, 6], [-5, -6]] 341 | return a, b, c, centers 342 | 343 | 344 | def obstacle_cfg_vneck(): 345 | c_sq = 0.36 346 | coef = 5 347 | return c_sq, coef 348 | 349 | 350 | def obstacle_cost_gmm(xt): 351 | 352 | Bs, D = xt.shape[:-1], xt.shape[-1] 353 | assert D == 2 354 | xt = xt.reshape(-1, xt.shape[-1]) 355 | 356 | batch_xt = xt.shape[0] 357 | 358 | centers, radius = obstacle_cfg_gmm() 359 | 360 | obs1 = torch.tensor(centers[0]).repeat((batch_xt, 1)).to(xt.device) 361 | obs2 = torch.tensor(centers[1]).repeat((batch_xt, 1)).to(xt.device) 362 | obs3 = torch.tensor(centers[2]).repeat((batch_xt, 1)).to(xt.device) 363 | 364 | dist1 = torch.norm(xt - obs1, dim=-1) 365 | dist2 = torch.norm(xt - obs2, dim=-1) 366 | dist3 = torch.norm(xt - obs3, dim=-1) 367 | 368 | cost1 = F.softplus(100 * (radius - dist1), beta=1, threshold=20) 369 | cost2 = F.softplus(100 * (radius - dist2), beta=1, threshold=20) 370 | cost3 = F.softplus(100 * (radius - dist3), beta=1, threshold=20) 371 | return (cost1 + cost2 + cost3).reshape(*Bs) 372 | 373 | 374 | def obstacle_cost_stunnel(xt): 375 | """ 376 | xt: (*, 2) -> (*,) 377 | """ 378 | 379 | a, b, c, centers = obstacle_cfg_stunnel() 380 | 381 | Bs, D = xt.shape[:-1], xt.shape[-1] 382 | assert D == 2 383 | 384 | _xt = xt.reshape(-1, D) 385 | x, y = _xt[:, 0], _xt[:, 1] 386 | 387 | d = a * (x - centers[0][0]) ** 2 + b * (y - centers[0][1]) ** 2 388 | # c1 = 1500 * (d < c) 389 | c1 = F.softplus(c - d, beta=1, threshold=20) 390 | 391 | d = a * (x - centers[1][0]) ** 2 + b * (y - centers[1][1]) ** 2 392 | # c2 = 1500 * (d < c) 393 | c2 = F.softplus(c - d, beta=1, threshold=20) 394 | 395 | cost = (c1 + c2).reshape(*Bs) 396 | return cost 397 | 398 | 399 | def obstacle_cost_vneck(xt): 400 | """ 401 | xt: (*, 2) -> (*,) 402 | """ 403 | assert xt.shape[-1] == 2 404 | 405 | c_sq, coef = obstacle_cfg_vneck() 406 | 407 | xt_sq = torch.square(xt) 408 | d = coef * xt_sq[..., 0] - xt_sq[..., 1] 409 | 410 | return F.softplus(-c_sq - d, beta=1, threshold=20) 411 | # return 15000 * (d < -c_sq) 412 | 413 | 414 | ########################################################## 415 | ################ interaction cost functions ############## 416 | ########################################################## 417 | 418 | 419 | def entropy_cost(xt, t, gpath): 420 | """ 421 | xt: (B, N, T, D), t: (T,) --> (B, N, T) 422 | """ 423 | B, N, T, D = xt.shape 424 | assert t.shape == (T,) 425 | 426 | ### build B*T indep Gaussians 427 | mean_t = gpath.mean(t).detach() 428 | gamma_t = gpath.gamma(t).detach() 429 | assert mean_t.shape == (B, T, D) 430 | assert gamma_t.shape == (B, T, 1) 431 | 432 | normals = td.Normal( 433 | mean_t.reshape(B * T, D), 434 | gamma_t.reshape(B * T, 1) * torch.ones(B * T, D, device=gpath.device), 435 | ) 436 | indep_normals = td.Independent(normals, 1) 437 | 438 | ### evaluate log-prob of all `B*N` samples at each timestamp 439 | ### w.r.t. `B` Gaussians 440 | xxt = xt.unsqueeze(2).expand(-1, -1, B, -1, -1) 441 | assert xxt.shape == (B, N, B, T, D) 442 | 443 | log_pt_01 = indep_normals.log_prob(xxt.reshape(B * N, B * T, D)).reshape( 444 | B * N, B, T 445 | ) 446 | pt = log_pt_01.exp().mean(dim=1) # (B*N, T) 447 | # pt = pt / pt.sum(dim=0, keepdim=True)#.detach() 448 | log_pt = pt.log().reshape(B, N, T) 449 | 450 | assert not torch.isnan(log_pt).any() 451 | assert log_pt.shape == (B, N, T) 452 | return log_pt 453 | 454 | 455 | def congestion_cost(xt): 456 | """ 457 | xt: (*, T, D) --> (*, T) 458 | """ 459 | 460 | T, D = xt.shape[-2:] 461 | 462 | yt = xt.reshape(-1, T, D) 463 | yt = yt[torch.randperm(yt.shape[0])].reshape_as(xt) # detach? 464 | 465 | dd = xt - yt 466 | dist = torch.sum(dd * dd, dim=-1) 467 | congestion = 2.0 / (dist + 1.0) 468 | assert congestion.shape == xt.shape[:-1] 469 | return congestion 470 | -------------------------------------------------------------------------------- /gsbm/utils.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | from typing import Union, Dict, Any 4 | import os 5 | from glob import glob 6 | import numpy as np 7 | import importlib 8 | from omegaconf import OmegaConf 9 | from pathlib import Path 10 | 11 | import torch 12 | from torch import distributed as dist 13 | 14 | 15 | def get_repo_path(): 16 | curr_dir = Path(os.path.dirname(os.path.realpath(__file__))) 17 | return curr_dir.parent 18 | 19 | 20 | def get_job_directory(file_or_checkpoint: Union[str, Dict[str, Any]]) -> str: 21 | found = False 22 | if isinstance(file_or_checkpoint, dict): 23 | chkpnt = file_or_checkpoint 24 | key = [x for x in chkpnt["callbacks"].keys() if "Checkpoint" in x][0] 25 | file = chkpnt["callbacks"][key]["dirpath"] 26 | else: 27 | file = file_or_checkpoint 28 | 29 | hydra_files = [] 30 | directory = os.path.dirname(file) 31 | while not found: 32 | hydra_files = glob( 33 | os.path.join(os.path.join(directory, ".hydra/config.yaml")), 34 | recursive=True, 35 | ) 36 | if len(hydra_files) > 0: 37 | break 38 | directory = os.path.dirname(directory) 39 | if directory == "": 40 | raise ValueError("Failed to find hydra config!") 41 | assert len(hydra_files) == 1, "Found ambiguous hydra config files!" 42 | job_dir = os.path.dirname(os.path.dirname(hydra_files[0])) 43 | return job_dir 44 | 45 | 46 | def restore_model(checkpoint, pl_name="gsbm.pl_model", device=None): 47 | ckpt = torch.load(checkpoint, map_location="cpu") 48 | job_dir = get_job_directory(checkpoint) 49 | cfg = OmegaConf.load(os.path.join(job_dir, ".hydra/config.yaml")) 50 | # print(f"Loaded cfg from {job_dir=}!") 51 | 52 | from .dataset import get_dist_boundary 53 | 54 | p0, p1, p0_val, p1_val = get_dist_boundary(cfg) 55 | pl_module = importlib.import_module(f"{pl_name}") 56 | model = pl_module.GSBMLitModule(cfg, p0, p1, p0_val, p1_val) 57 | model.load_state_dict(ckpt["state_dict"]) 58 | 59 | if device is not None: 60 | model = model.to(device) 61 | 62 | return model, cfg 63 | 64 | 65 | def chunk_multi_output(input, chunk_op, split_size): 66 | """ 67 | input: (B, *) 68 | chunk_op: (b, *) --> [(b, *), (b, *), ...] 69 | === 70 | output: [(B, *), (B, *), ...] 71 | """ 72 | B = input.shape[0] 73 | input_chunks = torch.split(input, split_size) 74 | 75 | output = None 76 | for chunk_idx, input_chunk in enumerate(input_chunks): 77 | output_chunk = chunk_op(input_chunk, chunk_idx) 78 | assert isinstance(output_chunk, tuple) 79 | 80 | # initialize output format 81 | if output is None: 82 | output = [[] for _ in range(len(output_chunk))] 83 | 84 | for n, o in enumerate(output_chunk): 85 | output[n].append(o) 86 | 87 | for n in range(len(output)): 88 | output[n] = torch.cat(output[n], dim=0) 89 | assert output[n].shape[0] == B 90 | return output 91 | 92 | 93 | def is_distributed(): 94 | return dist.is_available() and dist.is_initialized() 95 | 96 | 97 | def all_gather(tensor: torch.Tensor): 98 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] 99 | with torch.no_grad(): 100 | dist.all_gather(gathered_tensor, tensor) 101 | gathered_tensor = torch.cat(gathered_tensor, dim=0) 102 | return gathered_tensor 103 | 104 | 105 | def gather(outputs, key): 106 | return torch.cat([o[key].detach().cpu() for o in outputs], dim=0) 107 | 108 | 109 | def n_device(): 110 | return dist.get_world_size() if is_distributed() else 1 111 | -------------------------------------------------------------------------------- /gsbm/vae.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def load_model(checkpoint, device=None): 11 | chkpnt = torch.load(checkpoint, map_location="cpu") 12 | 13 | # Get just the model state dict. 14 | sd = chkpnt["state_dict"] 15 | new_sd = {} 16 | for key in sd.keys(): 17 | if key.startswith("model."): 18 | new_sd[key[6:]] = sd[key] 19 | 20 | try: 21 | cfg = chkpnt["cfg"] 22 | model = VAE(z_dim=cfg.z_dim, beta=cfg.beta, x_std=cfg.x_std) 23 | except: 24 | model = VAE(z_dim=256) 25 | model.load_state_dict(new_sd) 26 | if device is not None: 27 | model.to(device) 28 | 29 | # Important since there is batch norm. 30 | model.eval() 31 | 32 | return model 33 | 34 | 35 | def lerp(x0, x1, t): 36 | """Assumes all inputs can be broadcasted to the same shape.""" 37 | return x0 + t * (x1 - x0) 38 | 39 | 40 | def slerp(x0, x1, t): 41 | """ 42 | Assumes all inputs can be broadcasted to the same shape (..., D). 43 | Performs the slerp on the last dimension. 44 | """ 45 | low_norm = x0 / torch.norm(x0, dim=-1, keepdim=True) 46 | high_norm = x1 / torch.norm(x1, dim=-1, keepdim=True) 47 | omega = torch.acos((low_norm * high_norm).sum(-1)) 48 | so = torch.sin(omega) 49 | res = (torch.sin((1.0 - t) * omega) / so).unsqueeze(-1) * x0 + ( 50 | torch.sin(t * omega) / so 51 | ).unsqueeze(-1) * x1 52 | return res 53 | 54 | 55 | class Swish(nn.Module): 56 | def __init__(self): 57 | super(Swish, self).__init__() 58 | self.beta = nn.Parameter(torch.tensor([0.5])) 59 | 60 | def forward(self, x): 61 | return x * torch.sigmoid_(x * F.softplus(self.beta)) 62 | 63 | 64 | class ResizeConv2d(nn.Module): 65 | def __init__( 66 | self, in_channels, out_channels, kernel_size, scale_factor, mode="nearest" 67 | ): 68 | super().__init__() 69 | self.scale_factor = scale_factor 70 | self.mode = mode 71 | self.conv = nn.Conv2d( 72 | in_channels, out_channels, kernel_size, stride=1, padding=1 73 | ) 74 | 75 | def forward(self, x): 76 | x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) 77 | x = self.conv(x) 78 | return x 79 | 80 | 81 | class BasicBlockEnc(nn.Module): 82 | def __init__(self, in_planes, stride=1, same_width=False): 83 | super().__init__() 84 | 85 | if same_width: 86 | planes = in_planes 87 | else: 88 | planes = in_planes * stride 89 | 90 | self.conv1 = nn.Conv2d( 91 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 92 | ) 93 | self.bn1 = nn.BatchNorm2d(planes) 94 | self.actfn1 = Swish() 95 | self.conv2 = nn.Conv2d( 96 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 97 | ) 98 | self.bn2 = nn.BatchNorm2d(planes) 99 | self.actfn2 = Swish() 100 | 101 | if stride == 1: 102 | self.shortcut = nn.Sequential() 103 | else: 104 | self.shortcut = nn.Sequential( 105 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes), 107 | ) 108 | 109 | def forward(self, x): 110 | out = self.actfn1(self.bn1(self.conv1(x))) 111 | out = self.bn2(self.conv2(out)) 112 | out += self.shortcut(x) 113 | out = self.actfn2(out) 114 | return out 115 | 116 | 117 | class BasicBlockDec(nn.Module): 118 | def __init__(self, in_planes, stride=1, same_width=False): 119 | super().__init__() 120 | 121 | if same_width: 122 | planes = in_planes 123 | else: 124 | planes = int(in_planes / stride) 125 | 126 | self.conv2 = nn.Conv2d( 127 | in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False 128 | ) 129 | self.bn2 = nn.BatchNorm2d(in_planes) 130 | self.actfn2 = Swish() 131 | 132 | if stride == 1: 133 | self.conv1 = nn.Conv2d( 134 | in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False 135 | ) 136 | self.bn1 = nn.BatchNorm2d(planes) 137 | self.shortcut = nn.Sequential() 138 | else: 139 | self.conv1 = ResizeConv2d( 140 | in_planes, planes, kernel_size=3, scale_factor=stride 141 | ) 142 | self.bn1 = nn.BatchNorm2d(planes) 143 | self.shortcut = nn.Sequential( 144 | ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride), 145 | nn.BatchNorm2d(planes), 146 | ) 147 | self.actfn1 = Swish() 148 | 149 | def forward(self, x): 150 | out = self.actfn2(self.bn2(self.conv2(x))) 151 | out = self.bn1(self.conv1(out)) 152 | out += self.shortcut(x) 153 | out = self.actfn1(out) 154 | return out 155 | 156 | 157 | class ResNet18Enc(nn.Module): 158 | def __init__(self, num_Blocks=[2, 2, 2, 2], z_dim=10, nc=3, additional_layers=0): 159 | super().__init__() 160 | self.in_planes = 64 161 | self.z_dim = z_dim 162 | self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False) 163 | self.bn1 = nn.BatchNorm2d(64) 164 | self.actfn1 = Swish() 165 | self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1) 166 | self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2) 167 | self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2) 168 | self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2) 169 | 170 | if additional_layers > 0: 171 | layers = [ 172 | self._make_layer( 173 | BasicBlockEnc, 512, num_Blocks[3], stride=2, same_width=True 174 | ) 175 | for _ in range(additional_layers) 176 | ] 177 | self.additional_layers = nn.Sequential(*layers) 178 | else: 179 | self.additional_layers = None 180 | 181 | self.linear = nn.Linear(512, 2 * z_dim) 182 | 183 | def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride, same_width=False): 184 | strides = [stride] + [1] * (num_Blocks - 1) 185 | layers = [] 186 | for stride in strides: 187 | layers += [BasicBlockEnc(self.in_planes, stride, same_width=same_width)] 188 | self.in_planes = planes 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.actfn1(self.bn1(self.conv1(x))) 193 | x = self.layer1(x) 194 | x = self.layer2(x) 195 | x = self.layer3(x) 196 | x = self.layer4(x) 197 | if self.additional_layers is not None: 198 | x = self.additional_layers(x) 199 | x = F.adaptive_avg_pool2d(x, 1) 200 | x = x.view(x.size(0), -1) 201 | x = self.linear(x) 202 | mu = x[:, : self.z_dim] 203 | logstd = x[:, self.z_dim :] 204 | return mu, logstd 205 | 206 | 207 | class ResNet18Dec(nn.Module): 208 | def __init__(self, num_Blocks=[2, 2, 2, 2], z_dim=10, nc=3, additional_layers=0): 209 | super().__init__() 210 | self.in_planes = 512 211 | 212 | self.linear = nn.Linear(z_dim, 512) 213 | 214 | if additional_layers > 0: 215 | layers = [ 216 | self._make_layer( 217 | BasicBlockDec, 512, num_Blocks[3], stride=2, same_width=True 218 | ) 219 | for _ in range(additional_layers) 220 | ] 221 | self.additional_layers = nn.Sequential(*layers) 222 | else: 223 | self.additional_layers = None 224 | 225 | self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2) 226 | self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2) 227 | self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2) 228 | self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1) 229 | self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2) 230 | 231 | def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride, same_width=False): 232 | strides = [stride] + [1] * (num_Blocks - 1) 233 | layers = [] 234 | for stride in reversed(strides): 235 | layers += [BasicBlockDec(self.in_planes, stride, same_width=same_width)] 236 | self.in_planes = planes 237 | return nn.Sequential(*layers) 238 | 239 | def forward(self, z): 240 | x = self.linear(z) 241 | x = x.view(z.size(0), 512, 1, 1) 242 | x = F.interpolate(x, scale_factor=4) 243 | if self.additional_layers is not None: 244 | x = self.additional_layers(x) 245 | x = self.layer4(x) 246 | x = self.layer3(x) 247 | x = self.layer2(x) 248 | x = self.layer1(x) 249 | x = torch.tanh(self.conv1(x)) # squash output to [-1, 1] 250 | return x 251 | 252 | 253 | class VAE(nn.Module): 254 | def __init__(self, z_dim, beta=1.0, x_std=0.1, additional_layers=0): 255 | super().__init__() 256 | self.encoder = ResNet18Enc(z_dim=z_dim, additional_layers=additional_layers) 257 | self.decoder = ResNet18Dec(z_dim=z_dim, additional_layers=additional_layers) 258 | self.z_dim = z_dim 259 | self.beta = beta 260 | self.logsigma = math.log(x_std) # logstd of p(x | z). 261 | 262 | def sample_latent(self, x): 263 | mean, logstd = self.encoder(x) 264 | z = torch.randn_like(mean) * torch.exp(logstd) + mean 265 | return z, mean, logstd 266 | 267 | def forward(self, x): 268 | z, mean, logstd = self.sample_latent(x) 269 | x = self.decoder(z) 270 | return x, z, mean, logstd 271 | 272 | def sample(self, num_samples, device): 273 | z = torch.randn(num_samples, self.z_dim).to(device) 274 | return self.decoder(z) 275 | 276 | def reconstruct(self, x): 277 | return self(x)[0] 278 | 279 | def compute_elbo(self, x): 280 | """Computes the ELBO.""" 281 | bsz = x.shape[0] 282 | x_recon, z, mean, logstd = self(x) 283 | 284 | logqz = normal_logprob(z, mean, logstd).reshape(bsz, -1).sum(1) 285 | logpz = normal_logprob(z, 0.0, 0.0).reshape(bsz, -1).sum(1) 286 | logpx = normal_logprob(x, x_recon, self.logsigma).reshape(bsz, -1).sum(1) 287 | 288 | return logpx + self.beta * (logpz - logqz) 289 | 290 | 291 | def normal_logprob(z, mean, log_std): 292 | mean = (mean + torch.tensor(0.0)).to(z) 293 | log_std = (log_std + torch.tensor(0.0)).to(z) 294 | c = torch.tensor([math.log(2 * math.pi)]).to(z) 295 | inv_sigma = torch.exp(-log_std) 296 | tmp = (z - mean) * inv_sigma 297 | return -0.5 * (tmp * tmp + 2 * log_std + c) 298 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # stunnel (obstacle + congestion) 4 | python train.py experiment=stunnel prob.sigma=0.5,1,2 -m # bm 5 | python train.py experiment=stunnel prob.sigma=0.5,1,2 csoc.IW=true -m # bm-IW 6 | python train.py experiment=stunnel_eam prob.sigma=0.5,1,2 -m # eam 7 | python train.py experiment=stunnel_eam prob.sigma=0.5,1,2 csoc.IW=true -m # eam-IW 8 | 9 | # vneck (obstacle + entropy) 10 | python train.py experiment=vneck prob.sigma=1,2 -m 11 | 12 | # gmm (obstacle + entropy) 13 | python train.py experiment=gmm prob.sigma=1,2 -m 14 | 15 | # lidar 16 | python train.py experiment=lidar -m 17 | 18 | # opinion_2d 19 | python train.py experiment=opinion_2d -m 20 | 21 | # opinion 22 | python train.py experiment=opinion -m 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import setuptools 4 | 5 | setuptools.setup( 6 | name="gsbm", 7 | version="1.0.0", 8 | author="Guan-Horng Liu", 9 | author_email="ghliu@gatech.edu", 10 | description="Generalized Schrödinger bridge matching", 11 | url="https://github.com/facebookresearch/generalized-schrodinger-bridge-matching", 12 | packages=setuptools.find_packages(), 13 | ) 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Meta Platforms, Inc. and affiliates.""" 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | from datetime import datetime 7 | from omegaconf import DictConfig, OmegaConf 8 | import hydra 9 | from hydra.core.hydra_config import HydraConfig 10 | import logging 11 | import json 12 | from glob import glob 13 | import torch 14 | import torch.nn as nn 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.callbacks import ModelCheckpoint 17 | from pytorch_lightning.callbacks import LearningRateMonitor 18 | 19 | from gsbm.dataset import get_dist_boundary 20 | from gsbm.pl_model import GSBMLitModule 21 | 22 | import colored_traceback.always 23 | 24 | from ipdb import set_trace as debug 25 | 26 | torch.backends.cudnn.benchmark = True 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | @hydra.main(version_base=None, config_path="configs", config_name="train") 31 | def main(cfg: DictConfig): 32 | logging.getLogger("pytorch_lightning").setLevel(logging.getLevelName("INFO")) 33 | 34 | hydra_config = HydraConfig.get() 35 | 36 | # Get the number of nodes we are training on 37 | nnodes = hydra_config.launcher.get("nodes", 1) 38 | print("nnodes", nnodes) 39 | 40 | if cfg.get("seed", None) is not None: 41 | pl.utilities.seed.seed_everything(cfg.seed) 42 | 43 | print(cfg) 44 | 45 | print("Found {} CUDA devices.".format(torch.cuda.device_count())) 46 | for i in range(torch.cuda.device_count()): 47 | props = torch.cuda.get_device_properties(i) 48 | print( 49 | "{} \t Memory: {:.2f}GB".format(props.name, props.total_memory / (1024**3)) 50 | ) 51 | 52 | keys = [ 53 | "SLURM_NODELIST", 54 | "SLURM_JOB_ID", 55 | "SLURM_NTASKS", 56 | "SLURM_JOB_NAME", 57 | "SLURM_PROCID", 58 | "SLURM_LOCALID", 59 | "SLURM_NODEID", 60 | ] 61 | log.info(json.dumps({k: os.environ.get(k, None) for k in keys}, indent=4)) 62 | 63 | cmd_str = " \\\n".join([f"python {sys.argv[0]}"] + ["\t" + x for x in sys.argv[1:]]) 64 | with open("cmd.sh", "w") as fout: 65 | print("#!/bin/bash\n", file=fout) 66 | print(cmd_str, file=fout) 67 | 68 | log.info(f"CWD: {os.getcwd()}") 69 | 70 | # Construct model 71 | p0, p1, p0_val, p1_val = get_dist_boundary(cfg) 72 | model = GSBMLitModule(cfg, p0, p1, p0_val, p1_val) 73 | model.log_boundary(p0, p1, p0_val, p1_val) 74 | if cfg.prob.name == "opinion": 75 | model.log_basedrift(p0) 76 | # print(model) 77 | 78 | # Checkpointing, logging, and other misc. 79 | callbacks = [ 80 | ModelCheckpoint( 81 | dirpath="checkpoints", 82 | filename="epoch-{epoch:03d}_step-{step}", 83 | auto_insert_metric_name=False, 84 | save_top_k=-1, # save all models whenever callback occurs 85 | save_last=True, 86 | every_n_epochs=1, 87 | verbose=True, 88 | ), 89 | LearningRateMonitor(), 90 | ] 91 | 92 | slurm_plugin = pl.plugins.environments.SLURMEnvironment(auto_requeue=False) 93 | 94 | cfg_dict = OmegaConf.to_container(cfg, resolve=True) 95 | cfg_dict["cwd"] = os.getcwd() 96 | loggers = [pl.loggers.CSVLogger(save_dir=".")] 97 | if cfg.use_wandb: 98 | now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 99 | loggers.append( 100 | pl.loggers.WandbLogger( 101 | save_dir=".", 102 | name=f"{cfg.prob.name}_{now}", 103 | project="GSBM", 104 | log_model=False, 105 | config=cfg_dict, 106 | resume=True, 107 | ) 108 | ) 109 | 110 | strategy = "ddp" if torch.cuda.device_count() > 1 else None 111 | 112 | trainer = pl.Trainer( 113 | max_epochs=cfg.optim.max_epochs, 114 | accelerator="gpu", 115 | strategy=strategy, 116 | logger=loggers, 117 | num_nodes=nnodes, 118 | callbacks=callbacks, 119 | precision=cfg.get("precision", 32), 120 | gradient_clip_val=cfg.optim.grad_clip, 121 | plugins=slurm_plugin if slurm_plugin.detect() else None, 122 | reload_dataloaders_every_n_epochs=1, # GSBM: refresh on-policy samples every epoch 123 | num_sanity_val_steps=-1, # GSBM: validate before training -> radom coupling 124 | check_val_every_n_epoch=1, # GSBM: validate -> markovian coupling 125 | replace_sampler_ddp=False, # GSBM: avoid gather_all, use device-wise dataloader 126 | enable_progress_bar=False, 127 | ) 128 | 129 | # If we specified a checkpoint to resume from, use it 130 | checkpoint = cfg.get("resume", None) 131 | 132 | # Check if a checkpoint exists in this working directory. If so, then we are resuming from a pre-emption 133 | # This takes precedence over a command line specified checkpoint 134 | checkpoints = glob("checkpoints/**/*.ckpt", recursive=True) 135 | if len(checkpoints) > 0: 136 | # Use the checkpoint with the latest modification time 137 | checkpoint = sorted(checkpoints, key=os.path.getmtime)[-1] 138 | 139 | # Load dataset (train loader will be generated online) 140 | trainer.fit(model, ckpt_path=checkpoint) 141 | 142 | metric_dict = trainer.callback_metrics 143 | 144 | for k, v in metric_dict.items(): 145 | metric_dict[k] = float(v) 146 | 147 | with open("metrics.json", "w") as fout: 148 | print(json.dumps(metric_dict), file=fout) 149 | 150 | return metric_dict 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | --------------------------------------------------------------------------------