├── .gitignore ├── LICENSE ├── README.md ├── assets ├── GMM.gif ├── Stunnel.gif ├── Vneck.gif ├── opinion_after.gif └── opinion_before.gif ├── configs ├── __init__.py ├── gmm.py ├── opinion.py ├── opinion_1k.py ├── stunnel.py └── vneck.py ├── deepgsb ├── __init__.py ├── deepgsb.py ├── eval_metrics.py ├── loss_lib.py ├── replay_buffer.py ├── sb_policy.py └── util.py ├── git_utils.py ├── main.py ├── make_animation.py ├── mfg ├── __init__.py ├── constraint.py ├── mfg.py ├── opinion_lib.py ├── plotting.py ├── sde.py ├── state_cost.py └── util.py ├── models ├── __init__.py ├── opinion_net.py ├── toy_net.py └── util.py ├── options.py ├── requirements.yaml └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *__pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | scripts/* 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | 138 | data/* 139 | checkpoint/* 140 | plots/* 141 | *vscode/* 142 | deprecated.py 143 | runs/* 144 | .DS_Store 145 | ._.DS_Store 146 | *ipynb 147 | deprecated/* 148 | *gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Generalized Schrödinger Bridge
[[NeurIPS 2022 Oral](https://arxiv.org/abs/2209.09893)] 2 | 3 | Official PyTorch implementation of the paper 4 | "_**Deep** **G**eneralized **S**chrödinger **B**ridge_ (**DeepGSB**)" which introduces 5 | a new class of diffusion models as a scalable numerical solver for Mean-Field Games (MFGs), _e.g._, population modeling & opinion depolarization, with hard distributional constraints. 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |
Population modeling (crowd navigation)Opinion depolarization
drawingdrawingdrawingdrawingdrawing
21 | 22 | This repo is co-maintained by [Guan-Horng Liu](https://ghliu.github.io/), [Tianrong Chen](https://tianrongchen.github.io/), and [Oswin So](https://oswinso.xyz/). Contact us if you have any questions! If you find this library useful, please cite :arrow_down: 23 | ``` 24 | @inproceedings{liu2022deep, 25 | title={Deep Generalized Schr{\"o}dinger Bridge}, 26 | author={Liu, Guan-Horng and Chen, Tianrong and So, Oswin and Theodorou, Evangelos A}, 27 | booktitle={Advances in Neural Information Processing Systems}, 28 | year={2022} 29 | } 30 | ``` 31 | 32 | 33 | ## Install 34 | 35 | Install the dependencies with [Anaconda](https://www.anaconda.com/products/individual) and activate the environment `deepgsb` with 36 | ```bash 37 | conda env create --file requirements.yaml 38 | conda activate deepgsb 39 | ``` 40 | 41 | 42 | ## Run & Evaluate 43 | 44 | The repo contains 2 classes of Mean-Field Games, namely 45 | - **population modeling**: `GMM`, `Vneck`, `Stunnel` 46 | - **opinion depolarization**: `opinion`, `opinion-1k` (dim=1000). 47 | 48 | The commands to generate similar results shown in our paper can be found in `run.sh`. Results, checkpoints, and tensorboard log files will be saved respectively to the folders `results/`, `checkpoint/`, and `runs/`. 49 | ```bash 50 | bash run.sh # can be {GMM, Vneck, Stunnel, opinion, opinion-1k} 51 | ``` 52 | 53 | You can visualize the trained DeepGSB policies by making gif animation 54 | ```bash 55 | python make_animation.py --load --name 56 | ``` 57 | 58 | ## Structure 59 | 60 | We briefly document the file structure to ease the effort if you wish to integrate DeepGSB with your work flow. 61 | ```bash 62 | deepgsb/ 63 | ├── deepgsb.py # the DeepGSB MFG solver 64 | ├── sb_policy.py # the parametrized Schrödinger Bridge policy 65 | ├── loss_lib.py # all loss functions (IPF/KL, TD, FK/grad) 66 | ├── eval_metrics.py # all logging metrics (Wasserstein, etc) 67 | ├── replay_buffer.py 68 | └── util.py 69 | mfg/ 70 | ├── mfg.py # the Mean-Field Game environment 71 | ├── constraint.py # the distributional boundary constraint (p0, pT) 72 | ├── state_cost.py # all mean-field interaction state costs (F) 73 | ├── sde.py # the associated stochastic processes (f, sigma) 74 | ├── opinion_lib.py # all utilities for opinion depolarization MFG 75 | ├── plotting.py 76 | └── util.py 77 | models/ # the deep networks for parametrizing SB policy 78 | configs/ # the configurations for each MFG 79 | ``` 80 | -------------------------------------------------------------------------------- /assets/GMM.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/GMM.gif -------------------------------------------------------------------------------- /assets/Stunnel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/Stunnel.gif -------------------------------------------------------------------------------- /assets/Vneck.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/Vneck.gif -------------------------------------------------------------------------------- /assets/opinion_after.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/opinion_after.gif -------------------------------------------------------------------------------- /assets/opinion_before.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/opinion_before.gif -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | def get_default(problem_name, sb_param): 4 | env_name = problem_name.lower() 5 | config_fn = env_name + "_" + sb_param.replace("-", "_") 6 | 7 | # print(env_name) 8 | module = importlib.import_module(f"configs.{env_name}") 9 | assert hasattr(module, config_fn) 10 | return getattr(module, config_fn)() 11 | -------------------------------------------------------------------------------- /configs/gmm.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | def _common(): 4 | config = edict() 5 | config.problem_name = 'GMM' 6 | config.x_dim = 2 7 | 8 | # sde 9 | config.seed = 42 10 | config.t0 = 0.0 11 | config.T = 1.0 12 | config.interval = 100 13 | config.diffusion_std = 1.0 14 | 15 | # training 16 | config.num_itr = 250 17 | config.train_bs_x = 128 18 | config.rb_bs_x = 128 19 | 20 | # sampling & evaluation 21 | config.samp_bs = 5000 22 | config.snapshot_freq = 1 23 | # config.ckpt_freq = 2 24 | 25 | # optimization 26 | config.optimizer = 'AdamW' 27 | 28 | return config 29 | 30 | def gmm_actor_critic(): 31 | config = _common() 32 | 33 | # paramatrization 34 | config.sb_param = 'actor-critic' 35 | config.policy_net = 'toy' 36 | 37 | # optimization 38 | config.lr = 5e-4 39 | config.lr_y = 1e-3 40 | config.lr_gamma = 0.999 41 | 42 | # tuning 43 | config.num_stage = 40 44 | config.multistep_td = True 45 | config.samp_method = 'gauss' 46 | 47 | return config 48 | 49 | def gmm_critic(): 50 | config = _common() 51 | 52 | # paramatrization 53 | config.sb_param = 'critic' 54 | config.policy_net = 'toy' 55 | 56 | # optimization 57 | config.lr = 5e-4 58 | config.lr_gamma = 0.999 59 | 60 | # tuning 61 | config.num_stage = 40 62 | config.multistep_td = True 63 | config.samp_method = 'gauss' 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/opinion.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | def _common(): 5 | config = edict() 6 | config.problem_name = "opinion" 7 | config.x_dim = 2 8 | 9 | # sde 10 | config.seed = 42 11 | config.t0 = 0.0 12 | config.T = 3.0 13 | config.interval = 300 14 | config.diffusion_std = 0.1 15 | 16 | # training 17 | config.num_stage = 40 18 | config.num_itr = 100 19 | config.train_bs_x = 128 20 | config.rb_bs_x = 128 21 | 22 | # sampling & evaluation 23 | config.samp_bs = 2000 24 | config.snapshot_freq = 1 25 | # config.ckpt_freq = 2 26 | 27 | # optimization 28 | config.optimizer = "AdamW" 29 | 30 | return config 31 | 32 | 33 | def opinion_actor_critic(): 34 | config = _common() 35 | 36 | # paramatrization 37 | config.sb_param = "actor-critic" 38 | config.policy_net = "toy" 39 | 40 | # optimization 41 | coeff = 3. 42 | config.lr = coeff * 5e-4 43 | config.lr_y = 1e-3 44 | config.lr_gamma = 0.999 45 | 46 | # tuning 47 | config.num_stage = 40 48 | config.multistep_td = True 49 | config.use_rb_loss = True 50 | 51 | config.samp_method = "gauss" 52 | 53 | config.weights = {'kl': 0.8, 'non-kl': 0.05} 54 | 55 | return config 56 | 57 | 58 | def opinion_critic(): 59 | config = _common() 60 | 61 | # paramatrization 62 | config.sb_param = "critic" 63 | config.policy_net = "toy" 64 | 65 | # optimization 66 | config.lr = 5e-4 67 | config.lr_gamma = 0.999 68 | 69 | # tuning 70 | config.num_stage = 40 71 | config.multistep_td = True 72 | config.use_rb_loss = True 73 | 74 | config.samp_method = "gauss" 75 | 76 | config.weights = {'kl': 0.6, 'non-kl': 0.002} 77 | 78 | return config -------------------------------------------------------------------------------- /configs/opinion_1k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | def _common(): 5 | config = edict() 6 | config.problem_name = "opinion_1k" 7 | config.x_dim = 1000 8 | 9 | # sde 10 | config.seed = 42 11 | config.t0 = 0.0 12 | config.T = 3.0 13 | config.interval = 500 14 | config.diffusion_std = 0.5 15 | 16 | # training 17 | config.train_bs_x = 128 18 | config.rb_bs_x = 128 19 | config.buffer_size = 5000 20 | 21 | # sampling & evaluation 22 | config.samp_bs = 2500 23 | config.snapshot_freq = 1 24 | # config.ckpt_freq = 2 25 | 26 | # optimization 27 | config.optimizer = "AdamW" 28 | 29 | return config 30 | 31 | 32 | def opinion_1k_actor_critic(): 33 | config = _common() 34 | 35 | # paramatrization 36 | config.sb_param = "actor-critic" 37 | config.policy_net = "opinion_net" 38 | 39 | # optimization 40 | coeff = 1. 41 | 42 | config.lr = coeff * 5e-4 43 | config.lr_y = 1e-3 44 | config.lr_gamma = 0.999 45 | 46 | # tuning 47 | config.num_stage = 130 48 | config.num_itr = 250 49 | # config.multistep_td = True 50 | # config.use_rb_loss = True 51 | 52 | config.samp_method = "gauss" 53 | 54 | config.weights = {'kl': 0.8, 'non-kl': 0.05} 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/stunnel.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | def _common(): 4 | config = edict() 5 | config.problem_name = 'Stunnel' 6 | config.x_dim = 2 7 | 8 | # sde 9 | config.seed = 42 10 | config.t0 = 0.0 11 | config.T = 3.0 12 | config.interval = 300 13 | config.diffusion_std = 1.0 14 | 15 | # training 16 | config.num_itr = 500 17 | config.train_bs_x = 128 18 | config.rb_bs_x = 128 19 | 20 | # sampling & evaluation 21 | config.samp_bs = 5000 22 | config.snapshot_freq = 1 23 | # config.ckpt_freq = 2 24 | 25 | # optimization 26 | config.optimizer = 'AdamW' 27 | 28 | return config 29 | 30 | def stunnel_actor_critic(): 31 | config = _common() 32 | 33 | # paramatrization 34 | config.sb_param = 'actor-critic' 35 | config.policy_net = 'toy' 36 | 37 | # optimization 38 | config.lr = 5e-4 39 | config.lr_y = 1e-3 40 | config.lr_gamma = 0.999 41 | 42 | # tuning 43 | config.num_stage = 40 44 | config.multistep_td = True 45 | config.use_rb_loss = True 46 | config.samp_method = 'jacobi' 47 | 48 | return config 49 | 50 | def stunnel_critic(): 51 | config = _common() 52 | 53 | # paramatrization 54 | config.sb_param = 'critic' 55 | config.policy_net = 'toy' 56 | 57 | # optimization 58 | config.lr = 5e-4 59 | config.lr_gamma = 0.999 60 | 61 | # tuning 62 | config.num_stage = 40 63 | config.multistep_td = True 64 | config.samp_method = 'jacobi' 65 | config.ema = 0.9 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /configs/vneck.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | def _common(): 4 | config = edict() 5 | config.problem_name = 'Vneck' 6 | config.x_dim = 2 7 | 8 | # sde 9 | config.seed = 42 10 | config.t0 = 0.0 11 | config.T = 2.0 12 | config.interval = 200 13 | config.diffusion_std = 1.0 14 | 15 | # training 16 | config.num_itr = 250 17 | config.train_bs_x = 128 18 | config.rb_bs_x = 128 19 | 20 | # sampling & evaluation 21 | config.samp_bs = 5000 22 | config.snapshot_freq = 1 23 | # config.ckpt_freq = 2 24 | 25 | # optimization 26 | config.optimizer = 'AdamW' 27 | 28 | return config 29 | 30 | def vneck_actor_critic(): 31 | config = _common() 32 | 33 | # paramatrization 34 | config.sb_param = 'actor-critic' 35 | config.policy_net = 'toy' 36 | 37 | # optimization 38 | config.lr = 5e-4 39 | config.lr_y = 1e-3 40 | config.lr_gamma = 0.999 41 | 42 | # tuning 43 | config.num_stage = 40 44 | config.multistep_td = True 45 | config.use_rb_loss = True 46 | 47 | config.samp_method = 'jacobi' 48 | 49 | return config 50 | 51 | def vneck_critic(): 52 | config = _common() 53 | 54 | # paramatrization 55 | config.sb_param = 'critic' 56 | config.policy_net = 'toy' 57 | 58 | # optimization 59 | config.lr = 5e-4 60 | config.lr_gamma = 0.999 61 | 62 | # tuning 63 | config.num_stage = 40 64 | config.multistep_td = True 65 | config.use_rb_loss = True 66 | 67 | # gauss converges faster but with slightly worse W2 68 | config.samp_method = 'jacobi' 69 | 70 | return config 71 | -------------------------------------------------------------------------------- /deepgsb/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepgsb import DeepGSB -------------------------------------------------------------------------------- /deepgsb/deepgsb.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gc 3 | import logging 4 | import os 5 | import pickle 6 | import time 7 | from typing import Dict, Optional, Tuple 8 | 9 | import torch 10 | from easydict import EasyDict as edict 11 | from torch.optim import SGD, Adagrad, Adam, AdamW, RMSprop, lr_scheduler 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from mfg import MFG, MFGPolicy 15 | from options import Options 16 | 17 | from . import eval_metrics, loss_lib, sb_policy, util 18 | from .replay_buffer import Buffer 19 | 20 | log = logging.getLogger(__file__) 21 | 22 | OptSchedPair = Tuple[torch.optim.Optimizer, lr_scheduler._LRScheduler] 23 | 24 | 25 | def build_optimizer_sched(opt: Options, policy) -> OptSchedPair: 26 | optim_name = { 27 | 'Adam': Adam, 28 | 'AdamW': AdamW, 29 | 'Adagrad': Adagrad, 30 | 'RMSprop': RMSprop, 31 | 'SGD': SGD, 32 | }.get(opt.optimizer) 33 | 34 | optim_dict = { 35 | "lr": opt.lr, 36 | 'weight_decay':opt.l2_norm, 37 | } 38 | if opt.optimizer == 'SGD': 39 | optim_dict['momentum'] = 0.9 40 | 41 | if policy.param == "actor-critic": 42 | optimizer = optim_name([ 43 | {'params': policy.Znet.parameters()}, # use original optim_dict 44 | {'params': policy.Ynet.parameters(), 'lr': opt.lr_y}, 45 | ], **optim_dict) 46 | elif policy.param == "critic": 47 | optimizer = optim_name(policy.parameters(), **optim_dict) 48 | else: 49 | raise ValueError(f"Expected either actor-critic or critic, got {policy.param}") 50 | 51 | if opt.lr_gamma < 1.0: 52 | sched = lr_scheduler.StepLR(optimizer, step_size=opt.lr_step, gamma=opt.lr_gamma) 53 | else: 54 | sched = None 55 | 56 | return optimizer, sched 57 | 58 | def get_grad_loss_norm(opt: Options) -> str: 59 | problem_name: str = opt.problem_name 60 | 61 | if problem_name in ["GMM", "opinion", "opinion_1k"]: 62 | # Use L1 on GMM. 63 | return "l1" 64 | if problem_name in []: 65 | # Use L2 loss for ?? 66 | return "l2" 67 | 68 | # Otherwise, default to Huber loss. 69 | return "huber" 70 | 71 | class DeepGSB: 72 | def __init__(self, opt: Options, mfg: MFG, save_opt: bool = True): 73 | super(DeepGSB, self).__init__() 74 | 75 | # Save opt. 76 | if save_opt: 77 | opt_pkl_path = opt.ckpt_path + "/options.pkl" 78 | with open(opt_pkl_path, "wb") as f: 79 | pickle.dump(opt, f) 80 | log.info("Saved options pickle to {}!".format(opt_pkl_path)) 81 | 82 | self.start_time = time.time() 83 | 84 | self.mfg = mfg 85 | 86 | # build forward (z_f) and backward (z_b) policies 87 | self.z_f = sb_policy.build(opt, mfg.sde, 'forward') # p0 -> pT 88 | self.z_b = sb_policy.build(opt, mfg.sde, 'backward') # p0 -> pT 89 | if mfg.uses_mf_drift(): 90 | mfg.initialize_mf_drift(self.z_f) 91 | 92 | self.optimizer_f, self.sched_f = build_optimizer_sched(opt, self.z_f) 93 | self.optimizer_b, self.sched_b = build_optimizer_sched(opt, self.z_b) 94 | 95 | self.buffer_f = Buffer(opt, 'forward') if opt.use_rb_loss else None 96 | self.buffer_b = Buffer(opt, 'backward') if opt.use_rb_loss else None 97 | 98 | self.it_f = self.it_b = 0 99 | if opt.log_tb: # tensorboard related things 100 | self.writer=SummaryWriter(log_dir=opt.log_dir) 101 | 102 | if opt.load: 103 | util.restore_checkpoint(opt, self, opt.load) 104 | 105 | @property 106 | def is_critic_param(self) -> bool: 107 | return self.z_f.param == self.z_b.param == "critic" 108 | 109 | @property 110 | def is_actor_critic_param(self) -> bool: 111 | return self.z_f.param == self.z_b.param == "actor-critic" 112 | 113 | def get_count(self, direction: str) -> int: 114 | return self.it_f if direction == "forward" else self.it_b 115 | 116 | def update_count(self, direction: str) -> int: 117 | if direction == 'forward': 118 | self.it_f += 1 119 | return self.it_f 120 | elif direction == 'backward': 121 | self.it_b += 1 122 | return self.it_b 123 | else: 124 | raise RuntimeError() 125 | 126 | def get_optimizer_sched(self, z: MFGPolicy) -> OptSchedPair: 127 | if z == self.z_f: 128 | return self.optimizer_f, self.sched_f 129 | elif z == self.z_b: 130 | return self.optimizer_b, self.sched_b 131 | else: 132 | raise RuntimeError() 133 | 134 | @torch.no_grad() 135 | def sample_train_data(self, opt: Options, train_direction: str) -> edict: 136 | policy_opt, policy_impt = { 137 | 'forward': [self.z_f, self.z_b], # train forward, sample from backward 138 | 'backward': [self.z_b, self.z_f], # train backward, sample from forward 139 | }.get(train_direction) 140 | 141 | # prepare training data 142 | train_ts = self.mfg.ts.detach() 143 | 144 | # update mf_drift if we need it and we're sampling forward traj 145 | update_mf_drift = (self.mfg.uses_mf_drift() and policy_impt.direction == 'forward') 146 | 147 | ema, ema_impt = policy_opt.get_ema(), policy_impt.get_ema() 148 | with ema.average_parameters(), ema_impt.average_parameters(): 149 | policy_impt.freeze() 150 | policy_opt.freeze() 151 | 152 | xs, zs, ws, _ = self.mfg.sample_traj(policy_impt, update_mf_drift=update_mf_drift) 153 | train_xs = xs.detach().cpu(); del xs 154 | train_zs = zs.detach().cpu(); del zs 155 | train_ws = ws.detach().cpu(); del ws 156 | 157 | log.info('generate train data from [sampling]!') 158 | 159 | assert train_xs.shape[0] == opt.samp_bs 160 | assert train_xs.shape[1] == len(train_ts) 161 | assert train_xs.shape == train_zs.shape 162 | gc.collect() 163 | 164 | return edict( 165 | xs=train_xs, zs=train_zs, ws=train_ws, ts=train_ts 166 | ) 167 | 168 | def train_stage(self, opt: Options, stage: int, train_direction: str, datas: Optional[edict]=None) -> None: 169 | policy_opt, policy_impt = { 170 | 'forward': [self.z_f, self.z_b], # train forwad, sample from backward 171 | 'backward': [self.z_b, self.z_f], # train backward, sample from forward 172 | }.get(train_direction) 173 | 174 | buffer_impt, buffer_opt = { 175 | 'forward': [self.buffer_f, self.buffer_b], 176 | 'backward': [self.buffer_b, self.buffer_f], 177 | }.get(policy_impt.direction) 178 | 179 | if datas is None: 180 | datas = self.sample_train_data(opt, train_direction) 181 | 182 | # Compute the cost and statistical distance for the forward / backward trajectories 183 | if opt.log_tb: 184 | t1 = time.time() 185 | self.log_validate_metrics(opt, train_direction, datas) 186 | log.info("Done logging validate metrics! Took {:.1f} s!".format(time.time() - t1)) 187 | 188 | # update buffers 189 | if opt.use_rb_loss: 190 | buffer_impt.append(datas) 191 | 192 | self.train_ep(opt, stage, train_direction, datas, policy_opt, policy_impt, buffer_opt, buffer_impt) 193 | 194 | def train_ep( 195 | self, 196 | opt: Options, 197 | stage: int, 198 | direction: str, 199 | datas: edict, 200 | policy: MFGPolicy, 201 | policy_impt: MFGPolicy, 202 | buffer_opt: Optional[Buffer], 203 | buffer_impt: Optional[Buffer], 204 | ) -> None: 205 | train_xs, train_zs, train_ws, train_ts = datas.xs, datas.zs, datas.ws, datas.ts 206 | 207 | assert train_xs.shape[0] == opt.samp_bs 208 | assert train_zs.shape[0] == opt.samp_bs 209 | assert train_ts.shape[0] == opt.interval 210 | assert direction == policy.direction 211 | 212 | optimizer, sched = self.get_optimizer_sched(policy) 213 | optimizer_impt, _ = self.get_optimizer_sched(policy_impt) 214 | 215 | policy.activate() # activate Y (and Z) 216 | policy_impt.freeze() # freeze Y_impt (and Z_impt) 217 | 218 | if stage>0 and opt.use_rb_loss: assert len(buffer_opt)>0 and len(buffer_impt)>0 219 | 220 | mfg = self.mfg 221 | samp_direction = policy_impt.direction 222 | for it in range(opt.num_itr): 223 | step = self.update_count(direction) 224 | 225 | # -------- sample x_idx and t_idx \in [0, interval] -------- 226 | samp_x_idx = torch.randint(opt.samp_bs, (opt.train_bs_x,)) 227 | 228 | dim01 = [opt.train_bs_x, opt.interval] 229 | 230 | # -------- build sample -------- 231 | ts = train_ts.detach() 232 | xs = train_xs[samp_x_idx].to(opt.device) 233 | zs_impt = train_zs[samp_x_idx].to(opt.device) 234 | dw = train_ws[samp_x_idx].to(opt.device) 235 | 236 | if mfg.uses_xs_all(): 237 | samp_x_idx2 = torch.randint(opt.samp_bs, (opt.train_bs_x,)) 238 | xs_all = train_xs[samp_x_idx2].to(opt.device) 239 | else: 240 | xs_all = None 241 | 242 | optimizer.zero_grad() 243 | optimizer_impt.zero_grad() 244 | 245 | # -------- compute KL loss -------- 246 | loss_kl, zs, kl, _ = loss_lib.compute_kl_loss( 247 | opt, dim01, mfg, samp_direction, 248 | ts.detach(), xs.detach(), zs_impt.detach(), 249 | policy, return_all=True 250 | ) 251 | 252 | # -------- compute bsde TD loss -------- 253 | loss_bsde_td = loss_lib.compute_bsde_td_loss( 254 | opt, mfg, samp_direction, 255 | ts.detach(), xs.detach(), zs.detach(), dw.detach(), kl.detach(), 256 | policy, policy_impt, xs_all 257 | ) 258 | 259 | # -------- compute boundary loss -------- 260 | loss_boundary = loss_lib.compute_boundary_loss( 261 | opt, mfg, ts.detach(), xs.detach(), policy_impt, policy, 262 | ) 263 | 264 | # -------- compute mismatch loss between Z and \nabla_x Y -------- 265 | loss_grad = torch.Tensor([0.0]) 266 | if self.is_actor_critic_param: 267 | norm = get_grad_loss_norm(opt) 268 | loss_grad = loss_lib.compute_grad_loss(opt, ts.detach(), xs.detach(), mfg.sde, policy, norm) 269 | 270 | # -------- compute replay buffer loss --------- 271 | loss_bsde_td_rb = torch.Tensor([0.0]) 272 | if opt.use_rb_loss: 273 | loss_bsde_td_rb = loss_lib.compute_bsde_td_loss_from_buffer( 274 | opt, mfg, buffer_impt, ts.detach(), policy, policy_impt, xs_all 275 | ) 276 | 277 | # -------- compute loss and backprop -------- 278 | if self.is_critic_param: 279 | if opt.weighted_loss: 280 | w_kl, w_nkl = opt.weights['kl'], opt.weights['non-kl'] 281 | loss = w_kl * loss_kl + w_nkl * (loss_bsde_td + loss_boundary + loss_bsde_td_rb) 282 | else: 283 | loss = loss_kl + loss_bsde_td + loss_boundary + loss_bsde_td_rb 284 | 285 | elif self.is_actor_critic_param: 286 | if opt.weighted_loss: 287 | w_kl, w_nkl = opt.weights['kl'], opt.weights['non-kl'] 288 | loss = w_kl * loss_kl + w_nkl * (loss_boundary + loss_bsde_td + loss_grad + loss_bsde_td_rb) 289 | else: 290 | loss = loss_kl + loss_boundary + loss_bsde_td + loss_grad + loss_bsde_td_rb 291 | else: 292 | raise RuntimeError("") 293 | 294 | assert not torch.isnan(loss) 295 | loss.backward() 296 | 297 | optimizer.step() 298 | policy.update_ema() 299 | 300 | if sched is not None: sched.step() 301 | 302 | # -------- logging -------- 303 | loss = edict( 304 | kl=loss_kl, grad=loss_grad, 305 | boundary=loss_boundary, bsde_td=loss_bsde_td, 306 | ) 307 | if it % 20 == 0: 308 | self.log_train(opt, it, stage, loss, optimizer, direction) 309 | 310 | def train(self, opt: Options) -> None: 311 | self.evaluate(opt, 0) 312 | 313 | for stage in range(opt.num_stage): 314 | if opt.samp_method == 'jacobi': 315 | datas1 = self.sample_train_data(opt, 'forward') 316 | datas2 = self.sample_train_data(opt, 'backward') 317 | self.train_stage(opt, stage, 'forward', datas=datas1) 318 | self.train_stage(opt, stage, 'backward', datas=datas2) 319 | 320 | elif opt.samp_method == 'gauss': 321 | self.train_stage(opt, stage, 'forward') 322 | self.train_stage(opt, stage, 'backward') 323 | 324 | t1 = time.time() 325 | self.evaluate(opt, stage+1) 326 | log.info("Finished evaluate! Took {:.2f}s.".format(time.time() - t1)) 327 | 328 | if opt.log_tb: self.writer.close() 329 | 330 | @torch.no_grad() 331 | def evaluate(self, opt: Options, stage: int) -> None: 332 | snapshot, ckpt = util.evaluate_stage(opt, stage) 333 | if snapshot: 334 | self.z_f.freeze(); self.z_b.freeze() 335 | self.mfg.save_snapshot(self.z_f, self.z_b, stage) 336 | 337 | if ckpt and stage > 0: 338 | keys = ['z_f','optimizer_f','z_b','optimizer_b', "it_f", "it_b"] 339 | util.save_checkpoint(opt, self, keys, stage) 340 | 341 | def compute_validate_metrics(self, opt: Options, train_direction: str, datas: edict) -> Dict[str, torch.Tensor]: 342 | # Sample direction is opposite of train_direction. 343 | xs, zs, ts = datas.xs, datas.zs, datas.ts 344 | 345 | b, T, nx = xs.shape 346 | assert zs.shape == (b, T, nx) 347 | assert ts.shape == (T,) 348 | 349 | mfg = self.mfg 350 | dt = mfg.dt 351 | 352 | metrics = {} 353 | 354 | # Compute "polarization" via the "condition number" of the covariance matrix. 355 | if "opinion" in mfg.problem_name: 356 | metrics.update( 357 | eval_metrics.compute_conv_l1_metrics(mfg, xs, train_direction) 358 | ) 359 | 360 | # Compute Wasserstein distance between the x0 / xT and p0 / pT. 361 | metrics.update( 362 | eval_metrics.compute_sinkhorn_metrics(opt, mfg, xs, train_direction) 363 | ) 364 | 365 | # Compute the state + control cost. 366 | if mfg.uses_xs_all(): 367 | xs_all = xs.to(opt.device) 368 | xs_all = xs_all[torch.randperm(opt.samp_bs)] 369 | else: 370 | xs_all = None 371 | 372 | est_mf_cost, logp = eval_metrics.compute_est_mf_cost( 373 | opt, self, xs, ts, dt, xs_all, return_logp=True 374 | ) 375 | s_cost, mf_cost = eval_metrics.compute_state_cost(opt, mfg, xs, ts, xs_all, logp, dt) 376 | del logp, xs_all 377 | control_cost = eval_metrics.compute_control_cost(opt, zs, dt) 378 | 379 | mean_s_cost, mean_control_cost, mean_mf_cost = s_cost.mean(), control_cost.mean(), mf_cost.mean() 380 | 381 | mean_nonmf_cost = mean_s_cost + mean_control_cost 382 | mean_total_cost = mean_nonmf_cost + mfg.mf_coeff * mean_mf_cost 383 | 384 | metrics["est_mf_cost"] = est_mf_cost 385 | metrics["state_cost"] = mean_s_cost 386 | metrics["nonmf_cost"] = mean_nonmf_cost 387 | metrics["mf_cost"] = mean_mf_cost 388 | metrics["control_cost"] = mean_control_cost 389 | metrics["total_cost"] = mean_total_cost 390 | 391 | return metrics 392 | 393 | @torch.no_grad() 394 | def log_validate_metrics(self, opt: Options, direction: str, datas: edict) -> None: 395 | 396 | def tag(name: str) -> str: 397 | return os.path.join(f"{direction}-loss", name) 398 | step = self.get_count(direction) 399 | 400 | metrics = self.compute_validate_metrics(opt, direction, datas) 401 | 402 | # Log all metrics. 403 | for key in metrics: 404 | self.writer.add_scalar(tag(key), metrics[key], global_step=step) 405 | del metrics 406 | 407 | def log_train( 408 | self, opt: Options, it: int, stage: int, loss: edict, optimizer: torch.optim.Optimizer, direction: str 409 | ) -> None: 410 | time_elapsed = util.get_time(time.time()-self.start_time) 411 | lr = optimizer.param_groups[0]['lr'] 412 | log.info("[SB {0}] stage {1}/{2} | itr {3}/{4} | lr {5} | loss {6} | time {7}" 413 | .format( 414 | "fwd" if direction=="forward" else "bwd", 415 | str(1+stage).zfill(2), 416 | opt.num_stage, 417 | str(1+it).zfill(3), 418 | opt.num_itr, 419 | "{:.2e}".format(lr), 420 | util.get_loss_str(loss), 421 | "{0}:{1:02d}:{2:05.2f}".format(*time_elapsed), 422 | )) 423 | 424 | step = self.get_count(direction) 425 | if opt.log_tb: 426 | assert isinstance(loss, edict) 427 | for key, val in loss.items(): 428 | # assert val > 0 # for taking log 429 | self.writer.add_scalar( 430 | os.path.join(f'{direction}-loss', f'{key}'), val.detach(), global_step=step 431 | ) 432 | 433 | # Also log the current stage. 434 | self.writer.add_scalar(os.path.join(f"{direction}-loss", "stage"), stage, global_step=step) 435 | 436 | # Log the LR. 437 | self.writer.add_scalar(os.path.join(f"{direction}-opt", "lr"), lr, global_step=step) 438 | -------------------------------------------------------------------------------- /deepgsb/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import math 4 | import time 5 | from typing import Dict, List, Tuple, Union, TYPE_CHECKING 6 | 7 | import torch 8 | from geomloss import SamplesLoss 9 | 10 | from mfg import MFG 11 | from options import Options 12 | 13 | from . import util 14 | 15 | if TYPE_CHECKING: 16 | from deepgsb.deepgsb import DeepGSB 17 | 18 | log = logging.getLogger(__file__) 19 | 20 | 21 | def get_bound(mfg: MFG, xs: torch.Tensor, train_direction: str): 22 | return { 23 | "forward": [xs[:, 0], mfg.p0, "0"], # train forward, xs are sampled from backward 24 | "backward": [xs[:, -1], mfg.pT, "T"], # train backward, xs are sampled from forward 25 | }.get(train_direction) 26 | 27 | 28 | @torch.no_grad() 29 | def compute_conv_l1_metrics(mfg: MFG, xs: torch.Tensor, train_direction: str) -> Dict[str, torch.Tensor]: 30 | b, T, nx = xs.shape 31 | 32 | metrics = {} 33 | 34 | # (samp_bs, ...) 35 | x_bound, p_bound, time_str = get_bound(mfg, xs, train_direction) 36 | 37 | # Max singular value / min singular value. 38 | S = torch.linalg.svdvals(x_bound) 39 | # Largest to smallest. 40 | eigvals = (S ** 2) / (b - 1) 41 | cond_number = eigvals[0] / eigvals[-1] 42 | del S, eigvals 43 | 44 | # Eigvalsh returns smallest to largest. 45 | true_cov = p_bound.distribution.covariance_matrix.cpu() 46 | true_eigvals = torch.linalg.eigvalsh(true_cov) 47 | true_cond_number = true_eigvals[-1] / true_eigvals[0] 48 | 49 | metrics[f"cond_num diff {time_str}"] = torch.abs(cond_number - true_cond_number) 50 | 51 | # Pred stds (project on each axis). 52 | pred_stds = torch.std(x_bound, dim=0) 53 | # True stds (project on each axis). 54 | true_stds = p_bound.distribution.stddev.cpu() 55 | 56 | std_l1 = torch.abs(true_stds - pred_stds).sum() 57 | metrics[f"std l1 {time_str}"] = std_l1 58 | 59 | # Also just compare the covariance matrix. 60 | pred_cov = torch.cov(x_bound.T) 61 | assert pred_cov.shape == true_cov.shape 62 | cov_l1 = torch.abs(true_cov - pred_cov).sum() 63 | metrics[f"cov l1 {time_str}"] = cov_l1 64 | 65 | return metrics 66 | 67 | 68 | @torch.no_grad() 69 | def compute_sinkhorn_metrics(opt: Options, mfg: MFG, xs: torch.Tensor, train_direction: str) -> Dict[str, torch.Tensor]: 70 | b, T, nx = xs.shape 71 | 72 | metrics = {} 73 | 74 | # Higher scaling results in more accurate estimate. 75 | if opt.x_dim > 2: 76 | # L1 norm is better in high dimensional setting. 77 | # If it's high dimension, empiricla Wasserstein is probably not that good of a metric, so we don't need 78 | # to spend as much time on computing this. 79 | sinkhorn = SamplesLoss("sinkhorn", p=1, blur=5e-2, scaling=0.9) 80 | else: 81 | sinkhorn = SamplesLoss("sinkhorn", p=2, blur=5e-3, scaling=0.9) 82 | 83 | x_bound, p_bound, time_str = get_bound(mfg, xs, train_direction) 84 | W2_name = f"W2_{time_str}" 85 | 86 | p_samples = p_bound.sample(batch=b) 87 | 88 | log.info("Computing sinkhorn....") 89 | t1 = time.time() 90 | x_bound = x_bound.to(opt.device) 91 | metrics[W2_name] = sinkhorn(x_bound, p_samples) 92 | log.info("Done! sinkhorn took {:.1f} s".format(time.time() - t1)) 93 | 94 | return metrics 95 | 96 | 97 | @torch.no_grad() 98 | def compute_control_cost(opt: Options, zs: torch.Tensor, dt: float) -> torch.Tensor: 99 | b, T, nx = zs.shape 100 | 101 | zs = zs.to(opt.device) 102 | control_cost = torch.square(zs).sum((1, 2)) * dt 103 | return control_cost.detach().cpu() 104 | 105 | 106 | @torch.no_grad() 107 | def compute_state_cost( 108 | opt: Options, mfg: MFG, xs: torch.Tensor, ts: torch.Tensor, xs_all: torch.Tensor, logp: torch.Tensor, dt: float 109 | ) -> Tuple[torch.Tensor, torch.Tensor]: 110 | b, T, nx = xs.shape 111 | dim01 = (b, T) 112 | 113 | flat_xs = util.flatten_dim01(xs).to(opt.device) 114 | s_cost, mf_cost = mfg.state_cost_fn(flat_xs, ts, logp, xs_all) 115 | del flat_xs 116 | 117 | # (samp_bs, T) 118 | s_cost, mf_cost = util.unflatten_dim01(s_cost, dim01), util.unflatten_dim01(mf_cost, dim01) 119 | 120 | s_cost = (s_cost.sum(1) * dt).detach().cpu() 121 | mf_cost = (mf_cost.sum(1) * dt).detach().cpu() 122 | 123 | return s_cost, mf_cost 124 | 125 | 126 | @torch.jit.script 127 | def serial_logkde(train_xs: torch.Tensor, xs: torch.Tensor, bw: float, max_batch: int = 32) -> torch.Tensor: 128 | # xs: (b1, T, *) 129 | # train_xs: (b2, T, *) 130 | # out: (b1, T) 131 | 132 | b1, T, nx = xs.shape 133 | b2, T, nx = train_xs.shape 134 | assert xs.shape == (b1, T, nx) and train_xs.shape == (b2, T, nx) 135 | 136 | coeff = b2 * math.sqrt(2 * math.pi * bw) 137 | log_coeff = math.log(coeff) 138 | 139 | xs_chunks = torch.split(xs, max_batch, dim=0) 140 | out_chunks = [] 141 | for xs in xs_chunks: 142 | # 1: Compute diffs. (b1, b2, T, *) 143 | diffs = train_xs.unsqueeze(0) - xs.unsqueeze(1) 144 | # (b1, b2, T) 145 | norm_sq = torch.sum(torch.square(diffs), dim=-1) 146 | 147 | # (b1, b2, T) -> (b1, T) 148 | logsumexp = torch.logsumexp(-norm_sq / (2 * bw), dim=1) 149 | 150 | out = logsumexp - log_coeff 151 | out_chunks.append(out) 152 | 153 | out = torch.cat(out_chunks, dim=0) 154 | assert out.shape == (b1, T) 155 | 156 | return out 157 | 158 | 159 | @torch.no_grad() 160 | def compute_est_mf_cost( 161 | opt: Options, 162 | runner: "DeepGSB", 163 | xs: torch.Tensor, 164 | ts: torch.Tensor, 165 | dt: float, 166 | xs_all: torch.Tensor, 167 | return_logp: bool = False, 168 | ) -> Union[List[torch.Tensor], torch.Tensor]: 169 | if not runner.mfg.uses_logp(): 170 | return [torch.zeros(1), None] if return_logp else torch.zeros(1) 171 | 172 | b, T, nx = xs.shape 173 | dim01 = (b, T) 174 | 175 | bw = 0.2 ** 2 176 | 177 | xs = xs.to(opt.device) 178 | if b > 500: 179 | rand_idxs = torch.randperm(b)[:500] 180 | reduced_xs = xs[rand_idxs] 181 | del rand_idxs 182 | else: 183 | reduced_xs = xs 184 | 185 | max_batch = 32 if opt.x_dim == 2 else 2 186 | logp = serial_logkde(reduced_xs, xs, bw=bw, max_batch=max_batch) 187 | del reduced_xs 188 | 189 | assert logp.shape == (b, T) 190 | logp = util.flatten_dim01(logp).detach().cpu() 191 | 192 | gc.collect() 193 | 194 | # Evaluate logp in chunks to prevent cuda OOM. 195 | ts = ts.repeat(b) 196 | flat_xs = util.flatten_dim01(xs) 197 | 198 | max_chunk_size = 512 199 | flat_xs_chunks = torch.split(flat_xs, max_chunk_size, dim=0) 200 | ts_chunks = torch.split(ts, max_chunk_size, dim=0) 201 | 202 | est_logp = [] 203 | for flat_xs_chunk, ts_chunk in zip(flat_xs_chunks, ts_chunks): 204 | flat_xs_chunk = flat_xs_chunk.to(opt.device) 205 | ts_chunk = ts_chunk.to(opt.device) 206 | value_t_ema = runner.z_f.compute_value(flat_xs_chunk, ts_chunk, use_ema=True) 207 | value_impt_t_ema = runner.z_b.compute_value(flat_xs_chunk, ts_chunk, use_ema=True) 208 | est_logp.append((value_t_ema + value_impt_t_ema).detach().cpu()) 209 | del flat_xs_chunk, ts_chunk, value_t_ema, value_impt_t_ema 210 | est_logp = torch.cat(est_logp, dim=0) 211 | 212 | # flat_xs, ts = flat_xs.to(opt.device), ts.to(opt.device) 213 | # value_t_ema = self.z_f.compute_value(flat_xs, ts, use_ema=True) 214 | # value_impt_t_ema = self.z_b.compute_value(flat_xs, ts, use_ema=True) 215 | # est_logp = value_t_ema + value_impt_t_ema 216 | 217 | assert logp.shape == est_logp.shape == (b * T,) 218 | 219 | _, est_mf_cost = runner.mfg.state_cost_fn(flat_xs, ts, est_logp, xs_all) 220 | del est_logp 221 | 222 | est_mf_cost = util.unflatten_dim01(est_mf_cost, dim01) 223 | est_mf_cost = torch.mean(est_mf_cost.sum(1) * dt).detach().cpu() 224 | 225 | return [est_mf_cost, logp] if return_logp else est_mf_cost 226 | -------------------------------------------------------------------------------- /deepgsb/loss_lib.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Iterable 2 | 3 | import torch 4 | from torch.nn.functional import huber_loss 5 | 6 | from mfg import MFG, MFGPolicy 7 | from mfg.sde import SimpleSDE 8 | from options import Options 9 | 10 | from . import util 11 | from .replay_buffer import Buffer 12 | 13 | DivGZRetType = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 14 | KLLossRetType = Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] 15 | 16 | 17 | @torch.jit.script 18 | def rev_cumsum(x: torch.Tensor, dim: int) -> torch.Tensor: 19 | x = torch.flip(x, dims=[dim]) 20 | x = torch.cumsum(x, dim=dim) 21 | return torch.flip(x, dims=[dim]) 22 | 23 | 24 | def sample_e(opt: Options, x: torch.Tensor) -> torch.Tensor: 25 | if opt.noise_type == "gaussian": 26 | return torch.randn_like(x) 27 | elif opt.noise_type == "rademacher": 28 | return torch.randint(low=0, high=2, size=x.shape).to(x.device) * 2 - 1 29 | else: 30 | raise ValueError(f"Unsupport noise type {opt.noise_type}!") 31 | 32 | 33 | def compute_div_gz( 34 | opt: Options, mfg: MFG, ts: torch.Tensor, xs: torch.Tensor, policy: MFGPolicy, return_zs: bool = False 35 | ) -> DivGZRetType: 36 | b, T, nx = xs.shape 37 | assert ts.shape == (T,) 38 | 39 | if mfg.uses_state_dependent_drift(): 40 | f = mfg.f(xs, ts, policy.direction) 41 | assert f.shape == (b, T, nx) 42 | f = util.flatten_dim01(f) 43 | else: 44 | f = 0 45 | 46 | ts = ts.repeat(xs.shape[0]) 47 | xs = util.flatten_dim01(xs) 48 | zs = policy(xs, ts) 49 | 50 | g_ts = mfg.g(ts) 51 | g_ts = g_ts[:, None] 52 | gzs = g_ts * zs - f 53 | 54 | e = sample_e(opt, xs) 55 | e_dzdx = torch.autograd.grad(gzs, xs, e, create_graph=True)[0] 56 | div_gz = e_dzdx * e 57 | 58 | return [div_gz, zs] if return_zs else div_gz 59 | 60 | 61 | def compute_kl( 62 | opt: Options, 63 | mfg: MFG, 64 | ts: torch.Tensor, 65 | xs: torch.Tensor, 66 | zs_impt: torch.Tensor, 67 | policy: MFGPolicy, 68 | return_zs: bool = False, 69 | ) -> DivGZRetType: 70 | b, T, nx = xs.shape 71 | assert ts.shape == (T,) 72 | 73 | zs_impt = util.flatten_dim01(zs_impt) 74 | assert zs_impt.shape == (b * T, nx) 75 | 76 | with torch.enable_grad(): 77 | _xs = xs.detach() 78 | xs = _xs.requires_grad_(True) 79 | 80 | div_gz, zs = compute_div_gz(opt, mfg, ts, xs, policy, return_zs=True) 81 | assert div_gz.shape == zs.shape == (b * T, nx) 82 | 83 | # (b * T, xdim) 84 | kl = zs * (0.5 * zs + zs_impt) + div_gz 85 | assert kl.shape == (b * T, nx) 86 | 87 | return [kl, zs] if return_zs else kl 88 | 89 | 90 | def compute_norm_loss( 91 | norm: str, predict: torch.Tensor, label: torch.Tensor, batch_x: int, dt: float, delta: float = 1.0 92 | ) -> torch.Tensor: 93 | assert norm in ["l1", "l2", "huber"] 94 | assert predict.shape == label.shape 95 | 96 | if norm == "l1": 97 | return 0.5 * ((predict - label).abs() * dt).sum() / batch_x 98 | elif norm == "l2": 99 | return 0.5 * ((predict - label) ** 2 * dt).sum() / batch_x 100 | elif norm == "huber": 101 | return huber_loss(predict, label, reduction="sum", delta=delta) * dt / batch_x 102 | 103 | 104 | def compute_kl_loss( 105 | opt: Options, 106 | dim01: Iterable[int], 107 | mfg: MFG, 108 | samp_direction: str, 109 | ts: torch.Tensor, 110 | xs: torch.Tensor, 111 | zs_impt: torch.Tensor, 112 | policy: MFGPolicy, 113 | return_all: bool = False, 114 | ) -> KLLossRetType: 115 | 116 | kl, zs = compute_kl(opt, mfg, ts, xs, zs_impt, policy, return_zs=True) 117 | zs_impt = util.flatten_dim01(zs_impt) 118 | assert kl.shape == zs.shape == zs_impt.shape 119 | 120 | _, x_init, p_init, _, x_term, p_term = mfg.get_init_term_bound(xs, samp_direction) 121 | 122 | # computationally same as kl yet better interpretation 123 | bsde_y_yhat = util.unflatten_dim01((kl + 0.5 * zs_impt ** 2).sum(dim=-1) * mfg.dt, dim01).sum( 124 | dim=1 125 | ) # (batch_x, len_t) --> (batch_x) 126 | loss_kl = (p_init.log_prob(x_init) + bsde_y_yhat - p_term.log_prob(x_term)).mean() 127 | 128 | return [loss_kl, zs, kl, bsde_y_yhat] if return_all else loss_kl 129 | 130 | 131 | def compute_bsde_td_loss_multistep( 132 | mfg: MFG, 133 | samp_direction: str, 134 | ts: torch.Tensor, 135 | xs: torch.Tensor, 136 | zs: torch.Tensor, 137 | dw: torch.Tensor, 138 | kl: torch.Tensor, 139 | policy: MFGPolicy, 140 | policy_impt: MFGPolicy, 141 | xs_all: torch.Tensor, 142 | norm: str = "huber", 143 | ) -> torch.Tensor: 144 | # xs, xs_all: (b, T, nx) 145 | # zs: (b * T, nx) 146 | # dw: (b, T, nx) 147 | # kl: (b * T, nx) 148 | b, T, nx = xs.shape 149 | ts = ts.repeat(xs.shape[0]) 150 | xs = util.flatten_dim01(xs) 151 | dw = util.flatten_dim01(dw) 152 | zs = zs.reshape(*dw.shape) 153 | 154 | dim01 = [b, T] 155 | 156 | # (b * T, ) 157 | value_ema = policy.compute_value(xs, ts, use_ema=True) 158 | # ( b * T, ) 159 | value_impt_ema = policy_impt.compute_value(xs, ts, use_ema=True) 160 | 161 | log_p = value_ema + value_impt_ema 162 | state_cost = mfg.compute_state_cost(xs, ts, xs_all, log_p) 163 | 164 | # (b * T, ) 165 | bsde2_1step = (kl.sum(dim=-1) - state_cost) * mfg.dt + (zs * dw).sum(dim=-1) 166 | 167 | # (b, T) 168 | bsde2_1step = util.unflatten_dim01(bsde2_1step, dim01) 169 | 170 | # (b * T, ) 171 | value = policy.compute_value(xs, ts) 172 | 173 | # (b, T, nx) 174 | xs = util.unflatten_dim01(xs, dim01) 175 | 176 | t_bnd, x_bnd, p_bnd = mfg.get_bound(xs, samp_direction, "init") 177 | if samp_direction == "forward": 178 | with torch.no_grad(): 179 | # (b, ) 180 | value_impt_bnd = util.unflatten_dim01(value_impt_ema, dim01)[:, 0] 181 | # (b, ) 182 | target_bnd = p_bnd.log_prob(x_bnd) - value_impt_bnd 183 | 184 | # (b, ) 185 | bsde2_cumstep = torch.cumsum(bsde2_1step, dim=1) 186 | # (b, T) -> (b, T - 1) 187 | target = (target_bnd[:, None] + bsde2_cumstep)[:, :-1] 188 | # (b, T) -> (b, T - 1) 189 | predict = util.unflatten_dim01(value, dim01)[:, 1:] 190 | else: 191 | with torch.no_grad(): 192 | # (b, ) 193 | value_impt_bnd = util.unflatten_dim01(value_impt_ema, dim01)[:, -1] 194 | # (b, ) 195 | target_bnd = p_bnd.log_prob(x_bnd) - value_impt_bnd 196 | 197 | # Cumulative sum from T to 0. 198 | bsde2_cumstep = rev_cumsum(bsde2_1step, dim=1) 199 | target = (target_bnd[:, None] + bsde2_cumstep)[:, 1:] 200 | predict = util.unflatten_dim01(value, dim01)[:, :-1] 201 | 202 | label = target.detach() 203 | return compute_norm_loss(norm, predict, label, b, mfg.dt, delta=2.0) 204 | 205 | 206 | def compute_bsde_td_loss_singlestep( 207 | mfg: MFG, 208 | samp_direction: str, 209 | ts: torch.Tensor, 210 | xs: torch.Tensor, 211 | zs: torch.Tensor, 212 | dw: torch.Tensor, 213 | kl: torch.Tensor, 214 | policy: MFGPolicy, 215 | policy_impt: MFGPolicy, 216 | xs_all: torch.Tensor, 217 | norm: str = "l2", 218 | ) -> torch.Tensor: 219 | b, T, nx = xs.shape 220 | assert samp_direction in ["forward", "backward"] 221 | 222 | # (1) flattent all input from (b,T, ...) to (b*T, ...) 223 | ts = ts.repeat(xs.shape[0]) 224 | xs = util.flatten_dim01(xs) 225 | dw = util.flatten_dim01(dw) 226 | zs = zs.reshape(*dw.shape) 227 | 228 | value = policy.compute_value(xs, ts) 229 | value_ema = policy.compute_value(xs, ts, use_ema=True) 230 | value_impt_ema = policy_impt.compute_value(xs, ts, use_ema=True) 231 | 232 | # (2) compute state cost (i.e., F in Eq 14) 233 | log_p = value_ema + value_impt_ema 234 | state_cost = mfg.compute_state_cost(xs, ts, xs_all, log_p) 235 | 236 | # (3) construct δY, predicted Y, target Y (from ema) 237 | bsde_1step = (kl.sum(dim=-1) - state_cost) * mfg.dt + (zs * dw).sum(dim=-1) 238 | bsde_predict = value 239 | bsde_target = value_ema 240 | 241 | assert bsde_predict.shape == bsde_1step.shape == bsde_target.shape 242 | 243 | bsde_predict = util.unflatten_dim01(bsde_predict, [b, T]) 244 | bsde_1step = util.unflatten_dim01(bsde_1step, [b, T]) 245 | bsde_target = util.unflatten_dim01(bsde_target, [b, T]) 246 | 247 | # (4) compute TD prediction and target in Eq 14 248 | # [forward] dYhat_t: predict=Yhat_{t+1}, label=Yhat_t + [...] 249 | # [forward] dY_s: predict Y_{s+1}, label=Y_s + [...] 250 | if samp_direction == "forward": 251 | label = (bsde_target + bsde_1step)[:, :-1].detach() 252 | predict = bsde_predict[:, 1:] 253 | elif samp_direction == "backward": 254 | label = (bsde_target + bsde_1step)[:, 1:].detach() 255 | predict = bsde_predict[:, :-1] 256 | else: 257 | raise ValueError(f"samp_direction should be either forward or backward, got {samp_direction}") 258 | 259 | return compute_norm_loss(norm, predict, label, b, mfg.dt) 260 | 261 | 262 | def compute_bsde_td_loss( 263 | opt: Options, 264 | mfg: MFG, 265 | samp_direction: str, 266 | ts: torch.Tensor, 267 | xs: torch.Tensor, 268 | zs: torch.Tensor, 269 | dw: torch.Tensor, 270 | kl: torch.Tensor, 271 | policy: MFGPolicy, 272 | policy_impt: MFGPolicy, 273 | xs_all: torch.Tensor, 274 | ) -> torch.Tensor: 275 | if opt.multistep_td: 276 | bsde_td_loss = compute_bsde_td_loss_multistep 277 | else: 278 | bsde_td_loss = compute_bsde_td_loss_singlestep 279 | return bsde_td_loss(mfg, samp_direction, ts, xs, zs, dw, kl, policy, policy_impt, xs_all) 280 | 281 | 282 | def compute_bsde_td_loss_from_buffer( 283 | opt: Options, mfg, buffer: Buffer, ts: torch.Tensor, policy: MFGPolicy, policy_impt: MFGPolicy, xs_all: torch.Tensor 284 | ) -> torch.Tensor: 285 | samp_direction = policy_impt.direction 286 | assert policy.direction != buffer.direction 287 | assert policy_impt.direction == buffer.direction 288 | 289 | batch_x = opt.rb_bs_x 290 | xs, zs_impt, dw = buffer.sample_traj(batch_x) 291 | 292 | assert xs.shape == zs_impt.shape == (batch_x, opt.interval, opt.x_dim) 293 | 294 | kl, zs = compute_kl(opt, mfg, ts, xs, zs_impt, policy, return_zs=True) 295 | assert kl.shape == zs.shape and kl.shape[0] == (batch_x * opt.interval) 296 | 297 | return compute_bsde_td_loss(opt, mfg, samp_direction, ts, xs, zs, dw, kl, policy, policy_impt, xs_all) 298 | 299 | 300 | def compute_boundary_loss( 301 | opt: Options, mfg: MFG, ts: torch.Tensor, xs: torch.Tensor, policy2: MFGPolicy, policy: MFGPolicy, norm: str = "l2" 302 | ) -> torch.Tensor: 303 | 304 | t_bound, x_bound, p_bound = mfg.get_bound(xs.detach(), policy.direction, "term") 305 | 306 | batch_x, nx = x_bound.shape 307 | 308 | t_bound = t_bound.repeat(batch_x) 309 | 310 | value2_ema = policy2.compute_value(x_bound, t_bound, use_ema=True) 311 | label = (p_bound.log_prob(x_bound) - value2_ema).detach() 312 | predict = policy.compute_value(x_bound, t_bound) 313 | 314 | return compute_norm_loss(norm, predict, label, batch_x, mfg.dt) 315 | 316 | 317 | def compute_grad_loss( 318 | opt: Options, ts: torch.Tensor, xs: torch.Tensor, dyn: SimpleSDE, policy: MFGPolicy, norm: str 319 | ) -> torch.Tensor: 320 | batch_x = xs.shape[0] 321 | ts = ts.repeat(batch_x) 322 | xs = util.flatten_dim01(xs) 323 | 324 | predict = policy(xs, ts) 325 | label = dyn.std * policy.compute_value_grad(xs, ts).detach() 326 | 327 | return compute_norm_loss(norm, predict, label, batch_x, dyn.dt, delta=1.0) 328 | -------------------------------------------------------------------------------- /deepgsb/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from easydict import EasyDict as edict 5 | 6 | from options import Options 7 | 8 | 9 | class Buffer: 10 | def __init__(self, opt: Options, direction: str): 11 | self.opt = opt 12 | self.max_samples = opt.buffer_size 13 | self.direction = direction 14 | 15 | self.nx = opt.x_dim 16 | 17 | self.it = 0 18 | self.n_samples = 0 19 | self.samples = torch.empty(self.max_samples, self.opt.interval, 3 * self.nx, device="cpu") 20 | 21 | def __len__(self) -> int: 22 | return self.n_samples 23 | 24 | def append(self, datas: edict) -> None: 25 | xs, zs, dws = datas.xs, datas.zs, datas.ws 26 | assert xs.shape == zs.shape == dws.shape 27 | 28 | it, max_samples, batch = self.it, self.max_samples, xs.shape[0] 29 | sample = torch.cat([xs, zs, dws], dim=-1).detach().cpu() 30 | self.samples[it : it + batch] = sample[0 : min(batch, max_samples - it), ...] 31 | if batch > max_samples - it: 32 | _it = batch - (max_samples - it) 33 | self.samples[0:_it] = sample[-_it:, ...] 34 | assert ((it + batch) % max_samples) == _it 35 | 36 | self.it = (it + batch) % max_samples 37 | self.n_samples = min(self.n_samples + batch, max_samples) 38 | 39 | def clear(self) -> None: 40 | self.samples = torch.empty_like(self.samples) 41 | self.n_samples = self.it = 0 42 | 43 | def sample_traj(self, batch_x: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 44 | samp_x_idx = torch.randint(self.n_samples, (batch_x,)) 45 | samples = self.samples[samp_x_idx].to(self.opt.device) 46 | 47 | xs = samples[..., 0 * self.nx : 1 * self.nx].detach() # (batch_x, T, nx) 48 | zs = samples[..., 1 * self.nx : 2 * self.nx].detach() # (batch_x, T, nx) 49 | ws = samples[..., 2 * self.nx : 3 * self.nx].detach() # (batch_x, T, nx) 50 | 51 | return xs, zs, ws 52 | -------------------------------------------------------------------------------- /deepgsb/sb_policy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, TypeVar 3 | 4 | import torch 5 | from torch_ema import ExponentialMovingAverage 6 | 7 | from mfg import MFGPolicy 8 | from mfg.sde import SimpleSDE 9 | from models import build_opinion_net_policy, build_toy_net_policy 10 | from options import Options 11 | 12 | from . import util 13 | 14 | log = logging.getLogger(__file__) 15 | 16 | _TorchModule = TypeVar("_TorchModule", bound=torch.nn.Module) 17 | 18 | 19 | def build(opt: Options, dyn: SimpleSDE, direction: str) -> MFGPolicy: 20 | log.info("build {} SB model...".format(direction)) 21 | 22 | # ------ build SB policy ------ 23 | Ynet = _build_net(opt, "Y") 24 | 25 | if opt.sb_param == "critic": 26 | policy = SB_paramY(opt, direction, dyn, Ynet).to(opt.device) 27 | elif opt.sb_param == "actor-critic": 28 | Znet = _build_net(opt, "Z") 29 | policy = SB_paramYZ(opt, direction, dyn, Ynet, Znet).to(opt.device) 30 | else: 31 | raise RuntimeError(f"unknown sb net type {opt.sb_param}") 32 | 33 | log.info("# param in SBPolicy = {}".format(util.count_parameters(policy))) 34 | 35 | return policy 36 | 37 | 38 | def _build_net(opt: Options, YorZ: str) -> torch.nn.Module: 39 | assert YorZ in ["Y", "Z"] 40 | 41 | if opt.policy_net == "toy": 42 | net = build_toy_net_policy(opt, YorZ) 43 | elif opt.policy_net == "opinion_net": 44 | net = build_opinion_net_policy(opt, YorZ) 45 | else: 46 | raise RuntimeError() 47 | return net 48 | 49 | 50 | def _freeze(net: _TorchModule) -> _TorchModule: 51 | for p in net.parameters(): 52 | p.requires_grad = False 53 | return net 54 | 55 | 56 | def _activate(net: _TorchModule) -> _TorchModule: 57 | for p in net.parameters(): 58 | p.requires_grad = True 59 | return net 60 | 61 | 62 | class SchrodingerBridgeModel(MFGPolicy): 63 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, use_t_idx: bool = True): 64 | super(SchrodingerBridgeModel, self).__init__(direction, dyn) 65 | self.opt = opt 66 | self.use_t_idx = use_t_idx 67 | self.g = opt.diffusion_std 68 | 69 | def _preprocessing(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 70 | # make sure t.shape = [batch] 71 | t = t.squeeze() 72 | if t.dim() == 0: 73 | t = t.repeat(x.shape[0]) 74 | assert t.dim() == 1 and t.shape[0] == x.shape[0] 75 | 76 | if self.use_t_idx: 77 | t = t / self.opt.T * self.opt.interval 78 | return x, t 79 | 80 | def compute_value(self, x: torch.Tensor, t: torch.Tensor, use_ema: bool = False) -> torch.Tensor: 81 | raise NotImplementedError() 82 | 83 | def compute_policy(self, x, t) -> torch.Tensor: 84 | raise NotImplementedError() 85 | 86 | def forward(self, x, t): # set default calling to policy 87 | return self.compute_policy(x, t) 88 | 89 | def freeze(self): 90 | self.eval() 91 | self.zero_grad() 92 | 93 | def activate(self): 94 | self.train() 95 | 96 | 97 | class SB_paramY(SchrodingerBridgeModel): 98 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, Ynet: torch.nn.Module): 99 | super(SB_paramY, self).__init__(opt, direction, dyn, use_t_idx=True) 100 | self.Ynet = Ynet 101 | self.emaY = ExponentialMovingAverage(self.Ynet.parameters(), decay=opt.ema) 102 | self.param = "critic" 103 | 104 | def compute_value(self, x: torch.Tensor, t: torch.Tensor, use_ema: bool = False) -> torch.Tensor: 105 | x, t = self._preprocessing(x, t) 106 | if use_ema: 107 | with self.emaY.average_parameters(): 108 | return self._standardizeYnet(self.Ynet(x, t).squeeze()) 109 | return self._standardizeYnet(self.Ynet(x, t).squeeze()) 110 | 111 | def compute_value_grad(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 112 | x, t = self._preprocessing(x, t) 113 | requires_grad = x.requires_grad 114 | with torch.enable_grad(): 115 | x.requires_grad_(True) 116 | y = self._standardizeYnet(self.Ynet(x, t)) 117 | out = torch.autograd.grad(y.sum(), x, create_graph=self.training)[0] 118 | x.requires_grad_(requires_grad) # restore original setup 119 | if not self.training: 120 | self.zero_grad() # out = out.detach() 121 | return out 122 | 123 | def compute_policy(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 124 | # Z = g * nabla Y 125 | return self.g * self.compute_value_grad(x, t) 126 | 127 | def _standardizeYnet(self, out: torch.Tensor) -> torch.Tensor: 128 | # standardize the Ynet output 129 | return out / self.g 130 | 131 | def freeze(self) -> None: 132 | self.Ynet = _freeze(self.Ynet) 133 | super(SB_paramY, self).freeze() 134 | 135 | def activate(self) -> None: 136 | self.Ynet = _activate(self.Ynet) 137 | super(SB_paramY, self).activate() 138 | 139 | def get_ema(self) -> ExponentialMovingAverage: 140 | return self.emaY 141 | 142 | def update_ema(self) -> None: 143 | self.emaY.update() 144 | 145 | def state_dict(self): 146 | return { 147 | "Ynet": self.Ynet.state_dict(), 148 | "emaY": self.emaY.state_dict(), 149 | } 150 | 151 | def load_state_dict(self, state_dict) -> None: 152 | self.Ynet.load_state_dict(state_dict["Ynet"]) 153 | self.emaY.load_state_dict(state_dict["emaY"]) 154 | 155 | 156 | class SB_paramYZ(SB_paramY): 157 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, Ynet: torch.nn.Module, Znet: torch.nn.Module): 158 | super(SB_paramYZ, self).__init__(opt, direction, dyn, Ynet) 159 | self.Znet = Znet 160 | self.emaZ = ExponentialMovingAverage(self.Znet.parameters(), decay=opt.ema) 161 | self.param = "actor-critic" 162 | 163 | def compute_policy(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 164 | x, t = self._preprocessing(x, t) 165 | return self.Znet(x, t) 166 | 167 | def freeze(self) -> None: 168 | self.Znet = _freeze(self.Znet) 169 | super(SB_paramYZ, self).freeze() 170 | 171 | def activate(self, only_Ynet: bool = False) -> None: 172 | self.Znet = _activate(self.Znet) if not only_Ynet else _freeze(self.Znet) 173 | super(SB_paramYZ, self).activate() 174 | 175 | def get_ema(self) -> ExponentialMovingAverage: 176 | return self.emaZ 177 | 178 | def update_ema(self, only_Ynet: bool = False) -> None: 179 | if not only_Ynet: 180 | self.emaZ.update() 181 | super(SB_paramYZ, self).update_ema() 182 | 183 | def state_dict(self): 184 | sdict = { 185 | "Znet": self.Znet.state_dict(), 186 | "emaZ": self.emaZ.state_dict(), 187 | } 188 | sdict.update(super(SB_paramYZ, self).state_dict()) 189 | return sdict 190 | 191 | def load_state_dict(self, state_dict): 192 | self.Znet.load_state_dict(state_dict["Znet"]) 193 | self.emaZ.load_state_dict(state_dict["emaZ"]) 194 | super(SB_paramYZ, self).load_state_dict(state_dict) 195 | -------------------------------------------------------------------------------- /deepgsb/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import TYPE_CHECKING, Iterable, Tuple 3 | 4 | import termcolor 5 | import torch 6 | from easydict import EasyDict as edict 7 | 8 | if TYPE_CHECKING: 9 | from deepgsb.deepgsb import DeepGSB 10 | from options import Options 11 | 12 | log = logging.getLogger(__file__) 13 | 14 | 15 | # convert to colored strings 16 | def red(content): return termcolor.colored(str(content),"red",attrs=["bold"]) 17 | def green(content): return termcolor.colored(str(content),"green",attrs=["bold"]) 18 | def blue(content): return termcolor.colored(str(content),"blue",attrs=["bold"]) 19 | def cyan(content): return termcolor.colored(str(content),"cyan",attrs=["bold"]) 20 | def yellow(content): return termcolor.colored(str(content),"yellow",attrs=["bold"]) 21 | def magenta(content): return termcolor.colored(str(content),"magenta",attrs=["bold"]) 22 | 23 | def count_parameters(model: torch.nn.Module): 24 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 25 | 26 | def evaluate_stage(opt, stage): 27 | """ Determine what metrics to evaluate for the current stage, 28 | if metrics is None, use the frequency in opt to decide it. 29 | """ 30 | match = lambda freq: (freq>0 and stage%freq==0) 31 | return [match(opt.snapshot_freq), match(opt.ckpt_freq)] 32 | 33 | def get_time(sec: float) -> Tuple[int, int, float]: 34 | h = int(sec//3600) 35 | m = int((sec//60)%60) 36 | s = sec%60 37 | return h, m, s 38 | 39 | def flatten_dim01(x: torch.Tensor) -> torch.Tensor: 40 | # (dim0, dim1, *dim2) --> (dim0x1, *dim2) 41 | return x.reshape(-1, *x.shape[2:]) 42 | 43 | def unflatten_dim01(x: torch.Tensor, dim01) -> torch.Tensor: 44 | # (dim0x1, *dim2) --> (dim0, dim1, *dim2) 45 | return x.reshape(*dim01, *x.shape[1:]) 46 | 47 | def restore_checkpoint(opt: "Options", runner: "DeepGSB", load_name: str) -> None: 48 | assert load_name is not None 49 | log.info("#loading checkpoint {}...".format(load_name)) 50 | 51 | full_keys = ['z_f','optimizer_f','z_b','optimizer_b'] 52 | 53 | with torch.cuda.device(opt.gpu): 54 | checkpoint = torch.load(load_name,map_location=opt.device) 55 | ckpt_keys=[*checkpoint.keys()] 56 | 57 | for k in ckpt_keys: 58 | thing = getattr(runner,k) 59 | if hasattr(thing, "load_state_dict"): 60 | getattr(runner,k).load_state_dict(checkpoint[k]) 61 | else: 62 | setattr(runner, k, checkpoint[k]) 63 | 64 | if len(full_keys) != len(ckpt_keys): 65 | missing_keys = { k for k in set(full_keys) - set(ckpt_keys) } 66 | extra_keys = {k for k in set(ckpt_keys) - set(full_keys)} 67 | 68 | if len(missing_keys) > 0: 69 | log.warning("Does not load model for {}, check is it correct".format(missing_keys)) 70 | else: 71 | log.warning("Loaded extra keys not found in full_keys: {}".format(extra_keys)) 72 | 73 | else: 74 | log.info('#successfully loaded all the modules') 75 | 76 | # runner.ema_f.copy_to() 77 | # runner.ema_b.copy_to() 78 | # print(green('#loading form ema shadow parameter for polices')) 79 | log.info("#######summary of checkpoint##########") 80 | 81 | def save_checkpoint(opt: "Options", runner: "DeepGSB", keys: Iterable[str], stage_it: int) -> None: 82 | checkpoint = {} 83 | fn = opt.ckpt_path + "/stage_{0:04}.npz".format(stage_it) 84 | with torch.cuda.device(opt.gpu): 85 | for k in keys: 86 | variable = getattr(runner, k) 87 | if hasattr(variable, "state_dict"): 88 | checkpoint[k] = variable.state_dict() 89 | else: 90 | checkpoint[k] = variable 91 | 92 | torch.save(checkpoint, fn) 93 | print(green("checkpoint saved: {}".format(fn))) 94 | 95 | def get_loss_str(loss) -> str: 96 | if isinstance(loss, edict): 97 | return ' '.join([ f'({key})' + f'{val.item():+2.3f}' for key, val in loss.items()]) 98 | else: 99 | return f'{loss.item():+.4f}' 100 | 101 | 102 | -------------------------------------------------------------------------------- /git_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import subprocess 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | 8 | def git_describe() -> str: 9 | label = subprocess.check_output(["git", "describe", "--always", "HEAD"]).strip().decode("utf-8") 10 | return label 11 | 12 | 13 | def git_clean() -> bool: 14 | git_output = subprocess.check_output(["git", "status", "--porcelain"]).strip().decode("utf-8") 15 | is_clean = len(git_output) == 0 16 | 17 | return is_clean 18 | 19 | 20 | def git_diff() -> str: 21 | diff = subprocess.check_output(["git", "diff", "HEAD"]).strip().decode("utf-8") 22 | if len(diff) == 0 or diff[-1] != "\n": 23 | diff = diff + "\n" 24 | return diff 25 | 26 | 27 | def log_git_info(run_dir: pathlib.Path) -> None: 28 | label = git_describe() 29 | if git_clean(): 30 | log.info("Working tree is clean! HEAD is {}".format(label)) 31 | return 32 | 33 | diff_path = run_dir / "diff.patch" 34 | log.warning("Continuing with dirty working tree! HEAD is {}".format(label)) 35 | log.warning("Saving the results of git diff to {}".format(diff_path)) 36 | with open(diff_path, "w") as f: 37 | f.write(git_diff()) 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | import pathlib 5 | import sys 6 | 7 | import colored_traceback.always 8 | import torch 9 | 10 | import options 11 | from git_utils import log_git_info 12 | from deepgsb import DeepGSB 13 | from mfg import MFG 14 | 15 | from rich.console import Console 16 | from rich.logging import RichHandler 17 | from options import Options 18 | 19 | def setup_logger(log_dir: pathlib.Path) -> None: 20 | log_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | log_file = open(log_dir / "log.txt", "w") 23 | file_console = Console(file=log_file, width=150) 24 | logging.basicConfig( 25 | level=logging.INFO, 26 | format="%(message)s", 27 | datefmt="[%X]", 28 | force=True, 29 | handlers=[RichHandler(), RichHandler(console=file_console)], 30 | ) 31 | 32 | def run(opt: Options): 33 | log = logging.getLogger(__name__) 34 | log.info("=======================================================") 35 | log.info(" Deep Generalized Schrodinger Bridge ") 36 | log.info("=======================================================") 37 | log.info("Command used:\n{}".format(" ".join(sys.argv))) 38 | 39 | mfg = MFG(opt) 40 | deepgsb = DeepGSB(opt, mfg) 41 | deepgsb.train(opt) 42 | 43 | def main(): 44 | print("setting configurations...") 45 | opt = options.set() 46 | 47 | run_dir = pathlib.Path("results") / opt.dir 48 | setup_logger(run_dir) 49 | log_git_info(run_dir) 50 | 51 | if not opt.cpu: 52 | with torch.cuda.device(opt.gpu): 53 | run(opt) 54 | else: 55 | run(opt) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /make_animation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | import pathlib 5 | import pickle 6 | import sys 7 | import os 8 | import argparse 9 | 10 | import ipdb 11 | import torch 12 | import numpy as np 13 | 14 | from mfg import MFG 15 | from deepgsb import DeepGSB 16 | 17 | from rich.logging import RichHandler 18 | 19 | import matplotlib.pyplot as plt 20 | 21 | import imageio 22 | from mfg.plotting import * 23 | 24 | from mfg.opinion_lib import est_directional_similarity 25 | 26 | from ipdb import set_trace as debug 27 | 28 | def setup_logger() -> None: 29 | logging.basicConfig( 30 | level=logging.INFO, 31 | format="%(message)s", 32 | datefmt="[%X]", 33 | force=True, 34 | handlers=[RichHandler(),], 35 | ) 36 | 37 | def restore_ckpt_option(opt): 38 | assert opt.load is not None 39 | ckpt_path = pathlib.Path(opt.load) 40 | assert ckpt_path.exists() 41 | 42 | options_pkl_path = ckpt_path.parent / "options.pkl" 43 | assert options_pkl_path.exists() 44 | 45 | # Load options pkl and overwrite the load. 46 | with open(options_pkl_path, "rb") as f: 47 | ckpt_options = pickle.load(f) 48 | ckpt_options.load = opt.load 49 | 50 | return ckpt_options 51 | 52 | def build_steps(direction, interval, total_steps=100): 53 | steps = np.linspace(0, interval-1, total_steps).astype(int) 54 | if direction == "backward": 55 | steps = np.flip(steps) 56 | return steps 57 | 58 | def get_title(opt, direction): 59 | return { 60 | "GMM": "GMM", 61 | "Stunnel": "S-tunnel", 62 | "Vneck": "V-neck", 63 | "opinion": "Opinion", 64 | "opinion_1k": "Opinion 1k", 65 | }.get(opt.problem_name) + f" ({direction} policy)" 66 | 67 | @torch.no_grad() 68 | def plot_directional_sim(opt, xs, ax) -> None: 69 | 70 | n_est = 5000 71 | directional_sim = est_directional_similarity(xs, n_est) 72 | assert directional_sim.shape == (n_est, ) 73 | 74 | directional_sim = to_numpy(directional_sim) 75 | 76 | bins = 15 77 | _, _, patches = ax.hist(directional_sim, bins=bins, ) 78 | 79 | colors = plt.cm.coolwarm(np.linspace(1.0, 0.0, bins)) 80 | 81 | for c, p in zip(colors, patches): 82 | plt.setp(p, 'facecolor', c) 83 | 84 | ymax = 1000 if opt.x_dim == 2 else 2000 85 | ax.set_ylim(0, ymax) 86 | ax.set_xlim(0, 1) 87 | ax.set_xticks([]) 88 | ax.set_yticks([]) 89 | ax.set_xticks([], minor=True) 90 | ax.set_yticks([], minor=True) 91 | 92 | @torch.no_grad() 93 | def make_gif(opt, policy_f, policy_b, mfg, gif_name=None, plot_dim=[0,1]): 94 | 95 | file_path = os.path.join(".tmp", opt.group, opt.name) 96 | os.makedirs(file_path, exist_ok=True) 97 | 98 | xs_f, xs_b, xs_f_np, xs_b_np = sample_traj(opt, mfg, mfg.ts, policy_f, policy_b, plot_dim) 99 | 100 | xlims = get_lims(opt) 101 | ylims = get_ylims(opt) 102 | 103 | filenames = [] 104 | for xs, xs_np, policy in zip([xs_f, xs_b], [xs_f_np, xs_b_np], [policy_f, policy_b]): 105 | 106 | if "opinion" in opt.problem_name and policy.direction == "backward": 107 | # skip backward opinion due to the mean-field drift 108 | continue 109 | 110 | y_mesher = get_func_mesher(opt, mfg.ts, 200, policy.compute_value) if opt.x_dim == 2 else None 111 | 112 | colors = get_colors(xs_np.shape[1]) 113 | title = get_title(opt, policy.direction) 114 | # title = "Polarize (before apply DeepGSB)" 115 | # title = "Depolarize (after apply DeepGSB)" 116 | 117 | steps = build_steps(policy.direction, xs_np.shape[1], total_steps=100) 118 | for step in steps: 119 | # prepare plotting 120 | fig = plt.figure(figsize=(3,3), constrained_layout=True) 121 | ax = fig.subplots(1, 1) 122 | 123 | # plot policy and value 124 | plot_obs(opt, ax, zorder=0) 125 | ax.scatter(xs_np[:,step,0], xs_np[:,step,1], s=1.5, color=colors[step], alpha=0.5, zorder=1) 126 | if y_mesher is not None: 127 | cp = ax.contour(*y_mesher(step), levels=10, cmap="copper", linewidths=1, zorder=2) 128 | ax.clabel(cp, inline=True, fontsize=6) 129 | setup_ax(ax, title, xlims, ylims, title_fontsize=12) 130 | 131 | if "opinion" in opt.problem_name: 132 | axins = ax.inset_axes([0.59, 0.01, 0.4, 0.4]) 133 | plot_directional_sim(opt, xs[:,step], axins) 134 | axins.text( 135 | 0.5, 0.9, r"$t$=" + f"{step/xs.shape[1]:0.2f}" + r"$T$", 136 | transform=axins.transAxes, fontsize=7, ha='center', va='center' 137 | ) 138 | 139 | # save fig 140 | filename = f"{file_path}/{policy.direction}_{str(step).zfill(3)}.png" 141 | filenames.append(filename) 142 | plt.savefig(filename) 143 | plt.close(fig) 144 | 145 | # build gif 146 | images = list(map(lambda filename: imageio.imread(filename), filenames)) 147 | imageio.mimsave(f'{gif_name or opt.problem_name}.gif', images, duration=0.04) # modify the frame duration as needed 148 | 149 | # Remove files 150 | for filename in set(filenames): 151 | os.remove(filename) 152 | 153 | def run(ckpt_options, gif_name=None): 154 | mfg = MFG(ckpt_options) 155 | deepgsb = DeepGSB(ckpt_options, mfg, save_opt=False) 156 | make_gif(ckpt_options, deepgsb.z_f, deepgsb.z_b, mfg, gif_name=gif_name) 157 | 158 | 159 | def main(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("--load", type=str) 162 | parser.add_argument("--name", type=str, default=None) 163 | arg = parser.parse_args() 164 | 165 | setup_logger() 166 | log = logging.getLogger(__name__) 167 | log.info("Command used:\n{}".format(" ".join(sys.argv))) 168 | 169 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 170 | ckpt_options = restore_ckpt_option(arg) 171 | 172 | if not ckpt_options.cpu: 173 | with torch.cuda.device(ckpt_options.gpu): 174 | run(ckpt_options, gif_name=arg.name) 175 | else: 176 | run(ckpt_options, gif_name=arg.name) 177 | 178 | if __name__ == "__main__": 179 | with ipdb.launch_ipdb_on_exception(): 180 | main() 181 | -------------------------------------------------------------------------------- /mfg/__init__.py: -------------------------------------------------------------------------------- 1 | from .mfg import MFG, MFGPolicy -------------------------------------------------------------------------------- /mfg/constraint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import NamedTuple, Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributions as td 7 | 8 | log = logging.getLogger(__file__) 9 | 10 | 11 | class Sampler: 12 | def __init__(self, distribution: td.Distribution, batch_size: int, device: str): 13 | self.distribution = distribution 14 | self.batch_size = batch_size 15 | self.device = device 16 | 17 | def log_prob(self, x: torch.Tensor) -> torch.Tensor: 18 | return self.distribution.log_prob(x) 19 | 20 | def sample(self, batch: Optional[int] = None) -> torch.Tensor: 21 | if batch is None: 22 | batch = self.batch_size 23 | return self.distribution.sample([batch]).to(self.device) 24 | 25 | 26 | class ProblemDists(NamedTuple): 27 | p0: Sampler 28 | pT: Sampler 29 | 30 | 31 | def build_constraint(problem_name: str, batch_size: int, device: str) -> ProblemDists: 32 | log.info("build distributional constraints ...") 33 | 34 | distribution_builder = { 35 | "GMM": gmm_builder, 36 | "Stunnel": stunnel_builder, 37 | "Vneck": vneck_builder, 38 | "opinion": opinion_builder, 39 | "opinion_1k": opinion_1k_builder, 40 | }.get(problem_name) 41 | 42 | return distribution_builder(batch_size, device) 43 | 44 | 45 | def gmm_builder(batch_size: int, device: str) -> ProblemDists: 46 | 47 | # ----- pT ----- 48 | radius, num = 16, 8 49 | arc = 2 * np.pi / num 50 | xs = [np.cos(arc * idx) * radius for idx in range(num)] 51 | ys = [np.sin(arc * idx) * radius for idx in range(num)] 52 | 53 | mix = td.Categorical( 54 | torch.ones( 55 | num, 56 | ) 57 | ) 58 | comp = td.Independent(td.Normal(torch.Tensor([[x, y] for x, y in zip(xs, ys)]), torch.ones(num, 2)), 1) 59 | dist = td.MixtureSameFamily(mix, comp) 60 | pT = Sampler(dist, batch_size, device) 61 | 62 | # ----- p0 ----- 63 | dist = td.MultivariateNormal(torch.zeros(2), torch.eye(2)) 64 | p0 = Sampler(dist, batch_size, device) 65 | 66 | return ProblemDists(p0, pT) 67 | 68 | 69 | def vneck_builder(batch_size: int, device: str) -> ProblemDists: 70 | 71 | # ----- pT ----- 72 | dist = td.MultivariateNormal(torch.Tensor([7, 0]), 0.2 * torch.eye(2)) 73 | pT = Sampler(dist, batch_size, device) 74 | 75 | # ----- p0 ----- 76 | dist = td.MultivariateNormal(-torch.Tensor([7, 0]), 0.2 * torch.eye(2)) 77 | p0 = Sampler(dist, batch_size, device) 78 | 79 | return ProblemDists(p0, pT) 80 | 81 | 82 | def stunnel_builder(batch_size: int, device: str) -> ProblemDists: 83 | 84 | # ----- pT ----- 85 | dist = td.MultivariateNormal(torch.Tensor([11, 1]), 0.5 * torch.eye(2)) 86 | pT = Sampler(dist, batch_size, device) 87 | 88 | # ----- p0 ----- 89 | dist = td.MultivariateNormal(-torch.Tensor([11, 1]), 0.5 * torch.eye(2)) 90 | p0 = Sampler(dist, batch_size, device) 91 | 92 | return ProblemDists(p0, pT) 93 | 94 | 95 | def opinion_builder(batch_size: int, device: str) -> ProblemDists: 96 | 97 | p0_std = 0.25 98 | pT_std = 3.0 99 | 100 | # ----- p0 ----- 101 | mu0 = torch.zeros(2) 102 | covar0 = p0_std * torch.eye(2) 103 | 104 | # Start with kind-of polarized opinions. 105 | covar0[0, 0] = 0.5 106 | 107 | # ----- pT ----- 108 | muT = torch.zeros(2) 109 | # Want to finish with more homogenous opinions. 110 | covarT = pT_std * torch.eye(2) 111 | 112 | dist = td.MultivariateNormal(muT, covarT) 113 | pT = Sampler(dist, batch_size, device) 114 | 115 | dist = td.MultivariateNormal(mu0, covar0) 116 | p0 = Sampler(dist, batch_size, device) 117 | 118 | return ProblemDists(p0, pT) 119 | 120 | 121 | def opinion_1k_builder(batch_size: int, device: str) -> ProblemDists: 122 | 123 | p0_std = 0.25 124 | pT_std = 3.0 125 | 126 | # ----- p0 ----- 127 | mu0 = torch.zeros(1000) 128 | covar0 = p0_std * torch.eye(1000) 129 | 130 | # Start with kind-of polarized opinions. 131 | covar0[0, 0] = 4.0 132 | 133 | # ----- pT ----- 134 | muT = torch.zeros(1000) 135 | # Want to finish with more homogenous opinions. 136 | covarT = pT_std * torch.eye(1000) 137 | 138 | dist = td.MultivariateNormal(muT, covarT) 139 | pT = Sampler(dist, batch_size, device) 140 | 141 | dist = td.MultivariateNormal(mu0, covar0) 142 | p0 = Sampler(dist, batch_size, device) 143 | 144 | return ProblemDists(p0, pT) 145 | -------------------------------------------------------------------------------- /mfg/mfg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | 6 | from options import Options 7 | 8 | from .constraint import Sampler, build_constraint 9 | from .plotting import snapshot 10 | from .sde import SampledTraj, build_sde 11 | from .state_cost import build_state_cost_fn 12 | 13 | log = logging.getLogger(__file__) 14 | 15 | Bound = Tuple[torch.Tensor, torch.Tensor, Sampler] 16 | InitTermBound = Tuple[torch.Tensor, torch.Tensor, Sampler, torch.Tensor, torch.Tensor, Sampler] 17 | 18 | # if the state cost for this problem requires logp. 19 | logp_list = [ 20 | "Vneck", 21 | "opinion", 22 | "opinion_1k", 23 | ] 24 | 25 | # if the state cost for this problem requires xs_all. 26 | xs_all_list = ["GMM", "Stunnel", "Vneck"] 27 | 28 | # if the uncontrolled drift use MF 29 | mf_drift_list = ["opinion", "opinion_1k"] 30 | 31 | # if the uncontrolled drift depends on state 32 | state_dependent_drift_list = ["opinion", "opinion_1k"] 33 | 34 | 35 | def get_bound_index(direction: str, bound: str) -> int: 36 | assert direction in ["forward", "backward"] 37 | assert bound in ["init", "term"] 38 | 39 | if direction == "forward" and bound == "init": 40 | return 0 41 | elif direction == "forward" and bound == "term": 42 | return -1 43 | elif direction == "backward" and bound == "init": 44 | return -1 45 | elif direction == "backward" and bound == "term": 46 | return 0 47 | 48 | 49 | class MFGPolicy(torch.nn.Module): 50 | def __init__(self, direction, dyn): 51 | super(MFGPolicy, self).__init__() 52 | self.direction = direction 53 | self.dyn = dyn 54 | 55 | 56 | class MFG: 57 | def __init__(self, opt: Options): 58 | 59 | self.opt = opt 60 | self.problem_name = opt.problem_name 61 | 62 | self.ts = torch.linspace(opt.t0, opt.T, opt.interval) 63 | self.mf_coeff = opt.MF_cost 64 | 65 | self.p0, self.pT = build_constraint(opt.problem_name, opt.samp_bs, opt.device) 66 | self.state_cost_fn, self.obstacle_cost_fn = build_state_cost_fn(opt.problem_name) 67 | 68 | self.pbound = [self.p0, self.pT] 69 | 70 | self.sde = build_sde(opt, self.p0, self.pT) 71 | 72 | @property 73 | def dt(self) -> float: 74 | return self.sde.dt 75 | 76 | def f(self, x: torch.Tensor, t: torch.Tensor, direction: str) -> torch.Tensor: 77 | return self.sde.f(x, t, direction) 78 | 79 | def g(self, t: torch.Tensor) -> torch.Tensor: 80 | return self.sde.g(t) 81 | 82 | def uses_logp(self) -> bool: 83 | """Returns true if the state cost for this problem requires logp.""" 84 | return self.problem_name in logp_list 85 | 86 | def uses_xs_all(self) -> bool: 87 | """Returns true if the state cost for this problem requires xs_all.""" 88 | return self.problem_name in xs_all_list 89 | 90 | def uses_mf_drift(self) -> bool: 91 | """Returns true if the uncontrolled drift for this problem involves mean field.""" 92 | return self.problem_name in mf_drift_list 93 | 94 | def uses_state_dependent_drift(self) -> bool: 95 | """Returns true if the uncontrolled drift for this problem depends on state x.""" 96 | return self.problem_name in state_dependent_drift_list 97 | 98 | def initialize_mf_drift(self, policy: MFGPolicy) -> None: 99 | self.sde.initialize_mf_drift(self.ts, policy) 100 | 101 | def sample_traj(self, policy: MFGPolicy, **kwargs) -> SampledTraj: 102 | return self.sde.sample_traj(self.ts, policy, **kwargs) 103 | 104 | def compute_state_cost( 105 | self, xs: torch.Tensor, ts: torch.Tensor, xs_all: torch.Tensor, log_p: torch.Tensor 106 | ) -> torch.Tensor: 107 | s_cost, mf_cost = self.state_cost_fn(xs, ts, log_p, xs_all) 108 | return s_cost + self.mf_coeff * mf_cost 109 | 110 | def get_bound(self, xs: torch.Tensor, direction: str, bound: str) -> Bound: 111 | t_index = get_bound_index(direction, bound) 112 | 113 | b, T, nx = xs.shape 114 | assert len(self.ts) == T and nx == self.opt.x_dim 115 | 116 | return self.ts[t_index], xs[:, t_index, ...], self.pbound[t_index] 117 | 118 | def get_init_term_bound(self, xs: torch.Tensor, direction: str) -> InitTermBound: 119 | return (*self.get_bound(xs, direction, "init"), *self.get_bound(xs, direction, "term")) 120 | 121 | def save_snapshot(self, policy_f, policy_b, stage: int) -> None: 122 | plot_logp = self.opt.x_dim == 2 123 | snapshot(self.opt, policy_f, policy_b, self, stage, plot_logp=plot_logp) 124 | -------------------------------------------------------------------------------- /mfg/opinion_lib.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from options import Options 9 | 10 | log = logging.getLogger(__file__) 11 | 12 | 13 | @torch.no_grad() 14 | @torch.jit.script 15 | def est_directional_similarity(xs: torch.Tensor, n_est: int = 1000) -> torch.Tensor: 16 | """xs: (batch, nx). Returns (n_est, ) between 0 and 1.""" 17 | # xs: (batch, nx) 18 | batch, nx = xs.shape 19 | 20 | # Center first. 21 | xs = xs - torch.mean(xs, dim=0, keepdim=True) 22 | 23 | rand_idxs1 = torch.randint(batch, [n_est], dtype=torch.long) 24 | rand_idxs2 = torch.randint(batch, [n_est], dtype=torch.long) 25 | 26 | # (n_est, nx) 27 | xs1 = xs[rand_idxs1] 28 | # (n_est, nx) 29 | xs2 = xs[rand_idxs2] 30 | 31 | # Normalize to unit vector. 32 | xs1 /= torch.linalg.norm(xs1, dim=1, keepdim=True) 33 | xs2 /= torch.linalg.norm(xs2, dim=1, keepdim=True) 34 | 35 | # (n_est, ) 36 | cos_angle = torch.sum(xs1 * xs2, dim=1).clip(-1.0, 1.0) 37 | assert cos_angle.shape == (n_est,) 38 | 39 | # Should be in [0, pi). 40 | angle = torch.acos(cos_angle) 41 | assert (0 <= angle).all() 42 | assert (angle <= torch.pi).all() 43 | 44 | D_ij = 1.0 - angle / torch.pi 45 | assert D_ij.shape == (n_est,) 46 | 47 | return D_ij 48 | 49 | 50 | def opinion_thresh(inner: torch.Tensor) -> torch.Tensor: 51 | return 2.0 * (inner > 0) - 1.0 52 | 53 | 54 | @torch.jit.script 55 | def compute_mean_drift_term(mf_x: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: 56 | """Decompose the polarize dynamic Eq (18) in paper into 2 parts for faster computation: 57 | f_polarize(x,p,ξ) 58 | = E_{y~p}[a(x,y,ξ) * bar_y], where a(x,y,ξ) = sign()*sign() 59 | and bar_y = y / |y|^{0.5} 60 | = sign() * E_{y~p}[sign() * bar_y], since sign() is independent of y 61 | = A(x,ξ) * B(p,ξ) 62 | Hence, bar_f_polarize = bar_A(x,ξ) * bar_B(p,ξ) 63 | This function computes only bar_B(p,ξ). 64 | """ 65 | # mf_x: (b, nx), xi: (nx,) 66 | # output: (nx,) 67 | 68 | b, nx = mf_x.shape 69 | assert xi.shape == (nx,) 70 | 71 | mf_x_norm = torch.linalg.norm(mf_x, dim=-1, keepdim=True) 72 | assert torch.all(mf_x_norm > 0.0) 73 | 74 | normalized_mf_x = mf_x / torch.sqrt(mf_x_norm) 75 | assert normalized_mf_x.shape == (b, nx) 76 | 77 | # Compute the mean drift term: 1/J sum_j a(y_j) y_j / sqrt(| y_j |). 78 | mf_agree_j = opinion_thresh(torch.sum(mf_x * xi, dim=-1, keepdim=True)) 79 | assert mf_agree_j.shape == (b, 1) 80 | 81 | mean_drift_term = torch.mean(mf_agree_j * normalized_mf_x, dim=0) 82 | assert mean_drift_term.shape == (nx,) 83 | 84 | mean_drift_term_norm = torch.linalg.norm(mean_drift_term, dim=-1, keepdim=True) 85 | mean_drift_term = mean_drift_term / torch.sqrt(mean_drift_term_norm) 86 | assert mean_drift_term.shape == (nx,) 87 | 88 | return mean_drift_term 89 | 90 | 91 | @torch.jit.script 92 | def opinion_f(x: torch.Tensor, mf_drift: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: 93 | """This function computes the polarize dynamic in Eq (18) by 94 | bar_f_polarize(x,p,ξ) = bar_A(x,ξ) * bar_B(p,ξ) 95 | where bar_B(p,ξ) is pre-computed in func compute_mean_drift_term and passed in as mf_drift. 96 | """ 97 | # x: (b, T, nx), mf_drift: (T, nx), xi: (T, nx) 98 | # out: (b, T, nx) 99 | 100 | b, T, nx = x.shape 101 | assert xi.shape == mf_drift.shape == (T, nx) 102 | 103 | agree_i = opinion_thresh(torch.sum(x * xi, dim=-1, keepdim=True)) 104 | # Make sure we are not dividing by 0. 105 | agree_i[agree_i == 0] = 1.0 106 | 107 | abs_sqrt_agree_i = torch.sqrt(torch.abs(agree_i)) 108 | assert torch.all(abs_sqrt_agree_i > 0.0) 109 | 110 | norm_agree_i = agree_i / abs_sqrt_agree_i 111 | assert norm_agree_i.shape == (b, T, 1) 112 | 113 | f = norm_agree_i * mf_drift 114 | assert f.shape == (b, T, nx) 115 | 116 | return f 117 | 118 | 119 | def build_f_mul(opt: Options) -> torch.Tensor: 120 | # set f_mul with some heuristic so that it doesn't diverge exponentially fast 121 | # and yield bad normalization, since the more polarized the opinion is the faster it will grow 122 | ts = torch.linspace(opt.t0, opt.T, opt.interval) 123 | coeff = 8.0 124 | f_mul = torch.clip(1.0 - torch.exp(coeff * (ts - opt.T)) + 1e-5, min=1e-4, max=1.0) 125 | f_mul = f_mul ** 5.0 126 | return f_mul 127 | 128 | 129 | def build_xis(opt: Options) -> torch.Tensor: 130 | # Generate random unit vectors. 131 | rng = np.random.default_rng(seed=4078213) 132 | xis = rng.standard_normal([opt.interval, opt.x_dim]) 133 | 134 | # Construct a xis that has some degree of "continuous" over time, as a brownian motion. 135 | xi = xis[0] 136 | bm_xis = [xi] 137 | std = 0.4 138 | for t in range(1, opt.interval): 139 | xi = xi - (2.0 * xi) * 0.01 + std * math.sqrt(0.01) * xis[t] 140 | bm_xis.append(xi) 141 | assert len(bm_xis) == xis.shape[0] 142 | 143 | xis = torch.Tensor(np.stack(bm_xis)) 144 | xis /= torch.linalg.norm(xis, dim=-1, keepdim=True) 145 | 146 | # Just safeguard if the self.xis becomes different. 147 | log.info("USING BM XI! xis.sum(): {}".format(torch.sum(xis))) 148 | return xis 149 | -------------------------------------------------------------------------------- /mfg/plotting.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import os 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from . import util 9 | from . import opinion_lib 10 | from .state_cost import gmm_obstacle_cfg , vneck_obstacle_cfg, stunnel_obstacle_cfg 11 | 12 | import matplotlib.pyplot as plt 13 | import matplotlib.colors as mcol 14 | from matplotlib.patches import Circle, Ellipse 15 | 16 | from ipdb import set_trace as debug 17 | 18 | 19 | def to_numpy(t): 20 | return t.detach().cpu().numpy() 21 | 22 | def get_lims(opt): 23 | return { 24 | "GMM": [-16.25, 16.25], 25 | "Stunnel": [-15, 15], 26 | "Vneck": [-10, 10], 27 | "opinion": [-10, 10], 28 | "opinion_1k": [-10, 10], 29 | }.get(opt.problem_name) 30 | 31 | def get_ylims(opt): 32 | return { 33 | "GMM": [-16.25, 16.25], 34 | "Stunnel": [-10, 10], 35 | "Vneck": [-5, 5], 36 | "opinion": [-10, 10], 37 | "opinion_1k": [-10, 10], 38 | }.get(opt.problem_name) 39 | 40 | def get_colors(n_snapshot): 41 | # assert n_snapshot % 2 == 1 42 | cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName",["b","r"]) 43 | colors = cm1(np.linspace(0.0, 1.0, n_snapshot)) 44 | return colors 45 | 46 | def create_mesh(opt, n_grid, lims, convert_to_numpy=True): 47 | import warnings 48 | 49 | _x = torch.linspace(*(lims+[n_grid])) 50 | 51 | # Suppress warning about indexing arg becoming required. 52 | with warnings.catch_warnings(): 53 | X, Y = torch.meshgrid(_x, _x) 54 | 55 | xs = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=1).to(opt.device) 56 | return [to_numpy(X), to_numpy(Y), xs] if convert_to_numpy else [X, Y, xs] 57 | 58 | def get_func_mesher(opt, ts, grid_n, func, out_dim=1): 59 | if func is None: return None 60 | lims = get_lims(opt) 61 | X1, X2, XS = create_mesh(opt, grid_n, lims) 62 | out_shape = [grid_n,grid_n] if out_dim==1 else [grid_n,grid_n,out_dim] 63 | 64 | def mesher(idx): 65 | # print(func, ts[idx]) 66 | arg_xs = XS.detach() 67 | arg_ts = ts[idx].repeat(grid_n**2).detach() 68 | fn_out = func(arg_xs, arg_ts) 69 | return X1, X2, to_numpy(fn_out.reshape(*out_shape)) 70 | 71 | return mesher 72 | 73 | def plot_obs(opt, ax, scale=1., zorder=0): 74 | if opt.problem_name == 'GMM': 75 | centers, radius = gmm_obstacle_cfg() 76 | for c in centers: 77 | circle = Circle(xy=np.array(c), radius=radius, zorder=zorder) 78 | 79 | ax.add_artist(circle) 80 | circle.set_clip_box(ax.bbox) 81 | circle.set_facecolor("darkgray") 82 | circle.set_edgecolor(None) 83 | 84 | elif opt.problem_name == 'Vneck': 85 | c_sq, coef = vneck_obstacle_cfg() 86 | x = np.linspace(-6,6,100) 87 | y1 = np.sqrt(c_sq + coef * np.square(x)) 88 | y2 = np.ones_like(x) * y1[0] 89 | 90 | ax.fill_between(x, y1, y2, color="darkgray", edgecolor=None, zorder=zorder) 91 | ax.fill_between(x, -y1, -y2, color="darkgray", edgecolor=None, zorder=zorder) 92 | 93 | elif opt.problem_name == 'Stunnel': 94 | a, b, cc, centers = stunnel_obstacle_cfg() 95 | for c in centers: 96 | elp = Ellipse( 97 | xy=np.array(c)*scale, width=2*np.sqrt(cc/a)*scale, height=2*np.sqrt(cc/b)*scale, zorder=zorder 98 | ) 99 | 100 | ax.add_artist(elp) 101 | elp.set_clip_box(ax.bbox) 102 | elp.set_facecolor("darkgray") 103 | elp.set_edgecolor(None) 104 | 105 | def setup_ax(ax, title, xlims, ylims, title_fontsize=18): 106 | ax.axis('equal') 107 | ax.set_xlim(*xlims) 108 | ax.set_ylim(*ylims) 109 | ax.set_title(title, fontsize=title_fontsize) 110 | ax.set_xticks([]) 111 | ax.set_yticks([]) 112 | ax.set_xticks([], minor=True) 113 | ax.set_yticks([], minor=True) 114 | 115 | def plot_traj_snapshot(opt, xs, axes, sample_steps, titles, y_mesher=None): 116 | 117 | n_snapshot = len(axes) 118 | assert len(sample_steps) == len(titles) == n_snapshot 119 | 120 | if sample_steps is None: 121 | sample_steps = np.linspace(0, xs.shape[1]-1, n_snapshot).astype(int) 122 | 123 | xlims = get_lims(opt) 124 | ylims = get_ylims(opt) 125 | 126 | colors = get_colors(n_snapshot) 127 | 128 | for ax, step, title, color in zip(axes, sample_steps, titles, colors): 129 | plot_obs(opt, ax, zorder=0) 130 | 131 | ax.scatter(xs[:,step,0],xs[:,step,1], s=1.5, color=color, alpha=0.5, zorder=1) 132 | if y_mesher is not None: 133 | cp = ax.contour(*y_mesher(step), levels=10, cmap="copper", linewidths=1, zorder=2) 134 | ax.clabel(cp, inline=True, fontsize=6) 135 | setup_ax(ax, title, xlims, ylims) 136 | 137 | def plot_directional_sim(opt, ax, stage, xs_term) -> None: 138 | n_est = 5000 139 | directional_sim = opinion_lib.est_directional_similarity(xs_term, n_est) 140 | assert directional_sim.shape == (n_est, ) 141 | 142 | directional_sim = to_numpy(directional_sim) 143 | 144 | bins = 100 145 | ax.hist(directional_sim, bins=bins) 146 | ax.set(xlabel="Directional Similarity", title="Stage={:3}".format(stage), xlim=(0., 1.)) 147 | 148 | 149 | def get_fig_axes_steps(interval, n_snapshot=5, ax_length_in: float = 4): 150 | n_row, n_col = 1, n_snapshot 151 | figsize = (n_col*ax_length_in, n_row*ax_length_in) 152 | 153 | fig = plt.figure(figsize=figsize, constrained_layout=True) 154 | axes = fig.subplots(n_row, n_col) 155 | steps = np.linspace(0, interval-1, n_snapshot).astype(int) 156 | 157 | return fig, axes, steps 158 | 159 | @torch.no_grad() 160 | def sample_traj(opt, mfg, ts, policy_f, policy_b, plot_dim): 161 | xs_f, _, _, _ = mfg.sde.sample_traj(ts, policy_f) 162 | xs_b, _, _, _ = mfg.sde.sample_traj(ts, policy_b) 163 | util.assert_zero_grads(policy_f) 164 | util.assert_zero_grads(policy_b) 165 | 166 | if opt.x_dim > 2: 167 | xs_f, xs_b = util.proj_pca(xs_f, xs_b, reverse=False) 168 | 169 | xs_f_np, xs_b_np = to_numpy(xs_f[..., plot_dim]), to_numpy(xs_b[..., plot_dim]) 170 | 171 | return xs_f, xs_b, xs_f_np, xs_b_np 172 | 173 | @torch.no_grad() 174 | def snapshot(opt, policy_f, policy_b, mfg, stage, plot_logp=False, plot_dim=[0,1]): 175 | 176 | # sample forward & backward trajs 177 | ts = mfg.ts 178 | 179 | xs_f, xs_b, xs_f_np, xs_b_np = sample_traj(opt, mfg, ts, policy_f, policy_b, plot_dim) 180 | 181 | interval = len(ts) 182 | 183 | for xs, policy in zip([xs_f_np, xs_b_np], [policy_f, policy_b]): 184 | 185 | # prepare plotting 186 | titles = [r'$t$ = 0', r'$t$ = 0.25$T$', r'$t$ = 0.50$T$', r'$t$ = 0.75$T$', r'$t = T$'] 187 | fig, axes, sample_steps = get_fig_axes_steps(interval, n_snapshot=len(titles)) 188 | assert len(titles) == len(axes) == len(sample_steps) 189 | 190 | # plot policy and value 191 | y_mesher = get_func_mesher(opt, ts, 200, policy.compute_value) if opt.x_dim == 2 else None 192 | plot_traj_snapshot( 193 | opt, xs, axes, sample_steps, titles, y_mesher=y_mesher, 194 | ) 195 | 196 | plt.savefig(os.path.join('results', opt.dir, policy.direction, f'stage{stage}.pdf')) 197 | plt.close(fig) 198 | 199 | if "opinion" in opt.problem_name: 200 | for xs_term, policy in zip([xs_f[:,-1], xs_b[:,0]], [policy_f, policy_b]): 201 | fig, ax = plt.subplots(figsize=(8, 4), constrained_layout=True) 202 | plot_directional_sim(opt, ax, stage, xs_term) 203 | plt.savefig(os.path.join( 204 | 'results', opt.dir, f"directional_sim_{policy.direction}", f'stage{stage}.pdf' 205 | )) 206 | plt.close(fig) 207 | 208 | if plot_logp: 209 | assert opt.x_dim == 2 210 | def logp_fn(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 211 | y_f = policy_f.compute_value(x, t) 212 | y_b = policy_b.compute_value(x, t) 213 | return y_f + y_b 214 | 215 | # prepare plotting 216 | titles = [r'$t$ = 0', r'$t$ = 0.25$T$', r'$t$ = 0.50$T$', r'$t$ = 0.75$T$', r'$t = T$'] 217 | fig, axes, sample_steps = get_fig_axes_steps(interval, n_snapshot=len(titles)) 218 | assert len(titles) == len(axes) == len(sample_steps) 219 | 220 | # plot logp 221 | logp_mesher = get_func_mesher(opt, ts, 200, logp_fn) 222 | plot_traj_snapshot( 223 | opt, xs_f_np, axes, sample_steps, titles, y_mesher=logp_mesher, 224 | ) 225 | 226 | plt.savefig(os.path.join('results', opt.dir, "logp", f'stage{stage}.pdf')) 227 | plt.close(fig) 228 | 229 | return xs_f, xs_b 230 | -------------------------------------------------------------------------------- /mfg/sde.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import math 4 | from typing import Callable, Optional, NamedTuple 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from options import Options 10 | 11 | from . import opinion_lib, util 12 | from .constraint import Sampler 13 | 14 | log = logging.getLogger(__file__) 15 | 16 | DriftFn = Callable[[torch.Tensor], torch.Tensor] 17 | 18 | class SampledTraj(NamedTuple): 19 | xs: torch.Tensor 20 | zs: torch.Tensor 21 | ws: torch.Tensor 22 | x_term: torch.Tensor 23 | 24 | def _assert_increasing(name: str, ts: torch.Tensor) -> None: 25 | assert (ts[1:] > ts[:-1]).all(), "{} must be strictly increasing".format(name) 26 | 27 | 28 | def base_drift_builder(opt: Options) -> Optional[DriftFn]: 29 | if opt.problem_name in ["Vneck", "Stunnel"]: 30 | 31 | def base_drift(x): 32 | b, T, nx = x.shape 33 | const = torch.Tensor([6.0, 0.0], device=x.device) 34 | assert const.shape == (nx,) 35 | return const.repeat(b, T, 1) 36 | 37 | else: 38 | base_drift = None 39 | return base_drift 40 | 41 | 42 | def t_to_idx(t: torch.Tensor, interval: int, T: float) -> torch.Tensor: 43 | return (t / T * (interval - 1)).round().long() 44 | 45 | 46 | class BaseSDE(metaclass=abc.ABCMeta): 47 | def __init__(self, opt: Options, p0: Sampler, pT: Sampler): 48 | self.opt = opt 49 | self.dt = opt.T / opt.interval 50 | self.p0 = p0 51 | self.pT = pT 52 | self.mf_drifts: Optional[torch.Tensor] = None 53 | 54 | @abc.abstractmethod 55 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 56 | raise NotImplementedError 57 | 58 | @abc.abstractmethod 59 | def _g(self, t: torch.Tensor) -> torch.Tensor: 60 | raise NotImplementedError 61 | 62 | def f(self, x: torch.Tensor, t: torch.Tensor, direction: str) -> torch.Tensor: 63 | # x: (b, T, nx), t: (T,) 64 | # out: (b, T, nx) 65 | b, T, nx = x.shape 66 | assert t.shape == (T,) 67 | 68 | sign = 1.0 if direction == "forward" else -1.0 69 | _f = self._f(x, t) 70 | assert _f.shape == (b, T, nx) 71 | return sign * _f 72 | 73 | def g(self, t: torch.Tensor) -> torch.Tensor: 74 | return self._g(t) 75 | 76 | def dw(self, x: torch.Tensor, dt: Optional[float] = None) -> torch.Tensor: 77 | dt = self.dt if dt is None else dt 78 | return torch.randn_like(x) * np.sqrt(dt) 79 | 80 | def propagate( 81 | self, 82 | t: torch.Tensor, 83 | x: torch.Tensor, 84 | z: torch.Tensor, 85 | direction: str, 86 | f: Optional[DriftFn] = None, 87 | dw: Optional[torch.Tensor] = None, 88 | dt: Optional[float] = None, 89 | ) -> torch.Tensor: 90 | g = self.g(t) 91 | f = self.f(x, t, direction) if f is None else f 92 | dt = self.dt if dt is None else dt 93 | dw = self.dw(x, dt) if dw is None else dw 94 | 95 | return x + (f + g * z) * dt + g * dw 96 | 97 | @torch.no_grad() 98 | def initialize_mf_drift(self, ts: torch.Tensor, policy) -> None: 99 | raise NotImplementedError 100 | 101 | @torch.no_grad() 102 | def update_mf_drift(self, x: torch.Tensor, t_idx: int) -> None: 103 | raise NotImplementedError 104 | 105 | def sample_traj(self, ts: torch.Tensor, policy, update_mf_drift: bool = False) -> SampledTraj: 106 | 107 | # first we need to know whether we're doing forward or backward sampling 108 | direction = policy.direction 109 | assert direction in ["forward", "backward"] 110 | 111 | # set up ts and init_distribution 112 | _assert_increasing("ts", ts) 113 | init_dist = self.p0 if direction == "forward" else self.pT 114 | ts = ts if direction == "forward" else torch.flip(ts, dims=[0]) 115 | 116 | x = init_dist.sample(batch=self.opt.samp_bs) 117 | (b, nx), T = x.shape, len(ts) 118 | assert nx == self.opt.x_dim 119 | 120 | xs = torch.empty((b, T, nx)) 121 | zs = torch.empty_like(xs) 122 | ws = torch.empty_like(xs) 123 | if update_mf_drift: 124 | self.mf_drifts = torch.empty(T, nx) 125 | 126 | # don't use tqdm for fbsde since it'll resample every itr 127 | for idx, t in enumerate(ts): 128 | t_idx = idx if direction == "forward" else T - idx - 1 129 | assert t_idx == t_to_idx(t, self.opt.interval, self.opt.T), (t_idx, t) 130 | 131 | if update_mf_drift: 132 | self.update_mf_drift(x, t_idx) 133 | 134 | # f = self.f(x,t,direction) 135 | # handle propagation of single time step 136 | f = self.f(x.unsqueeze(1), t.unsqueeze(0), direction).squeeze(1) 137 | z = policy(x, t) 138 | dw = self.dw(x) 139 | util.assert_zero_grads(policy) 140 | 141 | xs[:, t_idx, ...] = x 142 | zs[:, t_idx, ...] = z 143 | ws[:, t_idx, ...] = dw 144 | 145 | x = self.propagate(t, x, z, direction, f=f, dw=dw) 146 | 147 | x_term = x 148 | 149 | return SampledTraj(xs, zs, ws, x_term) 150 | 151 | 152 | class SimpleSDE(BaseSDE): 153 | def __init__(self, opt: Options, p: Sampler, q: Sampler, base_drift: Optional[DriftFn] = None): 154 | super(SimpleSDE, self).__init__(opt, p, q) 155 | self.std = opt.diffusion_std 156 | self.base_drift = base_drift 157 | 158 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 159 | return torch.zeros_like(x) if self.base_drift is None else self.base_drift(x) 160 | 161 | def _g(self, t: torch.Tensor) -> torch.Tensor: 162 | return torch.Tensor([self.std]) 163 | 164 | 165 | class OpinionSDE(SimpleSDE): 166 | """modified from the party model: 167 | See Eq (4) in https://www.cs.cornell.edu/home/kleinber/ec21-polarization.pdf 168 | """ 169 | 170 | def __init__(self, opt: Options, p: Sampler, q: Sampler): 171 | super(OpinionSDE, self).__init__(opt, p, q) 172 | assert "opinion" in opt.problem_name 173 | 174 | self.f_mul = opinion_lib.build_f_mul(opt) 175 | self.xis = opinion_lib.build_xis(opt) 176 | self.polarize_strength = 1.0 if opt.x_dim == 2 else 6.0 177 | self.mf_drifts = None 178 | 179 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 180 | 181 | b, T, nx = x.shape 182 | assert nx == self.opt.x_dim 183 | 184 | idx = t_to_idx(t, self.opt.interval, self.opt.T) 185 | 186 | fmul = self.f_mul[idx].to(x.device).unsqueeze(-1) 187 | xi = self.xis[idx].to(x.device) 188 | mf_drift = self.mf_drifts[idx].to(x.device) 189 | assert fmul.shape == (T, 1) 190 | assert xi.shape == mf_drift.shape == (T, nx) 191 | 192 | f = self.polarize_strength * opinion_lib.opinion_f(x, mf_drift, xi) 193 | assert f.shape == x.shape 194 | 195 | f = fmul * f 196 | assert f.shape == x.shape 197 | 198 | return f 199 | 200 | @torch.no_grad() 201 | def initialize_mf_drift(self, ts: torch.Tensor, policy) -> None: 202 | self.sample_traj(ts, policy, update_mf_drift=True) 203 | 204 | @torch.no_grad() 205 | def update_mf_drift(self, x: torch.Tensor, t_idx: int) -> None: 206 | xi = self.xis[t_idx].to(x.device) 207 | mf_drift = opinion_lib.compute_mean_drift_term(x, xi) 208 | self.mf_drifts[t_idx] = mf_drift.detach().cpu() 209 | 210 | 211 | def build_sde(opt: Options, p: Sampler, q: Sampler) -> SimpleSDE: 212 | log.info("build base sde...") 213 | 214 | if "opinion" in opt.problem_name: 215 | return OpinionSDE(opt, p, q) 216 | else: 217 | base_drift = base_drift_builder(opt) 218 | return SimpleSDE(opt, p, q, base_drift=base_drift) 219 | -------------------------------------------------------------------------------- /mfg/state_cost.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | 4 | StateCostFn = Callable[[torch.Tensor], torch.Tensor] 5 | MFCostFn = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] 6 | 7 | def build_state_cost_fn(problem_name: str, return_obs_cost_fn: bool = False): 8 | obstacle_cost_fn = get_obstacle_cost_fn(problem_name) 9 | mf_cost_fn = get_mf_cost_fn(problem_name) 10 | 11 | def state_cost_fn(xs, ts, logp, xs_all): 12 | obstacle_cost = obstacle_cost_fn(xs) 13 | mf_cost = mf_cost_fn(xs, logp, xs_all) 14 | return obstacle_cost, mf_cost 15 | 16 | return state_cost_fn, obstacle_cost_fn if return_obs_cost_fn else state_cost_fn 17 | 18 | def get_obstacle_cost_fn(problem_name: str) -> StateCostFn: 19 | return { 20 | 'GMM': obstacle_cost_fn_gmm, 21 | 'Stunnel': obstacle_cost_fn_vneck, 22 | 'Vneck':obstacle_cost_fn_stunnel, 23 | 'opinion': zero_cost_fn, 24 | 'opinion_1k': zero_cost_fn, 25 | }.get(problem_name) 26 | 27 | def get_mf_cost_fn(problem_name: str) -> MFCostFn: 28 | return { 29 | 'GMM': zero_cost_fn, 30 | 'Stunnel': congestion_cost, 31 | 'Vneck': entropy_cost, 32 | 'opinion': entropy_cost, 33 | 'opinion_1k': entropy_cost, 34 | }.get(problem_name) 35 | 36 | ########################################################## 37 | ################ mean-field cost functions ############### 38 | ########################################################## 39 | 40 | def entropy_cost(xs: torch.Tensor, logp: torch.Tensor, xs_all: torch.Tensor) -> torch.Tensor: 41 | if logp is None: 42 | raise ValueError("Add this problem to logp_list.") 43 | 44 | return logp + 1 45 | 46 | def congestion_cost(xs: torch.Tensor, logp: torch.Tensor, xs_all: torch.Tensor) -> torch.Tensor: 47 | assert xs.ndim == 2 48 | xs = xs.reshape(-1, xs_all.shape[1], *xs.shape[1:]) 49 | 50 | assert xs.ndim == 3 # should be (batch_x, batch_y, x_dim) 51 | assert xs.shape[1:] == xs_all.shape[1:] 52 | 53 | dd = xs - xs_all # batch_x, batch_t, xdim 54 | dist = torch.sum(dd * dd, dim=-1) # batch_x, batch_t 55 | out = 2.0 / (dist + 1.0) 56 | cost = out.reshape(-1, *out.shape[2:]) 57 | return cost 58 | 59 | ########################################################## 60 | ################## obstacle cost functions ############### 61 | ########################################################## 62 | 63 | def zero_cost_fn(x: torch.Tensor, *args) -> torch.Tensor: 64 | return torch.zeros(*x.shape[:-1]) 65 | 66 | def gmm_obstacle_cfg(): 67 | centers = [[6,6], [6,-6], [-6,-6]] 68 | radius = 1.5 69 | return centers, radius 70 | 71 | def stunnel_obstacle_cfg(): 72 | a, b, c = 20, 1, 90 73 | centers = [[5,6], [-5,-6]] 74 | return a, b, c, centers 75 | 76 | def vneck_obstacle_cfg(): 77 | c_sq = 0.36 78 | coef = 5 79 | return c_sq, coef 80 | 81 | @torch.jit.script 82 | def obstacle_cost_fn_gmm(xs: torch.Tensor) -> torch.Tensor: 83 | xs = xs.reshape(-1,xs.shape[-1]) 84 | 85 | batch_xt = xs.shape[0] 86 | 87 | centers, radius = gmm_obstacle_cfg() 88 | 89 | obs1 = torch.tensor(centers[0]).repeat((batch_xt,1)).to(xs.device) 90 | obs2 = torch.tensor(centers[1]).repeat((batch_xt,1)).to(xs.device) 91 | obs3 = torch.tensor(centers[2]).repeat((batch_xt,1)).to(xs.device) 92 | 93 | dist1 = torch.norm(xs - obs1, dim=-1) 94 | dist2 = torch.norm(xs - obs2, dim=-1) 95 | dist3 = torch.norm(xs - obs3, dim=-1) 96 | 97 | cost1 = 1500 * (dist1 < radius) 98 | cost2 = 1500 * (dist2 < radius) 99 | cost3 = 1500 * (dist3 < radius) 100 | 101 | return cost1 + cost2 + cost3 102 | 103 | @torch.jit.script 104 | def obstacle_cost_fn_vneck(xs: torch.Tensor) -> torch.Tensor: 105 | 106 | a, b, c, centers = stunnel_obstacle_cfg() 107 | 108 | _xs = xs.reshape(-1,xs.shape[-1]) 109 | x, y = _xs[:,0], _xs[:,1] 110 | 111 | d = a*(x-centers[0][0])**2 + b*(y-centers[0][1])**2 112 | c1 = 1500 * (d < c) 113 | 114 | d = a*(x-centers[1][0])**2 + b*(y-centers[1][1])**2 115 | c2 = 1500 * (d < c) 116 | 117 | return c1+c2 118 | 119 | @torch.jit.script 120 | def obstacle_cost_fn_stunnel(xs: torch.Tensor) -> torch.Tensor: 121 | c_sq, coef = vneck_obstacle_cfg() 122 | 123 | xs_sq = torch.square(xs) 124 | d = coef * xs_sq[..., 0] - xs_sq[..., 1] 125 | 126 | return 1500 * (d < -c_sq) 127 | -------------------------------------------------------------------------------- /mfg/util.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Tuple 2 | import torch 3 | 4 | 5 | def assert_zero_grads(network: torch.nn.Module) -> None: 6 | for param in network.parameters(): 7 | if not (param.grad is None or torch.allclose(param.grad, torch.zeros_like(param.grad))): debug() 8 | assert param.grad is None or torch.allclose(param.grad, torch.zeros_like(param.grad)) 9 | 10 | @torch.no_grad() 11 | def proj_pca(xs_f: torch.Tensor, xs_b: torch.Tensor, reverse: bool) -> Tuple[torch.Tensor, torch.Tensor]: 12 | # xs: (batch, T, nx) 13 | # Only use final timestep of xs_f for PCA. 14 | batch, T, nx = xs_f.shape 15 | 16 | # (batch * T, nx) 17 | flat_xsf = xs_f.reshape(-1, *xs_f.shape[2:]) 18 | flat_xsb = xs_b.reshape(-1, *xs_b.shape[2:]) 19 | 20 | # Center by subtract mean. 21 | # (batch, nx) 22 | if reverse: 23 | # If reverse, use xs_b[0] instead of xs_f[T] 24 | final_xs_f = xs_b[:, 0, :] 25 | else: 26 | final_xs_f = xs_f[:, -1, :] 27 | 28 | mean_pca_xs = torch.mean(final_xs_f, dim=0, keepdim=True) 29 | final_xs_f -= mean_pca_xs 30 | 31 | # if batch is too large, it will run out of memory. 32 | if batch > 200: 33 | rand_idxs = torch.randperm(batch)[:200] 34 | final_xs_f = final_xs_f[rand_idxs] 35 | 36 | # U: (batch, k) 37 | # S: (k, k) 38 | # VT: (k, nx) 39 | U, S, VT = torch.linalg.svd(final_xs_f) 40 | 41 | # log.info("Singular values of xs_f at final timestep:") 42 | # log.info(S) 43 | 44 | # Keep the first and last directions. 45 | VT = VT[:2, :] 46 | # VT = VT[[0, -1], :] 47 | 48 | assert VT.shape == (2, nx) 49 | V = VT.T 50 | 51 | # Project both xs_f and xs_b onto V. 52 | flat_xsf -= mean_pca_xs 53 | flat_xsb -= mean_pca_xs 54 | 55 | proj_xs_f = flat_xsf @ V 56 | proj_xs_f = proj_xs_f.reshape(batch, T, *proj_xs_f.shape[1:]) 57 | 58 | proj_xs_b = flat_xsb @ V 59 | proj_xs_b = proj_xs_b.reshape(batch, T, *proj_xs_b.shape[1:]) 60 | 61 | return proj_xs_f, proj_xs_b -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .toy_net import build_toy_net_policy 2 | from .opinion_net import build_opinion_net_policy 3 | -------------------------------------------------------------------------------- /models/opinion_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from options import Options 5 | from .util import SiLU, timestep_embedding, ResNet_FC 6 | 7 | 8 | class OpinionYImpl(torch.nn.Module): 9 | def __init__(self, data_dim: int, time_embed_dim: int, hid: int, out_hid: int): 10 | super().__init__() 11 | 12 | self.t_module = nn.Sequential( 13 | nn.Linear(time_embed_dim, hid), 14 | SiLU(), 15 | nn.Linear(hid, hid), 16 | ) 17 | self.x_module = ResNet_FC(data_dim, hid, num_res_blocks=5) 18 | 19 | self.out_module = nn.Sequential( 20 | nn.Linear(hid + hid, out_hid), 21 | SiLU(), 22 | nn.Linear(out_hid, out_hid), 23 | SiLU(), 24 | nn.Linear(out_hid, 1), 25 | ) 26 | 27 | def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: 28 | t_out = self.t_module(t_emb) 29 | x_out = self.x_module(x) 30 | 31 | t_out = t_out.expand(x_out.shape) 32 | 33 | out = self.out_module(torch.cat([x_out, t_out], dim=-1)) 34 | 35 | return out 36 | 37 | 38 | class OpinionY(torch.nn.Module): 39 | def __init__(self, data_dim: int = 1000, hid: int = 128, out_hid: int = 128, time_embed_dim: int = 128): 40 | super(OpinionY,self).__init__() 41 | 42 | self.time_embed_dim = time_embed_dim 43 | self.yt_impl = OpinionYImpl(data_dim, time_embed_dim, hid, out_hid) 44 | self.yt_impl = torch.jit.script(self.yt_impl) 45 | 46 | @property 47 | def inner_dtype(self): 48 | """ 49 | Get the dtype used by the torso of the model. 50 | """ 51 | return next(self.input_blocks.parameters()).dtype 52 | 53 | def forward(self, x:torch.Tensor, t: torch.Tensor) -> torch.Tensor: 54 | """ 55 | Apply the model to an input batch. 56 | :param x: an [N x C x ...] Tensor of inputs. 57 | :param t: a 1-D batch of timesteps. 58 | """ 59 | 60 | # make sure t.shape = [T] 61 | if len(t.shape)==0: 62 | t=t[None] 63 | 64 | t_emb = timestep_embedding(t, self.time_embed_dim) 65 | 66 | out = self.yt_impl(x, t_emb) 67 | 68 | return out 69 | 70 | class OpinionZImpl(torch.nn.Module): 71 | def __init__(self, data_dim: int, time_embed_dim: int, hid: int): 72 | super().__init__() 73 | 74 | self.t_module = nn.Sequential( 75 | nn.Linear(time_embed_dim, hid), 76 | SiLU(), 77 | nn.Linear(hid, hid), 78 | ) 79 | self.x_module = ResNet_FC(data_dim, hid, num_res_blocks=5) 80 | 81 | self.out_module = nn.Sequential( 82 | nn.Linear(hid, hid), 83 | SiLU(), 84 | nn.Linear(hid, data_dim), 85 | ) 86 | 87 | def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: 88 | t_out = self.t_module(t_emb) 89 | x_out = self.x_module(x) 90 | out = self.out_module(x_out+t_out) 91 | return out 92 | 93 | 94 | class OpinionZ(torch.nn.Module): 95 | def __init__(self, data_dim=1000, hidden_dim=256, time_embed_dim=128): 96 | super(OpinionZ,self).__init__() 97 | 98 | self.time_embed_dim = time_embed_dim 99 | self.z_impl = OpinionZImpl(data_dim, time_embed_dim, hidden_dim) 100 | self.z_impl = torch.jit.script(self.z_impl) 101 | 102 | @property 103 | def inner_dtype(self): 104 | """ 105 | Get the dtype used by the torso of the model. 106 | """ 107 | return next(self.input_blocks.parameters()).dtype 108 | 109 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 110 | """ 111 | Apply the model to an input batch. 112 | :param x: an [N x C x ...] Tensor of inputs. 113 | :param t: a 1-D batch of timesteps. 114 | """ 115 | 116 | # make sure t.shape = [T] 117 | if len(t.shape)==0: 118 | t = t[None] 119 | 120 | t_emb = timestep_embedding(t, self.time_embed_dim) 121 | out = self.z_impl(x, t_emb) 122 | 123 | return out 124 | 125 | def build_opinion_net_policy(opt: Options, YorZ: str) -> torch.nn.Module: 126 | assert opt.x_dim == 1000 127 | # 2nets: {"hid":200, "out_hid":400, "time_embed_dim":256} 128 | return { 129 | "Y": OpinionY, 130 | "Z": OpinionZ, 131 | }.get(YorZ)() 132 | -------------------------------------------------------------------------------- /models/toy_net.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from options import Options 7 | 8 | from .util import SiLU, timestep_embedding 9 | 10 | 11 | class ToyY(torch.nn.Module): 12 | def __init__(self, data_dim: int = 2, hidden_dim: int = 128, time_embed_dim: int = 128): 13 | super(ToyY, self).__init__() 14 | 15 | self.time_embed_dim = time_embed_dim 16 | hid = hidden_dim 17 | 18 | self.t_module = nn.Sequential( 19 | nn.Linear(self.time_embed_dim, hid), 20 | SiLU(), 21 | nn.Linear(hid, hid), 22 | ) 23 | 24 | self.x_module = nn.Sequential( 25 | nn.Linear(data_dim, hid), 26 | SiLU(), 27 | nn.Linear(hid, hid), 28 | SiLU(), 29 | nn.Linear(hid, hid), 30 | ) 31 | 32 | self.out_module = nn.Sequential( 33 | nn.Linear(hid + hid, hid), 34 | SiLU(), 35 | nn.Linear(hid, hid), 36 | SiLU(), 37 | nn.Linear(hid, 1), 38 | ) 39 | 40 | @property 41 | def inner_dtype(self): 42 | """ 43 | Get the dtype used by the torso of the model. 44 | """ 45 | return next(self.input_blocks.parameters()).dtype 46 | 47 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 48 | """ 49 | Apply the model to an input batch. 50 | :param x: an [N x C x ...] Tensor of inputs. 51 | :param t: a 1-D batch of timesteps. 52 | """ 53 | 54 | # make sure t.shape = [T] 55 | if len(t.shape) == 0: 56 | t = t[None] 57 | 58 | t_emb = timestep_embedding(t, self.time_embed_dim) 59 | t_out = self.t_module(t_emb) 60 | x_out = self.x_module(x) 61 | xt_out = torch.cat([x_out, t_out], dim=1) 62 | out = self.out_module(xt_out) 63 | 64 | return out 65 | 66 | 67 | class ToyZ(torch.nn.Module): 68 | def __init__(self, data_dim: int = 2, hidden_dim: int = 128, time_embed_dim: int = 128): 69 | super(ToyZ, self).__init__() 70 | 71 | self.time_embed_dim = time_embed_dim 72 | hid = hidden_dim 73 | 74 | self.t_module = nn.Sequential( 75 | nn.Linear(self.time_embed_dim, hid), 76 | SiLU(), 77 | nn.Linear(hid, hid), 78 | ) 79 | 80 | self.x_module = nn.Sequential( 81 | nn.Linear(data_dim, hid), 82 | SiLU(), 83 | nn.Linear(hid, hid), 84 | SiLU(), 85 | nn.Linear(hid, hid), 86 | ) 87 | 88 | self.out_module = nn.Sequential( 89 | nn.Linear(hid, hid), 90 | SiLU(), 91 | nn.Linear(hid, data_dim), 92 | ) 93 | 94 | @property 95 | def inner_dtype(self): 96 | """ 97 | Get the dtype used by the torso of the model. 98 | """ 99 | return next(self.input_blocks.parameters()).dtype 100 | 101 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 102 | """ 103 | Apply the model to an input batch. 104 | :param x: an [N x C x ...] Tensor of inputs. 105 | :param t: a 1-D batch of timesteps. 106 | """ 107 | 108 | # make sure t.shape = [T] 109 | if len(t.shape) == 0: 110 | t = t[None] 111 | 112 | t_emb = timestep_embedding(t, self.time_embed_dim) 113 | t_out = self.t_module(t_emb) 114 | x_out = self.x_module(x) 115 | out = self.out_module(x_out + t_out) 116 | 117 | return out 118 | 119 | 120 | def build_toy_net_policy(opt: Options, YorZ: str) -> torch.nn.Module: 121 | assert opt.x_dim == 2 122 | return { 123 | "Y": ToyY, 124 | "Z": ToyZ, 125 | }.get(YorZ)() 126 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | 12 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 13 | These may be fractional. 14 | :param dim: the dimension of the output. 15 | :param max_period: controls the minimum frequency of the embeddings. 16 | :return: an [N x dim] Tensor of positional embeddings. 17 | """ 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | return embedding 27 | 28 | class SiLU(nn.Module): 29 | def forward(self, x): 30 | return x * torch.sigmoid(x) 31 | 32 | class ResNet_FC(nn.Module): 33 | def __init__(self, data_dim, hidden_dim, num_res_blocks): 34 | super().__init__() 35 | self.hidden_dim = hidden_dim 36 | self.map=nn.Linear(data_dim, hidden_dim) 37 | self.res_blocks = nn.ModuleList( 38 | [self.build_res_block() for _ in range(num_res_blocks)]) 39 | 40 | def build_linear(self, in_features, out_features): 41 | linear = nn.Linear(in_features, out_features) 42 | return linear 43 | 44 | def build_res_block(self): 45 | hid = self.hidden_dim 46 | layers = [] 47 | widths =[hid]*4 48 | for i in range(len(widths) - 1): 49 | layers.append(self.build_linear(widths[i], widths[i + 1])) 50 | layers.append(SiLU()) 51 | return nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | h=self.map(x) 55 | for res_block in self.res_blocks: 56 | h = (h + res_block(h)) / 2 57 | return h 58 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | from dataclasses import dataclass 6 | from typing import Dict, Optional 7 | 8 | import numpy as np 9 | import torch 10 | 11 | import configs 12 | 13 | 14 | @dataclass 15 | class Options: 16 | problem_name: str 17 | seed: int 18 | gpu: int 19 | load: Optional[str] 20 | dir: str 21 | group: str 22 | name: str 23 | log_fn: Optional[str] 24 | log_tb: bool 25 | cpu: bool 26 | t0: float 27 | T: float 28 | interval: int 29 | policy_net: str 30 | diffusion_std: float 31 | train_bs_x: int 32 | num_stage: int 33 | num_itr: int 34 | samp_bs: int 35 | samp_method: str 36 | rb_bs_x: int 37 | MF_cost: float 38 | lr: float 39 | lr_y: Optional[float] 40 | lr_gamma: float 41 | lr_step: int 42 | l2_norm: float 43 | optimizer: str 44 | noise_type: str 45 | ema: float 46 | snapshot_freq: int 47 | ckpt_freq: int 48 | sb_param: str 49 | use_rb_loss: bool 50 | multistep_td: bool 51 | buffer_size: int 52 | weighted_loss: bool 53 | x_dim: int 54 | device: str 55 | ckpt_path: str 56 | eval_path: str 57 | log_dir: str 58 | # Additional options set in problem config. 59 | weights: Optional[Dict[str, float]] = None 60 | 61 | 62 | def set(): 63 | # fmt: off 64 | # --------------- basic --------------- 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--problem-name", type=str) 67 | parser.add_argument("--seed", type=int, default=0) 68 | parser.add_argument("--gpu", type=int, default=0, help="GPU device") 69 | parser.add_argument("--load", type=str, default=None, help="load the checkpoints") 70 | parser.add_argument("--dir", type=str, default=None, help="directory name to save the experiments under results/") 71 | parser.add_argument("--group", type=str, default='0', help="father node of directionary for saving checkpoint") 72 | parser.add_argument("--name", type=str, default=None, help="son node of directionary for saving checkpoint") 73 | parser.add_argument("--log-fn", type=str, default=None, help="name of tensorboard logging") 74 | parser.add_argument("--log-tb", action="store_true", help="logging with tensorboard") 75 | parser.add_argument("--cpu", action="store_true", help="use cpu device") 76 | 77 | # --------------- DeepGSB & MFG --------------- 78 | parser.add_argument("--t0", type=float, default=0.0, help="time integral start time") 79 | parser.add_argument("--T", type=float, default=1.0, help="time integral end time") 80 | parser.add_argument("--interval", type=int, default=100, help="number of interval") 81 | parser.add_argument("--policy-net", type=str, help="model class of policy network") 82 | parser.add_argument("--diffusion-std", type=float, default=1.0, help="diffusion scalar in SDE") 83 | parser.add_argument("--sb-param", type=str, choices=['actor-critic', 'critic']) 84 | parser.add_argument("--MF-cost", type=float, default=0.0, help="coefficient of MF cost") 85 | 86 | # --------------- training & sampling --------------- 87 | parser.add_argument("--train-bs-x", type=int, help="batch size for sampling data") 88 | parser.add_argument("--num-stage", type=int, help="number of stage") 89 | parser.add_argument("--num-itr", type=int, help="number of training iterations (for each stage)") 90 | parser.add_argument("--samp-bs", type=int, help="batch size for all trajectory sampling purposes") 91 | parser.add_argument("--samp-method", type=str, default='jacobi', choices=['jacobi','gauss']) # gauss seidel 92 | parser.add_argument("--rb-bs-x", type=int, help="batch size when sampling from replay buffer") 93 | parser.add_argument("--use-rb-loss", action="store_true", help="whether or not to use the replay buffer loss") 94 | parser.add_argument("--multistep-td", action="store_true", help="whether or not to use the multi-step TD loss") 95 | parser.add_argument("--buffer-size", type=int, default=20000, help="the maximum size of replay buffer") 96 | parser.add_argument("--weighted-loss", action="store_true", help="whether or not to reweight the combined loss") 97 | 98 | # --------------- optimizer and loss --------------- 99 | parser.add_argument("--lr", type=float, help="learning rate for Znet") 100 | parser.add_argument("--lr-y", type=float, default=None, help="learning rate for Ynet") 101 | parser.add_argument("--lr-gamma", type=float, default=1.0, help="learning rate decay ratio") 102 | parser.add_argument("--lr-step", type=int, default=1000, help="learning rate decay step size") 103 | parser.add_argument("--l2-norm", type=float, default=0.0, help="weight decay rate") 104 | parser.add_argument("--optimizer", type=str, default='AdamW', help="optmizer") 105 | parser.add_argument("--noise-type", type=str, default='gaussian', choices=['gaussian','rademacher'], help='choose noise type to approximate Trace term') 106 | parser.add_argument("--ema", type=float, default=0.99) 107 | 108 | # ---------------- evaluation ---------------- 109 | parser.add_argument("--snapshot-freq", type=int, default=1, help="snapshot frequency w.r.t stages") 110 | parser.add_argument("--ckpt-freq", type=int, default=1, help="checkpoint saving frequency w.r.t stages") 111 | 112 | # fmt: on 113 | 114 | problem_name = parser.parse_args().problem_name 115 | sb_param = parser.parse_args().sb_param 116 | 117 | parser.set_defaults(**configs.get_default(problem_name, sb_param)) 118 | opt = parser.parse_args() 119 | # ========= seed & torch setup ========= 120 | if opt.seed is not None: 121 | # https://github.com/pytorch/pytorch/issues/7068 122 | seed = opt.seed 123 | random.seed(seed) 124 | os.environ["PYTHONHASHSEED"] = str(seed) 125 | np.random.seed(seed) 126 | torch.manual_seed(seed) 127 | torch.cuda.manual_seed(seed) 128 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 129 | torch.backends.cudnn.enabled = True 130 | torch.backends.cudnn.benchmark = True 131 | # torch.backends.cudnn.deterministic = True 132 | 133 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 134 | # torch.autograd.set_detect_anomaly(True) 135 | 136 | # ========= auto setup & path handle ========= 137 | opt.device = "cuda:" + str(opt.gpu) 138 | 139 | if opt.name is None: 140 | opt.name = opt.dir 141 | 142 | opt.ckpt_path = os.path.join("checkpoint", opt.group, opt.name) 143 | os.makedirs(opt.ckpt_path, exist_ok=True) 144 | if opt.snapshot_freq: 145 | opt.eval_path = os.path.join("results", opt.dir) 146 | os.makedirs(os.path.join(opt.eval_path, "forward"), exist_ok=True) 147 | os.makedirs(os.path.join(opt.eval_path, "backward"), exist_ok=True) 148 | os.makedirs(os.path.join(opt.eval_path, "logp"), exist_ok=True) 149 | if "opinion" in opt.problem_name: 150 | os.makedirs( 151 | os.path.join(opt.eval_path, "directional_sim_forward"), exist_ok=True 152 | ) 153 | os.makedirs( 154 | os.path.join(opt.eval_path, "directional_sim_backward"), exist_ok=True 155 | ) 156 | 157 | if opt.log_tb: 158 | opt.log_dir = os.path.join( 159 | "runs", opt.dir 160 | ) # if opt.log_fn is not None else None 161 | if os.path.exists(opt.log_dir): 162 | shutil.rmtree(opt.log_dir) # remove folder & its files 163 | 164 | opt = Options(**vars(opt)) 165 | return opt 166 | -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: deepgsb 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - numpy 10 | - scipy 11 | - termcolor 12 | - easydict 13 | - ipdb 14 | - pytorch==1.11.0 15 | - torchvision 16 | - cudatoolkit=10.2 17 | - tqdm 18 | - scikit-learn 19 | - imageio 20 | - matplotlib 21 | - tensorboard 22 | - pip 23 | - pip: 24 | - colored-traceback 25 | - torch-ema 26 | - gdown 27 | - rich 28 | - geomloss -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | DATE=$(date +%m.%d) 3 | 4 | EXP=$1 5 | 6 | if [[ "$EXP" == "GMM" || "$EXP" == "all" ]]; then 7 | # GMM (std1, obs1500, multistep, gauss) 8 | BASE='--problem-name GMM --ckpt-freq 2 --snapshot-freq 2 --log-tb' 9 | python main.py $BASE --sb-param critic --dir gmm-$DATE/deepgsb-c-std1 10 | python main.py $BASE --sb-param actor-critic --dir gmm-$DATE/deepgsb-ac-std1 11 | fi 12 | 13 | if [[ "$EXP" == "Vneck" || "$EXP" == "all" ]]; then 14 | # Vneck (std1, obs1500, mf0/3, multistep, jacobi, use_rb) 15 | BASE='--problem-name Vneck --ckpt-freq 2 --snapshot-freq 2 --log-tb' 16 | python main.py $BASE --sb-param critic --dir vneck-$DATE/deepgsb-c-std1-mf3 --MF-cost 3.0 17 | python main.py $BASE --sb-param critic --dir vneck-$DATE/deepgsb-c-std1-mf0 --MF-cost 0.0 18 | 19 | python main.py $BASE --sb-param actor-critic --dir vneck-$DATE/deepgsb-ac-std1-mf3 --MF-cost 3.0 20 | python main.py $BASE --sb-param actor-critic --dir vneck-$DATE/deepgsb-ac-std1-mf0 --MF-cost 0.0 21 | fi 22 | 23 | if [[ "$EXP" == "Stunnel" || "$EXP" == "all" ]]; then 24 | # Stunnel (obs1500, congestion1, multistep, jacobi, use_rb) 25 | BASE='--problem-name Stunnel --ckpt-freq 2 --snapshot-freq 2 --log-tb --MF-cost 0.5 ' 26 | 27 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std0.5-mf1 --diffusion-std 0.5 28 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std1-mf1 --diffusion-std 1.0 29 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std2-mf1 --diffusion-std 2.0 30 | 31 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std0.5-mf1 --diffusion-std 0.5 32 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std1-mf1 --diffusion-std 1.0 33 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std2-mf1 --diffusion-std 2.0 34 | fi 35 | 36 | if [[ "$EXP" == "opinion" || "$EXP" == "all" ]]; then 37 | # opinion (std0.1, multistep, gauss, use_rb) 38 | BASE='--problem-name opinion --ckpt-freq 10 --snapshot-freq 2 --log-tb --MF-cost 1.0 --weighted-loss' 39 | python main.py $BASE --sb-param critic --dir opinion-2d-$DATE/deepgsb-c-std0.1-mf1-w 40 | python main.py $BASE --sb-param actor-critic --dir opinion-2d-$DATE/deepgsb-ac-std0.1-mf1-w 41 | fi 42 | 43 | if [[ "$EXP" == "opinion-1k" || "$EXP" == "all" ]]; then 44 | # opinion (std0.5, singlestep, gauss) 45 | BASE='--problem-name opinion_1k --ckpt-freq 20 --snapshot-freq 10 --log-tb --weighted-loss' 46 | python main.py $BASE --sb-param actor-critic --dir opinion-1k-$DATE/deepgsb-ac-std0.5-mf1-w --MF-cost 1.0 47 | python main.py $BASE --sb-param actor-critic --dir opinion-1k-$DATE/deepgsb-ac-std0.5-mf0-w --MF-cost 0.0 48 | fi 49 | --------------------------------------------------------------------------------