├── .gitignore
├── LICENSE
├── README.md
├── assets
├── GMM.gif
├── Stunnel.gif
├── Vneck.gif
├── opinion_after.gif
└── opinion_before.gif
├── configs
├── __init__.py
├── gmm.py
├── opinion.py
├── opinion_1k.py
├── stunnel.py
└── vneck.py
├── deepgsb
├── __init__.py
├── deepgsb.py
├── eval_metrics.py
├── loss_lib.py
├── replay_buffer.py
├── sb_policy.py
└── util.py
├── git_utils.py
├── main.py
├── make_animation.py
├── mfg
├── __init__.py
├── constraint.py
├── mfg.py
├── opinion_lib.py
├── plotting.py
├── sde.py
├── state_cost.py
└── util.py
├── models
├── __init__.py
├── opinion_net.py
├── toy_net.py
└── util.py
├── options.py
├── requirements.yaml
└── run.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *__pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | scripts/*
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 |
138 | data/*
139 | checkpoint/*
140 | plots/*
141 | *vscode/*
142 | deprecated.py
143 | runs/*
144 | .DS_Store
145 | ._.DS_Store
146 | *ipynb
147 | deprecated/*
148 | *gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Generalized Schrödinger Bridge
[[NeurIPS 2022 Oral](https://arxiv.org/abs/2209.09893)]
2 |
3 | Official PyTorch implementation of the paper
4 | "_**Deep** **G**eneralized **S**chrödinger **B**ridge_ (**DeepGSB**)" which introduces
5 | a new class of diffusion models as a scalable numerical solver for Mean-Field Games (MFGs), _e.g._, population modeling & opinion depolarization, with hard distributional constraints.
6 |
7 |
8 |
9 |
10 | Population modeling (crowd navigation) |
11 | Opinion depolarization |
12 |
13 |
14 |  |
15 |  |
16 |  |
17 |  |
18 |  |
19 |
20 |
21 |
22 | This repo is co-maintained by [Guan-Horng Liu](https://ghliu.github.io/), [Tianrong Chen](https://tianrongchen.github.io/), and [Oswin So](https://oswinso.xyz/). Contact us if you have any questions! If you find this library useful, please cite :arrow_down:
23 | ```
24 | @inproceedings{liu2022deep,
25 | title={Deep Generalized Schr{\"o}dinger Bridge},
26 | author={Liu, Guan-Horng and Chen, Tianrong and So, Oswin and Theodorou, Evangelos A},
27 | booktitle={Advances in Neural Information Processing Systems},
28 | year={2022}
29 | }
30 | ```
31 |
32 |
33 | ## Install
34 |
35 | Install the dependencies with [Anaconda](https://www.anaconda.com/products/individual) and activate the environment `deepgsb` with
36 | ```bash
37 | conda env create --file requirements.yaml
38 | conda activate deepgsb
39 | ```
40 |
41 |
42 | ## Run & Evaluate
43 |
44 | The repo contains 2 classes of Mean-Field Games, namely
45 | - **population modeling**: `GMM`, `Vneck`, `Stunnel`
46 | - **opinion depolarization**: `opinion`, `opinion-1k` (dim=1000).
47 |
48 | The commands to generate similar results shown in our paper can be found in `run.sh`. Results, checkpoints, and tensorboard log files will be saved respectively to the folders `results/`, `checkpoint/`, and `runs/`.
49 | ```bash
50 | bash run.sh # can be {GMM, Vneck, Stunnel, opinion, opinion-1k}
51 | ```
52 |
53 | You can visualize the trained DeepGSB policies by making gif animation
54 | ```bash
55 | python make_animation.py --load --name
56 | ```
57 |
58 | ## Structure
59 |
60 | We briefly document the file structure to ease the effort if you wish to integrate DeepGSB with your work flow.
61 | ```bash
62 | deepgsb/
63 | ├── deepgsb.py # the DeepGSB MFG solver
64 | ├── sb_policy.py # the parametrized Schrödinger Bridge policy
65 | ├── loss_lib.py # all loss functions (IPF/KL, TD, FK/grad)
66 | ├── eval_metrics.py # all logging metrics (Wasserstein, etc)
67 | ├── replay_buffer.py
68 | └── util.py
69 | mfg/
70 | ├── mfg.py # the Mean-Field Game environment
71 | ├── constraint.py # the distributional boundary constraint (p0, pT)
72 | ├── state_cost.py # all mean-field interaction state costs (F)
73 | ├── sde.py # the associated stochastic processes (f, sigma)
74 | ├── opinion_lib.py # all utilities for opinion depolarization MFG
75 | ├── plotting.py
76 | └── util.py
77 | models/ # the deep networks for parametrizing SB policy
78 | configs/ # the configurations for each MFG
79 | ```
80 |
--------------------------------------------------------------------------------
/assets/GMM.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/GMM.gif
--------------------------------------------------------------------------------
/assets/Stunnel.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/Stunnel.gif
--------------------------------------------------------------------------------
/assets/Vneck.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/Vneck.gif
--------------------------------------------------------------------------------
/assets/opinion_after.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/opinion_after.gif
--------------------------------------------------------------------------------
/assets/opinion_before.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ghliu/DeepGSB/d11a8a19b90443cb3af28a90f35c532045e5eebc/assets/opinion_before.gif
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | def get_default(problem_name, sb_param):
4 | env_name = problem_name.lower()
5 | config_fn = env_name + "_" + sb_param.replace("-", "_")
6 |
7 | # print(env_name)
8 | module = importlib.import_module(f"configs.{env_name}")
9 | assert hasattr(module, config_fn)
10 | return getattr(module, config_fn)()
11 |
--------------------------------------------------------------------------------
/configs/gmm.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | def _common():
4 | config = edict()
5 | config.problem_name = 'GMM'
6 | config.x_dim = 2
7 |
8 | # sde
9 | config.seed = 42
10 | config.t0 = 0.0
11 | config.T = 1.0
12 | config.interval = 100
13 | config.diffusion_std = 1.0
14 |
15 | # training
16 | config.num_itr = 250
17 | config.train_bs_x = 128
18 | config.rb_bs_x = 128
19 |
20 | # sampling & evaluation
21 | config.samp_bs = 5000
22 | config.snapshot_freq = 1
23 | # config.ckpt_freq = 2
24 |
25 | # optimization
26 | config.optimizer = 'AdamW'
27 |
28 | return config
29 |
30 | def gmm_actor_critic():
31 | config = _common()
32 |
33 | # paramatrization
34 | config.sb_param = 'actor-critic'
35 | config.policy_net = 'toy'
36 |
37 | # optimization
38 | config.lr = 5e-4
39 | config.lr_y = 1e-3
40 | config.lr_gamma = 0.999
41 |
42 | # tuning
43 | config.num_stage = 40
44 | config.multistep_td = True
45 | config.samp_method = 'gauss'
46 |
47 | return config
48 |
49 | def gmm_critic():
50 | config = _common()
51 |
52 | # paramatrization
53 | config.sb_param = 'critic'
54 | config.policy_net = 'toy'
55 |
56 | # optimization
57 | config.lr = 5e-4
58 | config.lr_gamma = 0.999
59 |
60 | # tuning
61 | config.num_stage = 40
62 | config.multistep_td = True
63 | config.samp_method = 'gauss'
64 |
65 | return config
66 |
--------------------------------------------------------------------------------
/configs/opinion.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 |
4 | def _common():
5 | config = edict()
6 | config.problem_name = "opinion"
7 | config.x_dim = 2
8 |
9 | # sde
10 | config.seed = 42
11 | config.t0 = 0.0
12 | config.T = 3.0
13 | config.interval = 300
14 | config.diffusion_std = 0.1
15 |
16 | # training
17 | config.num_stage = 40
18 | config.num_itr = 100
19 | config.train_bs_x = 128
20 | config.rb_bs_x = 128
21 |
22 | # sampling & evaluation
23 | config.samp_bs = 2000
24 | config.snapshot_freq = 1
25 | # config.ckpt_freq = 2
26 |
27 | # optimization
28 | config.optimizer = "AdamW"
29 |
30 | return config
31 |
32 |
33 | def opinion_actor_critic():
34 | config = _common()
35 |
36 | # paramatrization
37 | config.sb_param = "actor-critic"
38 | config.policy_net = "toy"
39 |
40 | # optimization
41 | coeff = 3.
42 | config.lr = coeff * 5e-4
43 | config.lr_y = 1e-3
44 | config.lr_gamma = 0.999
45 |
46 | # tuning
47 | config.num_stage = 40
48 | config.multistep_td = True
49 | config.use_rb_loss = True
50 |
51 | config.samp_method = "gauss"
52 |
53 | config.weights = {'kl': 0.8, 'non-kl': 0.05}
54 |
55 | return config
56 |
57 |
58 | def opinion_critic():
59 | config = _common()
60 |
61 | # paramatrization
62 | config.sb_param = "critic"
63 | config.policy_net = "toy"
64 |
65 | # optimization
66 | config.lr = 5e-4
67 | config.lr_gamma = 0.999
68 |
69 | # tuning
70 | config.num_stage = 40
71 | config.multistep_td = True
72 | config.use_rb_loss = True
73 |
74 | config.samp_method = "gauss"
75 |
76 | config.weights = {'kl': 0.6, 'non-kl': 0.002}
77 |
78 | return config
--------------------------------------------------------------------------------
/configs/opinion_1k.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 |
4 | def _common():
5 | config = edict()
6 | config.problem_name = "opinion_1k"
7 | config.x_dim = 1000
8 |
9 | # sde
10 | config.seed = 42
11 | config.t0 = 0.0
12 | config.T = 3.0
13 | config.interval = 500
14 | config.diffusion_std = 0.5
15 |
16 | # training
17 | config.train_bs_x = 128
18 | config.rb_bs_x = 128
19 | config.buffer_size = 5000
20 |
21 | # sampling & evaluation
22 | config.samp_bs = 2500
23 | config.snapshot_freq = 1
24 | # config.ckpt_freq = 2
25 |
26 | # optimization
27 | config.optimizer = "AdamW"
28 |
29 | return config
30 |
31 |
32 | def opinion_1k_actor_critic():
33 | config = _common()
34 |
35 | # paramatrization
36 | config.sb_param = "actor-critic"
37 | config.policy_net = "opinion_net"
38 |
39 | # optimization
40 | coeff = 1.
41 |
42 | config.lr = coeff * 5e-4
43 | config.lr_y = 1e-3
44 | config.lr_gamma = 0.999
45 |
46 | # tuning
47 | config.num_stage = 130
48 | config.num_itr = 250
49 | # config.multistep_td = True
50 | # config.use_rb_loss = True
51 |
52 | config.samp_method = "gauss"
53 |
54 | config.weights = {'kl': 0.8, 'non-kl': 0.05}
55 |
56 | return config
57 |
--------------------------------------------------------------------------------
/configs/stunnel.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | def _common():
4 | config = edict()
5 | config.problem_name = 'Stunnel'
6 | config.x_dim = 2
7 |
8 | # sde
9 | config.seed = 42
10 | config.t0 = 0.0
11 | config.T = 3.0
12 | config.interval = 300
13 | config.diffusion_std = 1.0
14 |
15 | # training
16 | config.num_itr = 500
17 | config.train_bs_x = 128
18 | config.rb_bs_x = 128
19 |
20 | # sampling & evaluation
21 | config.samp_bs = 5000
22 | config.snapshot_freq = 1
23 | # config.ckpt_freq = 2
24 |
25 | # optimization
26 | config.optimizer = 'AdamW'
27 |
28 | return config
29 |
30 | def stunnel_actor_critic():
31 | config = _common()
32 |
33 | # paramatrization
34 | config.sb_param = 'actor-critic'
35 | config.policy_net = 'toy'
36 |
37 | # optimization
38 | config.lr = 5e-4
39 | config.lr_y = 1e-3
40 | config.lr_gamma = 0.999
41 |
42 | # tuning
43 | config.num_stage = 40
44 | config.multistep_td = True
45 | config.use_rb_loss = True
46 | config.samp_method = 'jacobi'
47 |
48 | return config
49 |
50 | def stunnel_critic():
51 | config = _common()
52 |
53 | # paramatrization
54 | config.sb_param = 'critic'
55 | config.policy_net = 'toy'
56 |
57 | # optimization
58 | config.lr = 5e-4
59 | config.lr_gamma = 0.999
60 |
61 | # tuning
62 | config.num_stage = 40
63 | config.multistep_td = True
64 | config.samp_method = 'jacobi'
65 | config.ema = 0.9
66 |
67 | return config
68 |
--------------------------------------------------------------------------------
/configs/vneck.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | def _common():
4 | config = edict()
5 | config.problem_name = 'Vneck'
6 | config.x_dim = 2
7 |
8 | # sde
9 | config.seed = 42
10 | config.t0 = 0.0
11 | config.T = 2.0
12 | config.interval = 200
13 | config.diffusion_std = 1.0
14 |
15 | # training
16 | config.num_itr = 250
17 | config.train_bs_x = 128
18 | config.rb_bs_x = 128
19 |
20 | # sampling & evaluation
21 | config.samp_bs = 5000
22 | config.snapshot_freq = 1
23 | # config.ckpt_freq = 2
24 |
25 | # optimization
26 | config.optimizer = 'AdamW'
27 |
28 | return config
29 |
30 | def vneck_actor_critic():
31 | config = _common()
32 |
33 | # paramatrization
34 | config.sb_param = 'actor-critic'
35 | config.policy_net = 'toy'
36 |
37 | # optimization
38 | config.lr = 5e-4
39 | config.lr_y = 1e-3
40 | config.lr_gamma = 0.999
41 |
42 | # tuning
43 | config.num_stage = 40
44 | config.multistep_td = True
45 | config.use_rb_loss = True
46 |
47 | config.samp_method = 'jacobi'
48 |
49 | return config
50 |
51 | def vneck_critic():
52 | config = _common()
53 |
54 | # paramatrization
55 | config.sb_param = 'critic'
56 | config.policy_net = 'toy'
57 |
58 | # optimization
59 | config.lr = 5e-4
60 | config.lr_gamma = 0.999
61 |
62 | # tuning
63 | config.num_stage = 40
64 | config.multistep_td = True
65 | config.use_rb_loss = True
66 |
67 | # gauss converges faster but with slightly worse W2
68 | config.samp_method = 'jacobi'
69 |
70 | return config
71 |
--------------------------------------------------------------------------------
/deepgsb/__init__.py:
--------------------------------------------------------------------------------
1 | from .deepgsb import DeepGSB
--------------------------------------------------------------------------------
/deepgsb/deepgsb.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import gc
3 | import logging
4 | import os
5 | import pickle
6 | import time
7 | from typing import Dict, Optional, Tuple
8 |
9 | import torch
10 | from easydict import EasyDict as edict
11 | from torch.optim import SGD, Adagrad, Adam, AdamW, RMSprop, lr_scheduler
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 | from mfg import MFG, MFGPolicy
15 | from options import Options
16 |
17 | from . import eval_metrics, loss_lib, sb_policy, util
18 | from .replay_buffer import Buffer
19 |
20 | log = logging.getLogger(__file__)
21 |
22 | OptSchedPair = Tuple[torch.optim.Optimizer, lr_scheduler._LRScheduler]
23 |
24 |
25 | def build_optimizer_sched(opt: Options, policy) -> OptSchedPair:
26 | optim_name = {
27 | 'Adam': Adam,
28 | 'AdamW': AdamW,
29 | 'Adagrad': Adagrad,
30 | 'RMSprop': RMSprop,
31 | 'SGD': SGD,
32 | }.get(opt.optimizer)
33 |
34 | optim_dict = {
35 | "lr": opt.lr,
36 | 'weight_decay':opt.l2_norm,
37 | }
38 | if opt.optimizer == 'SGD':
39 | optim_dict['momentum'] = 0.9
40 |
41 | if policy.param == "actor-critic":
42 | optimizer = optim_name([
43 | {'params': policy.Znet.parameters()}, # use original optim_dict
44 | {'params': policy.Ynet.parameters(), 'lr': opt.lr_y},
45 | ], **optim_dict)
46 | elif policy.param == "critic":
47 | optimizer = optim_name(policy.parameters(), **optim_dict)
48 | else:
49 | raise ValueError(f"Expected either actor-critic or critic, got {policy.param}")
50 |
51 | if opt.lr_gamma < 1.0:
52 | sched = lr_scheduler.StepLR(optimizer, step_size=opt.lr_step, gamma=opt.lr_gamma)
53 | else:
54 | sched = None
55 |
56 | return optimizer, sched
57 |
58 | def get_grad_loss_norm(opt: Options) -> str:
59 | problem_name: str = opt.problem_name
60 |
61 | if problem_name in ["GMM", "opinion", "opinion_1k"]:
62 | # Use L1 on GMM.
63 | return "l1"
64 | if problem_name in []:
65 | # Use L2 loss for ??
66 | return "l2"
67 |
68 | # Otherwise, default to Huber loss.
69 | return "huber"
70 |
71 | class DeepGSB:
72 | def __init__(self, opt: Options, mfg: MFG, save_opt: bool = True):
73 | super(DeepGSB, self).__init__()
74 |
75 | # Save opt.
76 | if save_opt:
77 | opt_pkl_path = opt.ckpt_path + "/options.pkl"
78 | with open(opt_pkl_path, "wb") as f:
79 | pickle.dump(opt, f)
80 | log.info("Saved options pickle to {}!".format(opt_pkl_path))
81 |
82 | self.start_time = time.time()
83 |
84 | self.mfg = mfg
85 |
86 | # build forward (z_f) and backward (z_b) policies
87 | self.z_f = sb_policy.build(opt, mfg.sde, 'forward') # p0 -> pT
88 | self.z_b = sb_policy.build(opt, mfg.sde, 'backward') # p0 -> pT
89 | if mfg.uses_mf_drift():
90 | mfg.initialize_mf_drift(self.z_f)
91 |
92 | self.optimizer_f, self.sched_f = build_optimizer_sched(opt, self.z_f)
93 | self.optimizer_b, self.sched_b = build_optimizer_sched(opt, self.z_b)
94 |
95 | self.buffer_f = Buffer(opt, 'forward') if opt.use_rb_loss else None
96 | self.buffer_b = Buffer(opt, 'backward') if opt.use_rb_loss else None
97 |
98 | self.it_f = self.it_b = 0
99 | if opt.log_tb: # tensorboard related things
100 | self.writer=SummaryWriter(log_dir=opt.log_dir)
101 |
102 | if opt.load:
103 | util.restore_checkpoint(opt, self, opt.load)
104 |
105 | @property
106 | def is_critic_param(self) -> bool:
107 | return self.z_f.param == self.z_b.param == "critic"
108 |
109 | @property
110 | def is_actor_critic_param(self) -> bool:
111 | return self.z_f.param == self.z_b.param == "actor-critic"
112 |
113 | def get_count(self, direction: str) -> int:
114 | return self.it_f if direction == "forward" else self.it_b
115 |
116 | def update_count(self, direction: str) -> int:
117 | if direction == 'forward':
118 | self.it_f += 1
119 | return self.it_f
120 | elif direction == 'backward':
121 | self.it_b += 1
122 | return self.it_b
123 | else:
124 | raise RuntimeError()
125 |
126 | def get_optimizer_sched(self, z: MFGPolicy) -> OptSchedPair:
127 | if z == self.z_f:
128 | return self.optimizer_f, self.sched_f
129 | elif z == self.z_b:
130 | return self.optimizer_b, self.sched_b
131 | else:
132 | raise RuntimeError()
133 |
134 | @torch.no_grad()
135 | def sample_train_data(self, opt: Options, train_direction: str) -> edict:
136 | policy_opt, policy_impt = {
137 | 'forward': [self.z_f, self.z_b], # train forward, sample from backward
138 | 'backward': [self.z_b, self.z_f], # train backward, sample from forward
139 | }.get(train_direction)
140 |
141 | # prepare training data
142 | train_ts = self.mfg.ts.detach()
143 |
144 | # update mf_drift if we need it and we're sampling forward traj
145 | update_mf_drift = (self.mfg.uses_mf_drift() and policy_impt.direction == 'forward')
146 |
147 | ema, ema_impt = policy_opt.get_ema(), policy_impt.get_ema()
148 | with ema.average_parameters(), ema_impt.average_parameters():
149 | policy_impt.freeze()
150 | policy_opt.freeze()
151 |
152 | xs, zs, ws, _ = self.mfg.sample_traj(policy_impt, update_mf_drift=update_mf_drift)
153 | train_xs = xs.detach().cpu(); del xs
154 | train_zs = zs.detach().cpu(); del zs
155 | train_ws = ws.detach().cpu(); del ws
156 |
157 | log.info('generate train data from [sampling]!')
158 |
159 | assert train_xs.shape[0] == opt.samp_bs
160 | assert train_xs.shape[1] == len(train_ts)
161 | assert train_xs.shape == train_zs.shape
162 | gc.collect()
163 |
164 | return edict(
165 | xs=train_xs, zs=train_zs, ws=train_ws, ts=train_ts
166 | )
167 |
168 | def train_stage(self, opt: Options, stage: int, train_direction: str, datas: Optional[edict]=None) -> None:
169 | policy_opt, policy_impt = {
170 | 'forward': [self.z_f, self.z_b], # train forwad, sample from backward
171 | 'backward': [self.z_b, self.z_f], # train backward, sample from forward
172 | }.get(train_direction)
173 |
174 | buffer_impt, buffer_opt = {
175 | 'forward': [self.buffer_f, self.buffer_b],
176 | 'backward': [self.buffer_b, self.buffer_f],
177 | }.get(policy_impt.direction)
178 |
179 | if datas is None:
180 | datas = self.sample_train_data(opt, train_direction)
181 |
182 | # Compute the cost and statistical distance for the forward / backward trajectories
183 | if opt.log_tb:
184 | t1 = time.time()
185 | self.log_validate_metrics(opt, train_direction, datas)
186 | log.info("Done logging validate metrics! Took {:.1f} s!".format(time.time() - t1))
187 |
188 | # update buffers
189 | if opt.use_rb_loss:
190 | buffer_impt.append(datas)
191 |
192 | self.train_ep(opt, stage, train_direction, datas, policy_opt, policy_impt, buffer_opt, buffer_impt)
193 |
194 | def train_ep(
195 | self,
196 | opt: Options,
197 | stage: int,
198 | direction: str,
199 | datas: edict,
200 | policy: MFGPolicy,
201 | policy_impt: MFGPolicy,
202 | buffer_opt: Optional[Buffer],
203 | buffer_impt: Optional[Buffer],
204 | ) -> None:
205 | train_xs, train_zs, train_ws, train_ts = datas.xs, datas.zs, datas.ws, datas.ts
206 |
207 | assert train_xs.shape[0] == opt.samp_bs
208 | assert train_zs.shape[0] == opt.samp_bs
209 | assert train_ts.shape[0] == opt.interval
210 | assert direction == policy.direction
211 |
212 | optimizer, sched = self.get_optimizer_sched(policy)
213 | optimizer_impt, _ = self.get_optimizer_sched(policy_impt)
214 |
215 | policy.activate() # activate Y (and Z)
216 | policy_impt.freeze() # freeze Y_impt (and Z_impt)
217 |
218 | if stage>0 and opt.use_rb_loss: assert len(buffer_opt)>0 and len(buffer_impt)>0
219 |
220 | mfg = self.mfg
221 | samp_direction = policy_impt.direction
222 | for it in range(opt.num_itr):
223 | step = self.update_count(direction)
224 |
225 | # -------- sample x_idx and t_idx \in [0, interval] --------
226 | samp_x_idx = torch.randint(opt.samp_bs, (opt.train_bs_x,))
227 |
228 | dim01 = [opt.train_bs_x, opt.interval]
229 |
230 | # -------- build sample --------
231 | ts = train_ts.detach()
232 | xs = train_xs[samp_x_idx].to(opt.device)
233 | zs_impt = train_zs[samp_x_idx].to(opt.device)
234 | dw = train_ws[samp_x_idx].to(opt.device)
235 |
236 | if mfg.uses_xs_all():
237 | samp_x_idx2 = torch.randint(opt.samp_bs, (opt.train_bs_x,))
238 | xs_all = train_xs[samp_x_idx2].to(opt.device)
239 | else:
240 | xs_all = None
241 |
242 | optimizer.zero_grad()
243 | optimizer_impt.zero_grad()
244 |
245 | # -------- compute KL loss --------
246 | loss_kl, zs, kl, _ = loss_lib.compute_kl_loss(
247 | opt, dim01, mfg, samp_direction,
248 | ts.detach(), xs.detach(), zs_impt.detach(),
249 | policy, return_all=True
250 | )
251 |
252 | # -------- compute bsde TD loss --------
253 | loss_bsde_td = loss_lib.compute_bsde_td_loss(
254 | opt, mfg, samp_direction,
255 | ts.detach(), xs.detach(), zs.detach(), dw.detach(), kl.detach(),
256 | policy, policy_impt, xs_all
257 | )
258 |
259 | # -------- compute boundary loss --------
260 | loss_boundary = loss_lib.compute_boundary_loss(
261 | opt, mfg, ts.detach(), xs.detach(), policy_impt, policy,
262 | )
263 |
264 | # -------- compute mismatch loss between Z and \nabla_x Y --------
265 | loss_grad = torch.Tensor([0.0])
266 | if self.is_actor_critic_param:
267 | norm = get_grad_loss_norm(opt)
268 | loss_grad = loss_lib.compute_grad_loss(opt, ts.detach(), xs.detach(), mfg.sde, policy, norm)
269 |
270 | # -------- compute replay buffer loss ---------
271 | loss_bsde_td_rb = torch.Tensor([0.0])
272 | if opt.use_rb_loss:
273 | loss_bsde_td_rb = loss_lib.compute_bsde_td_loss_from_buffer(
274 | opt, mfg, buffer_impt, ts.detach(), policy, policy_impt, xs_all
275 | )
276 |
277 | # -------- compute loss and backprop --------
278 | if self.is_critic_param:
279 | if opt.weighted_loss:
280 | w_kl, w_nkl = opt.weights['kl'], opt.weights['non-kl']
281 | loss = w_kl * loss_kl + w_nkl * (loss_bsde_td + loss_boundary + loss_bsde_td_rb)
282 | else:
283 | loss = loss_kl + loss_bsde_td + loss_boundary + loss_bsde_td_rb
284 |
285 | elif self.is_actor_critic_param:
286 | if opt.weighted_loss:
287 | w_kl, w_nkl = opt.weights['kl'], opt.weights['non-kl']
288 | loss = w_kl * loss_kl + w_nkl * (loss_boundary + loss_bsde_td + loss_grad + loss_bsde_td_rb)
289 | else:
290 | loss = loss_kl + loss_boundary + loss_bsde_td + loss_grad + loss_bsde_td_rb
291 | else:
292 | raise RuntimeError("")
293 |
294 | assert not torch.isnan(loss)
295 | loss.backward()
296 |
297 | optimizer.step()
298 | policy.update_ema()
299 |
300 | if sched is not None: sched.step()
301 |
302 | # -------- logging --------
303 | loss = edict(
304 | kl=loss_kl, grad=loss_grad,
305 | boundary=loss_boundary, bsde_td=loss_bsde_td,
306 | )
307 | if it % 20 == 0:
308 | self.log_train(opt, it, stage, loss, optimizer, direction)
309 |
310 | def train(self, opt: Options) -> None:
311 | self.evaluate(opt, 0)
312 |
313 | for stage in range(opt.num_stage):
314 | if opt.samp_method == 'jacobi':
315 | datas1 = self.sample_train_data(opt, 'forward')
316 | datas2 = self.sample_train_data(opt, 'backward')
317 | self.train_stage(opt, stage, 'forward', datas=datas1)
318 | self.train_stage(opt, stage, 'backward', datas=datas2)
319 |
320 | elif opt.samp_method == 'gauss':
321 | self.train_stage(opt, stage, 'forward')
322 | self.train_stage(opt, stage, 'backward')
323 |
324 | t1 = time.time()
325 | self.evaluate(opt, stage+1)
326 | log.info("Finished evaluate! Took {:.2f}s.".format(time.time() - t1))
327 |
328 | if opt.log_tb: self.writer.close()
329 |
330 | @torch.no_grad()
331 | def evaluate(self, opt: Options, stage: int) -> None:
332 | snapshot, ckpt = util.evaluate_stage(opt, stage)
333 | if snapshot:
334 | self.z_f.freeze(); self.z_b.freeze()
335 | self.mfg.save_snapshot(self.z_f, self.z_b, stage)
336 |
337 | if ckpt and stage > 0:
338 | keys = ['z_f','optimizer_f','z_b','optimizer_b', "it_f", "it_b"]
339 | util.save_checkpoint(opt, self, keys, stage)
340 |
341 | def compute_validate_metrics(self, opt: Options, train_direction: str, datas: edict) -> Dict[str, torch.Tensor]:
342 | # Sample direction is opposite of train_direction.
343 | xs, zs, ts = datas.xs, datas.zs, datas.ts
344 |
345 | b, T, nx = xs.shape
346 | assert zs.shape == (b, T, nx)
347 | assert ts.shape == (T,)
348 |
349 | mfg = self.mfg
350 | dt = mfg.dt
351 |
352 | metrics = {}
353 |
354 | # Compute "polarization" via the "condition number" of the covariance matrix.
355 | if "opinion" in mfg.problem_name:
356 | metrics.update(
357 | eval_metrics.compute_conv_l1_metrics(mfg, xs, train_direction)
358 | )
359 |
360 | # Compute Wasserstein distance between the x0 / xT and p0 / pT.
361 | metrics.update(
362 | eval_metrics.compute_sinkhorn_metrics(opt, mfg, xs, train_direction)
363 | )
364 |
365 | # Compute the state + control cost.
366 | if mfg.uses_xs_all():
367 | xs_all = xs.to(opt.device)
368 | xs_all = xs_all[torch.randperm(opt.samp_bs)]
369 | else:
370 | xs_all = None
371 |
372 | est_mf_cost, logp = eval_metrics.compute_est_mf_cost(
373 | opt, self, xs, ts, dt, xs_all, return_logp=True
374 | )
375 | s_cost, mf_cost = eval_metrics.compute_state_cost(opt, mfg, xs, ts, xs_all, logp, dt)
376 | del logp, xs_all
377 | control_cost = eval_metrics.compute_control_cost(opt, zs, dt)
378 |
379 | mean_s_cost, mean_control_cost, mean_mf_cost = s_cost.mean(), control_cost.mean(), mf_cost.mean()
380 |
381 | mean_nonmf_cost = mean_s_cost + mean_control_cost
382 | mean_total_cost = mean_nonmf_cost + mfg.mf_coeff * mean_mf_cost
383 |
384 | metrics["est_mf_cost"] = est_mf_cost
385 | metrics["state_cost"] = mean_s_cost
386 | metrics["nonmf_cost"] = mean_nonmf_cost
387 | metrics["mf_cost"] = mean_mf_cost
388 | metrics["control_cost"] = mean_control_cost
389 | metrics["total_cost"] = mean_total_cost
390 |
391 | return metrics
392 |
393 | @torch.no_grad()
394 | def log_validate_metrics(self, opt: Options, direction: str, datas: edict) -> None:
395 |
396 | def tag(name: str) -> str:
397 | return os.path.join(f"{direction}-loss", name)
398 | step = self.get_count(direction)
399 |
400 | metrics = self.compute_validate_metrics(opt, direction, datas)
401 |
402 | # Log all metrics.
403 | for key in metrics:
404 | self.writer.add_scalar(tag(key), metrics[key], global_step=step)
405 | del metrics
406 |
407 | def log_train(
408 | self, opt: Options, it: int, stage: int, loss: edict, optimizer: torch.optim.Optimizer, direction: str
409 | ) -> None:
410 | time_elapsed = util.get_time(time.time()-self.start_time)
411 | lr = optimizer.param_groups[0]['lr']
412 | log.info("[SB {0}] stage {1}/{2} | itr {3}/{4} | lr {5} | loss {6} | time {7}"
413 | .format(
414 | "fwd" if direction=="forward" else "bwd",
415 | str(1+stage).zfill(2),
416 | opt.num_stage,
417 | str(1+it).zfill(3),
418 | opt.num_itr,
419 | "{:.2e}".format(lr),
420 | util.get_loss_str(loss),
421 | "{0}:{1:02d}:{2:05.2f}".format(*time_elapsed),
422 | ))
423 |
424 | step = self.get_count(direction)
425 | if opt.log_tb:
426 | assert isinstance(loss, edict)
427 | for key, val in loss.items():
428 | # assert val > 0 # for taking log
429 | self.writer.add_scalar(
430 | os.path.join(f'{direction}-loss', f'{key}'), val.detach(), global_step=step
431 | )
432 |
433 | # Also log the current stage.
434 | self.writer.add_scalar(os.path.join(f"{direction}-loss", "stage"), stage, global_step=step)
435 |
436 | # Log the LR.
437 | self.writer.add_scalar(os.path.join(f"{direction}-opt", "lr"), lr, global_step=step)
438 |
--------------------------------------------------------------------------------
/deepgsb/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | import math
4 | import time
5 | from typing import Dict, List, Tuple, Union, TYPE_CHECKING
6 |
7 | import torch
8 | from geomloss import SamplesLoss
9 |
10 | from mfg import MFG
11 | from options import Options
12 |
13 | from . import util
14 |
15 | if TYPE_CHECKING:
16 | from deepgsb.deepgsb import DeepGSB
17 |
18 | log = logging.getLogger(__file__)
19 |
20 |
21 | def get_bound(mfg: MFG, xs: torch.Tensor, train_direction: str):
22 | return {
23 | "forward": [xs[:, 0], mfg.p0, "0"], # train forward, xs are sampled from backward
24 | "backward": [xs[:, -1], mfg.pT, "T"], # train backward, xs are sampled from forward
25 | }.get(train_direction)
26 |
27 |
28 | @torch.no_grad()
29 | def compute_conv_l1_metrics(mfg: MFG, xs: torch.Tensor, train_direction: str) -> Dict[str, torch.Tensor]:
30 | b, T, nx = xs.shape
31 |
32 | metrics = {}
33 |
34 | # (samp_bs, ...)
35 | x_bound, p_bound, time_str = get_bound(mfg, xs, train_direction)
36 |
37 | # Max singular value / min singular value.
38 | S = torch.linalg.svdvals(x_bound)
39 | # Largest to smallest.
40 | eigvals = (S ** 2) / (b - 1)
41 | cond_number = eigvals[0] / eigvals[-1]
42 | del S, eigvals
43 |
44 | # Eigvalsh returns smallest to largest.
45 | true_cov = p_bound.distribution.covariance_matrix.cpu()
46 | true_eigvals = torch.linalg.eigvalsh(true_cov)
47 | true_cond_number = true_eigvals[-1] / true_eigvals[0]
48 |
49 | metrics[f"cond_num diff {time_str}"] = torch.abs(cond_number - true_cond_number)
50 |
51 | # Pred stds (project on each axis).
52 | pred_stds = torch.std(x_bound, dim=0)
53 | # True stds (project on each axis).
54 | true_stds = p_bound.distribution.stddev.cpu()
55 |
56 | std_l1 = torch.abs(true_stds - pred_stds).sum()
57 | metrics[f"std l1 {time_str}"] = std_l1
58 |
59 | # Also just compare the covariance matrix.
60 | pred_cov = torch.cov(x_bound.T)
61 | assert pred_cov.shape == true_cov.shape
62 | cov_l1 = torch.abs(true_cov - pred_cov).sum()
63 | metrics[f"cov l1 {time_str}"] = cov_l1
64 |
65 | return metrics
66 |
67 |
68 | @torch.no_grad()
69 | def compute_sinkhorn_metrics(opt: Options, mfg: MFG, xs: torch.Tensor, train_direction: str) -> Dict[str, torch.Tensor]:
70 | b, T, nx = xs.shape
71 |
72 | metrics = {}
73 |
74 | # Higher scaling results in more accurate estimate.
75 | if opt.x_dim > 2:
76 | # L1 norm is better in high dimensional setting.
77 | # If it's high dimension, empiricla Wasserstein is probably not that good of a metric, so we don't need
78 | # to spend as much time on computing this.
79 | sinkhorn = SamplesLoss("sinkhorn", p=1, blur=5e-2, scaling=0.9)
80 | else:
81 | sinkhorn = SamplesLoss("sinkhorn", p=2, blur=5e-3, scaling=0.9)
82 |
83 | x_bound, p_bound, time_str = get_bound(mfg, xs, train_direction)
84 | W2_name = f"W2_{time_str}"
85 |
86 | p_samples = p_bound.sample(batch=b)
87 |
88 | log.info("Computing sinkhorn....")
89 | t1 = time.time()
90 | x_bound = x_bound.to(opt.device)
91 | metrics[W2_name] = sinkhorn(x_bound, p_samples)
92 | log.info("Done! sinkhorn took {:.1f} s".format(time.time() - t1))
93 |
94 | return metrics
95 |
96 |
97 | @torch.no_grad()
98 | def compute_control_cost(opt: Options, zs: torch.Tensor, dt: float) -> torch.Tensor:
99 | b, T, nx = zs.shape
100 |
101 | zs = zs.to(opt.device)
102 | control_cost = torch.square(zs).sum((1, 2)) * dt
103 | return control_cost.detach().cpu()
104 |
105 |
106 | @torch.no_grad()
107 | def compute_state_cost(
108 | opt: Options, mfg: MFG, xs: torch.Tensor, ts: torch.Tensor, xs_all: torch.Tensor, logp: torch.Tensor, dt: float
109 | ) -> Tuple[torch.Tensor, torch.Tensor]:
110 | b, T, nx = xs.shape
111 | dim01 = (b, T)
112 |
113 | flat_xs = util.flatten_dim01(xs).to(opt.device)
114 | s_cost, mf_cost = mfg.state_cost_fn(flat_xs, ts, logp, xs_all)
115 | del flat_xs
116 |
117 | # (samp_bs, T)
118 | s_cost, mf_cost = util.unflatten_dim01(s_cost, dim01), util.unflatten_dim01(mf_cost, dim01)
119 |
120 | s_cost = (s_cost.sum(1) * dt).detach().cpu()
121 | mf_cost = (mf_cost.sum(1) * dt).detach().cpu()
122 |
123 | return s_cost, mf_cost
124 |
125 |
126 | @torch.jit.script
127 | def serial_logkde(train_xs: torch.Tensor, xs: torch.Tensor, bw: float, max_batch: int = 32) -> torch.Tensor:
128 | # xs: (b1, T, *)
129 | # train_xs: (b2, T, *)
130 | # out: (b1, T)
131 |
132 | b1, T, nx = xs.shape
133 | b2, T, nx = train_xs.shape
134 | assert xs.shape == (b1, T, nx) and train_xs.shape == (b2, T, nx)
135 |
136 | coeff = b2 * math.sqrt(2 * math.pi * bw)
137 | log_coeff = math.log(coeff)
138 |
139 | xs_chunks = torch.split(xs, max_batch, dim=0)
140 | out_chunks = []
141 | for xs in xs_chunks:
142 | # 1: Compute diffs. (b1, b2, T, *)
143 | diffs = train_xs.unsqueeze(0) - xs.unsqueeze(1)
144 | # (b1, b2, T)
145 | norm_sq = torch.sum(torch.square(diffs), dim=-1)
146 |
147 | # (b1, b2, T) -> (b1, T)
148 | logsumexp = torch.logsumexp(-norm_sq / (2 * bw), dim=1)
149 |
150 | out = logsumexp - log_coeff
151 | out_chunks.append(out)
152 |
153 | out = torch.cat(out_chunks, dim=0)
154 | assert out.shape == (b1, T)
155 |
156 | return out
157 |
158 |
159 | @torch.no_grad()
160 | def compute_est_mf_cost(
161 | opt: Options,
162 | runner: "DeepGSB",
163 | xs: torch.Tensor,
164 | ts: torch.Tensor,
165 | dt: float,
166 | xs_all: torch.Tensor,
167 | return_logp: bool = False,
168 | ) -> Union[List[torch.Tensor], torch.Tensor]:
169 | if not runner.mfg.uses_logp():
170 | return [torch.zeros(1), None] if return_logp else torch.zeros(1)
171 |
172 | b, T, nx = xs.shape
173 | dim01 = (b, T)
174 |
175 | bw = 0.2 ** 2
176 |
177 | xs = xs.to(opt.device)
178 | if b > 500:
179 | rand_idxs = torch.randperm(b)[:500]
180 | reduced_xs = xs[rand_idxs]
181 | del rand_idxs
182 | else:
183 | reduced_xs = xs
184 |
185 | max_batch = 32 if opt.x_dim == 2 else 2
186 | logp = serial_logkde(reduced_xs, xs, bw=bw, max_batch=max_batch)
187 | del reduced_xs
188 |
189 | assert logp.shape == (b, T)
190 | logp = util.flatten_dim01(logp).detach().cpu()
191 |
192 | gc.collect()
193 |
194 | # Evaluate logp in chunks to prevent cuda OOM.
195 | ts = ts.repeat(b)
196 | flat_xs = util.flatten_dim01(xs)
197 |
198 | max_chunk_size = 512
199 | flat_xs_chunks = torch.split(flat_xs, max_chunk_size, dim=0)
200 | ts_chunks = torch.split(ts, max_chunk_size, dim=0)
201 |
202 | est_logp = []
203 | for flat_xs_chunk, ts_chunk in zip(flat_xs_chunks, ts_chunks):
204 | flat_xs_chunk = flat_xs_chunk.to(opt.device)
205 | ts_chunk = ts_chunk.to(opt.device)
206 | value_t_ema = runner.z_f.compute_value(flat_xs_chunk, ts_chunk, use_ema=True)
207 | value_impt_t_ema = runner.z_b.compute_value(flat_xs_chunk, ts_chunk, use_ema=True)
208 | est_logp.append((value_t_ema + value_impt_t_ema).detach().cpu())
209 | del flat_xs_chunk, ts_chunk, value_t_ema, value_impt_t_ema
210 | est_logp = torch.cat(est_logp, dim=0)
211 |
212 | # flat_xs, ts = flat_xs.to(opt.device), ts.to(opt.device)
213 | # value_t_ema = self.z_f.compute_value(flat_xs, ts, use_ema=True)
214 | # value_impt_t_ema = self.z_b.compute_value(flat_xs, ts, use_ema=True)
215 | # est_logp = value_t_ema + value_impt_t_ema
216 |
217 | assert logp.shape == est_logp.shape == (b * T,)
218 |
219 | _, est_mf_cost = runner.mfg.state_cost_fn(flat_xs, ts, est_logp, xs_all)
220 | del est_logp
221 |
222 | est_mf_cost = util.unflatten_dim01(est_mf_cost, dim01)
223 | est_mf_cost = torch.mean(est_mf_cost.sum(1) * dt).detach().cpu()
224 |
225 | return [est_mf_cost, logp] if return_logp else est_mf_cost
226 |
--------------------------------------------------------------------------------
/deepgsb/loss_lib.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union, Iterable
2 |
3 | import torch
4 | from torch.nn.functional import huber_loss
5 |
6 | from mfg import MFG, MFGPolicy
7 | from mfg.sde import SimpleSDE
8 | from options import Options
9 |
10 | from . import util
11 | from .replay_buffer import Buffer
12 |
13 | DivGZRetType = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
14 | KLLossRetType = Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
15 |
16 |
17 | @torch.jit.script
18 | def rev_cumsum(x: torch.Tensor, dim: int) -> torch.Tensor:
19 | x = torch.flip(x, dims=[dim])
20 | x = torch.cumsum(x, dim=dim)
21 | return torch.flip(x, dims=[dim])
22 |
23 |
24 | def sample_e(opt: Options, x: torch.Tensor) -> torch.Tensor:
25 | if opt.noise_type == "gaussian":
26 | return torch.randn_like(x)
27 | elif opt.noise_type == "rademacher":
28 | return torch.randint(low=0, high=2, size=x.shape).to(x.device) * 2 - 1
29 | else:
30 | raise ValueError(f"Unsupport noise type {opt.noise_type}!")
31 |
32 |
33 | def compute_div_gz(
34 | opt: Options, mfg: MFG, ts: torch.Tensor, xs: torch.Tensor, policy: MFGPolicy, return_zs: bool = False
35 | ) -> DivGZRetType:
36 | b, T, nx = xs.shape
37 | assert ts.shape == (T,)
38 |
39 | if mfg.uses_state_dependent_drift():
40 | f = mfg.f(xs, ts, policy.direction)
41 | assert f.shape == (b, T, nx)
42 | f = util.flatten_dim01(f)
43 | else:
44 | f = 0
45 |
46 | ts = ts.repeat(xs.shape[0])
47 | xs = util.flatten_dim01(xs)
48 | zs = policy(xs, ts)
49 |
50 | g_ts = mfg.g(ts)
51 | g_ts = g_ts[:, None]
52 | gzs = g_ts * zs - f
53 |
54 | e = sample_e(opt, xs)
55 | e_dzdx = torch.autograd.grad(gzs, xs, e, create_graph=True)[0]
56 | div_gz = e_dzdx * e
57 |
58 | return [div_gz, zs] if return_zs else div_gz
59 |
60 |
61 | def compute_kl(
62 | opt: Options,
63 | mfg: MFG,
64 | ts: torch.Tensor,
65 | xs: torch.Tensor,
66 | zs_impt: torch.Tensor,
67 | policy: MFGPolicy,
68 | return_zs: bool = False,
69 | ) -> DivGZRetType:
70 | b, T, nx = xs.shape
71 | assert ts.shape == (T,)
72 |
73 | zs_impt = util.flatten_dim01(zs_impt)
74 | assert zs_impt.shape == (b * T, nx)
75 |
76 | with torch.enable_grad():
77 | _xs = xs.detach()
78 | xs = _xs.requires_grad_(True)
79 |
80 | div_gz, zs = compute_div_gz(opt, mfg, ts, xs, policy, return_zs=True)
81 | assert div_gz.shape == zs.shape == (b * T, nx)
82 |
83 | # (b * T, xdim)
84 | kl = zs * (0.5 * zs + zs_impt) + div_gz
85 | assert kl.shape == (b * T, nx)
86 |
87 | return [kl, zs] if return_zs else kl
88 |
89 |
90 | def compute_norm_loss(
91 | norm: str, predict: torch.Tensor, label: torch.Tensor, batch_x: int, dt: float, delta: float = 1.0
92 | ) -> torch.Tensor:
93 | assert norm in ["l1", "l2", "huber"]
94 | assert predict.shape == label.shape
95 |
96 | if norm == "l1":
97 | return 0.5 * ((predict - label).abs() * dt).sum() / batch_x
98 | elif norm == "l2":
99 | return 0.5 * ((predict - label) ** 2 * dt).sum() / batch_x
100 | elif norm == "huber":
101 | return huber_loss(predict, label, reduction="sum", delta=delta) * dt / batch_x
102 |
103 |
104 | def compute_kl_loss(
105 | opt: Options,
106 | dim01: Iterable[int],
107 | mfg: MFG,
108 | samp_direction: str,
109 | ts: torch.Tensor,
110 | xs: torch.Tensor,
111 | zs_impt: torch.Tensor,
112 | policy: MFGPolicy,
113 | return_all: bool = False,
114 | ) -> KLLossRetType:
115 |
116 | kl, zs = compute_kl(opt, mfg, ts, xs, zs_impt, policy, return_zs=True)
117 | zs_impt = util.flatten_dim01(zs_impt)
118 | assert kl.shape == zs.shape == zs_impt.shape
119 |
120 | _, x_init, p_init, _, x_term, p_term = mfg.get_init_term_bound(xs, samp_direction)
121 |
122 | # computationally same as kl yet better interpretation
123 | bsde_y_yhat = util.unflatten_dim01((kl + 0.5 * zs_impt ** 2).sum(dim=-1) * mfg.dt, dim01).sum(
124 | dim=1
125 | ) # (batch_x, len_t) --> (batch_x)
126 | loss_kl = (p_init.log_prob(x_init) + bsde_y_yhat - p_term.log_prob(x_term)).mean()
127 |
128 | return [loss_kl, zs, kl, bsde_y_yhat] if return_all else loss_kl
129 |
130 |
131 | def compute_bsde_td_loss_multistep(
132 | mfg: MFG,
133 | samp_direction: str,
134 | ts: torch.Tensor,
135 | xs: torch.Tensor,
136 | zs: torch.Tensor,
137 | dw: torch.Tensor,
138 | kl: torch.Tensor,
139 | policy: MFGPolicy,
140 | policy_impt: MFGPolicy,
141 | xs_all: torch.Tensor,
142 | norm: str = "huber",
143 | ) -> torch.Tensor:
144 | # xs, xs_all: (b, T, nx)
145 | # zs: (b * T, nx)
146 | # dw: (b, T, nx)
147 | # kl: (b * T, nx)
148 | b, T, nx = xs.shape
149 | ts = ts.repeat(xs.shape[0])
150 | xs = util.flatten_dim01(xs)
151 | dw = util.flatten_dim01(dw)
152 | zs = zs.reshape(*dw.shape)
153 |
154 | dim01 = [b, T]
155 |
156 | # (b * T, )
157 | value_ema = policy.compute_value(xs, ts, use_ema=True)
158 | # ( b * T, )
159 | value_impt_ema = policy_impt.compute_value(xs, ts, use_ema=True)
160 |
161 | log_p = value_ema + value_impt_ema
162 | state_cost = mfg.compute_state_cost(xs, ts, xs_all, log_p)
163 |
164 | # (b * T, )
165 | bsde2_1step = (kl.sum(dim=-1) - state_cost) * mfg.dt + (zs * dw).sum(dim=-1)
166 |
167 | # (b, T)
168 | bsde2_1step = util.unflatten_dim01(bsde2_1step, dim01)
169 |
170 | # (b * T, )
171 | value = policy.compute_value(xs, ts)
172 |
173 | # (b, T, nx)
174 | xs = util.unflatten_dim01(xs, dim01)
175 |
176 | t_bnd, x_bnd, p_bnd = mfg.get_bound(xs, samp_direction, "init")
177 | if samp_direction == "forward":
178 | with torch.no_grad():
179 | # (b, )
180 | value_impt_bnd = util.unflatten_dim01(value_impt_ema, dim01)[:, 0]
181 | # (b, )
182 | target_bnd = p_bnd.log_prob(x_bnd) - value_impt_bnd
183 |
184 | # (b, )
185 | bsde2_cumstep = torch.cumsum(bsde2_1step, dim=1)
186 | # (b, T) -> (b, T - 1)
187 | target = (target_bnd[:, None] + bsde2_cumstep)[:, :-1]
188 | # (b, T) -> (b, T - 1)
189 | predict = util.unflatten_dim01(value, dim01)[:, 1:]
190 | else:
191 | with torch.no_grad():
192 | # (b, )
193 | value_impt_bnd = util.unflatten_dim01(value_impt_ema, dim01)[:, -1]
194 | # (b, )
195 | target_bnd = p_bnd.log_prob(x_bnd) - value_impt_bnd
196 |
197 | # Cumulative sum from T to 0.
198 | bsde2_cumstep = rev_cumsum(bsde2_1step, dim=1)
199 | target = (target_bnd[:, None] + bsde2_cumstep)[:, 1:]
200 | predict = util.unflatten_dim01(value, dim01)[:, :-1]
201 |
202 | label = target.detach()
203 | return compute_norm_loss(norm, predict, label, b, mfg.dt, delta=2.0)
204 |
205 |
206 | def compute_bsde_td_loss_singlestep(
207 | mfg: MFG,
208 | samp_direction: str,
209 | ts: torch.Tensor,
210 | xs: torch.Tensor,
211 | zs: torch.Tensor,
212 | dw: torch.Tensor,
213 | kl: torch.Tensor,
214 | policy: MFGPolicy,
215 | policy_impt: MFGPolicy,
216 | xs_all: torch.Tensor,
217 | norm: str = "l2",
218 | ) -> torch.Tensor:
219 | b, T, nx = xs.shape
220 | assert samp_direction in ["forward", "backward"]
221 |
222 | # (1) flattent all input from (b,T, ...) to (b*T, ...)
223 | ts = ts.repeat(xs.shape[0])
224 | xs = util.flatten_dim01(xs)
225 | dw = util.flatten_dim01(dw)
226 | zs = zs.reshape(*dw.shape)
227 |
228 | value = policy.compute_value(xs, ts)
229 | value_ema = policy.compute_value(xs, ts, use_ema=True)
230 | value_impt_ema = policy_impt.compute_value(xs, ts, use_ema=True)
231 |
232 | # (2) compute state cost (i.e., F in Eq 14)
233 | log_p = value_ema + value_impt_ema
234 | state_cost = mfg.compute_state_cost(xs, ts, xs_all, log_p)
235 |
236 | # (3) construct δY, predicted Y, target Y (from ema)
237 | bsde_1step = (kl.sum(dim=-1) - state_cost) * mfg.dt + (zs * dw).sum(dim=-1)
238 | bsde_predict = value
239 | bsde_target = value_ema
240 |
241 | assert bsde_predict.shape == bsde_1step.shape == bsde_target.shape
242 |
243 | bsde_predict = util.unflatten_dim01(bsde_predict, [b, T])
244 | bsde_1step = util.unflatten_dim01(bsde_1step, [b, T])
245 | bsde_target = util.unflatten_dim01(bsde_target, [b, T])
246 |
247 | # (4) compute TD prediction and target in Eq 14
248 | # [forward] dYhat_t: predict=Yhat_{t+1}, label=Yhat_t + [...]
249 | # [forward] dY_s: predict Y_{s+1}, label=Y_s + [...]
250 | if samp_direction == "forward":
251 | label = (bsde_target + bsde_1step)[:, :-1].detach()
252 | predict = bsde_predict[:, 1:]
253 | elif samp_direction == "backward":
254 | label = (bsde_target + bsde_1step)[:, 1:].detach()
255 | predict = bsde_predict[:, :-1]
256 | else:
257 | raise ValueError(f"samp_direction should be either forward or backward, got {samp_direction}")
258 |
259 | return compute_norm_loss(norm, predict, label, b, mfg.dt)
260 |
261 |
262 | def compute_bsde_td_loss(
263 | opt: Options,
264 | mfg: MFG,
265 | samp_direction: str,
266 | ts: torch.Tensor,
267 | xs: torch.Tensor,
268 | zs: torch.Tensor,
269 | dw: torch.Tensor,
270 | kl: torch.Tensor,
271 | policy: MFGPolicy,
272 | policy_impt: MFGPolicy,
273 | xs_all: torch.Tensor,
274 | ) -> torch.Tensor:
275 | if opt.multistep_td:
276 | bsde_td_loss = compute_bsde_td_loss_multistep
277 | else:
278 | bsde_td_loss = compute_bsde_td_loss_singlestep
279 | return bsde_td_loss(mfg, samp_direction, ts, xs, zs, dw, kl, policy, policy_impt, xs_all)
280 |
281 |
282 | def compute_bsde_td_loss_from_buffer(
283 | opt: Options, mfg, buffer: Buffer, ts: torch.Tensor, policy: MFGPolicy, policy_impt: MFGPolicy, xs_all: torch.Tensor
284 | ) -> torch.Tensor:
285 | samp_direction = policy_impt.direction
286 | assert policy.direction != buffer.direction
287 | assert policy_impt.direction == buffer.direction
288 |
289 | batch_x = opt.rb_bs_x
290 | xs, zs_impt, dw = buffer.sample_traj(batch_x)
291 |
292 | assert xs.shape == zs_impt.shape == (batch_x, opt.interval, opt.x_dim)
293 |
294 | kl, zs = compute_kl(opt, mfg, ts, xs, zs_impt, policy, return_zs=True)
295 | assert kl.shape == zs.shape and kl.shape[0] == (batch_x * opt.interval)
296 |
297 | return compute_bsde_td_loss(opt, mfg, samp_direction, ts, xs, zs, dw, kl, policy, policy_impt, xs_all)
298 |
299 |
300 | def compute_boundary_loss(
301 | opt: Options, mfg: MFG, ts: torch.Tensor, xs: torch.Tensor, policy2: MFGPolicy, policy: MFGPolicy, norm: str = "l2"
302 | ) -> torch.Tensor:
303 |
304 | t_bound, x_bound, p_bound = mfg.get_bound(xs.detach(), policy.direction, "term")
305 |
306 | batch_x, nx = x_bound.shape
307 |
308 | t_bound = t_bound.repeat(batch_x)
309 |
310 | value2_ema = policy2.compute_value(x_bound, t_bound, use_ema=True)
311 | label = (p_bound.log_prob(x_bound) - value2_ema).detach()
312 | predict = policy.compute_value(x_bound, t_bound)
313 |
314 | return compute_norm_loss(norm, predict, label, batch_x, mfg.dt)
315 |
316 |
317 | def compute_grad_loss(
318 | opt: Options, ts: torch.Tensor, xs: torch.Tensor, dyn: SimpleSDE, policy: MFGPolicy, norm: str
319 | ) -> torch.Tensor:
320 | batch_x = xs.shape[0]
321 | ts = ts.repeat(batch_x)
322 | xs = util.flatten_dim01(xs)
323 |
324 | predict = policy(xs, ts)
325 | label = dyn.std * policy.compute_value_grad(xs, ts).detach()
326 |
327 | return compute_norm_loss(norm, predict, label, batch_x, dyn.dt, delta=1.0)
328 |
--------------------------------------------------------------------------------
/deepgsb/replay_buffer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | from easydict import EasyDict as edict
5 |
6 | from options import Options
7 |
8 |
9 | class Buffer:
10 | def __init__(self, opt: Options, direction: str):
11 | self.opt = opt
12 | self.max_samples = opt.buffer_size
13 | self.direction = direction
14 |
15 | self.nx = opt.x_dim
16 |
17 | self.it = 0
18 | self.n_samples = 0
19 | self.samples = torch.empty(self.max_samples, self.opt.interval, 3 * self.nx, device="cpu")
20 |
21 | def __len__(self) -> int:
22 | return self.n_samples
23 |
24 | def append(self, datas: edict) -> None:
25 | xs, zs, dws = datas.xs, datas.zs, datas.ws
26 | assert xs.shape == zs.shape == dws.shape
27 |
28 | it, max_samples, batch = self.it, self.max_samples, xs.shape[0]
29 | sample = torch.cat([xs, zs, dws], dim=-1).detach().cpu()
30 | self.samples[it : it + batch] = sample[0 : min(batch, max_samples - it), ...]
31 | if batch > max_samples - it:
32 | _it = batch - (max_samples - it)
33 | self.samples[0:_it] = sample[-_it:, ...]
34 | assert ((it + batch) % max_samples) == _it
35 |
36 | self.it = (it + batch) % max_samples
37 | self.n_samples = min(self.n_samples + batch, max_samples)
38 |
39 | def clear(self) -> None:
40 | self.samples = torch.empty_like(self.samples)
41 | self.n_samples = self.it = 0
42 |
43 | def sample_traj(self, batch_x: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
44 | samp_x_idx = torch.randint(self.n_samples, (batch_x,))
45 | samples = self.samples[samp_x_idx].to(self.opt.device)
46 |
47 | xs = samples[..., 0 * self.nx : 1 * self.nx].detach() # (batch_x, T, nx)
48 | zs = samples[..., 1 * self.nx : 2 * self.nx].detach() # (batch_x, T, nx)
49 | ws = samples[..., 2 * self.nx : 3 * self.nx].detach() # (batch_x, T, nx)
50 |
51 | return xs, zs, ws
52 |
--------------------------------------------------------------------------------
/deepgsb/sb_policy.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Tuple, TypeVar
3 |
4 | import torch
5 | from torch_ema import ExponentialMovingAverage
6 |
7 | from mfg import MFGPolicy
8 | from mfg.sde import SimpleSDE
9 | from models import build_opinion_net_policy, build_toy_net_policy
10 | from options import Options
11 |
12 | from . import util
13 |
14 | log = logging.getLogger(__file__)
15 |
16 | _TorchModule = TypeVar("_TorchModule", bound=torch.nn.Module)
17 |
18 |
19 | def build(opt: Options, dyn: SimpleSDE, direction: str) -> MFGPolicy:
20 | log.info("build {} SB model...".format(direction))
21 |
22 | # ------ build SB policy ------
23 | Ynet = _build_net(opt, "Y")
24 |
25 | if opt.sb_param == "critic":
26 | policy = SB_paramY(opt, direction, dyn, Ynet).to(opt.device)
27 | elif opt.sb_param == "actor-critic":
28 | Znet = _build_net(opt, "Z")
29 | policy = SB_paramYZ(opt, direction, dyn, Ynet, Znet).to(opt.device)
30 | else:
31 | raise RuntimeError(f"unknown sb net type {opt.sb_param}")
32 |
33 | log.info("# param in SBPolicy = {}".format(util.count_parameters(policy)))
34 |
35 | return policy
36 |
37 |
38 | def _build_net(opt: Options, YorZ: str) -> torch.nn.Module:
39 | assert YorZ in ["Y", "Z"]
40 |
41 | if opt.policy_net == "toy":
42 | net = build_toy_net_policy(opt, YorZ)
43 | elif opt.policy_net == "opinion_net":
44 | net = build_opinion_net_policy(opt, YorZ)
45 | else:
46 | raise RuntimeError()
47 | return net
48 |
49 |
50 | def _freeze(net: _TorchModule) -> _TorchModule:
51 | for p in net.parameters():
52 | p.requires_grad = False
53 | return net
54 |
55 |
56 | def _activate(net: _TorchModule) -> _TorchModule:
57 | for p in net.parameters():
58 | p.requires_grad = True
59 | return net
60 |
61 |
62 | class SchrodingerBridgeModel(MFGPolicy):
63 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, use_t_idx: bool = True):
64 | super(SchrodingerBridgeModel, self).__init__(direction, dyn)
65 | self.opt = opt
66 | self.use_t_idx = use_t_idx
67 | self.g = opt.diffusion_std
68 |
69 | def _preprocessing(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
70 | # make sure t.shape = [batch]
71 | t = t.squeeze()
72 | if t.dim() == 0:
73 | t = t.repeat(x.shape[0])
74 | assert t.dim() == 1 and t.shape[0] == x.shape[0]
75 |
76 | if self.use_t_idx:
77 | t = t / self.opt.T * self.opt.interval
78 | return x, t
79 |
80 | def compute_value(self, x: torch.Tensor, t: torch.Tensor, use_ema: bool = False) -> torch.Tensor:
81 | raise NotImplementedError()
82 |
83 | def compute_policy(self, x, t) -> torch.Tensor:
84 | raise NotImplementedError()
85 |
86 | def forward(self, x, t): # set default calling to policy
87 | return self.compute_policy(x, t)
88 |
89 | def freeze(self):
90 | self.eval()
91 | self.zero_grad()
92 |
93 | def activate(self):
94 | self.train()
95 |
96 |
97 | class SB_paramY(SchrodingerBridgeModel):
98 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, Ynet: torch.nn.Module):
99 | super(SB_paramY, self).__init__(opt, direction, dyn, use_t_idx=True)
100 | self.Ynet = Ynet
101 | self.emaY = ExponentialMovingAverage(self.Ynet.parameters(), decay=opt.ema)
102 | self.param = "critic"
103 |
104 | def compute_value(self, x: torch.Tensor, t: torch.Tensor, use_ema: bool = False) -> torch.Tensor:
105 | x, t = self._preprocessing(x, t)
106 | if use_ema:
107 | with self.emaY.average_parameters():
108 | return self._standardizeYnet(self.Ynet(x, t).squeeze())
109 | return self._standardizeYnet(self.Ynet(x, t).squeeze())
110 |
111 | def compute_value_grad(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
112 | x, t = self._preprocessing(x, t)
113 | requires_grad = x.requires_grad
114 | with torch.enable_grad():
115 | x.requires_grad_(True)
116 | y = self._standardizeYnet(self.Ynet(x, t))
117 | out = torch.autograd.grad(y.sum(), x, create_graph=self.training)[0]
118 | x.requires_grad_(requires_grad) # restore original setup
119 | if not self.training:
120 | self.zero_grad() # out = out.detach()
121 | return out
122 |
123 | def compute_policy(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
124 | # Z = g * nabla Y
125 | return self.g * self.compute_value_grad(x, t)
126 |
127 | def _standardizeYnet(self, out: torch.Tensor) -> torch.Tensor:
128 | # standardize the Ynet output
129 | return out / self.g
130 |
131 | def freeze(self) -> None:
132 | self.Ynet = _freeze(self.Ynet)
133 | super(SB_paramY, self).freeze()
134 |
135 | def activate(self) -> None:
136 | self.Ynet = _activate(self.Ynet)
137 | super(SB_paramY, self).activate()
138 |
139 | def get_ema(self) -> ExponentialMovingAverage:
140 | return self.emaY
141 |
142 | def update_ema(self) -> None:
143 | self.emaY.update()
144 |
145 | def state_dict(self):
146 | return {
147 | "Ynet": self.Ynet.state_dict(),
148 | "emaY": self.emaY.state_dict(),
149 | }
150 |
151 | def load_state_dict(self, state_dict) -> None:
152 | self.Ynet.load_state_dict(state_dict["Ynet"])
153 | self.emaY.load_state_dict(state_dict["emaY"])
154 |
155 |
156 | class SB_paramYZ(SB_paramY):
157 | def __init__(self, opt: Options, direction: str, dyn: SimpleSDE, Ynet: torch.nn.Module, Znet: torch.nn.Module):
158 | super(SB_paramYZ, self).__init__(opt, direction, dyn, Ynet)
159 | self.Znet = Znet
160 | self.emaZ = ExponentialMovingAverage(self.Znet.parameters(), decay=opt.ema)
161 | self.param = "actor-critic"
162 |
163 | def compute_policy(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
164 | x, t = self._preprocessing(x, t)
165 | return self.Znet(x, t)
166 |
167 | def freeze(self) -> None:
168 | self.Znet = _freeze(self.Znet)
169 | super(SB_paramYZ, self).freeze()
170 |
171 | def activate(self, only_Ynet: bool = False) -> None:
172 | self.Znet = _activate(self.Znet) if not only_Ynet else _freeze(self.Znet)
173 | super(SB_paramYZ, self).activate()
174 |
175 | def get_ema(self) -> ExponentialMovingAverage:
176 | return self.emaZ
177 |
178 | def update_ema(self, only_Ynet: bool = False) -> None:
179 | if not only_Ynet:
180 | self.emaZ.update()
181 | super(SB_paramYZ, self).update_ema()
182 |
183 | def state_dict(self):
184 | sdict = {
185 | "Znet": self.Znet.state_dict(),
186 | "emaZ": self.emaZ.state_dict(),
187 | }
188 | sdict.update(super(SB_paramYZ, self).state_dict())
189 | return sdict
190 |
191 | def load_state_dict(self, state_dict):
192 | self.Znet.load_state_dict(state_dict["Znet"])
193 | self.emaZ.load_state_dict(state_dict["emaZ"])
194 | super(SB_paramYZ, self).load_state_dict(state_dict)
195 |
--------------------------------------------------------------------------------
/deepgsb/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import TYPE_CHECKING, Iterable, Tuple
3 |
4 | import termcolor
5 | import torch
6 | from easydict import EasyDict as edict
7 |
8 | if TYPE_CHECKING:
9 | from deepgsb.deepgsb import DeepGSB
10 | from options import Options
11 |
12 | log = logging.getLogger(__file__)
13 |
14 |
15 | # convert to colored strings
16 | def red(content): return termcolor.colored(str(content),"red",attrs=["bold"])
17 | def green(content): return termcolor.colored(str(content),"green",attrs=["bold"])
18 | def blue(content): return termcolor.colored(str(content),"blue",attrs=["bold"])
19 | def cyan(content): return termcolor.colored(str(content),"cyan",attrs=["bold"])
20 | def yellow(content): return termcolor.colored(str(content),"yellow",attrs=["bold"])
21 | def magenta(content): return termcolor.colored(str(content),"magenta",attrs=["bold"])
22 |
23 | def count_parameters(model: torch.nn.Module):
24 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
25 |
26 | def evaluate_stage(opt, stage):
27 | """ Determine what metrics to evaluate for the current stage,
28 | if metrics is None, use the frequency in opt to decide it.
29 | """
30 | match = lambda freq: (freq>0 and stage%freq==0)
31 | return [match(opt.snapshot_freq), match(opt.ckpt_freq)]
32 |
33 | def get_time(sec: float) -> Tuple[int, int, float]:
34 | h = int(sec//3600)
35 | m = int((sec//60)%60)
36 | s = sec%60
37 | return h, m, s
38 |
39 | def flatten_dim01(x: torch.Tensor) -> torch.Tensor:
40 | # (dim0, dim1, *dim2) --> (dim0x1, *dim2)
41 | return x.reshape(-1, *x.shape[2:])
42 |
43 | def unflatten_dim01(x: torch.Tensor, dim01) -> torch.Tensor:
44 | # (dim0x1, *dim2) --> (dim0, dim1, *dim2)
45 | return x.reshape(*dim01, *x.shape[1:])
46 |
47 | def restore_checkpoint(opt: "Options", runner: "DeepGSB", load_name: str) -> None:
48 | assert load_name is not None
49 | log.info("#loading checkpoint {}...".format(load_name))
50 |
51 | full_keys = ['z_f','optimizer_f','z_b','optimizer_b']
52 |
53 | with torch.cuda.device(opt.gpu):
54 | checkpoint = torch.load(load_name,map_location=opt.device)
55 | ckpt_keys=[*checkpoint.keys()]
56 |
57 | for k in ckpt_keys:
58 | thing = getattr(runner,k)
59 | if hasattr(thing, "load_state_dict"):
60 | getattr(runner,k).load_state_dict(checkpoint[k])
61 | else:
62 | setattr(runner, k, checkpoint[k])
63 |
64 | if len(full_keys) != len(ckpt_keys):
65 | missing_keys = { k for k in set(full_keys) - set(ckpt_keys) }
66 | extra_keys = {k for k in set(ckpt_keys) - set(full_keys)}
67 |
68 | if len(missing_keys) > 0:
69 | log.warning("Does not load model for {}, check is it correct".format(missing_keys))
70 | else:
71 | log.warning("Loaded extra keys not found in full_keys: {}".format(extra_keys))
72 |
73 | else:
74 | log.info('#successfully loaded all the modules')
75 |
76 | # runner.ema_f.copy_to()
77 | # runner.ema_b.copy_to()
78 | # print(green('#loading form ema shadow parameter for polices'))
79 | log.info("#######summary of checkpoint##########")
80 |
81 | def save_checkpoint(opt: "Options", runner: "DeepGSB", keys: Iterable[str], stage_it: int) -> None:
82 | checkpoint = {}
83 | fn = opt.ckpt_path + "/stage_{0:04}.npz".format(stage_it)
84 | with torch.cuda.device(opt.gpu):
85 | for k in keys:
86 | variable = getattr(runner, k)
87 | if hasattr(variable, "state_dict"):
88 | checkpoint[k] = variable.state_dict()
89 | else:
90 | checkpoint[k] = variable
91 |
92 | torch.save(checkpoint, fn)
93 | print(green("checkpoint saved: {}".format(fn)))
94 |
95 | def get_loss_str(loss) -> str:
96 | if isinstance(loss, edict):
97 | return ' '.join([ f'({key})' + f'{val.item():+2.3f}' for key, val in loss.items()])
98 | else:
99 | return f'{loss.item():+.4f}'
100 |
101 |
102 |
--------------------------------------------------------------------------------
/git_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pathlib
3 | import subprocess
4 |
5 | log = logging.getLogger(__name__)
6 |
7 |
8 | def git_describe() -> str:
9 | label = subprocess.check_output(["git", "describe", "--always", "HEAD"]).strip().decode("utf-8")
10 | return label
11 |
12 |
13 | def git_clean() -> bool:
14 | git_output = subprocess.check_output(["git", "status", "--porcelain"]).strip().decode("utf-8")
15 | is_clean = len(git_output) == 0
16 |
17 | return is_clean
18 |
19 |
20 | def git_diff() -> str:
21 | diff = subprocess.check_output(["git", "diff", "HEAD"]).strip().decode("utf-8")
22 | if len(diff) == 0 or diff[-1] != "\n":
23 | diff = diff + "\n"
24 | return diff
25 |
26 |
27 | def log_git_info(run_dir: pathlib.Path) -> None:
28 | label = git_describe()
29 | if git_clean():
30 | log.info("Working tree is clean! HEAD is {}".format(label))
31 | return
32 |
33 | diff_path = run_dir / "diff.patch"
34 | log.warning("Continuing with dirty working tree! HEAD is {}".format(label))
35 | log.warning("Saving the results of git diff to {}".format(diff_path))
36 | with open(diff_path, "w") as f:
37 | f.write(git_diff())
38 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import logging
4 | import pathlib
5 | import sys
6 |
7 | import colored_traceback.always
8 | import torch
9 |
10 | import options
11 | from git_utils import log_git_info
12 | from deepgsb import DeepGSB
13 | from mfg import MFG
14 |
15 | from rich.console import Console
16 | from rich.logging import RichHandler
17 | from options import Options
18 |
19 | def setup_logger(log_dir: pathlib.Path) -> None:
20 | log_dir.mkdir(exist_ok=True, parents=True)
21 |
22 | log_file = open(log_dir / "log.txt", "w")
23 | file_console = Console(file=log_file, width=150)
24 | logging.basicConfig(
25 | level=logging.INFO,
26 | format="%(message)s",
27 | datefmt="[%X]",
28 | force=True,
29 | handlers=[RichHandler(), RichHandler(console=file_console)],
30 | )
31 |
32 | def run(opt: Options):
33 | log = logging.getLogger(__name__)
34 | log.info("=======================================================")
35 | log.info(" Deep Generalized Schrodinger Bridge ")
36 | log.info("=======================================================")
37 | log.info("Command used:\n{}".format(" ".join(sys.argv)))
38 |
39 | mfg = MFG(opt)
40 | deepgsb = DeepGSB(opt, mfg)
41 | deepgsb.train(opt)
42 |
43 | def main():
44 | print("setting configurations...")
45 | opt = options.set()
46 |
47 | run_dir = pathlib.Path("results") / opt.dir
48 | setup_logger(run_dir)
49 | log_git_info(run_dir)
50 |
51 | if not opt.cpu:
52 | with torch.cuda.device(opt.gpu):
53 | run(opt)
54 | else:
55 | run(opt)
56 |
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/make_animation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import logging
4 | import pathlib
5 | import pickle
6 | import sys
7 | import os
8 | import argparse
9 |
10 | import ipdb
11 | import torch
12 | import numpy as np
13 |
14 | from mfg import MFG
15 | from deepgsb import DeepGSB
16 |
17 | from rich.logging import RichHandler
18 |
19 | import matplotlib.pyplot as plt
20 |
21 | import imageio
22 | from mfg.plotting import *
23 |
24 | from mfg.opinion_lib import est_directional_similarity
25 |
26 | from ipdb import set_trace as debug
27 |
28 | def setup_logger() -> None:
29 | logging.basicConfig(
30 | level=logging.INFO,
31 | format="%(message)s",
32 | datefmt="[%X]",
33 | force=True,
34 | handlers=[RichHandler(),],
35 | )
36 |
37 | def restore_ckpt_option(opt):
38 | assert opt.load is not None
39 | ckpt_path = pathlib.Path(opt.load)
40 | assert ckpt_path.exists()
41 |
42 | options_pkl_path = ckpt_path.parent / "options.pkl"
43 | assert options_pkl_path.exists()
44 |
45 | # Load options pkl and overwrite the load.
46 | with open(options_pkl_path, "rb") as f:
47 | ckpt_options = pickle.load(f)
48 | ckpt_options.load = opt.load
49 |
50 | return ckpt_options
51 |
52 | def build_steps(direction, interval, total_steps=100):
53 | steps = np.linspace(0, interval-1, total_steps).astype(int)
54 | if direction == "backward":
55 | steps = np.flip(steps)
56 | return steps
57 |
58 | def get_title(opt, direction):
59 | return {
60 | "GMM": "GMM",
61 | "Stunnel": "S-tunnel",
62 | "Vneck": "V-neck",
63 | "opinion": "Opinion",
64 | "opinion_1k": "Opinion 1k",
65 | }.get(opt.problem_name) + f" ({direction} policy)"
66 |
67 | @torch.no_grad()
68 | def plot_directional_sim(opt, xs, ax) -> None:
69 |
70 | n_est = 5000
71 | directional_sim = est_directional_similarity(xs, n_est)
72 | assert directional_sim.shape == (n_est, )
73 |
74 | directional_sim = to_numpy(directional_sim)
75 |
76 | bins = 15
77 | _, _, patches = ax.hist(directional_sim, bins=bins, )
78 |
79 | colors = plt.cm.coolwarm(np.linspace(1.0, 0.0, bins))
80 |
81 | for c, p in zip(colors, patches):
82 | plt.setp(p, 'facecolor', c)
83 |
84 | ymax = 1000 if opt.x_dim == 2 else 2000
85 | ax.set_ylim(0, ymax)
86 | ax.set_xlim(0, 1)
87 | ax.set_xticks([])
88 | ax.set_yticks([])
89 | ax.set_xticks([], minor=True)
90 | ax.set_yticks([], minor=True)
91 |
92 | @torch.no_grad()
93 | def make_gif(opt, policy_f, policy_b, mfg, gif_name=None, plot_dim=[0,1]):
94 |
95 | file_path = os.path.join(".tmp", opt.group, opt.name)
96 | os.makedirs(file_path, exist_ok=True)
97 |
98 | xs_f, xs_b, xs_f_np, xs_b_np = sample_traj(opt, mfg, mfg.ts, policy_f, policy_b, plot_dim)
99 |
100 | xlims = get_lims(opt)
101 | ylims = get_ylims(opt)
102 |
103 | filenames = []
104 | for xs, xs_np, policy in zip([xs_f, xs_b], [xs_f_np, xs_b_np], [policy_f, policy_b]):
105 |
106 | if "opinion" in opt.problem_name and policy.direction == "backward":
107 | # skip backward opinion due to the mean-field drift
108 | continue
109 |
110 | y_mesher = get_func_mesher(opt, mfg.ts, 200, policy.compute_value) if opt.x_dim == 2 else None
111 |
112 | colors = get_colors(xs_np.shape[1])
113 | title = get_title(opt, policy.direction)
114 | # title = "Polarize (before apply DeepGSB)"
115 | # title = "Depolarize (after apply DeepGSB)"
116 |
117 | steps = build_steps(policy.direction, xs_np.shape[1], total_steps=100)
118 | for step in steps:
119 | # prepare plotting
120 | fig = plt.figure(figsize=(3,3), constrained_layout=True)
121 | ax = fig.subplots(1, 1)
122 |
123 | # plot policy and value
124 | plot_obs(opt, ax, zorder=0)
125 | ax.scatter(xs_np[:,step,0], xs_np[:,step,1], s=1.5, color=colors[step], alpha=0.5, zorder=1)
126 | if y_mesher is not None:
127 | cp = ax.contour(*y_mesher(step), levels=10, cmap="copper", linewidths=1, zorder=2)
128 | ax.clabel(cp, inline=True, fontsize=6)
129 | setup_ax(ax, title, xlims, ylims, title_fontsize=12)
130 |
131 | if "opinion" in opt.problem_name:
132 | axins = ax.inset_axes([0.59, 0.01, 0.4, 0.4])
133 | plot_directional_sim(opt, xs[:,step], axins)
134 | axins.text(
135 | 0.5, 0.9, r"$t$=" + f"{step/xs.shape[1]:0.2f}" + r"$T$",
136 | transform=axins.transAxes, fontsize=7, ha='center', va='center'
137 | )
138 |
139 | # save fig
140 | filename = f"{file_path}/{policy.direction}_{str(step).zfill(3)}.png"
141 | filenames.append(filename)
142 | plt.savefig(filename)
143 | plt.close(fig)
144 |
145 | # build gif
146 | images = list(map(lambda filename: imageio.imread(filename), filenames))
147 | imageio.mimsave(f'{gif_name or opt.problem_name}.gif', images, duration=0.04) # modify the frame duration as needed
148 |
149 | # Remove files
150 | for filename in set(filenames):
151 | os.remove(filename)
152 |
153 | def run(ckpt_options, gif_name=None):
154 | mfg = MFG(ckpt_options)
155 | deepgsb = DeepGSB(ckpt_options, mfg, save_opt=False)
156 | make_gif(ckpt_options, deepgsb.z_f, deepgsb.z_b, mfg, gif_name=gif_name)
157 |
158 |
159 | def main():
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument("--load", type=str)
162 | parser.add_argument("--name", type=str, default=None)
163 | arg = parser.parse_args()
164 |
165 | setup_logger()
166 | log = logging.getLogger(__name__)
167 | log.info("Command used:\n{}".format(" ".join(sys.argv)))
168 |
169 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
170 | ckpt_options = restore_ckpt_option(arg)
171 |
172 | if not ckpt_options.cpu:
173 | with torch.cuda.device(ckpt_options.gpu):
174 | run(ckpt_options, gif_name=arg.name)
175 | else:
176 | run(ckpt_options, gif_name=arg.name)
177 |
178 | if __name__ == "__main__":
179 | with ipdb.launch_ipdb_on_exception():
180 | main()
181 |
--------------------------------------------------------------------------------
/mfg/__init__.py:
--------------------------------------------------------------------------------
1 | from .mfg import MFG, MFGPolicy
--------------------------------------------------------------------------------
/mfg/constraint.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import NamedTuple, Optional
3 |
4 | import numpy as np
5 | import torch
6 | import torch.distributions as td
7 |
8 | log = logging.getLogger(__file__)
9 |
10 |
11 | class Sampler:
12 | def __init__(self, distribution: td.Distribution, batch_size: int, device: str):
13 | self.distribution = distribution
14 | self.batch_size = batch_size
15 | self.device = device
16 |
17 | def log_prob(self, x: torch.Tensor) -> torch.Tensor:
18 | return self.distribution.log_prob(x)
19 |
20 | def sample(self, batch: Optional[int] = None) -> torch.Tensor:
21 | if batch is None:
22 | batch = self.batch_size
23 | return self.distribution.sample([batch]).to(self.device)
24 |
25 |
26 | class ProblemDists(NamedTuple):
27 | p0: Sampler
28 | pT: Sampler
29 |
30 |
31 | def build_constraint(problem_name: str, batch_size: int, device: str) -> ProblemDists:
32 | log.info("build distributional constraints ...")
33 |
34 | distribution_builder = {
35 | "GMM": gmm_builder,
36 | "Stunnel": stunnel_builder,
37 | "Vneck": vneck_builder,
38 | "opinion": opinion_builder,
39 | "opinion_1k": opinion_1k_builder,
40 | }.get(problem_name)
41 |
42 | return distribution_builder(batch_size, device)
43 |
44 |
45 | def gmm_builder(batch_size: int, device: str) -> ProblemDists:
46 |
47 | # ----- pT -----
48 | radius, num = 16, 8
49 | arc = 2 * np.pi / num
50 | xs = [np.cos(arc * idx) * radius for idx in range(num)]
51 | ys = [np.sin(arc * idx) * radius for idx in range(num)]
52 |
53 | mix = td.Categorical(
54 | torch.ones(
55 | num,
56 | )
57 | )
58 | comp = td.Independent(td.Normal(torch.Tensor([[x, y] for x, y in zip(xs, ys)]), torch.ones(num, 2)), 1)
59 | dist = td.MixtureSameFamily(mix, comp)
60 | pT = Sampler(dist, batch_size, device)
61 |
62 | # ----- p0 -----
63 | dist = td.MultivariateNormal(torch.zeros(2), torch.eye(2))
64 | p0 = Sampler(dist, batch_size, device)
65 |
66 | return ProblemDists(p0, pT)
67 |
68 |
69 | def vneck_builder(batch_size: int, device: str) -> ProblemDists:
70 |
71 | # ----- pT -----
72 | dist = td.MultivariateNormal(torch.Tensor([7, 0]), 0.2 * torch.eye(2))
73 | pT = Sampler(dist, batch_size, device)
74 |
75 | # ----- p0 -----
76 | dist = td.MultivariateNormal(-torch.Tensor([7, 0]), 0.2 * torch.eye(2))
77 | p0 = Sampler(dist, batch_size, device)
78 |
79 | return ProblemDists(p0, pT)
80 |
81 |
82 | def stunnel_builder(batch_size: int, device: str) -> ProblemDists:
83 |
84 | # ----- pT -----
85 | dist = td.MultivariateNormal(torch.Tensor([11, 1]), 0.5 * torch.eye(2))
86 | pT = Sampler(dist, batch_size, device)
87 |
88 | # ----- p0 -----
89 | dist = td.MultivariateNormal(-torch.Tensor([11, 1]), 0.5 * torch.eye(2))
90 | p0 = Sampler(dist, batch_size, device)
91 |
92 | return ProblemDists(p0, pT)
93 |
94 |
95 | def opinion_builder(batch_size: int, device: str) -> ProblemDists:
96 |
97 | p0_std = 0.25
98 | pT_std = 3.0
99 |
100 | # ----- p0 -----
101 | mu0 = torch.zeros(2)
102 | covar0 = p0_std * torch.eye(2)
103 |
104 | # Start with kind-of polarized opinions.
105 | covar0[0, 0] = 0.5
106 |
107 | # ----- pT -----
108 | muT = torch.zeros(2)
109 | # Want to finish with more homogenous opinions.
110 | covarT = pT_std * torch.eye(2)
111 |
112 | dist = td.MultivariateNormal(muT, covarT)
113 | pT = Sampler(dist, batch_size, device)
114 |
115 | dist = td.MultivariateNormal(mu0, covar0)
116 | p0 = Sampler(dist, batch_size, device)
117 |
118 | return ProblemDists(p0, pT)
119 |
120 |
121 | def opinion_1k_builder(batch_size: int, device: str) -> ProblemDists:
122 |
123 | p0_std = 0.25
124 | pT_std = 3.0
125 |
126 | # ----- p0 -----
127 | mu0 = torch.zeros(1000)
128 | covar0 = p0_std * torch.eye(1000)
129 |
130 | # Start with kind-of polarized opinions.
131 | covar0[0, 0] = 4.0
132 |
133 | # ----- pT -----
134 | muT = torch.zeros(1000)
135 | # Want to finish with more homogenous opinions.
136 | covarT = pT_std * torch.eye(1000)
137 |
138 | dist = td.MultivariateNormal(muT, covarT)
139 | pT = Sampler(dist, batch_size, device)
140 |
141 | dist = td.MultivariateNormal(mu0, covar0)
142 | p0 = Sampler(dist, batch_size, device)
143 |
144 | return ProblemDists(p0, pT)
145 |
--------------------------------------------------------------------------------
/mfg/mfg.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Callable, Optional, Tuple
3 |
4 | import torch
5 |
6 | from options import Options
7 |
8 | from .constraint import Sampler, build_constraint
9 | from .plotting import snapshot
10 | from .sde import SampledTraj, build_sde
11 | from .state_cost import build_state_cost_fn
12 |
13 | log = logging.getLogger(__file__)
14 |
15 | Bound = Tuple[torch.Tensor, torch.Tensor, Sampler]
16 | InitTermBound = Tuple[torch.Tensor, torch.Tensor, Sampler, torch.Tensor, torch.Tensor, Sampler]
17 |
18 | # if the state cost for this problem requires logp.
19 | logp_list = [
20 | "Vneck",
21 | "opinion",
22 | "opinion_1k",
23 | ]
24 |
25 | # if the state cost for this problem requires xs_all.
26 | xs_all_list = ["GMM", "Stunnel", "Vneck"]
27 |
28 | # if the uncontrolled drift use MF
29 | mf_drift_list = ["opinion", "opinion_1k"]
30 |
31 | # if the uncontrolled drift depends on state
32 | state_dependent_drift_list = ["opinion", "opinion_1k"]
33 |
34 |
35 | def get_bound_index(direction: str, bound: str) -> int:
36 | assert direction in ["forward", "backward"]
37 | assert bound in ["init", "term"]
38 |
39 | if direction == "forward" and bound == "init":
40 | return 0
41 | elif direction == "forward" and bound == "term":
42 | return -1
43 | elif direction == "backward" and bound == "init":
44 | return -1
45 | elif direction == "backward" and bound == "term":
46 | return 0
47 |
48 |
49 | class MFGPolicy(torch.nn.Module):
50 | def __init__(self, direction, dyn):
51 | super(MFGPolicy, self).__init__()
52 | self.direction = direction
53 | self.dyn = dyn
54 |
55 |
56 | class MFG:
57 | def __init__(self, opt: Options):
58 |
59 | self.opt = opt
60 | self.problem_name = opt.problem_name
61 |
62 | self.ts = torch.linspace(opt.t0, opt.T, opt.interval)
63 | self.mf_coeff = opt.MF_cost
64 |
65 | self.p0, self.pT = build_constraint(opt.problem_name, opt.samp_bs, opt.device)
66 | self.state_cost_fn, self.obstacle_cost_fn = build_state_cost_fn(opt.problem_name)
67 |
68 | self.pbound = [self.p0, self.pT]
69 |
70 | self.sde = build_sde(opt, self.p0, self.pT)
71 |
72 | @property
73 | def dt(self) -> float:
74 | return self.sde.dt
75 |
76 | def f(self, x: torch.Tensor, t: torch.Tensor, direction: str) -> torch.Tensor:
77 | return self.sde.f(x, t, direction)
78 |
79 | def g(self, t: torch.Tensor) -> torch.Tensor:
80 | return self.sde.g(t)
81 |
82 | def uses_logp(self) -> bool:
83 | """Returns true if the state cost for this problem requires logp."""
84 | return self.problem_name in logp_list
85 |
86 | def uses_xs_all(self) -> bool:
87 | """Returns true if the state cost for this problem requires xs_all."""
88 | return self.problem_name in xs_all_list
89 |
90 | def uses_mf_drift(self) -> bool:
91 | """Returns true if the uncontrolled drift for this problem involves mean field."""
92 | return self.problem_name in mf_drift_list
93 |
94 | def uses_state_dependent_drift(self) -> bool:
95 | """Returns true if the uncontrolled drift for this problem depends on state x."""
96 | return self.problem_name in state_dependent_drift_list
97 |
98 | def initialize_mf_drift(self, policy: MFGPolicy) -> None:
99 | self.sde.initialize_mf_drift(self.ts, policy)
100 |
101 | def sample_traj(self, policy: MFGPolicy, **kwargs) -> SampledTraj:
102 | return self.sde.sample_traj(self.ts, policy, **kwargs)
103 |
104 | def compute_state_cost(
105 | self, xs: torch.Tensor, ts: torch.Tensor, xs_all: torch.Tensor, log_p: torch.Tensor
106 | ) -> torch.Tensor:
107 | s_cost, mf_cost = self.state_cost_fn(xs, ts, log_p, xs_all)
108 | return s_cost + self.mf_coeff * mf_cost
109 |
110 | def get_bound(self, xs: torch.Tensor, direction: str, bound: str) -> Bound:
111 | t_index = get_bound_index(direction, bound)
112 |
113 | b, T, nx = xs.shape
114 | assert len(self.ts) == T and nx == self.opt.x_dim
115 |
116 | return self.ts[t_index], xs[:, t_index, ...], self.pbound[t_index]
117 |
118 | def get_init_term_bound(self, xs: torch.Tensor, direction: str) -> InitTermBound:
119 | return (*self.get_bound(xs, direction, "init"), *self.get_bound(xs, direction, "term"))
120 |
121 | def save_snapshot(self, policy_f, policy_b, stage: int) -> None:
122 | plot_logp = self.opt.x_dim == 2
123 | snapshot(self.opt, policy_f, policy_b, self, stage, plot_logp=plot_logp)
124 |
--------------------------------------------------------------------------------
/mfg/opinion_lib.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from options import Options
9 |
10 | log = logging.getLogger(__file__)
11 |
12 |
13 | @torch.no_grad()
14 | @torch.jit.script
15 | def est_directional_similarity(xs: torch.Tensor, n_est: int = 1000) -> torch.Tensor:
16 | """xs: (batch, nx). Returns (n_est, ) between 0 and 1."""
17 | # xs: (batch, nx)
18 | batch, nx = xs.shape
19 |
20 | # Center first.
21 | xs = xs - torch.mean(xs, dim=0, keepdim=True)
22 |
23 | rand_idxs1 = torch.randint(batch, [n_est], dtype=torch.long)
24 | rand_idxs2 = torch.randint(batch, [n_est], dtype=torch.long)
25 |
26 | # (n_est, nx)
27 | xs1 = xs[rand_idxs1]
28 | # (n_est, nx)
29 | xs2 = xs[rand_idxs2]
30 |
31 | # Normalize to unit vector.
32 | xs1 /= torch.linalg.norm(xs1, dim=1, keepdim=True)
33 | xs2 /= torch.linalg.norm(xs2, dim=1, keepdim=True)
34 |
35 | # (n_est, )
36 | cos_angle = torch.sum(xs1 * xs2, dim=1).clip(-1.0, 1.0)
37 | assert cos_angle.shape == (n_est,)
38 |
39 | # Should be in [0, pi).
40 | angle = torch.acos(cos_angle)
41 | assert (0 <= angle).all()
42 | assert (angle <= torch.pi).all()
43 |
44 | D_ij = 1.0 - angle / torch.pi
45 | assert D_ij.shape == (n_est,)
46 |
47 | return D_ij
48 |
49 |
50 | def opinion_thresh(inner: torch.Tensor) -> torch.Tensor:
51 | return 2.0 * (inner > 0) - 1.0
52 |
53 |
54 | @torch.jit.script
55 | def compute_mean_drift_term(mf_x: torch.Tensor, xi: torch.Tensor) -> torch.Tensor:
56 | """Decompose the polarize dynamic Eq (18) in paper into 2 parts for faster computation:
57 | f_polarize(x,p,ξ)
58 | = E_{y~p}[a(x,y,ξ) * bar_y], where a(x,y,ξ) = sign()*sign()
59 | and bar_y = y / |y|^{0.5}
60 | = sign() * E_{y~p}[sign() * bar_y], since sign() is independent of y
61 | = A(x,ξ) * B(p,ξ)
62 | Hence, bar_f_polarize = bar_A(x,ξ) * bar_B(p,ξ)
63 | This function computes only bar_B(p,ξ).
64 | """
65 | # mf_x: (b, nx), xi: (nx,)
66 | # output: (nx,)
67 |
68 | b, nx = mf_x.shape
69 | assert xi.shape == (nx,)
70 |
71 | mf_x_norm = torch.linalg.norm(mf_x, dim=-1, keepdim=True)
72 | assert torch.all(mf_x_norm > 0.0)
73 |
74 | normalized_mf_x = mf_x / torch.sqrt(mf_x_norm)
75 | assert normalized_mf_x.shape == (b, nx)
76 |
77 | # Compute the mean drift term: 1/J sum_j a(y_j) y_j / sqrt(| y_j |).
78 | mf_agree_j = opinion_thresh(torch.sum(mf_x * xi, dim=-1, keepdim=True))
79 | assert mf_agree_j.shape == (b, 1)
80 |
81 | mean_drift_term = torch.mean(mf_agree_j * normalized_mf_x, dim=0)
82 | assert mean_drift_term.shape == (nx,)
83 |
84 | mean_drift_term_norm = torch.linalg.norm(mean_drift_term, dim=-1, keepdim=True)
85 | mean_drift_term = mean_drift_term / torch.sqrt(mean_drift_term_norm)
86 | assert mean_drift_term.shape == (nx,)
87 |
88 | return mean_drift_term
89 |
90 |
91 | @torch.jit.script
92 | def opinion_f(x: torch.Tensor, mf_drift: torch.Tensor, xi: torch.Tensor) -> torch.Tensor:
93 | """This function computes the polarize dynamic in Eq (18) by
94 | bar_f_polarize(x,p,ξ) = bar_A(x,ξ) * bar_B(p,ξ)
95 | where bar_B(p,ξ) is pre-computed in func compute_mean_drift_term and passed in as mf_drift.
96 | """
97 | # x: (b, T, nx), mf_drift: (T, nx), xi: (T, nx)
98 | # out: (b, T, nx)
99 |
100 | b, T, nx = x.shape
101 | assert xi.shape == mf_drift.shape == (T, nx)
102 |
103 | agree_i = opinion_thresh(torch.sum(x * xi, dim=-1, keepdim=True))
104 | # Make sure we are not dividing by 0.
105 | agree_i[agree_i == 0] = 1.0
106 |
107 | abs_sqrt_agree_i = torch.sqrt(torch.abs(agree_i))
108 | assert torch.all(abs_sqrt_agree_i > 0.0)
109 |
110 | norm_agree_i = agree_i / abs_sqrt_agree_i
111 | assert norm_agree_i.shape == (b, T, 1)
112 |
113 | f = norm_agree_i * mf_drift
114 | assert f.shape == (b, T, nx)
115 |
116 | return f
117 |
118 |
119 | def build_f_mul(opt: Options) -> torch.Tensor:
120 | # set f_mul with some heuristic so that it doesn't diverge exponentially fast
121 | # and yield bad normalization, since the more polarized the opinion is the faster it will grow
122 | ts = torch.linspace(opt.t0, opt.T, opt.interval)
123 | coeff = 8.0
124 | f_mul = torch.clip(1.0 - torch.exp(coeff * (ts - opt.T)) + 1e-5, min=1e-4, max=1.0)
125 | f_mul = f_mul ** 5.0
126 | return f_mul
127 |
128 |
129 | def build_xis(opt: Options) -> torch.Tensor:
130 | # Generate random unit vectors.
131 | rng = np.random.default_rng(seed=4078213)
132 | xis = rng.standard_normal([opt.interval, opt.x_dim])
133 |
134 | # Construct a xis that has some degree of "continuous" over time, as a brownian motion.
135 | xi = xis[0]
136 | bm_xis = [xi]
137 | std = 0.4
138 | for t in range(1, opt.interval):
139 | xi = xi - (2.0 * xi) * 0.01 + std * math.sqrt(0.01) * xis[t]
140 | bm_xis.append(xi)
141 | assert len(bm_xis) == xis.shape[0]
142 |
143 | xis = torch.Tensor(np.stack(bm_xis))
144 | xis /= torch.linalg.norm(xis, dim=-1, keepdim=True)
145 |
146 | # Just safeguard if the self.xis becomes different.
147 | log.info("USING BM XI! xis.sum(): {}".format(torch.sum(xis)))
148 | return xis
149 |
--------------------------------------------------------------------------------
/mfg/plotting.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import os
4 |
5 | import torch
6 | import numpy as np
7 |
8 | from . import util
9 | from . import opinion_lib
10 | from .state_cost import gmm_obstacle_cfg , vneck_obstacle_cfg, stunnel_obstacle_cfg
11 |
12 | import matplotlib.pyplot as plt
13 | import matplotlib.colors as mcol
14 | from matplotlib.patches import Circle, Ellipse
15 |
16 | from ipdb import set_trace as debug
17 |
18 |
19 | def to_numpy(t):
20 | return t.detach().cpu().numpy()
21 |
22 | def get_lims(opt):
23 | return {
24 | "GMM": [-16.25, 16.25],
25 | "Stunnel": [-15, 15],
26 | "Vneck": [-10, 10],
27 | "opinion": [-10, 10],
28 | "opinion_1k": [-10, 10],
29 | }.get(opt.problem_name)
30 |
31 | def get_ylims(opt):
32 | return {
33 | "GMM": [-16.25, 16.25],
34 | "Stunnel": [-10, 10],
35 | "Vneck": [-5, 5],
36 | "opinion": [-10, 10],
37 | "opinion_1k": [-10, 10],
38 | }.get(opt.problem_name)
39 |
40 | def get_colors(n_snapshot):
41 | # assert n_snapshot % 2 == 1
42 | cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName",["b","r"])
43 | colors = cm1(np.linspace(0.0, 1.0, n_snapshot))
44 | return colors
45 |
46 | def create_mesh(opt, n_grid, lims, convert_to_numpy=True):
47 | import warnings
48 |
49 | _x = torch.linspace(*(lims+[n_grid]))
50 |
51 | # Suppress warning about indexing arg becoming required.
52 | with warnings.catch_warnings():
53 | X, Y = torch.meshgrid(_x, _x)
54 |
55 | xs = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=1).to(opt.device)
56 | return [to_numpy(X), to_numpy(Y), xs] if convert_to_numpy else [X, Y, xs]
57 |
58 | def get_func_mesher(opt, ts, grid_n, func, out_dim=1):
59 | if func is None: return None
60 | lims = get_lims(opt)
61 | X1, X2, XS = create_mesh(opt, grid_n, lims)
62 | out_shape = [grid_n,grid_n] if out_dim==1 else [grid_n,grid_n,out_dim]
63 |
64 | def mesher(idx):
65 | # print(func, ts[idx])
66 | arg_xs = XS.detach()
67 | arg_ts = ts[idx].repeat(grid_n**2).detach()
68 | fn_out = func(arg_xs, arg_ts)
69 | return X1, X2, to_numpy(fn_out.reshape(*out_shape))
70 |
71 | return mesher
72 |
73 | def plot_obs(opt, ax, scale=1., zorder=0):
74 | if opt.problem_name == 'GMM':
75 | centers, radius = gmm_obstacle_cfg()
76 | for c in centers:
77 | circle = Circle(xy=np.array(c), radius=radius, zorder=zorder)
78 |
79 | ax.add_artist(circle)
80 | circle.set_clip_box(ax.bbox)
81 | circle.set_facecolor("darkgray")
82 | circle.set_edgecolor(None)
83 |
84 | elif opt.problem_name == 'Vneck':
85 | c_sq, coef = vneck_obstacle_cfg()
86 | x = np.linspace(-6,6,100)
87 | y1 = np.sqrt(c_sq + coef * np.square(x))
88 | y2 = np.ones_like(x) * y1[0]
89 |
90 | ax.fill_between(x, y1, y2, color="darkgray", edgecolor=None, zorder=zorder)
91 | ax.fill_between(x, -y1, -y2, color="darkgray", edgecolor=None, zorder=zorder)
92 |
93 | elif opt.problem_name == 'Stunnel':
94 | a, b, cc, centers = stunnel_obstacle_cfg()
95 | for c in centers:
96 | elp = Ellipse(
97 | xy=np.array(c)*scale, width=2*np.sqrt(cc/a)*scale, height=2*np.sqrt(cc/b)*scale, zorder=zorder
98 | )
99 |
100 | ax.add_artist(elp)
101 | elp.set_clip_box(ax.bbox)
102 | elp.set_facecolor("darkgray")
103 | elp.set_edgecolor(None)
104 |
105 | def setup_ax(ax, title, xlims, ylims, title_fontsize=18):
106 | ax.axis('equal')
107 | ax.set_xlim(*xlims)
108 | ax.set_ylim(*ylims)
109 | ax.set_title(title, fontsize=title_fontsize)
110 | ax.set_xticks([])
111 | ax.set_yticks([])
112 | ax.set_xticks([], minor=True)
113 | ax.set_yticks([], minor=True)
114 |
115 | def plot_traj_snapshot(opt, xs, axes, sample_steps, titles, y_mesher=None):
116 |
117 | n_snapshot = len(axes)
118 | assert len(sample_steps) == len(titles) == n_snapshot
119 |
120 | if sample_steps is None:
121 | sample_steps = np.linspace(0, xs.shape[1]-1, n_snapshot).astype(int)
122 |
123 | xlims = get_lims(opt)
124 | ylims = get_ylims(opt)
125 |
126 | colors = get_colors(n_snapshot)
127 |
128 | for ax, step, title, color in zip(axes, sample_steps, titles, colors):
129 | plot_obs(opt, ax, zorder=0)
130 |
131 | ax.scatter(xs[:,step,0],xs[:,step,1], s=1.5, color=color, alpha=0.5, zorder=1)
132 | if y_mesher is not None:
133 | cp = ax.contour(*y_mesher(step), levels=10, cmap="copper", linewidths=1, zorder=2)
134 | ax.clabel(cp, inline=True, fontsize=6)
135 | setup_ax(ax, title, xlims, ylims)
136 |
137 | def plot_directional_sim(opt, ax, stage, xs_term) -> None:
138 | n_est = 5000
139 | directional_sim = opinion_lib.est_directional_similarity(xs_term, n_est)
140 | assert directional_sim.shape == (n_est, )
141 |
142 | directional_sim = to_numpy(directional_sim)
143 |
144 | bins = 100
145 | ax.hist(directional_sim, bins=bins)
146 | ax.set(xlabel="Directional Similarity", title="Stage={:3}".format(stage), xlim=(0., 1.))
147 |
148 |
149 | def get_fig_axes_steps(interval, n_snapshot=5, ax_length_in: float = 4):
150 | n_row, n_col = 1, n_snapshot
151 | figsize = (n_col*ax_length_in, n_row*ax_length_in)
152 |
153 | fig = plt.figure(figsize=figsize, constrained_layout=True)
154 | axes = fig.subplots(n_row, n_col)
155 | steps = np.linspace(0, interval-1, n_snapshot).astype(int)
156 |
157 | return fig, axes, steps
158 |
159 | @torch.no_grad()
160 | def sample_traj(opt, mfg, ts, policy_f, policy_b, plot_dim):
161 | xs_f, _, _, _ = mfg.sde.sample_traj(ts, policy_f)
162 | xs_b, _, _, _ = mfg.sde.sample_traj(ts, policy_b)
163 | util.assert_zero_grads(policy_f)
164 | util.assert_zero_grads(policy_b)
165 |
166 | if opt.x_dim > 2:
167 | xs_f, xs_b = util.proj_pca(xs_f, xs_b, reverse=False)
168 |
169 | xs_f_np, xs_b_np = to_numpy(xs_f[..., plot_dim]), to_numpy(xs_b[..., plot_dim])
170 |
171 | return xs_f, xs_b, xs_f_np, xs_b_np
172 |
173 | @torch.no_grad()
174 | def snapshot(opt, policy_f, policy_b, mfg, stage, plot_logp=False, plot_dim=[0,1]):
175 |
176 | # sample forward & backward trajs
177 | ts = mfg.ts
178 |
179 | xs_f, xs_b, xs_f_np, xs_b_np = sample_traj(opt, mfg, ts, policy_f, policy_b, plot_dim)
180 |
181 | interval = len(ts)
182 |
183 | for xs, policy in zip([xs_f_np, xs_b_np], [policy_f, policy_b]):
184 |
185 | # prepare plotting
186 | titles = [r'$t$ = 0', r'$t$ = 0.25$T$', r'$t$ = 0.50$T$', r'$t$ = 0.75$T$', r'$t = T$']
187 | fig, axes, sample_steps = get_fig_axes_steps(interval, n_snapshot=len(titles))
188 | assert len(titles) == len(axes) == len(sample_steps)
189 |
190 | # plot policy and value
191 | y_mesher = get_func_mesher(opt, ts, 200, policy.compute_value) if opt.x_dim == 2 else None
192 | plot_traj_snapshot(
193 | opt, xs, axes, sample_steps, titles, y_mesher=y_mesher,
194 | )
195 |
196 | plt.savefig(os.path.join('results', opt.dir, policy.direction, f'stage{stage}.pdf'))
197 | plt.close(fig)
198 |
199 | if "opinion" in opt.problem_name:
200 | for xs_term, policy in zip([xs_f[:,-1], xs_b[:,0]], [policy_f, policy_b]):
201 | fig, ax = plt.subplots(figsize=(8, 4), constrained_layout=True)
202 | plot_directional_sim(opt, ax, stage, xs_term)
203 | plt.savefig(os.path.join(
204 | 'results', opt.dir, f"directional_sim_{policy.direction}", f'stage{stage}.pdf'
205 | ))
206 | plt.close(fig)
207 |
208 | if plot_logp:
209 | assert opt.x_dim == 2
210 | def logp_fn(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
211 | y_f = policy_f.compute_value(x, t)
212 | y_b = policy_b.compute_value(x, t)
213 | return y_f + y_b
214 |
215 | # prepare plotting
216 | titles = [r'$t$ = 0', r'$t$ = 0.25$T$', r'$t$ = 0.50$T$', r'$t$ = 0.75$T$', r'$t = T$']
217 | fig, axes, sample_steps = get_fig_axes_steps(interval, n_snapshot=len(titles))
218 | assert len(titles) == len(axes) == len(sample_steps)
219 |
220 | # plot logp
221 | logp_mesher = get_func_mesher(opt, ts, 200, logp_fn)
222 | plot_traj_snapshot(
223 | opt, xs_f_np, axes, sample_steps, titles, y_mesher=logp_mesher,
224 | )
225 |
226 | plt.savefig(os.path.join('results', opt.dir, "logp", f'stage{stage}.pdf'))
227 | plt.close(fig)
228 |
229 | return xs_f, xs_b
230 |
--------------------------------------------------------------------------------
/mfg/sde.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 | import math
4 | from typing import Callable, Optional, NamedTuple
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from options import Options
10 |
11 | from . import opinion_lib, util
12 | from .constraint import Sampler
13 |
14 | log = logging.getLogger(__file__)
15 |
16 | DriftFn = Callable[[torch.Tensor], torch.Tensor]
17 |
18 | class SampledTraj(NamedTuple):
19 | xs: torch.Tensor
20 | zs: torch.Tensor
21 | ws: torch.Tensor
22 | x_term: torch.Tensor
23 |
24 | def _assert_increasing(name: str, ts: torch.Tensor) -> None:
25 | assert (ts[1:] > ts[:-1]).all(), "{} must be strictly increasing".format(name)
26 |
27 |
28 | def base_drift_builder(opt: Options) -> Optional[DriftFn]:
29 | if opt.problem_name in ["Vneck", "Stunnel"]:
30 |
31 | def base_drift(x):
32 | b, T, nx = x.shape
33 | const = torch.Tensor([6.0, 0.0], device=x.device)
34 | assert const.shape == (nx,)
35 | return const.repeat(b, T, 1)
36 |
37 | else:
38 | base_drift = None
39 | return base_drift
40 |
41 |
42 | def t_to_idx(t: torch.Tensor, interval: int, T: float) -> torch.Tensor:
43 | return (t / T * (interval - 1)).round().long()
44 |
45 |
46 | class BaseSDE(metaclass=abc.ABCMeta):
47 | def __init__(self, opt: Options, p0: Sampler, pT: Sampler):
48 | self.opt = opt
49 | self.dt = opt.T / opt.interval
50 | self.p0 = p0
51 | self.pT = pT
52 | self.mf_drifts: Optional[torch.Tensor] = None
53 |
54 | @abc.abstractmethod
55 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
56 | raise NotImplementedError
57 |
58 | @abc.abstractmethod
59 | def _g(self, t: torch.Tensor) -> torch.Tensor:
60 | raise NotImplementedError
61 |
62 | def f(self, x: torch.Tensor, t: torch.Tensor, direction: str) -> torch.Tensor:
63 | # x: (b, T, nx), t: (T,)
64 | # out: (b, T, nx)
65 | b, T, nx = x.shape
66 | assert t.shape == (T,)
67 |
68 | sign = 1.0 if direction == "forward" else -1.0
69 | _f = self._f(x, t)
70 | assert _f.shape == (b, T, nx)
71 | return sign * _f
72 |
73 | def g(self, t: torch.Tensor) -> torch.Tensor:
74 | return self._g(t)
75 |
76 | def dw(self, x: torch.Tensor, dt: Optional[float] = None) -> torch.Tensor:
77 | dt = self.dt if dt is None else dt
78 | return torch.randn_like(x) * np.sqrt(dt)
79 |
80 | def propagate(
81 | self,
82 | t: torch.Tensor,
83 | x: torch.Tensor,
84 | z: torch.Tensor,
85 | direction: str,
86 | f: Optional[DriftFn] = None,
87 | dw: Optional[torch.Tensor] = None,
88 | dt: Optional[float] = None,
89 | ) -> torch.Tensor:
90 | g = self.g(t)
91 | f = self.f(x, t, direction) if f is None else f
92 | dt = self.dt if dt is None else dt
93 | dw = self.dw(x, dt) if dw is None else dw
94 |
95 | return x + (f + g * z) * dt + g * dw
96 |
97 | @torch.no_grad()
98 | def initialize_mf_drift(self, ts: torch.Tensor, policy) -> None:
99 | raise NotImplementedError
100 |
101 | @torch.no_grad()
102 | def update_mf_drift(self, x: torch.Tensor, t_idx: int) -> None:
103 | raise NotImplementedError
104 |
105 | def sample_traj(self, ts: torch.Tensor, policy, update_mf_drift: bool = False) -> SampledTraj:
106 |
107 | # first we need to know whether we're doing forward or backward sampling
108 | direction = policy.direction
109 | assert direction in ["forward", "backward"]
110 |
111 | # set up ts and init_distribution
112 | _assert_increasing("ts", ts)
113 | init_dist = self.p0 if direction == "forward" else self.pT
114 | ts = ts if direction == "forward" else torch.flip(ts, dims=[0])
115 |
116 | x = init_dist.sample(batch=self.opt.samp_bs)
117 | (b, nx), T = x.shape, len(ts)
118 | assert nx == self.opt.x_dim
119 |
120 | xs = torch.empty((b, T, nx))
121 | zs = torch.empty_like(xs)
122 | ws = torch.empty_like(xs)
123 | if update_mf_drift:
124 | self.mf_drifts = torch.empty(T, nx)
125 |
126 | # don't use tqdm for fbsde since it'll resample every itr
127 | for idx, t in enumerate(ts):
128 | t_idx = idx if direction == "forward" else T - idx - 1
129 | assert t_idx == t_to_idx(t, self.opt.interval, self.opt.T), (t_idx, t)
130 |
131 | if update_mf_drift:
132 | self.update_mf_drift(x, t_idx)
133 |
134 | # f = self.f(x,t,direction)
135 | # handle propagation of single time step
136 | f = self.f(x.unsqueeze(1), t.unsqueeze(0), direction).squeeze(1)
137 | z = policy(x, t)
138 | dw = self.dw(x)
139 | util.assert_zero_grads(policy)
140 |
141 | xs[:, t_idx, ...] = x
142 | zs[:, t_idx, ...] = z
143 | ws[:, t_idx, ...] = dw
144 |
145 | x = self.propagate(t, x, z, direction, f=f, dw=dw)
146 |
147 | x_term = x
148 |
149 | return SampledTraj(xs, zs, ws, x_term)
150 |
151 |
152 | class SimpleSDE(BaseSDE):
153 | def __init__(self, opt: Options, p: Sampler, q: Sampler, base_drift: Optional[DriftFn] = None):
154 | super(SimpleSDE, self).__init__(opt, p, q)
155 | self.std = opt.diffusion_std
156 | self.base_drift = base_drift
157 |
158 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
159 | return torch.zeros_like(x) if self.base_drift is None else self.base_drift(x)
160 |
161 | def _g(self, t: torch.Tensor) -> torch.Tensor:
162 | return torch.Tensor([self.std])
163 |
164 |
165 | class OpinionSDE(SimpleSDE):
166 | """modified from the party model:
167 | See Eq (4) in https://www.cs.cornell.edu/home/kleinber/ec21-polarization.pdf
168 | """
169 |
170 | def __init__(self, opt: Options, p: Sampler, q: Sampler):
171 | super(OpinionSDE, self).__init__(opt, p, q)
172 | assert "opinion" in opt.problem_name
173 |
174 | self.f_mul = opinion_lib.build_f_mul(opt)
175 | self.xis = opinion_lib.build_xis(opt)
176 | self.polarize_strength = 1.0 if opt.x_dim == 2 else 6.0
177 | self.mf_drifts = None
178 |
179 | def _f(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
180 |
181 | b, T, nx = x.shape
182 | assert nx == self.opt.x_dim
183 |
184 | idx = t_to_idx(t, self.opt.interval, self.opt.T)
185 |
186 | fmul = self.f_mul[idx].to(x.device).unsqueeze(-1)
187 | xi = self.xis[idx].to(x.device)
188 | mf_drift = self.mf_drifts[idx].to(x.device)
189 | assert fmul.shape == (T, 1)
190 | assert xi.shape == mf_drift.shape == (T, nx)
191 |
192 | f = self.polarize_strength * opinion_lib.opinion_f(x, mf_drift, xi)
193 | assert f.shape == x.shape
194 |
195 | f = fmul * f
196 | assert f.shape == x.shape
197 |
198 | return f
199 |
200 | @torch.no_grad()
201 | def initialize_mf_drift(self, ts: torch.Tensor, policy) -> None:
202 | self.sample_traj(ts, policy, update_mf_drift=True)
203 |
204 | @torch.no_grad()
205 | def update_mf_drift(self, x: torch.Tensor, t_idx: int) -> None:
206 | xi = self.xis[t_idx].to(x.device)
207 | mf_drift = opinion_lib.compute_mean_drift_term(x, xi)
208 | self.mf_drifts[t_idx] = mf_drift.detach().cpu()
209 |
210 |
211 | def build_sde(opt: Options, p: Sampler, q: Sampler) -> SimpleSDE:
212 | log.info("build base sde...")
213 |
214 | if "opinion" in opt.problem_name:
215 | return OpinionSDE(opt, p, q)
216 | else:
217 | base_drift = base_drift_builder(opt)
218 | return SimpleSDE(opt, p, q, base_drift=base_drift)
219 |
--------------------------------------------------------------------------------
/mfg/state_cost.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import torch
3 |
4 | StateCostFn = Callable[[torch.Tensor], torch.Tensor]
5 | MFCostFn = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
6 |
7 | def build_state_cost_fn(problem_name: str, return_obs_cost_fn: bool = False):
8 | obstacle_cost_fn = get_obstacle_cost_fn(problem_name)
9 | mf_cost_fn = get_mf_cost_fn(problem_name)
10 |
11 | def state_cost_fn(xs, ts, logp, xs_all):
12 | obstacle_cost = obstacle_cost_fn(xs)
13 | mf_cost = mf_cost_fn(xs, logp, xs_all)
14 | return obstacle_cost, mf_cost
15 |
16 | return state_cost_fn, obstacle_cost_fn if return_obs_cost_fn else state_cost_fn
17 |
18 | def get_obstacle_cost_fn(problem_name: str) -> StateCostFn:
19 | return {
20 | 'GMM': obstacle_cost_fn_gmm,
21 | 'Stunnel': obstacle_cost_fn_vneck,
22 | 'Vneck':obstacle_cost_fn_stunnel,
23 | 'opinion': zero_cost_fn,
24 | 'opinion_1k': zero_cost_fn,
25 | }.get(problem_name)
26 |
27 | def get_mf_cost_fn(problem_name: str) -> MFCostFn:
28 | return {
29 | 'GMM': zero_cost_fn,
30 | 'Stunnel': congestion_cost,
31 | 'Vneck': entropy_cost,
32 | 'opinion': entropy_cost,
33 | 'opinion_1k': entropy_cost,
34 | }.get(problem_name)
35 |
36 | ##########################################################
37 | ################ mean-field cost functions ###############
38 | ##########################################################
39 |
40 | def entropy_cost(xs: torch.Tensor, logp: torch.Tensor, xs_all: torch.Tensor) -> torch.Tensor:
41 | if logp is None:
42 | raise ValueError("Add this problem to logp_list.")
43 |
44 | return logp + 1
45 |
46 | def congestion_cost(xs: torch.Tensor, logp: torch.Tensor, xs_all: torch.Tensor) -> torch.Tensor:
47 | assert xs.ndim == 2
48 | xs = xs.reshape(-1, xs_all.shape[1], *xs.shape[1:])
49 |
50 | assert xs.ndim == 3 # should be (batch_x, batch_y, x_dim)
51 | assert xs.shape[1:] == xs_all.shape[1:]
52 |
53 | dd = xs - xs_all # batch_x, batch_t, xdim
54 | dist = torch.sum(dd * dd, dim=-1) # batch_x, batch_t
55 | out = 2.0 / (dist + 1.0)
56 | cost = out.reshape(-1, *out.shape[2:])
57 | return cost
58 |
59 | ##########################################################
60 | ################## obstacle cost functions ###############
61 | ##########################################################
62 |
63 | def zero_cost_fn(x: torch.Tensor, *args) -> torch.Tensor:
64 | return torch.zeros(*x.shape[:-1])
65 |
66 | def gmm_obstacle_cfg():
67 | centers = [[6,6], [6,-6], [-6,-6]]
68 | radius = 1.5
69 | return centers, radius
70 |
71 | def stunnel_obstacle_cfg():
72 | a, b, c = 20, 1, 90
73 | centers = [[5,6], [-5,-6]]
74 | return a, b, c, centers
75 |
76 | def vneck_obstacle_cfg():
77 | c_sq = 0.36
78 | coef = 5
79 | return c_sq, coef
80 |
81 | @torch.jit.script
82 | def obstacle_cost_fn_gmm(xs: torch.Tensor) -> torch.Tensor:
83 | xs = xs.reshape(-1,xs.shape[-1])
84 |
85 | batch_xt = xs.shape[0]
86 |
87 | centers, radius = gmm_obstacle_cfg()
88 |
89 | obs1 = torch.tensor(centers[0]).repeat((batch_xt,1)).to(xs.device)
90 | obs2 = torch.tensor(centers[1]).repeat((batch_xt,1)).to(xs.device)
91 | obs3 = torch.tensor(centers[2]).repeat((batch_xt,1)).to(xs.device)
92 |
93 | dist1 = torch.norm(xs - obs1, dim=-1)
94 | dist2 = torch.norm(xs - obs2, dim=-1)
95 | dist3 = torch.norm(xs - obs3, dim=-1)
96 |
97 | cost1 = 1500 * (dist1 < radius)
98 | cost2 = 1500 * (dist2 < radius)
99 | cost3 = 1500 * (dist3 < radius)
100 |
101 | return cost1 + cost2 + cost3
102 |
103 | @torch.jit.script
104 | def obstacle_cost_fn_vneck(xs: torch.Tensor) -> torch.Tensor:
105 |
106 | a, b, c, centers = stunnel_obstacle_cfg()
107 |
108 | _xs = xs.reshape(-1,xs.shape[-1])
109 | x, y = _xs[:,0], _xs[:,1]
110 |
111 | d = a*(x-centers[0][0])**2 + b*(y-centers[0][1])**2
112 | c1 = 1500 * (d < c)
113 |
114 | d = a*(x-centers[1][0])**2 + b*(y-centers[1][1])**2
115 | c2 = 1500 * (d < c)
116 |
117 | return c1+c2
118 |
119 | @torch.jit.script
120 | def obstacle_cost_fn_stunnel(xs: torch.Tensor) -> torch.Tensor:
121 | c_sq, coef = vneck_obstacle_cfg()
122 |
123 | xs_sq = torch.square(xs)
124 | d = coef * xs_sq[..., 0] - xs_sq[..., 1]
125 |
126 | return 1500 * (d < -c_sq)
127 |
--------------------------------------------------------------------------------
/mfg/util.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Tuple
2 | import torch
3 |
4 |
5 | def assert_zero_grads(network: torch.nn.Module) -> None:
6 | for param in network.parameters():
7 | if not (param.grad is None or torch.allclose(param.grad, torch.zeros_like(param.grad))): debug()
8 | assert param.grad is None or torch.allclose(param.grad, torch.zeros_like(param.grad))
9 |
10 | @torch.no_grad()
11 | def proj_pca(xs_f: torch.Tensor, xs_b: torch.Tensor, reverse: bool) -> Tuple[torch.Tensor, torch.Tensor]:
12 | # xs: (batch, T, nx)
13 | # Only use final timestep of xs_f for PCA.
14 | batch, T, nx = xs_f.shape
15 |
16 | # (batch * T, nx)
17 | flat_xsf = xs_f.reshape(-1, *xs_f.shape[2:])
18 | flat_xsb = xs_b.reshape(-1, *xs_b.shape[2:])
19 |
20 | # Center by subtract mean.
21 | # (batch, nx)
22 | if reverse:
23 | # If reverse, use xs_b[0] instead of xs_f[T]
24 | final_xs_f = xs_b[:, 0, :]
25 | else:
26 | final_xs_f = xs_f[:, -1, :]
27 |
28 | mean_pca_xs = torch.mean(final_xs_f, dim=0, keepdim=True)
29 | final_xs_f -= mean_pca_xs
30 |
31 | # if batch is too large, it will run out of memory.
32 | if batch > 200:
33 | rand_idxs = torch.randperm(batch)[:200]
34 | final_xs_f = final_xs_f[rand_idxs]
35 |
36 | # U: (batch, k)
37 | # S: (k, k)
38 | # VT: (k, nx)
39 | U, S, VT = torch.linalg.svd(final_xs_f)
40 |
41 | # log.info("Singular values of xs_f at final timestep:")
42 | # log.info(S)
43 |
44 | # Keep the first and last directions.
45 | VT = VT[:2, :]
46 | # VT = VT[[0, -1], :]
47 |
48 | assert VT.shape == (2, nx)
49 | V = VT.T
50 |
51 | # Project both xs_f and xs_b onto V.
52 | flat_xsf -= mean_pca_xs
53 | flat_xsb -= mean_pca_xs
54 |
55 | proj_xs_f = flat_xsf @ V
56 | proj_xs_f = proj_xs_f.reshape(batch, T, *proj_xs_f.shape[1:])
57 |
58 | proj_xs_b = flat_xsb @ V
59 | proj_xs_b = proj_xs_b.reshape(batch, T, *proj_xs_b.shape[1:])
60 |
61 | return proj_xs_f, proj_xs_b
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .toy_net import build_toy_net_policy
2 | from .opinion_net import build_opinion_net_policy
3 |
--------------------------------------------------------------------------------
/models/opinion_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from options import Options
5 | from .util import SiLU, timestep_embedding, ResNet_FC
6 |
7 |
8 | class OpinionYImpl(torch.nn.Module):
9 | def __init__(self, data_dim: int, time_embed_dim: int, hid: int, out_hid: int):
10 | super().__init__()
11 |
12 | self.t_module = nn.Sequential(
13 | nn.Linear(time_embed_dim, hid),
14 | SiLU(),
15 | nn.Linear(hid, hid),
16 | )
17 | self.x_module = ResNet_FC(data_dim, hid, num_res_blocks=5)
18 |
19 | self.out_module = nn.Sequential(
20 | nn.Linear(hid + hid, out_hid),
21 | SiLU(),
22 | nn.Linear(out_hid, out_hid),
23 | SiLU(),
24 | nn.Linear(out_hid, 1),
25 | )
26 |
27 | def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
28 | t_out = self.t_module(t_emb)
29 | x_out = self.x_module(x)
30 |
31 | t_out = t_out.expand(x_out.shape)
32 |
33 | out = self.out_module(torch.cat([x_out, t_out], dim=-1))
34 |
35 | return out
36 |
37 |
38 | class OpinionY(torch.nn.Module):
39 | def __init__(self, data_dim: int = 1000, hid: int = 128, out_hid: int = 128, time_embed_dim: int = 128):
40 | super(OpinionY,self).__init__()
41 |
42 | self.time_embed_dim = time_embed_dim
43 | self.yt_impl = OpinionYImpl(data_dim, time_embed_dim, hid, out_hid)
44 | self.yt_impl = torch.jit.script(self.yt_impl)
45 |
46 | @property
47 | def inner_dtype(self):
48 | """
49 | Get the dtype used by the torso of the model.
50 | """
51 | return next(self.input_blocks.parameters()).dtype
52 |
53 | def forward(self, x:torch.Tensor, t: torch.Tensor) -> torch.Tensor:
54 | """
55 | Apply the model to an input batch.
56 | :param x: an [N x C x ...] Tensor of inputs.
57 | :param t: a 1-D batch of timesteps.
58 | """
59 |
60 | # make sure t.shape = [T]
61 | if len(t.shape)==0:
62 | t=t[None]
63 |
64 | t_emb = timestep_embedding(t, self.time_embed_dim)
65 |
66 | out = self.yt_impl(x, t_emb)
67 |
68 | return out
69 |
70 | class OpinionZImpl(torch.nn.Module):
71 | def __init__(self, data_dim: int, time_embed_dim: int, hid: int):
72 | super().__init__()
73 |
74 | self.t_module = nn.Sequential(
75 | nn.Linear(time_embed_dim, hid),
76 | SiLU(),
77 | nn.Linear(hid, hid),
78 | )
79 | self.x_module = ResNet_FC(data_dim, hid, num_res_blocks=5)
80 |
81 | self.out_module = nn.Sequential(
82 | nn.Linear(hid, hid),
83 | SiLU(),
84 | nn.Linear(hid, data_dim),
85 | )
86 |
87 | def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
88 | t_out = self.t_module(t_emb)
89 | x_out = self.x_module(x)
90 | out = self.out_module(x_out+t_out)
91 | return out
92 |
93 |
94 | class OpinionZ(torch.nn.Module):
95 | def __init__(self, data_dim=1000, hidden_dim=256, time_embed_dim=128):
96 | super(OpinionZ,self).__init__()
97 |
98 | self.time_embed_dim = time_embed_dim
99 | self.z_impl = OpinionZImpl(data_dim, time_embed_dim, hidden_dim)
100 | self.z_impl = torch.jit.script(self.z_impl)
101 |
102 | @property
103 | def inner_dtype(self):
104 | """
105 | Get the dtype used by the torso of the model.
106 | """
107 | return next(self.input_blocks.parameters()).dtype
108 |
109 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
110 | """
111 | Apply the model to an input batch.
112 | :param x: an [N x C x ...] Tensor of inputs.
113 | :param t: a 1-D batch of timesteps.
114 | """
115 |
116 | # make sure t.shape = [T]
117 | if len(t.shape)==0:
118 | t = t[None]
119 |
120 | t_emb = timestep_embedding(t, self.time_embed_dim)
121 | out = self.z_impl(x, t_emb)
122 |
123 | return out
124 |
125 | def build_opinion_net_policy(opt: Options, YorZ: str) -> torch.nn.Module:
126 | assert opt.x_dim == 1000
127 | # 2nets: {"hid":200, "out_hid":400, "time_embed_dim":256}
128 | return {
129 | "Y": OpinionY,
130 | "Z": OpinionZ,
131 | }.get(YorZ)()
132 |
--------------------------------------------------------------------------------
/models/toy_net.py:
--------------------------------------------------------------------------------
1 | from typing import Type, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from options import Options
7 |
8 | from .util import SiLU, timestep_embedding
9 |
10 |
11 | class ToyY(torch.nn.Module):
12 | def __init__(self, data_dim: int = 2, hidden_dim: int = 128, time_embed_dim: int = 128):
13 | super(ToyY, self).__init__()
14 |
15 | self.time_embed_dim = time_embed_dim
16 | hid = hidden_dim
17 |
18 | self.t_module = nn.Sequential(
19 | nn.Linear(self.time_embed_dim, hid),
20 | SiLU(),
21 | nn.Linear(hid, hid),
22 | )
23 |
24 | self.x_module = nn.Sequential(
25 | nn.Linear(data_dim, hid),
26 | SiLU(),
27 | nn.Linear(hid, hid),
28 | SiLU(),
29 | nn.Linear(hid, hid),
30 | )
31 |
32 | self.out_module = nn.Sequential(
33 | nn.Linear(hid + hid, hid),
34 | SiLU(),
35 | nn.Linear(hid, hid),
36 | SiLU(),
37 | nn.Linear(hid, 1),
38 | )
39 |
40 | @property
41 | def inner_dtype(self):
42 | """
43 | Get the dtype used by the torso of the model.
44 | """
45 | return next(self.input_blocks.parameters()).dtype
46 |
47 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
48 | """
49 | Apply the model to an input batch.
50 | :param x: an [N x C x ...] Tensor of inputs.
51 | :param t: a 1-D batch of timesteps.
52 | """
53 |
54 | # make sure t.shape = [T]
55 | if len(t.shape) == 0:
56 | t = t[None]
57 |
58 | t_emb = timestep_embedding(t, self.time_embed_dim)
59 | t_out = self.t_module(t_emb)
60 | x_out = self.x_module(x)
61 | xt_out = torch.cat([x_out, t_out], dim=1)
62 | out = self.out_module(xt_out)
63 |
64 | return out
65 |
66 |
67 | class ToyZ(torch.nn.Module):
68 | def __init__(self, data_dim: int = 2, hidden_dim: int = 128, time_embed_dim: int = 128):
69 | super(ToyZ, self).__init__()
70 |
71 | self.time_embed_dim = time_embed_dim
72 | hid = hidden_dim
73 |
74 | self.t_module = nn.Sequential(
75 | nn.Linear(self.time_embed_dim, hid),
76 | SiLU(),
77 | nn.Linear(hid, hid),
78 | )
79 |
80 | self.x_module = nn.Sequential(
81 | nn.Linear(data_dim, hid),
82 | SiLU(),
83 | nn.Linear(hid, hid),
84 | SiLU(),
85 | nn.Linear(hid, hid),
86 | )
87 |
88 | self.out_module = nn.Sequential(
89 | nn.Linear(hid, hid),
90 | SiLU(),
91 | nn.Linear(hid, data_dim),
92 | )
93 |
94 | @property
95 | def inner_dtype(self):
96 | """
97 | Get the dtype used by the torso of the model.
98 | """
99 | return next(self.input_blocks.parameters()).dtype
100 |
101 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
102 | """
103 | Apply the model to an input batch.
104 | :param x: an [N x C x ...] Tensor of inputs.
105 | :param t: a 1-D batch of timesteps.
106 | """
107 |
108 | # make sure t.shape = [T]
109 | if len(t.shape) == 0:
110 | t = t[None]
111 |
112 | t_emb = timestep_embedding(t, self.time_embed_dim)
113 | t_out = self.t_module(t_emb)
114 | x_out = self.x_module(x)
115 | out = self.out_module(x_out + t_out)
116 |
117 | return out
118 |
119 |
120 | def build_toy_net_policy(opt: Options, YorZ: str) -> torch.nn.Module:
121 | assert opt.x_dim == 2
122 | return {
123 | "Y": ToyY,
124 | "Z": ToyZ,
125 | }.get(YorZ)()
126 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | def timestep_embedding(timesteps, dim, max_period=10000):
9 | """
10 | Create sinusoidal timestep embeddings.
11 |
12 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
13 | These may be fractional.
14 | :param dim: the dimension of the output.
15 | :param max_period: controls the minimum frequency of the embeddings.
16 | :return: an [N x dim] Tensor of positional embeddings.
17 | """
18 | half = dim // 2
19 | freqs = torch.exp(
20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
21 | ).to(device=timesteps.device)
22 | args = timesteps[:, None].float() * freqs[None]
23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
24 | if dim % 2:
25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
26 | return embedding
27 |
28 | class SiLU(nn.Module):
29 | def forward(self, x):
30 | return x * torch.sigmoid(x)
31 |
32 | class ResNet_FC(nn.Module):
33 | def __init__(self, data_dim, hidden_dim, num_res_blocks):
34 | super().__init__()
35 | self.hidden_dim = hidden_dim
36 | self.map=nn.Linear(data_dim, hidden_dim)
37 | self.res_blocks = nn.ModuleList(
38 | [self.build_res_block() for _ in range(num_res_blocks)])
39 |
40 | def build_linear(self, in_features, out_features):
41 | linear = nn.Linear(in_features, out_features)
42 | return linear
43 |
44 | def build_res_block(self):
45 | hid = self.hidden_dim
46 | layers = []
47 | widths =[hid]*4
48 | for i in range(len(widths) - 1):
49 | layers.append(self.build_linear(widths[i], widths[i + 1]))
50 | layers.append(SiLU())
51 | return nn.Sequential(*layers)
52 |
53 | def forward(self, x):
54 | h=self.map(x)
55 | for res_block in self.res_blocks:
56 | h = (h + res_block(h)) / 2
57 | return h
58 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | from dataclasses import dataclass
6 | from typing import Dict, Optional
7 |
8 | import numpy as np
9 | import torch
10 |
11 | import configs
12 |
13 |
14 | @dataclass
15 | class Options:
16 | problem_name: str
17 | seed: int
18 | gpu: int
19 | load: Optional[str]
20 | dir: str
21 | group: str
22 | name: str
23 | log_fn: Optional[str]
24 | log_tb: bool
25 | cpu: bool
26 | t0: float
27 | T: float
28 | interval: int
29 | policy_net: str
30 | diffusion_std: float
31 | train_bs_x: int
32 | num_stage: int
33 | num_itr: int
34 | samp_bs: int
35 | samp_method: str
36 | rb_bs_x: int
37 | MF_cost: float
38 | lr: float
39 | lr_y: Optional[float]
40 | lr_gamma: float
41 | lr_step: int
42 | l2_norm: float
43 | optimizer: str
44 | noise_type: str
45 | ema: float
46 | snapshot_freq: int
47 | ckpt_freq: int
48 | sb_param: str
49 | use_rb_loss: bool
50 | multistep_td: bool
51 | buffer_size: int
52 | weighted_loss: bool
53 | x_dim: int
54 | device: str
55 | ckpt_path: str
56 | eval_path: str
57 | log_dir: str
58 | # Additional options set in problem config.
59 | weights: Optional[Dict[str, float]] = None
60 |
61 |
62 | def set():
63 | # fmt: off
64 | # --------------- basic ---------------
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--problem-name", type=str)
67 | parser.add_argument("--seed", type=int, default=0)
68 | parser.add_argument("--gpu", type=int, default=0, help="GPU device")
69 | parser.add_argument("--load", type=str, default=None, help="load the checkpoints")
70 | parser.add_argument("--dir", type=str, default=None, help="directory name to save the experiments under results/")
71 | parser.add_argument("--group", type=str, default='0', help="father node of directionary for saving checkpoint")
72 | parser.add_argument("--name", type=str, default=None, help="son node of directionary for saving checkpoint")
73 | parser.add_argument("--log-fn", type=str, default=None, help="name of tensorboard logging")
74 | parser.add_argument("--log-tb", action="store_true", help="logging with tensorboard")
75 | parser.add_argument("--cpu", action="store_true", help="use cpu device")
76 |
77 | # --------------- DeepGSB & MFG ---------------
78 | parser.add_argument("--t0", type=float, default=0.0, help="time integral start time")
79 | parser.add_argument("--T", type=float, default=1.0, help="time integral end time")
80 | parser.add_argument("--interval", type=int, default=100, help="number of interval")
81 | parser.add_argument("--policy-net", type=str, help="model class of policy network")
82 | parser.add_argument("--diffusion-std", type=float, default=1.0, help="diffusion scalar in SDE")
83 | parser.add_argument("--sb-param", type=str, choices=['actor-critic', 'critic'])
84 | parser.add_argument("--MF-cost", type=float, default=0.0, help="coefficient of MF cost")
85 |
86 | # --------------- training & sampling ---------------
87 | parser.add_argument("--train-bs-x", type=int, help="batch size for sampling data")
88 | parser.add_argument("--num-stage", type=int, help="number of stage")
89 | parser.add_argument("--num-itr", type=int, help="number of training iterations (for each stage)")
90 | parser.add_argument("--samp-bs", type=int, help="batch size for all trajectory sampling purposes")
91 | parser.add_argument("--samp-method", type=str, default='jacobi', choices=['jacobi','gauss']) # gauss seidel
92 | parser.add_argument("--rb-bs-x", type=int, help="batch size when sampling from replay buffer")
93 | parser.add_argument("--use-rb-loss", action="store_true", help="whether or not to use the replay buffer loss")
94 | parser.add_argument("--multistep-td", action="store_true", help="whether or not to use the multi-step TD loss")
95 | parser.add_argument("--buffer-size", type=int, default=20000, help="the maximum size of replay buffer")
96 | parser.add_argument("--weighted-loss", action="store_true", help="whether or not to reweight the combined loss")
97 |
98 | # --------------- optimizer and loss ---------------
99 | parser.add_argument("--lr", type=float, help="learning rate for Znet")
100 | parser.add_argument("--lr-y", type=float, default=None, help="learning rate for Ynet")
101 | parser.add_argument("--lr-gamma", type=float, default=1.0, help="learning rate decay ratio")
102 | parser.add_argument("--lr-step", type=int, default=1000, help="learning rate decay step size")
103 | parser.add_argument("--l2-norm", type=float, default=0.0, help="weight decay rate")
104 | parser.add_argument("--optimizer", type=str, default='AdamW', help="optmizer")
105 | parser.add_argument("--noise-type", type=str, default='gaussian', choices=['gaussian','rademacher'], help='choose noise type to approximate Trace term')
106 | parser.add_argument("--ema", type=float, default=0.99)
107 |
108 | # ---------------- evaluation ----------------
109 | parser.add_argument("--snapshot-freq", type=int, default=1, help="snapshot frequency w.r.t stages")
110 | parser.add_argument("--ckpt-freq", type=int, default=1, help="checkpoint saving frequency w.r.t stages")
111 |
112 | # fmt: on
113 |
114 | problem_name = parser.parse_args().problem_name
115 | sb_param = parser.parse_args().sb_param
116 |
117 | parser.set_defaults(**configs.get_default(problem_name, sb_param))
118 | opt = parser.parse_args()
119 | # ========= seed & torch setup =========
120 | if opt.seed is not None:
121 | # https://github.com/pytorch/pytorch/issues/7068
122 | seed = opt.seed
123 | random.seed(seed)
124 | os.environ["PYTHONHASHSEED"] = str(seed)
125 | np.random.seed(seed)
126 | torch.manual_seed(seed)
127 | torch.cuda.manual_seed(seed)
128 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
129 | torch.backends.cudnn.enabled = True
130 | torch.backends.cudnn.benchmark = True
131 | # torch.backends.cudnn.deterministic = True
132 |
133 | torch.set_default_tensor_type("torch.cuda.FloatTensor")
134 | # torch.autograd.set_detect_anomaly(True)
135 |
136 | # ========= auto setup & path handle =========
137 | opt.device = "cuda:" + str(opt.gpu)
138 |
139 | if opt.name is None:
140 | opt.name = opt.dir
141 |
142 | opt.ckpt_path = os.path.join("checkpoint", opt.group, opt.name)
143 | os.makedirs(opt.ckpt_path, exist_ok=True)
144 | if opt.snapshot_freq:
145 | opt.eval_path = os.path.join("results", opt.dir)
146 | os.makedirs(os.path.join(opt.eval_path, "forward"), exist_ok=True)
147 | os.makedirs(os.path.join(opt.eval_path, "backward"), exist_ok=True)
148 | os.makedirs(os.path.join(opt.eval_path, "logp"), exist_ok=True)
149 | if "opinion" in opt.problem_name:
150 | os.makedirs(
151 | os.path.join(opt.eval_path, "directional_sim_forward"), exist_ok=True
152 | )
153 | os.makedirs(
154 | os.path.join(opt.eval_path, "directional_sim_backward"), exist_ok=True
155 | )
156 |
157 | if opt.log_tb:
158 | opt.log_dir = os.path.join(
159 | "runs", opt.dir
160 | ) # if opt.log_fn is not None else None
161 | if os.path.exists(opt.log_dir):
162 | shutil.rmtree(opt.log_dir) # remove folder & its files
163 |
164 | opt = Options(**vars(opt))
165 | return opt
166 |
--------------------------------------------------------------------------------
/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: deepgsb
2 | channels:
3 | - anaconda
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - python=3.8
9 | - numpy
10 | - scipy
11 | - termcolor
12 | - easydict
13 | - ipdb
14 | - pytorch==1.11.0
15 | - torchvision
16 | - cudatoolkit=10.2
17 | - tqdm
18 | - scikit-learn
19 | - imageio
20 | - matplotlib
21 | - tensorboard
22 | - pip
23 | - pip:
24 | - colored-traceback
25 | - torch-ema
26 | - gdown
27 | - rich
28 | - geomloss
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 |
2 | DATE=$(date +%m.%d)
3 |
4 | EXP=$1
5 |
6 | if [[ "$EXP" == "GMM" || "$EXP" == "all" ]]; then
7 | # GMM (std1, obs1500, multistep, gauss)
8 | BASE='--problem-name GMM --ckpt-freq 2 --snapshot-freq 2 --log-tb'
9 | python main.py $BASE --sb-param critic --dir gmm-$DATE/deepgsb-c-std1
10 | python main.py $BASE --sb-param actor-critic --dir gmm-$DATE/deepgsb-ac-std1
11 | fi
12 |
13 | if [[ "$EXP" == "Vneck" || "$EXP" == "all" ]]; then
14 | # Vneck (std1, obs1500, mf0/3, multistep, jacobi, use_rb)
15 | BASE='--problem-name Vneck --ckpt-freq 2 --snapshot-freq 2 --log-tb'
16 | python main.py $BASE --sb-param critic --dir vneck-$DATE/deepgsb-c-std1-mf3 --MF-cost 3.0
17 | python main.py $BASE --sb-param critic --dir vneck-$DATE/deepgsb-c-std1-mf0 --MF-cost 0.0
18 |
19 | python main.py $BASE --sb-param actor-critic --dir vneck-$DATE/deepgsb-ac-std1-mf3 --MF-cost 3.0
20 | python main.py $BASE --sb-param actor-critic --dir vneck-$DATE/deepgsb-ac-std1-mf0 --MF-cost 0.0
21 | fi
22 |
23 | if [[ "$EXP" == "Stunnel" || "$EXP" == "all" ]]; then
24 | # Stunnel (obs1500, congestion1, multistep, jacobi, use_rb)
25 | BASE='--problem-name Stunnel --ckpt-freq 2 --snapshot-freq 2 --log-tb --MF-cost 0.5 '
26 |
27 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std0.5-mf1 --diffusion-std 0.5
28 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std1-mf1 --diffusion-std 1.0
29 | python main.py $BASE --sb-param critic --dir stunnel-$DATE/deepgsb-c-std2-mf1 --diffusion-std 2.0
30 |
31 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std0.5-mf1 --diffusion-std 0.5
32 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std1-mf1 --diffusion-std 1.0
33 | python main.py $BASE --sb-param actor-critic --dir stunnel-$DATE/deepgsb-ac-std2-mf1 --diffusion-std 2.0
34 | fi
35 |
36 | if [[ "$EXP" == "opinion" || "$EXP" == "all" ]]; then
37 | # opinion (std0.1, multistep, gauss, use_rb)
38 | BASE='--problem-name opinion --ckpt-freq 10 --snapshot-freq 2 --log-tb --MF-cost 1.0 --weighted-loss'
39 | python main.py $BASE --sb-param critic --dir opinion-2d-$DATE/deepgsb-c-std0.1-mf1-w
40 | python main.py $BASE --sb-param actor-critic --dir opinion-2d-$DATE/deepgsb-ac-std0.1-mf1-w
41 | fi
42 |
43 | if [[ "$EXP" == "opinion-1k" || "$EXP" == "all" ]]; then
44 | # opinion (std0.5, singlestep, gauss)
45 | BASE='--problem-name opinion_1k --ckpt-freq 20 --snapshot-freq 10 --log-tb --weighted-loss'
46 | python main.py $BASE --sb-param actor-critic --dir opinion-1k-$DATE/deepgsb-ac-std0.5-mf1-w --MF-cost 1.0
47 | python main.py $BASE --sb-param actor-critic --dir opinion-1k-$DATE/deepgsb-ac-std0.5-mf0-w --MF-cost 0.0
48 | fi
49 |
--------------------------------------------------------------------------------