├── .DS_Store
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE.md
├── README.md
├── data
├── MCMC_SyntheticData.py
└── plots
│ ├── EpistemicAleatoricUncertainty.png
│ ├── GMM_HMC1.gif
│ ├── HMC_Sampler.gif
│ ├── HMC_Sampler2.gif
│ ├── HMC_Sampler3.gif
│ └── HMC_Sampler4.gif
├── experiments
├── .DS_Store
├── AIS.py
├── DataViz.py
├── MCMC_HMCTest.py
├── MCMC_Test.py
├── Testing.py
└── functional_pytorch
├── models
└── MCMC_Models.py
└── src
├── MCMC_Acceptance.py
├── MCMC_Chain.py
├── MCMC_Optim.py
├── MCMC_ProbModel.py
├── MCMC_Sampler.py
├── MCMC_Utils.py
└── README.md
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # user files
7 | *.chain
8 | *.ipynb
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # celery beat schedule file
98 | celerybeat-schedule
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | }
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "{}"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright {yyyy} {name of copyright owner}
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # torch-MC^2 (torch-MCMC)
3 | HMC on 3 layer NN | HMC on GMM
4 | :-------------------------------------------:|:------------------------------:
5 |  | 
6 |
7 |
8 |
9 | This package implements a series of MCMC sampling algorithms in PyTorch in a modular way:
10 |
11 | - Metropolis Hastings
12 | - Stochastic Gradient Langevin Dynamics
13 | - (Stochastic) Hamiltonian Monte Carlo
14 | - Stochastic Gradient Nose-Hoover Thermostat
15 | - SG Riemann Hamiltonian Monte Carlo (coming ...)
16 |
17 | The focus lies on four core ingredients to MCMC with corresponding routines:
18 |
19 | - `MCMC.src.MCMC_ProbModel` : Probabilistic wrapper around your model providing a uniform interface
20 | - `MCMC.src.MCMC_Chain` : Markov Chain for storing samples√
21 | - `MCMC.src.MCMC_Optim` : MCMC_Optim parent class for handling parameters
22 | - `MCMC.src.MCMC_Sampler`: Sampler that binds it all together
23 |
24 |
25 | These classes and functions are constructed along the structure of the core PyTorch framework.
26 | Especially the gradient samplers are designed around PyTorch's `optim` class to handle all things related to parameters.
27 |
28 | # ProbModel
29 |
30 | The wrapper `MCMC.src.MCMC_ProbModel` defines are small set of functions which are required in order to allow the `Sampler_Chain` to interact with it and evaluate the relevant quantities.
31 |
32 | Any parameter in the model that we wish to sample from has to be designated a `torch.nn.Parameter()`.
33 | This could be as simple as a single particle that we move around a 2-D distribution or a full neural network.
34 |
35 | It has four methods which have to be defined by the user:
36 |
37 | `MCMC_ProbModel.log_prob()`:
38 |
39 | Evaluates the log_probability of the likelihood of the model.
40 |
41 | It returns a dictionary which the first entry being the log-probability, i.e. `{"log_prob": -20.03}`.
42 |
43 | Moreover, additional evaluation metrics can be added to the dictionary as well, i.e. `{"log_prob": -20.03, "Accuracy": 0.58}`.
44 | The sampler will inspect the dictionary returned by the evaluation of the model and will create corresponding running averages of the used metrics.
45 |
46 | `MCMC_ProbModel.reset_parameters()`:
47 |
48 | Any value that we want to sample has to be declared as a `torch.nn.Parameter()` such that the `MCMC_Optims` can track the values in the background.
49 | `reset_parameters()` is mainly used to reinitialize the model.
50 |
51 | `MCMC_ProbModel.pretrain()`:
52 |
53 | In cases where we want a good initial guess to start our markov chain, we can implement a pretrain method which will optimize the parameters in some user defined manner.
54 | If `pretrain=True` is passed during initialization of the sampler, it will assume that the `MCMC_ProbModel.pretrain()` is implemented.
55 |
56 | In order to allow more sophisticated dynamic samplers such as `HMC_Sampler` to properly sample mini-batches, the probmodel should be initialized with a dataloader that takes care of sampling minibatches.
57 | That way, dynamic samplers can simple access `probmodel.dataloader`.
58 |
59 | # Chain
60 |
61 | This is just a convenience container that stores the sampled values and can be queried for specific values to determine the progress of the sampling chain.
62 | The samples of the parameters of the model are stored as a list `chain.samples` where each entry is PyTorch's very own `state_dict()`.
63 |
64 | After the sampler is finished the samples of the model can be accessed through the property `chain.samples` which returns a list of `state_dict()`'s that can be loaded into the model.
65 |
66 | An example:
67 |
68 | ```
69 |
70 | for sample_state_dict in chain.samples:
71 |
72 | self.load_state_dict(sample_state_dict)
73 |
74 | ... do something like ensemble prediction ...
75 | ```
76 |
77 | `MCMC_Chain` is implemented as a mutable sequence and allows the concatination of tuples `(probmodel/torch.state_dict, log_probs: dict/odict, accept: bool)` and the concatenation of entire chains.
78 |
79 | # MCMC_Optim
80 |
81 | The `MCMC_Optim`'s inherit from PyTorch's very own `Optimizers` and make working with gradients just so significantely more pleasant.
82 |
83 | By calling `MCMC_Optim.step()` they propose a new a set of parameters for `ProbModel`, the `log_prob()` of which is evaluated by `MCMC_Sampler`.
84 |
85 | # MCMC_Sampler
86 |
87 | The core component that ties everything together.
88 |
89 | 1. `MCMC_Optim` does a `step()` and proposes new parameters for the `ProbModel`.
90 | 2. `MCMC_Sampler` evaluates the `log_prob()` of the `ProbModel` and determines the acceptance of the proposal.
91 | 3. If accepted, the `ProbModel` is passed to the `MCMC_Chain` to be saved
92 | 4. If not accepted, we play the game again with a new proposal.
93 |
94 | # What's the data structure underneath?
95 |
96 | Each sampler uses the following datastructure:
97 |
98 | ```
99 | Sampler:
100 | - Sampler_Chain #1
101 | - ProbModel #1
102 | - Optim #1
103 | - Chain #1
104 | - Sampler_Chain #2
105 | - ProbModel #2
106 | - Optim #2
107 | - Chain #2
108 | - Sampler_Chain #3
109 | .
110 | .
111 | .
112 | .
113 | .
114 | .
115 |
116 | ```
117 |
118 | By packaging the optimizers and probmodels directly into the chain, these chains can be run completely independently, possibly even on multi-GPU systems.
119 |
120 | # Final Note
121 |
122 | The are way more sophisticated sampling packages out there such as Pyro, Stan and PyMC.
123 | Yet all of these packages require implementing the models explicitely for these frameworks.
124 | This package aims at providing MCMC sampling for **native** PyTorch Models such that the infamous Anon Reviewer 2 can be satisfied who requests a MCMC benchmark of an experiment.
125 |
126 | # Final Final Note
127 |
128 | May your chains hang low
129 |
130 | https://www.youtube.com/watch?v=4SBN_ikibtg
--------------------------------------------------------------------------------
/data/MCMC_SyntheticData.py:
--------------------------------------------------------------------------------
1 | import future, sys, os, datetime, argparse
2 | # print(os.path.dirname(sys.executable))
3 | import torch
4 | import numpy as np
5 | import matplotlib
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 | from matplotlib.lines import Line2D
9 |
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 | from torch.nn import Module, Parameter
14 | from torch.nn import Linear, Tanh, ReLU
15 | import torch.nn.functional as F
16 |
17 | Tensor = torch.Tensor
18 | FloatTensor = torch.FloatTensor
19 |
20 | torch.set_printoptions(precision=4, sci_mode=False)
21 | np.set_printoptions(precision=4, suppress=True)
22 |
23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD
24 |
25 | cwd = os.path.abspath(os.getcwd())
26 | os.chdir(cwd)
27 |
28 | params = argparse.ArgumentParser()
29 | params.add_argument('-xyz', type=str, default='test_xyz')
30 |
31 | params = params.parse_args()
32 |
33 |
34 | def generate_linear_regression_data(num_samples=100, m=1.0, b=-1.0, y_noise=1.0, x_noise=.01, plot=False):
35 | x = torch.linspace(-2, 2, num_samples).reshape(-1, 1)
36 | x += x_noise * torch.randn_like(x)
37 | y = m * x + b
38 | y += y_noise * torch.randn_like(y)
39 |
40 | if plot:
41 | plt.scatter(x, y)
42 | plt.show()
43 |
44 | return x, y
45 |
46 |
47 | def generate_multimodal_linear_regression(num_samples, y_noise=1, x_noise=1, plot=False):
48 | x1, y1 = generate_linear_regression_data(num_samples=num_samples // 10 * int(3), m=-1., b=0, y_noise=0.1, x_noise=0.1, plot=False)
49 | x2, y2 = generate_linear_regression_data(num_samples=num_samples // 10 * int(7), m=2, b=0, y_noise=0.1, x_noise=0.1, plot=False)
50 |
51 | x = torch.cat([x1, x2], dim=0).float()
52 | y = torch.cat([y1, y2], dim=0).float()
53 |
54 | if plot:
55 | plt.scatter(x, y, s=1)
56 | plt.show()
57 |
58 | return x, y
59 |
60 |
61 | def generate_nonstationary_data(num_samples=1000, y_constant_noise_std=0.1, y_nonstationary_noise_std=1., plot=False):
62 | x = np.linspace(-0.35, 0.45, num_samples)
63 | x_noise = np.random.normal(0., 0.01, size=x.shape)
64 |
65 | constant_noise = np.random.normal(0, y_constant_noise_std, size=x.shape)
66 | std = np.linspace(0, y_nonstationary_noise_std, num_samples) # * _y_noise_std
67 | non_stationary_noise = np.random.normal(loc=0, scale=std)
68 |
69 | y = x + 0.3 * np.sin(2 * np.pi * (x + x_noise)) + 0.3 * np.sin(4 * np.pi * (x + x_noise)) + non_stationary_noise + constant_noise
70 |
71 | x = torch.from_numpy(x).reshape(-1, 1).float()
72 | y = torch.from_numpy(y).reshape(-1, 1).float()
73 |
74 | x = (x - x.mean(dim=0))/(x.std(dim=0)+1e-3)
75 | y = (y - y.mean(dim=0))/(y.std(dim=0)+1e-3)
76 |
77 | if plot:
78 | plt.scatter(x, y)
79 | plt.show()
80 |
81 | return x, y
--------------------------------------------------------------------------------
/data/plots/EpistemicAleatoricUncertainty.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/EpistemicAleatoricUncertainty.png
--------------------------------------------------------------------------------
/data/plots/GMM_HMC1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/GMM_HMC1.gif
--------------------------------------------------------------------------------
/data/plots/HMC_Sampler.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler.gif
--------------------------------------------------------------------------------
/data/plots/HMC_Sampler2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler2.gif
--------------------------------------------------------------------------------
/data/plots/HMC_Sampler3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler3.gif
--------------------------------------------------------------------------------
/data/plots/HMC_Sampler4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler4.gif
--------------------------------------------------------------------------------
/experiments/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/experiments/.DS_Store
--------------------------------------------------------------------------------
/experiments/AIS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | from matplotlib import cm
5 | from matplotlib.colors import to_rgba
6 | from tqdm import tqdm
7 | from scipy.integrate import quad
8 | import numpy as np
9 |
10 |
11 | class Normal1D:
12 | def __init__(self, mean, std):
13 | self.mean = mean
14 | self.std = std
15 |
16 | def sample(self, num_samples=1):
17 | return torch.normal(self.mean, self.std, size=(num_samples,))
18 |
19 | def prob(self, x):
20 | var = self.std**2
21 | return (
22 | 1
23 | / (self.std * torch.sqrt(2 * torch.tensor(torch.pi)))
24 | * torch.exp(-0.5 * (x - self.mean) ** 2 / var)
25 | )
26 |
27 | def energy(self, x):
28 | return 0.5 * (x - self.mean) ** 2 / self.std**2
29 |
30 | @property
31 | def Z(self):
32 | """Partition function for the normal distribution Z = (sqrt(2 * pi * std**2)"""
33 | return self.std * torch.sqrt(2 * torch.tensor(torch.pi))
34 |
35 | def log_prob(self, x):
36 | var = self.std**2
37 | return (
38 | -0.5 * torch.log(2 * torch.tensor(torch.pi) * var)
39 | - 0.5 * (x - self.mean) ** 2 / var
40 | )
41 |
42 |
43 | class GaussianMixture1D:
44 | def __init__(self, means, stds, weights):
45 | self.means = torch.tensor(means, dtype=torch.float32)
46 | self.stds = torch.tensor(stds, dtype=torch.float32)
47 | self.weights = torch.tensor(weights, dtype=torch.float32)
48 | self.weights = self.weights / self.weights.sum() # Normalize weights
49 |
50 | def sample(self, num_samples=1):
51 | component = torch.multinomial(self.weights, num_samples, replacement=True)
52 | samples = torch.normal(self.means[component], self.stds[component])
53 | return samples
54 |
55 | def prob(self, x):
56 | x = x.unsqueeze(-1) # Shape (N, 1)
57 | probs = (
58 | 1
59 | / (self.stds * torch.sqrt(2 * torch.tensor(torch.pi)))
60 | * torch.exp(-0.5 * (x - self.means) ** 2 / (self.stds**2))
61 | )
62 | weighted_probs = probs * self.weights
63 | return weighted_probs.sum(dim=-1)
64 |
65 | def log_prob(self, x):
66 | x = x.unsqueeze(-1)
67 | log_probs = (
68 | -0.5 * torch.log(2 * torch.tensor(torch.pi) * self.stds**2)
69 | - 0.5 * (x - self.means) ** 2 / (self.stds**2)
70 | + torch.log(self.weights)
71 | )
72 | return torch.logsumexp(log_probs, dim=-1)
73 |
74 | def energy(self, x):
75 | # Negative log probability (up to constant)
76 | return -self.log_prob(x) + 1
77 |
78 |
79 | def sgl_sampler(energy_fn, x_init, lr=1e-2, n_steps=1000, noise_scale=1.0):
80 | """
81 | Stochastic Gradient Langevin Dynamics (SGLD) sampler.
82 | Args:
83 | energy_fn: Callable, computes energy for input x (requires_grad=True).
84 | x_init: Initial sample (torch.tensor, requires_grad=False).
85 | lr: Learning rate (step size).
86 | n_steps: Number of sampling steps.
87 | noise_scale: Multiplier for injected noise (default 1.0).
88 | Returns:
89 | samples: Tensor of shape (n_steps,).
90 | """
91 | x = x_init.clone().detach().requires_grad_(True)
92 | samples = []
93 | pbar = tqdm(range(n_steps))
94 | for _ in pbar:
95 | lr_t = lr * 0.5 * (1 + torch.cos(torch.tensor(_ / n_steps * torch.pi)))
96 | pbar.set_description(f"lr={lr_t:.5f}")
97 | x.requires_grad_(True)
98 | energy = energy_fn(x)
99 | grad = torch.autograd.grad(energy.sum(), x)[0]
100 | noise = torch.randn_like(x) * lr**0.5
101 | lr_t = lr / 2 * (1 + torch.cos(torch.tensor(_ / n_steps * torch.pi)))
102 | x = (x - lr_t * grad + noise).detach()
103 | samples.append(x.detach().clone())
104 | return x
105 |
106 |
107 | def metropolishastings_sampler(energy_fn, x_init, n_steps=1000, proposal_std=0.5):
108 | """
109 | Metropolis-Hastings sampler.
110 | Args:
111 | energy_fn: Callable, computes energy for input x (requires_grad=True).
112 | x_init: Initial sample (torch.tensor, requires_grad=False).
113 | n_steps: Number of sampling steps.
114 | proposal_std: Standard deviation for the proposal distribution.
115 | Returns:
116 | samples: Tensor of shape (n_steps,).
117 | """
118 | x = x_init
119 | samples = []
120 | # progress_bar = tqdm(total=n_steps, desc="MH")
121 | acceptance_ratio_ema, ema_weight = None, 0.9999
122 | pbar = tqdm(range(n_steps), desc="MH")
123 | for _ in pbar:
124 | x_new = x + torch.randn_like(x) * proposal_std
125 | log_acceptance_ratio = energy_fn(x) - energy_fn(x_new)
126 | acceptance_ratio = torch.exp(log_acceptance_ratio)
127 | accept = torch.rand_like(acceptance_ratio) < acceptance_ratio
128 |
129 | # Debiased running average (corrects for initial bias)
130 | if acceptance_ratio_ema is None:
131 | acceptance_ratio_ema = accept.int().float().mean()
132 | ema_correction = 1.0
133 | else:
134 | acceptance_ratio_ema = (
135 | acceptance_ratio_ema * ema_weight
136 | + accept.int().float().mean() * (1 - ema_weight)
137 | )
138 | ema_correction = 1 - ema_weight ** (_ + 1)
139 | debiased_acceptance = acceptance_ratio_ema / ema_correction
140 | pbar.set_postfix({"Accept": float(acceptance_ratio_ema.mean())})
141 | x = torch.where(accept, x_new, x)
142 | samples.append(x.detach().clone())
143 | return x
144 |
145 |
146 | def compute_partition_function_1d(energy_fn, x_min, x_max):
147 | integrand = lambda x: np.exp(-energy_fn(torch.tensor(x)).item())
148 | Z, _ = quad(integrand, x_min, x_max)
149 | return Z
150 |
151 |
152 | p0 = Normal1D(torch.tensor(0.0), torch.tensor(5.0))
153 |
154 | N = 10_000
155 | samples = p0.sample(N)
156 | prob = p0.prob(samples)
157 | log_prob = p0.log_prob(samples)
158 | Z = p0.Z
159 |
160 | p1 = Normal1D(torch.tensor(0.0), torch.tensor(1.0))
161 | print(p0.Z, p1.Z)
162 |
163 | log_w = -p1.energy(samples) - p0.log_prob(samples)
164 | logZ1 = torch.logsumexp(log_w, dim=0) - torch.log(torch.tensor(N, dtype=torch.float32))
165 | print(f"Estimated Z1: {torch.exp(logZ1)}")
166 | print(f"True Z1: {p1.Z}")
167 |
168 | samples = samples[samples < 3]
169 | samples = samples[samples > -3]
170 | logprob1 = -logZ1 - p1.energy(samples)
171 | prob1 = torch.exp(logprob1)
172 |
173 | plt.figure(figsize=(8, 4))
174 | plt.hist(
175 | samples.numpy(),
176 | bins=50,
177 | weights=prob1.numpy(),
178 | density=True,
179 | alpha=0.6,
180 | label="Weighted Histogram",
181 | )
182 | # sns.kdeplot(samples.numpy(), weights=prob1.numpy(), color='red', label='KDE')
183 | plt.title("Histogram of prob1 at sample locations")
184 | plt.xlabel("x")
185 | plt.ylabel("Probability Density")
186 | plt.legend()
187 | plt.show()
188 |
189 | gmm = GaussianMixture1D(
190 | means=[-3.0, 0.0, 3.0],
191 | stds=[0.5, 0.5, 0.5],
192 | weights=[2, 0.3, 0.1],
193 | )
194 | x = torch.linspace(-5, 5, 200)
195 |
196 |
197 | plt.figure(figsize=(8, 4))
198 | plt.plot(
199 | x.numpy(),
200 | p1.energy(x).numpy(),
201 | label="True Normal(0, 1) Energy",
202 | color="blue",
203 | )
204 | plt.plot(
205 | x.numpy(),
206 | gmm.energy(x).numpy(),
207 | label="GMM Energy",
208 | color="orange",
209 | )
210 |
211 | # Define colors for interpolation
212 | color1 = to_rgba("blue")
213 | color2 = to_rgba("orange")
214 |
215 | for t in [0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 1.0]:
216 | # Linear interpolation in color space
217 | interp_color = tuple((1 - t) * c1 + t * c2 for c1, c2 in zip(color1, color2))
218 | energy_t = gmm.energy(x) ** t * p1.energy(x) ** (1 - t)
219 | plt.plot(
220 | x.numpy(),
221 | energy_t.numpy(),
222 | label=f"t={t:.2f}",
223 | linestyle="--",
224 | color=interp_color,
225 | )
226 |
227 | plt.title("Energy Interpolation between Normal and GMM")
228 | plt.xlabel("x")
229 | plt.ylabel("Energy")
230 | plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
231 | plt.show()
232 |
233 | t = 0.0
234 | energy_t = lambda x: gmm.energy(x) ** t * p1.energy(x) ** (1 - t)
235 | energy_t = lambda x: p1.energy(x)
236 | samples = metropolishastings_sampler(
237 | energy_fn=energy_t, x_init=torch.randn(5_000), n_steps=2_000, proposal_std=0.1
238 | )
239 |
240 | print(f"{samples.shape=}")
241 | # Count unique samples and their frequencies
242 | log_weights = -energy_t(samples)
243 | bins = torch.histc(samples, bins=50, min=samples.min(), max=samples.max())
244 | bin_edges = torch.linspace(samples.min(), samples.max(), steps=101)
245 | bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
246 | bin_probs = bins / bins.sum()
247 | Z_binned = (torch.exp(-energy_t(bin_centers)) * (bin_edges[1:] - bin_edges[:-1])).sum()
248 | print(f"Z_binned: {Z_binned}")
249 |
250 |
251 | # Weight by frequency of each unique sample
252 | # logZ_est = torch.logsumexp(log_weights, dim=0) - torch.log(
253 | # torch.tensor(len(samples), dtype=torch.float32)
254 | # )
255 |
256 | Z_est = torch.exp(log_weights).sum() / len(samples)
257 | print(
258 | f"Estimated partition function Z_t: {Z_est.item()} vs {compute_partition_function_1d(energy_t, -4, 4)}"
259 | )
260 |
261 | _ = plt.hist(
262 | samples.detach().numpy(),
263 | bins=100,
264 | density=True,
265 | alpha=0.6,
266 | label="SGLD Samples",
267 | )
268 | # plt.plot(x.numpy(), gmm.prob(x).numpy(), label="GMM PDF", color="orange")
269 | # plt.plot(x.numpy(), energy_t(x).numpy(), label="Normal PDF", color="blue")
270 |
--------------------------------------------------------------------------------
/experiments/DataViz.py:
--------------------------------------------------------------------------------
1 | import future, sys, os, datetime, argparse
2 | # print(os.path.dirname(sys.executable))
3 | import torch
4 | import numpy as np
5 | import matplotlib
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 | from matplotlib.lines import Line2D
9 |
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 | from torch.nn import Module, Parameter
14 | from torch.nn import Linear, Tanh, ReLU
15 | import torch.nn.functional as F
16 |
17 | Tensor = torch.Tensor
18 | FloatTensor = torch.FloatTensor
19 |
20 | torch.set_printoptions(precision=4, sci_mode=False)
21 | np.set_printoptions(precision=4, suppress=True)
22 |
23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD
24 |
25 | import scipy
26 | import scipy as sp
27 | from scipy.io import loadmat as sp_loadmat
28 | import copy
29 |
30 | cwd = os.path.abspath(os.getcwd())
31 | os.chdir(cwd)
32 |
33 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel
34 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNN
35 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MetropolisHastings_Sampler, MALA_Sampler, HMC_Sampler
36 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_multimodal_linear_regression, generate_nonstationary_data
37 | from pytorch_MCMC.src.MCMC_Utils import posterior_dist
38 | from Utils.Utils import RunningAverageMeter, str2bool
39 |
40 | def create_supervised_gif(model, chain, data):
41 |
42 | x, y = data
43 | x_min = 2 * x.min()
44 | x_max = 2 * x.max()
45 |
46 | data, mu, _ = model.predict(chain)
47 |
48 | gif_frames = []
49 |
50 | samples = [400, 600, 800, 1000]
51 | samples += range(2000, len(chain)//2, 2000)
52 | samples += range(len(chain)//2, len(chain), 4000)
53 |
54 | # print(len(samples))
55 | # exit()
56 |
57 | for i in range(400,len(chain), 500):
58 | print(f"{i}/{len(samples)}")
59 |
60 | # _, _, std = model.predict(chain[:i])
61 | fig = plt.figure()
62 | _, mu, std = model.predict(chain[399:i])
63 | plt.fill_between(data.squeeze(), mu + std, mu - std, color='red', alpha=0.25)
64 | plt.fill_between(data.squeeze(), mu + 2 * std, mu - 2 * std, color='red', alpha=0.10)
65 | plt.fill_between(data.squeeze(), mu + 3 * std, mu - 3 * std, color='red', alpha=0.05)
66 |
67 | plt.plot(data.squeeze(), mu, c='red')
68 | plt.scatter(x, y, alpha=1, s=1, color='blue')
69 | plt.ylim(2 * y.min(), 2 * y.max())
70 | plt.xlim(x_min, x_max)
71 | plt.grid()
72 |
73 | fig.canvas.draw() # draw the canvas, cache the renderer
74 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
75 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
76 |
77 | # plt.show()
78 | gif_frames.append(image)
79 |
80 | import imageio
81 | imageio.mimsave('HMC_Sampler5.gif', gif_frames, fps=4)
82 |
83 | def create_gmm_gif(chains):
84 |
85 | # num_samples = [40, 80, 120, 160, 200, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
86 | num_samples = [x for x in range(3,2000,40)]
87 | num_samples += [x for x in range(2000, len(chains[0]), 500)]
88 |
89 | gif_frames = []
90 |
91 | for num_samples_ in num_samples:
92 |
93 | print(f"{num_samples_}/{len(chains[0])}")
94 |
95 | post = []
96 |
97 | for chain in chains:
98 |
99 | for model_state_dict in chain.samples[:num_samples_]:
100 | post.append(list(model_state_dict.values())[0])
101 |
102 | post = torch.cat(post, dim=0)
103 |
104 | fig = plt.figure()
105 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]),
106 | density=True)
107 | plt.colorbar(hist2d[3])
108 | # plt.show()
109 |
110 |
111 | fig.canvas.draw() # draw the canvas, cache the renderer
112 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
113 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
114 |
115 |
116 | gif_frames.append(image)
117 |
118 | import imageio
119 | imageio.mimsave('GMM_HMC1.gif', gif_frames, fps=4)
120 |
121 |
122 |
123 | if True:
124 | chain = torch.load("hmc_regnn_ss0.01_len10000.chain")
125 | chain = chain[:50000]
126 | data = generate_nonstationary_data(num_samples=1000, plot=False, x_noise_std=0.01, y_noise_std=0.1)
127 | nn = RegressionNN(*data, batch_size=50)
128 |
129 | create_supervised_gif(nn, chain, data)
130 |
131 | if False:
132 |
133 | chains = torch.load("GMM_Chains.chain")
134 |
135 | create_gmm_gif(chains)
136 |
137 | posterior_dist(chains[0][:50])
138 | plt.show()
--------------------------------------------------------------------------------
/experiments/MCMC_HMCTest.py:
--------------------------------------------------------------------------------
1 | # cleaner interaction
2 |
3 |
4 | import os, argparse
5 |
6 | # print(os.path.dirname(sys.executable))
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 |
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 |
14 | SEED = 0
15 | # torch.manual_seed(SEED)
16 | # np.random.seed(SEED)
17 |
18 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNNHomo, RegressionNNHetero
19 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MALA_Sampler, HMC_Sampler, SGNHT_Sampler
20 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_nonstationary_data
21 | from Utils.Utils import str2bool
22 |
23 | params = argparse.ArgumentParser(description='parser example')
24 | params.add_argument('-logname', type=str, default='Tmp')
25 |
26 | params.add_argument('-num_samples', type=int, default=200)
27 | params.add_argument('-model', choices=['gmm', 'linreg', 'regnn'], default='gmm')
28 | params.add_argument('-sampler', choices=['sgld','mala', 'hmc', 'sgnht'], default='sgnht')
29 |
30 | params.add_argument('-step_size', type=float, default=0.1)
31 | params.add_argument('-num_steps', type=int, default=10000)
32 | params.add_argument('-pretrain', type=str2bool, default=False)
33 | params.add_argument('-tune', type=str2bool, default=False)
34 | params.add_argument('-burn_in', type=int, default=2000)
35 | # params.add_argument('-num_chains', type=int, default=1)
36 | params.add_argument('-num_chains', type=int, default=os.cpu_count() - 1)
37 | params.add_argument('-batch_size', type=int, default=50)
38 |
39 | params.add_argument('-hmc_traj_length', type=int, default=20)
40 |
41 | params.add_argument('-val_split', type=float, default=0.9) # first part is train, second is val i.e. val_split=0.8 -> 80% train, 20% val
42 |
43 | params.add_argument('-val_prediction_steps', type=int, default=50)
44 | params.add_argument('-val_converge_criterion', type=int, default=20)
45 | params.add_argument('-val_per_epoch', type=int, default=200)
46 |
47 | params = params.parse_args()
48 |
49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50 |
51 | if torch.cuda.is_available():
52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53 | FloatTensor = torch.cuda.FloatTensor
54 | Tensor = torch.cuda.FloatTensorf
55 | else:
56 | device = torch.device('cpu')
57 | FloatTensor = torch.FloatTensor
58 | Tensor = torch.FloatTensor
59 |
60 | if params.model == 'gmm':
61 | gmm = GMM()
62 | gmm.generate_surface(plot=True)
63 |
64 | if params.sampler=='sgld':
65 | sampler = SGLD_Sampler(probmodel=gmm,
66 | step_size=params.step_size,
67 | num_steps=params.num_steps,
68 | burn_in=params.burn_in,
69 | pretrain=params.pretrain,
70 | tune=params.tune,
71 | num_chains=params.num_chains)
72 | elif params.sampler=='mala':
73 | sampler = MALA_Sampler(probmodel=gmm,
74 | step_size=params.step_size,
75 | num_steps=params.num_steps,
76 | burn_in=params.burn_in,
77 | pretrain=params.pretrain,
78 | tune=params.tune,
79 | num_chains=params.num_chains)
80 | elif params.sampler=='hmc':
81 | sampler = HMC_Sampler(probmodel=gmm,
82 | step_size=params.step_size,
83 | num_steps=params.num_steps,
84 | burn_in=params.burn_in,
85 | pretrain=params.pretrain,
86 | tune=params.tune,
87 | traj_length=params.hmc_traj_length,
88 | num_chains=params.num_chains)
89 | elif params.sampler == 'sgnht':
90 | sampler = SGNHT_Sampler(probmodel=gmm,
91 | step_size=params.step_size,
92 | num_steps=params.num_steps,
93 | burn_in=params.burn_in,
94 | pretrain=params.pretrain,
95 | tune=params.tune,
96 | traj_length=params.hmc_traj_length,
97 | num_chains=params.num_chains)
98 | else:
99 | raise ValueError('No Sampler defined')
100 |
101 | sampler.sample_chains()
102 | sampler.posterior_dist()
103 | # sampler.trace()
104 |
105 | # plt.plot(sampler.chain.accepted_steps)
106 | plt.show()
107 |
108 | elif params.model == 'linreg':
109 |
110 | x, y = generate_linear_regression_data(num_samples=params.num_samples, m=-2., b=-1, y_noise=0.5)
111 | linreg = LinReg(x, y)
112 | # sampler = MetropolisHastings_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, tune=params.tune)
113 | sampler = SGLD_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in,
114 | pretrain=params.pretrain, tune=params.tune)
115 | sampler.sample_chains()
116 | sampler.posterior_dist()
117 | linreg.predict(sampler.chain)
118 |
119 | elif params.model == 'regnn':
120 |
121 | model = ['homo', 'hetero'][1]
122 | if model=='homo':
123 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, y_constant_noise_std=0.25, y_nonstationary_noise_std=0.01)
124 | nn = RegressionNNHomo(x, y, batch_size=params.batch_size)
125 | elif model=='hetero':
126 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, y_constant_noise_std=0.001, y_nonstationary_noise_std=0.5)
127 | nn = RegressionNNHetero(x, y, batch_size=params.batch_size)
128 |
129 | if params.sampler == 'sgld':
130 | sampler = SGLD_Sampler(probmodel=nn,
131 | step_size=params.step_size,
132 | num_steps=params.num_steps,
133 | burn_in=params.burn_in,
134 | pretrain=params.pretrain,
135 | tune=params.tune,
136 | num_chains=params.num_chains)
137 | elif params.sampler == 'mala':
138 | sampler = MALA_Sampler(probmodel=nn,
139 | step_size=params.step_size,
140 | num_steps=params.num_steps,
141 | burn_in=params.burn_in,
142 | pretrain=params.pretrain,
143 | tune=params.tune,
144 | num_chains=params.num_chains)
145 | elif params.sampler == 'hmc':
146 | sampler = HMC_Sampler(probmodel=nn,
147 | step_size=params.step_size,
148 | num_steps=params.num_steps,
149 | burn_in=params.burn_in,
150 | pretrain=params.pretrain,
151 | tune=params.tune,
152 | traj_length=params.hmc_traj_length,
153 | num_chains=params.num_chains)
154 | elif params.sampler == 'sgnht':
155 | sampler = SGNHT_Sampler(probmodel=nn,
156 | step_size=params.step_size,
157 | num_steps=params.num_steps,
158 | burn_in=params.burn_in,
159 | pretrain=params.pretrain,
160 | tune=params.tune,
161 | traj_length=params.hmc_traj_length,
162 | num_chains=params.num_chains)
163 | else:
164 | raise ValueError('No Sampler defined')
165 | chains = sampler.sample_chains()
166 |
167 | nn.predict(chains, plot=True)
168 |
--------------------------------------------------------------------------------
/experiments/MCMC_Test.py:
--------------------------------------------------------------------------------
1 | # cleaner interaction
2 |
3 |
4 | import os, argparse
5 |
6 | # print(os.path.dirname(sys.executable))
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 |
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 |
14 | SEED = 0
15 | # torch.manual_seed(SEED)
16 | # np.random.seed(SEED)
17 |
18 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNN
19 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MALA_Sampler, HMC_Sampler
20 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_nonstationary_data
21 | from Utils.Utils import str2bool
22 |
23 | params = argparse.ArgumentParser(description='parser example')
24 | params.add_argument('-logname', type=str, default='Tmp')
25 |
26 | params.add_argument('-num_samples', type=int, default=1000)
27 | params.add_argument('-model', choices=['gmm', 'linreg', 'regnn'], default='gmm')
28 | params.add_argument('-sampler', choices=['sgld', 'mala', 'hmc'], default='hmc')
29 |
30 | params.add_argument('-step_size', type=float, default=0.1)
31 | params.add_argument('-num_steps', type=int, default=10000)
32 | params.add_argument('-pretrain', type=str2bool, default=False)
33 | params.add_argument('-tune', type=str2bool, default=False)
34 | params.add_argument('-burn_in', type=int, default=1000)
35 | # params.add_argument('-num_chains', type=int, default=1)
36 | params.add_argument('-num_chains', type=int, default=os.cpu_count() - 1)
37 | params.add_argument('-batch_size', type=int, default=50)
38 |
39 | params.add_argument('-hmc_traj_length', type=int, default=20)
40 |
41 | params.add_argument('-val_split', type=float, default=0.9) # first part is train, second is val i.e. val_split=0.8 -> 80% train, 20% val
42 |
43 | params.add_argument('-val_prediction_steps', type=int, default=50)
44 | params.add_argument('-val_converge_criterion', type=int, default=20)
45 | params.add_argument('-val_per_epoch', type=int, default=200)
46 |
47 | params = params.parse_args()
48 |
49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50 |
51 | if torch.cuda.is_available():
52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53 | FloatTensor = torch.cuda.FloatTensor
54 | Tensor = torch.cuda.FloatTensorf
55 | else:
56 | device = torch.device('cpu')
57 | FloatTensor = torch.FloatTensor
58 | Tensor = torch.FloatTensor
59 |
60 | if params.model == 'gmm':
61 | gmm = GMM()
62 | # gmm.generate_surface(plot=True)
63 |
64 | if params.sampler == 'sgld':
65 | sampler = SGLD_Sampler(probmodel=gmm,
66 | step_size=params.step_size,
67 | num_steps=params.num_steps,
68 | burn_in=params.burn_in,
69 | pretrain=params.pretrain,
70 | tune=params.tune,
71 | num_chains=params.num_chains)
72 | elif params.sampler == 'mala':
73 | sampler = MALA_Sampler(probmodel=gmm,
74 | step_size=params.step_size,
75 | num_steps=params.num_steps,
76 | burn_in=params.burn_in,
77 | pretrain=params.pretrain,
78 | tune=params.tune,
79 | num_chains=params.num_chains)
80 | elif params.sampler == 'hmc':
81 | sampler = HMC_Sampler(probmodel=gmm,
82 | step_size=params.step_size,
83 | num_steps=params.num_steps,
84 | burn_in=params.burn_in,
85 | pretrain=params.pretrain,
86 | tune=params.tune,
87 | traj_length=params.hmc_traj_length,
88 | num_chains=params.num_chains)
89 | sampler.sample_chains()
90 | sampler.posterior_dist()
91 | # sampler.trace()
92 |
93 | # plt.plot(sampler.chain.accepted_steps)
94 | plt.show()
95 |
96 | elif params.model == 'linreg':
97 |
98 | x, y = generate_linear_regression_data(num_samples=params.num_samples, m=-2., b=-1, y_noise=0.5)
99 | linreg = LinReg(x, y)
100 | # sampler = MetropolisHastings_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, tune=params.tune)
101 | sampler = SGLD_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in,
102 | pretrain=params.pretrain, tune=params.tune)
103 | sampler.sample_chains()
104 | sampler.posterior_dist()
105 | linreg.predict(sampler.chain)
106 |
107 | elif params.model == 'regnn':
108 |
109 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, x_noise_std=0.01, y_noise_std=0.1)
110 | # print(f'{x.shape=} {y.shape=}')
111 | nn = RegressionNN(x, y, batch_size=params.batch_size)
112 |
113 | # sampler = SGLD_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, tune=params.tune)
114 | # sampler = MALA_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain,
115 | # tune=params.tune)
116 | sampler = HMC_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain,
117 | tune=params.tune)
118 | # sampler = MetropolisHastings_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, tune=params.tune)
119 | sampler.sample_chains()
120 | torch.save(sampler.chain, 'hmc_regnn_ss0.01_len10000.chain')
121 | nn.predict(sampler.chain)
122 |
--------------------------------------------------------------------------------
/experiments/Testing.py:
--------------------------------------------------------------------------------
1 | import future, sys, os, datetime, argparse
2 | # print(os.path.dirname(sys.executable))
3 | import torch
4 | import numpy as np
5 | import matplotlib
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 | from matplotlib.lines import Line2D
9 |
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 | from torch.nn import Module, Parameter
14 | from torch.nn import Linear, Tanh, ReLU
15 | import torch.nn.functional as F
16 |
17 | Tensor = torch.Tensor
18 | FloatTensor = torch.FloatTensor
19 |
20 | torch.set_printoptions(precision=4, sci_mode=False)
21 | np.set_printoptions(precision=4, suppress=True)
22 |
23 | import scipy
24 | import scipy as sp
25 | from scipy.io import loadmat as sp_loadmat
26 | import copy
27 |
28 | cwd = os.path.abspath(os.getcwd())
29 | os.chdir(cwd)
30 |
31 | params = argparse.ArgumentParser()
32 | params.add_argument('-xyz', type=str, default='test_xyz')
33 |
34 | params = params.parse_args()
--------------------------------------------------------------------------------
/experiments/functional_pytorch:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | inputs = torch.randn(64, 3)
4 | targets = torch.randn(64, 3)
5 | model = torch.nn.Linear(3, 3)
6 |
7 | params = dict(model.named_parameters())
8 |
9 |
10 | def compute_loss(params, inputs, targets):
11 | prediction = torch.func.functional_call(model, params, (inputs,))
12 | return torch.nn.functional.mse_loss(prediction, targets)
13 |
14 |
15 | grads = torch.func.grad(compute_loss)(params, inputs, targets)
16 |
17 |
18 | # %%
19 | import copy
20 |
21 | import torch
22 |
23 | num_models = 5
24 | batch_size = 64
25 | in_features, out_features = 3, 3
26 | models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
27 | data = torch.randn(batch_size, 3)
28 |
29 | # Construct a version of the model with no memory by putting the Tensors on
30 | # the meta device.
31 | base_model = copy.deepcopy(models[0])
32 | base_model.to("meta")
33 |
34 | params, buffers = torch.func.stack_module_state(models)
35 |
36 |
37 | # It is possible to vmap directly over torch.func.functional_call,
38 | # but wrapping it in a function makes it clearer what is going on.
39 | def call_single_model(params, buffers, data):
40 | return torch.func.functional_call(base_model, (params, buffers), (data,))
41 |
42 |
43 | def call_single_loss(params, buffers, data):
44 | prediction = torch.func.functional_call(base_model, (params, buffers), (data,))
45 | return torch.nn.functional.mse_loss(prediction, targets)
46 |
47 |
48 | output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
49 | grad = torch.func.grad(
50 | lambda p, b, d: torch.sum(torch.vmap(call_single_loss, (0, 0, None))(p, b, d))
51 | )(params, buffers, data)
52 | assert output.shape == (num_models, batch_size, out_features)
53 |
--------------------------------------------------------------------------------
/models/MCMC_Models.py:
--------------------------------------------------------------------------------
1 | import future, sys, os, datetime, argparse
2 | # print(os.path.dirname(sys.executable))
3 | import torch
4 | import numpy as np
5 | import matplotlib
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 | from matplotlib.lines import Line2D
9 |
10 | # matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 | from torch.nn import Module, Parameter, Sequential
14 | from torch.nn import Linear, Tanh, ReLU, CELU
15 | from torch.utils.data import TensorDataset, DataLoader
16 | import torch.nn.functional as F
17 | from torch.distributions import MultivariateNormal, Categorical, Normal
18 |
19 | from joblib import Parallel, delayed
20 |
21 | Tensor = torch.Tensor
22 | FloatTensor = torch.FloatTensor
23 |
24 | torch.set_printoptions(precision=4, sci_mode=False)
25 | np.set_printoptions(precision=4, suppress=True)
26 |
27 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD
28 |
29 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel
30 |
31 | class GMM(ProbModel):
32 |
33 | def __init__(self):
34 |
35 | dataloader = DataLoader(TensorDataset(torch.zeros(1,2))) # bogus dataloader
36 | ProbModel.__init__(self, dataloader)
37 |
38 | self.means = FloatTensor([[-1, -1.25], [-1, 1.25], [1.5, 1]])
39 | # self.means = FloatTensor([[-1,-1.25]])
40 | self.num_dists = self.means.shape[0]
41 | I = FloatTensor([[1, 0], [0, 1]])
42 | I_compl = FloatTensor([[0, 1], [1, 0]])
43 | self.covars = [I * 0.5, I * 0.5, I * 0.5 + I_compl * 0.3]
44 | # self.covars = [I * 0.9, I * 0.9, I * 0.9 + I_compl * 0.3]
45 | self.weights = [0.4, 0.2, 0.4]
46 | self.dists = []
47 |
48 | for mean, covar in zip(self.means, self.covars):
49 | self.dists.append(MultivariateNormal(mean, covar))
50 |
51 | self.X_grid = None
52 | self.Y_grid = None
53 | self.surface = None
54 |
55 | self.param = torch.nn.Parameter(self.sample())
56 |
57 |
58 |
59 | def forward(self, x=None):
60 |
61 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(x)) for dist, weight in zip(self.dists, self.weights)], dim=1)
62 | log_prob = torch.log(torch.sum(log_probs, dim=1))
63 |
64 | return log_prob
65 |
66 | def log_prob(self, *x):
67 |
68 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(self.param)) for dist, weight in zip(self.dists, self.weights)], dim=1)
69 | log_prob = torch.log(torch.sum(log_probs, dim=1))
70 |
71 | return {'log_prob': log_prob}
72 |
73 | def prob(self, x):
74 |
75 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(x)) for dist, weight in zip(self.dists, self.weights)], dim=1)
76 | log_prob = torch.sum(log_probs, dim=1)
77 | return log_prob
78 |
79 | def sample(self, _shape=(1,)):
80 |
81 | probs = torch.ones(self.num_dists) / self.num_dists
82 | categorical = Categorical(probs)
83 | sampled_dists = categorical.sample(_shape)
84 |
85 | samples = []
86 | for sampled_dist in sampled_dists:
87 | sample = self.dists[sampled_dist].sample((1,))
88 | samples.append(sample)
89 |
90 | samples = torch.cat(samples)
91 |
92 | return samples
93 |
94 | def reset_parameters(self):
95 |
96 | self.param.data = self.sample()
97 |
98 | def generate_surface(self, plot_min=-3, plot_max=3, plot_res=500, plot=False):
99 |
100 | # print('in surface')
101 |
102 | x = np.linspace(plot_min, plot_max, plot_res)
103 | y = np.linspace(plot_min, plot_max, plot_res)
104 | X, Y = np.meshgrid(x, y)
105 |
106 | self.X_grid = X
107 | self.Y_grid = Y
108 |
109 | points = FloatTensor(np.stack((X.ravel(), Y.ravel())).T) # .requires_grad_()
110 |
111 | probs = self.prob(points).view(plot_res, plot_res)
112 | self.surface = probs.numpy()
113 |
114 | area = ((plot_max - plot_min) / plot_res) ** 2
115 | sum_px = probs.sum() * area # analogous to integrating cubes: volume is probs are the height times the area
116 |
117 | fig = plt.figure(figsize=(10, 10))
118 |
119 | contour = plt.contourf(self.X_grid, self.Y_grid, self.surface, levels=20)
120 | plt.xlim(-3, 3)
121 | plt.ylim(-3, 3)
122 | plt.grid()
123 | cbar = fig.colorbar(contour)
124 | if plot: plt.show()
125 |
126 | return fig
127 |
128 | class LinReg(ProbModel):
129 |
130 | def __init__(self, x, y):
131 | super().__init__()
132 |
133 | self.data = x
134 | self.target = y
135 |
136 | self.dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0])
137 |
138 | self.m = Parameter(FloatTensor(1 * torch.randn((1,))))
139 | self.b = Parameter(FloatTensor(1 * torch.randn((1,))))
140 | # self.log_noise = Parameter(FloatTensor([-1.]))
141 | self.log_noise = FloatTensor([0])
142 |
143 | def reset_parameters(self):
144 | torch.nn.init.normal_(self.m, std=.1)
145 | torch.nn.init.normal_(self.b, std=.1)
146 |
147 | def sample(self):
148 | self.reset_parameters()
149 |
150 | def forward(self, x):
151 |
152 | return self.m * x + self.b
153 |
154 | def log_prob(self):
155 |
156 | data, target = next(self.dataloader.__iter__())
157 | # data, target = self.data, self.target
158 | mu = self.forward(data)
159 | log_prob = Normal(mu, F.softplus(self.log_noise)).log_prob(target).mean()
160 |
161 | return {'log_prob': log_prob}
162 |
163 | @torch.no_grad()
164 | def predict(self, chain):
165 |
166 | x_min = 2*self.data.min()
167 | x_max = 2*self.data.max()
168 | data = torch.arange(x_min, x_max).reshape(-1,1)
169 |
170 | pred = []
171 | for model_state_dict in chain.samples:
172 | self.load_state_dict(model_state_dict)
173 | # data.append(self.data)
174 | pred_i = self.forward(data)
175 | pred.append(pred_i)
176 |
177 | pred = torch.stack(pred)
178 | # data = torch.stack(data)
179 |
180 | mu = pred.mean(dim=0).squeeze()
181 | std = pred.std(dim=0).squeeze()
182 |
183 | # print(f'{data.shape=}')
184 | # print(f'{pred.shape=}')
185 |
186 | plt.plot(data, mu, alpha=1., color='red')
187 | plt.fill_between(data.squeeze(), mu+std, mu-std, color='red', alpha=0.25)
188 | plt.fill_between(data.squeeze(), mu+2*std, mu-2*std, color='red', alpha=0.10)
189 | plt.fill_between(data.squeeze(), mu+3*std, mu-3*std, color='red', alpha=0.05)
190 | plt.scatter(self.data, self.target, alpha=1, s=1, color='blue')
191 | plt.ylim(pred.min(), pred.max())
192 | plt.xlim(x_min, x_max)
193 | plt.show()
194 |
195 | class RegressionNNHomo(ProbModel):
196 |
197 | def __init__(self, x, y, batch_size=1):
198 |
199 | self.data = x
200 | self.target = y
201 |
202 | # dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0], drop_last=False)
203 | dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=batch_size, drop_last=False)
204 |
205 | ProbModel.__init__(self, dataloader)
206 |
207 | num_hidden = 50
208 | self.model = Sequential(Linear(1, num_hidden),
209 | ReLU(),
210 | Linear(num_hidden, num_hidden),
211 | ReLU(),
212 | # Linear(num_hidden, num_hidden),
213 | # ReLU(),
214 | # Linear(num_hidden, num_hidden),
215 | # ReLU(),
216 | Linear(num_hidden, 1))
217 |
218 | self.log_std = Parameter(FloatTensor([-1]))
219 |
220 | def reset_parameters(self):
221 | for module in self.model.modules():
222 | if isinstance(module, Linear):
223 | module.reset_parameters()
224 |
225 | self.log_std.data = FloatTensor([3.])
226 |
227 | def sample(self):
228 | self.reset_parameters()
229 |
230 | def forward(self, x):
231 | pred = self.model(x)
232 | return pred
233 |
234 | def log_prob(self, data, target):
235 |
236 | # if data is None and target is None:
237 | # data, target = next(self.dataloader.__iter__())
238 |
239 | mu = self.forward(data)
240 | mse = F.mse_loss(mu,target)
241 |
242 | log_prob = Normal(mu, F.softplus(self.log_std)).log_prob(target).mean()*len(self.dataloader.dataset)
243 |
244 | return {'log_prob': log_prob, 'MSE': mse.detach_()}
245 |
246 | def pretrain(self):
247 |
248 | num_epochs = 200
249 | optim = torch.optim.Adam(self.parameters(), lr=0.01)
250 |
251 | # print(f"{F.softplus(self.log_std)=}")
252 |
253 | progress = tqdm(range(num_epochs))
254 | for epoch in progress:
255 | for batch_i, (data, target) in enumerate(self.dataloader):
256 | optim.zero_grad()
257 | mu = self.forward(data)
258 | loss = -Normal(mu, F.softplus(self.log_std)).log_prob(target).mean()
259 | mse_loss = F.mse_loss(mu, target)
260 | loss.backward()
261 | optim.step()
262 |
263 | desc = f'Pretraining: MSE:{mse_loss:.3f}'
264 | progress.set_description(desc)
265 |
266 | # print(f"{F.softplus(self.log_std)=}")
267 |
268 | @torch.no_grad()
269 | def predict(self, chains, plot=False):
270 |
271 | x_min = 2*self.data.min()
272 | x_max = 2*self.data.max()
273 | data = torch.linspace(x_min, x_max).reshape(-1,1)
274 |
275 | def parallel_predict(parallel_chain):
276 | parallel_pred = []
277 | for model_state_dict in parallel_chain.samples[::50]:
278 | self.load_state_dict(model_state_dict)
279 | pred_mu_i = self.forward(data)
280 | parallel_pred.append(pred_mu_i)
281 | try:
282 | parallel_pred_mu = torch.stack(parallel_pred) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N])
283 | return parallel_pred_mu
284 | except:
285 | pass
286 |
287 |
288 | parallel_pred = Parallel(n_jobs=len(chains))(delayed(parallel_predict)(chain) for chain in chains)
289 |
290 | pred = [parallel_pred_i for parallel_pred_i in parallel_pred if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ]
291 | # pred_log_std = [parallel_pred_i for parallel_pred_i in parallel_pred_log_std if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ]
292 |
293 |
294 | pred = torch.cat(pred).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim
295 | std = F.softplus(self.log_std)
296 |
297 |
298 | epistemic = pred.std(dim=0)
299 | aleatoric = std
300 | total_std = (epistemic ** 2 + aleatoric ** 2) ** 0.5
301 |
302 | mu = pred.mean(dim=0)
303 | std = std.mean(dim=0)
304 |
305 | data.squeeze_()
306 |
307 | if plot:
308 | fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
309 | axs = axs.flatten()
310 |
311 | axs[0].scatter(self.data, self.target, alpha=1, s=1, color='blue')
312 | axs[0].plot(data.squeeze(), mu, alpha=1., color='red')
313 | axs[0].fill_between(data, mu + total_std, mu - total_std, color='red', alpha=0.25)
314 | axs[0].fill_between(data, mu + 2 * total_std, mu - 2 * total_std, color='red', alpha=0.10)
315 | axs[0].fill_between(data, mu + 3 * total_std, mu - 3 * total_std, color='red', alpha=0.05)
316 |
317 | [axs[1].plot(data, pred, alpha=0.1, color='red') for pred in pred]
318 | axs[1].scatter(self.data, self.target, alpha=1, s=1, color='blue')
319 |
320 | axs[2].scatter(self.data, self.target, alpha=1, s=1, color='blue')
321 | axs[2].plot(data, mu, color='red')
322 | axs[2].fill_between(data, mu - aleatoric, mu + aleatoric, color='red', alpha=0.25, label='Aleatoric')
323 | axs[2].legend()
324 |
325 | axs[3].scatter(self.data, self.target, alpha=1, s=1, color='blue')
326 | axs[3].plot(data, mu, color='red')
327 | axs[3].fill_between(data, mu - epistemic, mu + epistemic, color='red', alpha=0.25, label='Epistemic')
328 | axs[3].legend()
329 |
330 | plt.ylim(2 * self.target.min(), 2 * self.target.max())
331 | plt.xlim(x_min, x_max)
332 | plt.show()
333 |
334 | class RegressionNNHetero(ProbModel):
335 |
336 | def __init__(self, x, y, batch_size=1):
337 |
338 | self.data = x
339 | self.target = y
340 |
341 | dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0], drop_last=False)
342 | # dataloader = DataLoader(TensorDataset(x, y), shuffle=True, batch_size=batch_size, drop_last=False)
343 |
344 | ProbModel.__init__(self, dataloader)
345 |
346 | num_hidden = 50
347 | self.model = Sequential(Linear(1, num_hidden),
348 | ReLU(),
349 | Linear(num_hidden, num_hidden),
350 | ReLU(),
351 | Linear(num_hidden, num_hidden),
352 | ReLU(),
353 | # Linear(num_hidden, num_hidden),
354 | # ReLU(),
355 | Linear(num_hidden, 2))
356 |
357 | def reset_parameters(self):
358 | for module in self.model.modules():
359 | if isinstance(module, Linear):
360 | module.reset_parameters()
361 |
362 | def sample(self):
363 | self.reset_parameters()
364 |
365 | def forward(self, x):
366 | pred = self.model(x)
367 | mu, log_std = torch.chunk(pred, chunks=2, dim=-1)
368 | return mu, log_std
369 |
370 | def log_prob(self, data, target):
371 |
372 | # if data is None and target is None:
373 | # data, target = next(self.dataloader.__iter__())
374 |
375 | mu, log_std = self.forward(data)
376 | mse = F.mse_loss(mu,target)
377 |
378 | log_prob = Normal(mu, F.softplus(log_std)).log_prob(target).mean()*len(self.dataloader.dataset)
379 |
380 | return {'log_prob': log_prob, 'MSE': mse.detach_()}
381 |
382 | def pretrain(self):
383 |
384 | num_epochs = 100
385 | optim = torch.optim.Adam(self.parameters(), lr=0.001)
386 |
387 | progress = tqdm(range(num_epochs))
388 | for epoch in progress:
389 | for batch_i, (data, target) in enumerate(self.dataloader):
390 | optim.zero_grad()
391 | mu, log_std = self.forward(data)
392 | loss = -Normal(mu, F.softplus(log_std)).log_prob(target).mean()
393 | mse_loss = F.mse_loss(mu, target)
394 | loss.backward()
395 | optim.step()
396 |
397 | desc = f'Pretraining: MSE:{mse_loss:.3f}'
398 | progress.set_description(desc)
399 |
400 | @torch.no_grad()
401 | def predict(self, chains, plot=False):
402 |
403 | x_min = 2*self.data.min()
404 | x_max = 2*self.data.max()
405 | data = torch.linspace(x_min, x_max).reshape(-1,1)
406 |
407 | def parallel_predict(parallel_chain):
408 | parallel_pred_mu = []
409 | parallel_pred_log_std = []
410 | for model_state_dict in parallel_chain.samples[::50]:
411 | self.load_state_dict(model_state_dict)
412 | pred_mu_i, pred_log_std_i = self.forward(data)
413 | parallel_pred_mu.append(pred_mu_i)
414 | parallel_pred_log_std.append(pred_log_std_i)
415 |
416 | try:
417 | parallel_pred_mu = torch.stack(parallel_pred_mu) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N])
418 | parallel_pred_log_std = torch.stack(parallel_pred_log_std) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N])
419 | return parallel_pred_mu, parallel_pred_log_std
420 | except:
421 | pass
422 |
423 |
424 | parallel_pred_mu, parallel_pred_log_std = zip(*Parallel(n_jobs=len(chains))(delayed(parallel_predict)(chain) for chain in chains))
425 |
426 | pred_mu = [parallel_pred_i for parallel_pred_i in parallel_pred_mu if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ]
427 | pred_log_std = [parallel_pred_i for parallel_pred_i in parallel_pred_log_std if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ]
428 |
429 |
430 | pred_mu = torch.cat(pred_mu).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim
431 | pred_log_std = torch.cat(pred_log_std).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim
432 |
433 | mu = pred_mu.squeeze()
434 | std = F.softplus(pred_log_std).squeeze()
435 |
436 | epistemic = mu.std(dim=0)
437 | aleatoric = (std**2).mean(dim=0)**0.5
438 | total_std = (epistemic**2 + aleatoric**2)**0.5
439 |
440 | mu = mu.mean(dim=0)
441 | std = std.mean(dim=0)
442 |
443 | data.squeeze_()
444 |
445 | if plot:
446 |
447 | fig, axs = plt.subplots(2,2, sharex=True, sharey=True)
448 | axs = axs.flatten()
449 |
450 | axs[0].scatter(self.data, self.target, alpha=1, s=1, color='blue')
451 | axs[0].plot(data.squeeze(), mu, alpha=1., color='red')
452 | axs[0].fill_between(data, mu+total_std, mu-total_std, color='red', alpha=0.25)
453 | axs[0].fill_between(data, mu+2*total_std, mu-2*total_std, color='red', alpha=0.10)
454 | axs[0].fill_between(data, mu+3*total_std, mu-3*total_std, color='red', alpha=0.05)
455 |
456 | [axs[1].plot(data, pred, alpha=0.1, color='red') for pred in pred_mu]
457 | axs[1].scatter(self.data, self.target, alpha=1, s=1, color='blue')
458 |
459 | axs[2].scatter(self.data, self.target, alpha=1, s=1, color='blue')
460 | axs[2].plot(data, mu, color='red')
461 | axs[2].fill_between(data, mu-aleatoric, mu+aleatoric, color='red', alpha=0.25, label='Aleatoric')
462 | axs[2].legend()
463 |
464 | axs[3].scatter(self.data, self.target, alpha=1, s=1, color='blue')
465 | axs[3].plot(data, mu, color='red')
466 | axs[3].fill_between(data, mu-epistemic, mu+epistemic, color='red', alpha=0.25, label='Epistemic')
467 | axs[3].legend()
468 |
469 | plt.ylim(2*self.target.min(), 2*self.target.max())
470 | plt.xlim(x_min, x_max)
471 | plt.show()
472 |
473 |
474 | return data, mu, std
475 |
476 |
477 | if __name__ == '__main__':
478 |
479 | pass
--------------------------------------------------------------------------------
/src/MCMC_Acceptance.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4 |
5 | class MetropolisHastingsAcceptance():
6 |
7 | def __init__(self):
8 |
9 | pass
10 |
11 | def __call__(self, log_prob_proposal, log_prob_state):
12 | '''
13 | accept = min ( p(x') / p(x) , 1)
14 | log_accept = min( log_p(x') - log_p(x) , 1)
15 | log_accept = min (log_ratio, 1)
16 | '''
17 |
18 |
19 | if not torch.isnan(log_prob_proposal) or not torch.isinf(log_prob_proposal):
20 | log_ratio = (log_prob_proposal - log_prob_state)
21 | log_ratio = torch.min(log_ratio, torch.zeros_like(log_ratio))
22 |
23 | log_u = torch.zeros_like(log_ratio).uniform_(0,1).log()
24 |
25 | log_accept = torch.gt(log_ratio, log_u)
26 | log_accept = log_accept.bool().item()
27 |
28 | return log_accept, log_ratio
29 |
30 | elif torch.isnan(log_prob_proposal) or torch.isinf(log_prob_proposal):
31 | exit(f'log_prob_proposal is nan or inf {log_prob_proposal}')
32 | return False, torch.Tensor([-1])
33 |
34 | class SDE_Acceptance():
35 |
36 | def __init__(self):
37 |
38 | pass
39 |
40 | def __call__(self, log_prob_proposal, log_prob_state):
41 |
42 | return True, torch.Tensor([0.])
--------------------------------------------------------------------------------
/src/MCMC_Chain.py:
--------------------------------------------------------------------------------
1 | import future, sys, os, datetime, argparse, copy, warnings, time
2 | from collections import MutableSequence, Iterable, OrderedDict
3 | from itertools import compress
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | from matplotlib.lines import Line2D
10 | matplotlib.rcParams["figure.figsize"] = [10, 10]
11 |
12 | import torch
13 | from torch.nn import Module, Parameter
14 | from torch.nn import Linear, Tanh, ReLU
15 | import torch.nn.functional as F
16 |
17 | Tensor = torch.Tensor
18 | FloatTensor = torch.FloatTensor
19 |
20 | torch.set_printoptions(precision=4, sci_mode=False)
21 | np.set_printoptions(precision=4, suppress=True)
22 |
23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD
24 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel
25 | from pytorch_MCMC.src.MCMC_Optim import SGLD_Optim, MetropolisHastings_Optim, MALA_Optim, HMC_Optim, SGNHT_Optim
26 | from pytorch_MCMC.src.MCMC_Acceptance import SDE_Acceptance, MetropolisHastingsAcceptance
27 | from Utils.Utils import RunningAverageMeter
28 |
29 | '''
30 | Python Container Time Complexity: https://wiki.python.org/moin/TimeComplexity
31 | '''
32 |
33 |
34 | class Chain(MutableSequence):
35 |
36 | '''
37 | A container for storing the MCMC chain conveniently:
38 | samples: list of state_dicts
39 | log_probs: list of log_probs
40 | accepts: list of bools
41 | state_idx:
42 | init index of last accepted via np.where(accepts==True)[0][-1]
43 | can be set via len(samples) while sampling
44 |
45 | @property
46 | samples: filters the samples
47 |
48 |
49 | '''
50 |
51 | def __init__(self, probmodel=None):
52 |
53 | super().__init__()
54 |
55 | if probmodel is None:
56 | '''
57 | Create an empty chain
58 | '''
59 | self.state_dicts = []
60 | self.log_probs = []
61 | self.accepts = []
62 |
63 | if probmodel is not None:
64 | '''
65 | Initialize chain with given model
66 | '''
67 | assert isinstance(probmodel, ProbModel)
68 |
69 | self.state_dicts = [copy.deepcopy(probmodel.state_dict())]
70 | log_prob = probmodel.log_prob(*next(probmodel.dataloader.__iter__()))
71 | log_prob['log_prob'].detach_()
72 | self.log_probs = [copy.deepcopy(log_prob)]
73 | self.accepts = [True]
74 | self.last_accepted_idx = 0
75 |
76 | self.running_avgs = {}
77 | for key, value in log_prob.items():
78 | self.running_avgs.update({key: RunningAverageMeter(0.99)})
79 |
80 | self.running_accepts = RunningAverageMeter(0.999)
81 |
82 | def __len__(self):
83 | return len(self.state_dicts)
84 |
85 | def __iter__(self):
86 | return zip(self.state_dicts, self.log_probs, self.accepts)
87 |
88 | def __delitem__(self):
89 | raise NotImplementedError
90 |
91 | def __setitem__(self):
92 | raise NotImplementedError
93 |
94 | def insert(self):
95 | raise NotImplementedError
96 |
97 | def __repr__(self):
98 | return f'MCMC Chain: Length:{len(self)} Accept:{self.accept_ratio:.2f}'
99 |
100 | def __getitem__(self, i):
101 | chain = copy.deepcopy(self)
102 | chain.state_dicts = self.samples[i]
103 | chain.log_probs = self.log_probs[i]
104 | chain.accepts = self.accepts[i]
105 | return chain
106 |
107 | def __add__(self, other):
108 |
109 | if type(other) in [tuple, list]:
110 | assert len(other) == 3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4"
111 | self.append(*other)
112 | elif isinstance(other, Chain):
113 | self.cat(other)
114 |
115 | return self
116 |
117 | def __iadd__(self, other):
118 |
119 | if type(other) in [tuple, list]:
120 | assert len(other)==3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4"
121 | self.append(*other)
122 | elif isinstance(other, Chain):
123 | self.cat_chains(other)
124 |
125 | return self
126 |
127 | @property
128 | def state_idx(self):
129 | '''
130 | Returns the index of the last accepted sample a.k.a. the state of the chain
131 |
132 | '''
133 | if not hasattr(self, 'state_idx'):
134 | '''
135 | If the chain hasn't a state_idx, compute it from self.accepts by taking the last True of self.accepts
136 | '''
137 | self.last_accepted_idx = np.where(self.accepts==True)[0][-1]
138 | return self.last_accepted_idx
139 | else:
140 | '''
141 | Check that the state of the chain is actually the last True in self.accepts
142 | '''
143 | last_accepted_sample_ = np.where(self.accepts == True)[0][-1]
144 | assert last_accepted_sample_ == self.last_accepted_idx
145 | assert self.accepts[self.last_accepted_idx]==True
146 | return self.last_accepted_idx
147 |
148 |
149 | @property
150 | def samples(self):
151 | '''
152 | Filters the list of state_dicts with the list of bools from self.accepts
153 | :return: list of accepted state_dicts
154 | '''
155 | return list(compress(self.state_dicts, self.accepts))
156 |
157 | @property
158 | def accept_ratio(self):
159 | '''
160 | Sum the boolean list (=total number of Trues) and divides it by its length
161 | :return: float valued accept ratio
162 | '''
163 | return sum(self.accepts)/len(self.accepts)
164 |
165 | @property
166 | def state(self):
167 | return {'state_dict': self.state_dicts[self.last_accepted_idx], 'log_prob': self.log_probs[self.last_accepted_idx]}
168 |
169 | def cat_chains(self, other):
170 |
171 | assert isinstance(other, Chain)
172 | self.state_dicts += other.state_dicts
173 | self.log_probs += other.log_probs
174 | self.accepts += other.accepts
175 |
176 | for key, value in other.running_avgs.items():
177 | self.running_avgs[key].avg = 0.5*self.running_avgs[key].avg + 0.5 * other.running_avgs[key].avg
178 |
179 |
180 | def append(self, probmodel, log_prob, accept):
181 |
182 | if isinstance(probmodel, ProbModel):
183 | params_state_dict = copy.deepcopy(probmodel.state_dict())
184 | elif isinstance(probmodel, OrderedDict):
185 | params_state_dict = copy.deepcopy(probmodel)
186 | assert isinstance(log_prob, dict)
187 | assert type(log_prob['log_prob'])==torch.Tensor
188 | assert log_prob['log_prob'].numel()==1
189 |
190 | log_prob['log_prob'].detach_()
191 |
192 |
193 | self.accepts.append(accept)
194 | self.running_accepts.update(1 * accept)
195 |
196 | if accept:
197 | self.state_dicts.append(params_state_dict)
198 | self.log_probs.append(copy.deepcopy(log_prob))
199 | self.last_accepted_idx = len(self.state_dicts)-1
200 | for key, value in log_prob.items():
201 | self.running_avgs[key].update(value.item())
202 |
203 | elif not accept:
204 | self.state_dicts.append(False)
205 | self.log_probs.append(False)
206 |
207 | class Sampler_Chain:
208 |
209 | def __init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune):
210 |
211 | self.probmodel = probmodel
212 | self.chain = Chain(probmodel=self.probmodel)
213 |
214 | self.step_size = step_size
215 | self.num_steps = num_steps
216 | self.burn_in = burn_in
217 |
218 | self.pretrain = pretrain
219 | self.tune = tune
220 |
221 | def propose(self):
222 | raise NotImplementedError
223 |
224 | def __repr__(self):
225 | raise NotImplementedError
226 |
227 | def tune_step_size(self):
228 |
229 | tune_interval_length = 100
230 | num_tune_intervals = int(self.burn_in // tune_interval_length)
231 |
232 | verbose = True
233 |
234 | print(f'Tuning: Init Step Size: {self.optim.param_groups[0]["step_size"]:.5f}')
235 |
236 | self.probmodel.reset_parameters()
237 | tune_chain = Chain(probmodel=self.probmodel)
238 | tune_chain.running_accepts.momentum = 0.5
239 |
240 | progress = tqdm(range(self.burn_in))
241 | for tune_step in progress:
242 |
243 |
244 |
245 | sample_log_prob, sample = self.propose()
246 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
247 | tune_chain += (self.probmodel, sample_log_prob, accept)
248 |
249 | # if tune_step < self.burn_in and tune_step % tune_interval_length == 0 and tune_step > 0:
250 | if tune_step > 1:
251 | # self.optim.dual_average_tune(tune_chain, np.exp(log_ratio.item()))
252 | self.optim.dual_average_tune(tune_chain.accepts[-tune_interval_length:], tune_step, np.exp(log_ratio.item()))
253 | # self.optim.tune(tune_chain.accepts[-tune_interval_length:])
254 |
255 | if not accept:
256 |
257 | if torch.isnan(sample_log_prob['log_prob']):
258 | print(self.chain.state)
259 | exit()
260 | self.probmodel.load_state_dict(self.chain.state['state_dict'])
261 |
262 | desc = f'Tuning: Accept: {tune_chain.running_accepts.avg:.2f}/{tune_chain.accept_ratio:.2f} StepSize: {self.optim.param_groups[0]["step_size"]:.5f}'
263 |
264 | progress.set_description(
265 | desc=desc)
266 |
267 |
268 |
269 | time.sleep(0.1) # for cleaner printing in the console
270 |
271 | def sample_chain(self):
272 |
273 | self.probmodel.reset_parameters()
274 |
275 | if self.pretrain:
276 | try:
277 | self.probmodel.pretrain()
278 | except:
279 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.')
280 |
281 | if self.tune:
282 | self.tune_step_size()
283 |
284 | # print(f"After Tuning Step Size: {self.optim.param_groups[0]['step_size']=}")
285 |
286 | self.chain = Chain(probmodel=self.probmodel)
287 |
288 | progress = tqdm(range(self.num_steps))
289 | for step in progress:
290 |
291 | proposal_log_prob, sample = self.propose()
292 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
293 | self.chain += (self.probmodel, proposal_log_prob, accept)
294 |
295 | if not accept:
296 |
297 | if torch.isnan(proposal_log_prob['log_prob']):
298 | print(self.chain.state)
299 | exit()
300 | self.probmodel.load_state_dict(self.chain.state['state_dict'])
301 |
302 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t'
303 | for key, running_avg in self.chain.running_avgs.items():
304 | desc += f' {key}: {running_avg.avg:.2f} '
305 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}'
306 | # desc +=f" Std: {F.softplus(self.probmodel.log_std.detach()).item():.3f}"
307 | progress.set_description(desc=desc)
308 |
309 | self.chain = self.chain[self.burn_in:]
310 |
311 | return self.chain
312 |
313 | class SGLD_Chain(Sampler_Chain):
314 |
315 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False):
316 |
317 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune)
318 |
319 | self.optim = SGLD_Optim(self.probmodel,
320 | step_size=step_size,
321 | prior_std=1.,
322 | addnoise=True)
323 |
324 | self.acceptance = SDE_Acceptance()
325 |
326 | def __repr__(self):
327 | return 'SGLD'
328 |
329 | @torch.enable_grad()
330 | def propose(self):
331 |
332 | self.optim.zero_grad()
333 | batch = next(self.probmodel.dataloader.__iter__())
334 | log_prob = self.probmodel.log_prob(*batch)
335 | (-log_prob['log_prob']).backward()
336 | self.optim.step()
337 |
338 | return log_prob, self.probmodel
339 |
340 | class MALA_Chain(Sampler_Chain):
341 |
342 | def __init__(self, probmodel, step_size=0.1, num_steps=2000, burn_in=100, pretrain=False, tune=False, num_chain=0):
343 |
344 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune)
345 |
346 | self.num_chain = num_chain
347 |
348 | self.optim = MALA_Optim(self.probmodel,
349 | step_size=step_size,
350 | prior_std=1.,
351 | addnoise=True)
352 |
353 | self.acceptance = MetropolisHastingsAcceptance()
354 | # self.acceptance = SDE_Acceptance()
355 |
356 | def __repr__(self):
357 | return 'MALA'
358 |
359 | @torch.enable_grad()
360 | def propose(self):
361 |
362 | self.optim.zero_grad()
363 | batch = next(self.probmodel.dataloader.__iter__())
364 | log_prob = self.probmodel.log_prob(*batch)
365 | (-log_prob['log_prob']).backward()
366 | self.optim.step()
367 |
368 | return log_prob, self.probmodel
369 |
370 | class HMC_Chain(Sampler_Chain):
371 |
372 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False,
373 | traj_length=20):
374 |
375 | # assert probmodel.log_prob().keys()[:3] == ['log_prob', 'data', ]
376 |
377 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune)
378 |
379 | self.traj_length = traj_length
380 |
381 | self.optim = HMC_Optim(self.probmodel,
382 | step_size=step_size,
383 | prior_std=1.)
384 |
385 | # self.acceptance = SDE_Acceptance()
386 | self.acceptance = MetropolisHastingsAcceptance()
387 |
388 | def __repr__(self):
389 | return 'HMC'
390 |
391 | def sample_chain(self):
392 |
393 | self.probmodel.reset_parameters()
394 |
395 | if self.pretrain:
396 | try:
397 | self.probmodel.pretrain()
398 | except:
399 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.')
400 |
401 | if self.tune: self.tune_step_size()
402 |
403 | self.chain = Chain(probmodel=self.probmodel)
404 |
405 | progress = tqdm(range(self.num_steps))
406 | for step in progress:
407 |
408 | _ = self.propose() # values are added directly to self.chain
409 |
410 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t'
411 | for key, running_avg in self.chain.running_avgs.items():
412 | desc += f' {key}: {running_avg.avg:.2f} '
413 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}'
414 | progress.set_description(desc=desc)
415 |
416 | self.chain = self.chain[self.burn_in:]
417 |
418 | return self.chain
419 |
420 | def propose(self):
421 | '''
422 | 1) sample momentum for each parameter
423 | 2) sample one minibatch for an entire trajectory
424 | 3) solve trajectory forward for self.traj_length steps
425 | '''
426 |
427 | hamiltonian_solver = ['euler', 'leapfrog'][0]
428 |
429 | self.optim.sample_momentum()
430 | batch = next(self.probmodel.dataloader.__iter__()) # samples one minibatch from dataloader
431 |
432 | def closure():
433 | '''
434 | Computes the gradients once for batch
435 | '''
436 | self.optim.zero_grad()
437 | log_prob = self.probmodel.log_prob(*batch)
438 | (-log_prob['log_prob']).backward()
439 | return log_prob
440 |
441 | if hamiltonian_solver=='leapfrog': log_prob = closure() # compute initial grads
442 |
443 | for traj_step in range(self.traj_length):
444 | if hamiltonian_solver=='euler':
445 | proposal_log_prob = closure()
446 | self.optim.step()
447 | elif hamiltonian_solver=='leapfrog':
448 | proposal_log_prob = self.optim.leapfrog_step(closure)
449 |
450 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
451 |
452 | if not accept:
453 | if torch.isnan(proposal_log_prob['log_prob']):
454 | print(f"{proposal_log_prob=}")
455 | print(self.chain.state)
456 | exit()
457 | self.probmodel.load_state_dict(self.chain.state['state_dict'])
458 |
459 | self.chain += (self.probmodel, proposal_log_prob, accept)
460 |
461 | class SGNHT_Chain(Sampler_Chain):
462 |
463 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False,
464 | traj_length=20):
465 |
466 | # assert probmodel.log_prob().keys()[:3] == ['log_prob', 'data', ]
467 |
468 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune)
469 |
470 | self.traj_length = traj_length
471 |
472 | self.optim = SGNHT_Optim(self.probmodel,
473 | step_size=step_size,
474 | prior_std=1.)
475 |
476 | # print(f"{self.optim.A=}")
477 | # print(f"{self.optim.num_params=}")
478 | # print(f"{self.optim.A=}")
479 | # exit()
480 |
481 | # self.acceptance = SDE_Acceptance()
482 | self.acceptance = MetropolisHastingsAcceptance()
483 |
484 | def __repr__(self):
485 | return 'SGNHT'
486 |
487 | def sample_chain(self):
488 |
489 | self.probmodel.reset_parameters()
490 |
491 | if self.pretrain:
492 | try:
493 | self.probmodel.pretrain()
494 | except:
495 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.')
496 |
497 | if self.tune: self.tune_step_size()
498 |
499 | self.chain = Chain(probmodel=self.probmodel)
500 | self.optim.sample_momentum()
501 | self.optim.sample_thermostat()
502 |
503 | progress = tqdm(range(self.num_steps))
504 | for step in progress:
505 |
506 | proposal_log_prob, sample = self.propose()
507 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
508 | self.chain += (self.probmodel, proposal_log_prob, accept)
509 |
510 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t'
511 | for key, running_avg in self.chain.running_avgs.items():
512 | desc += f' {key}: {running_avg.avg:.2f} '
513 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}'
514 | progress.set_description(desc=desc)
515 |
516 | self.chain = self.chain[self.burn_in:]
517 |
518 | return self.chain
519 |
520 | def propose(self):
521 | '''
522 | 1) sample momentum for each parameter
523 | 2) sample one minibatch for an entire trajectory
524 | 3) solve trajectory forward for self.traj_length steps
525 | '''
526 |
527 | hamiltonian_solver = ['euler', 'leapfrog'][0]
528 |
529 | # self.optim.sample_momentum()
530 | # self.optim.sample_thermostat()
531 | batch = next(self.probmodel.dataloader.__iter__()) # samples one minibatch from dataloader
532 |
533 | self.optim.zero_grad()
534 | proposal_log_prob = self.probmodel.log_prob(*batch)
535 | (-proposal_log_prob['log_prob']).backward()
536 | self.optim.step()
537 |
538 | # def closure():
539 | # '''
540 | # Computes the gradients once for batch
541 | # '''
542 | # self.optim.zero_grad()
543 | # log_prob = self.probmodel.log_prob(*batch)
544 | # (-log_prob['log_prob']).backward()
545 | # return log_prob
546 | #
547 | # if hamiltonian_solver=='leapfrog': log_prob = closure() # compute initial grads
548 | #
549 | # for traj_step in range(self.traj_length):
550 | # if hamiltonian_solver=='euler':
551 | # proposal_log_prob = closure()
552 | # self.optim.step()
553 | # elif hamiltonian_solver=='leapfrog':
554 | # proposal_log_prob = self.optim.leapfrog_step(closure)
555 | #
556 | # accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
557 | #
558 | # if not accept:
559 | # if torch.isnan(proposal_log_prob['log_prob']):
560 | # print(f"{proposal_log_prob=}")
561 | # print(self.chain.state)
562 | # exit()
563 | # self.probmodel.load_state_dict(self.chain.state['state_dict'])
564 |
565 | return proposal_log_prob, self.probmodel
566 |
567 |
--------------------------------------------------------------------------------
/src/MCMC_Optim.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from tqdm import tqdm
3 | from collections import MutableSequence
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 |
7 | import numpy as np
8 |
9 | import torch
10 | from torch.optim import Optimizer, SGD
11 | from torch.distributions.distribution import Distribution
12 | from torch.distributions import MultivariateNormal, Normal
13 | from torch.nn import Module
14 |
15 | class MCMC_Optim:
16 |
17 | def __init__(self):
18 |
19 | self.tune_params = {'delta': 0.65,
20 | 't0': 10,
21 | 'gamma': .05,
22 | 'kappa': .75,
23 | # 'mu': np.log(self.param_groups[0]["step_size"]),
24 | 'mu': 0.,
25 | 'H': 0,
26 | 'log_eps': 1.}
27 |
28 | # print(f"@MCMC_Optim {self.tune_params=}")
29 | # exit()
30 |
31 | def tune(self, accepts):
32 |
33 | '''
34 | PyMC:
35 | # Switch statement
36 | if acceptance < 0.001: 0.1
37 | elif acceptance < 0.05: 0.5
38 | elif acceptance < 0.2: 0.9
39 | elif acceptance > 0.95: 10
40 | elif acceptance > 0.75: 2
41 | elif acceptance > 0.5: 1.1
42 | '''
43 |
44 | avg_acc = sum(accepts)/len(accepts)
45 |
46 | '''
47 | Switch statement: the first condition that is met exits the switch statement
48 | '''
49 | if avg_acc < 0.001:
50 | scale = 0.1
51 | elif avg_acc < 0.05:
52 | scale = 0.5
53 | elif avg_acc < 0.20:
54 | # PyMC: 0.9
55 | scale = 0.5
56 |
57 | elif avg_acc > 0.99:
58 | scale = 10.
59 | elif avg_acc > 0.75:
60 | scale = 2.
61 | elif avg_acc > 0.5:
62 | # PyMC: 1.1
63 | scale = 1.1
64 | else:
65 | scale = 0.9
66 |
67 | for group in self.param_groups:
68 |
69 | group['step_size']*=scale
70 | # print(f'{avg_acc=:.3f} & {scale=} -> {group["lr"]=:.3f}')
71 |
72 | def dual_average_tune(self, accepts, t, alpha):
73 | '''
74 | NUTS Sampler p.17 : Algorithm 5
75 |
76 | mu = log(10 * initial_step_size)
77 |
78 | H_m : running difference between target acceptance rate and current acceptance rate
79 | delta : target acceptance rate
80 | alpha : (singular) current acceptance rate
81 |
82 | log_eps = mu - t**0.5 / gamma H_m
83 | running_log_eps = t**(-kappa) log_eps + (1 - t**(-kappa)) running_log_eps
84 | '''
85 |
86 | # accept_ratio = sum(accepts)/len(accepts)
87 | assert 0 KFAC -> Optimization -> PHD
25 |
26 | import scipy
27 | import scipy as sp
28 | from scipy.io import loadmat as sp_loadmat
29 | import copy
30 |
31 | class ProbModel(torch.nn.Module):
32 |
33 | '''
34 | ProbModel:
35 |
36 | '''
37 |
38 | def __init__(self, dataloader):
39 | super().__init__()
40 | assert isinstance(dataloader, torch.utils.data.DataLoader)
41 | self.dataloader = dataloader
42 |
43 |
44 | def log_prob(self):
45 | '''
46 | If minibatches have to be sampled due to memory constraints,
47 | a standard PyTorch dataloader can be used.
48 | "Infinite minibatch sampling" can be achieved by calling:
49 | data, target = next(dataloader.__iter__())
50 | next(Iterable.__iter__()) calls a single mini-batch sampling step
51 | But since it's not in a loop, we can call it add infinum
52 | '''
53 | raise NotImplementedError
54 |
55 | def sample_minibatch(self):
56 | '''
57 | Idea:
58 | Hybrid Monte Carlo Samplers require a constant tuple (data, target) to compute trajectories
59 | '''
60 | raise NotImplementedError
61 |
62 | def reset_parameters(self):
63 | raise NotImplementedError
64 |
65 | def predict(self, chain):
66 | raise NotImplementedError
67 |
68 | def pretrain(self):
69 |
70 | pass
--------------------------------------------------------------------------------
/src/MCMC_Sampler.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys, copy, time
3 | from tqdm import tqdm
4 | from collections import MutableSequence, OrderedDict
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 |
10 | from pytorch_MCMC.src.MCMC_Utils import posterior_dist
11 | from pytorch_MCMC.src.MCMC_Optim import SGLD_Optim, MetropolisHastings_Optim
12 | from pytorch_MCMC.src.MCMC_Acceptance import MetropolisHastingsAcceptance
13 | from pytorch_MCMC.src.MCMC_Chain import Chain, SGLD_Chain, MALA_Chain, HMC_Chain, SGNHT_Chain
14 | from pytorch_MCMC.src.MCMC_Acceptance import MetropolisHastingsAcceptance, SDE_Acceptance
15 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel
16 |
17 | from joblib import Parallel, delayed
18 | import concurrent.futures
19 |
20 |
21 | import torch
22 | from torch.nn import Parameter
23 | from torch.optim import SGD
24 | from torch.optim.lr_scheduler import StepLR
25 | from torch.distributions.distribution import Distribution
26 | from torch.distributions import MultivariateNormal, Normal
27 | from torch.nn import Module
28 |
29 |
30 | class Sampler:
31 |
32 | def __init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune):
33 |
34 | self.probmodel = probmodel
35 | self.chain = None
36 | self.num_chains = num_chains
37 |
38 | self.step_size = step_size
39 | self.num_steps = num_steps
40 | self.burn_in = burn_in
41 |
42 | self.pretrain = pretrain
43 | self.tune = tune
44 |
45 | test_log_prob = self.probmodel.log_prob(*next(self.probmodel.dataloader.__iter__()))
46 | assert type(test_log_prob) == dict
47 | assert list(test_log_prob.keys())[0]=='log_prob'
48 |
49 | def sample_chains(self):
50 | raise NotImplementedError
51 |
52 | def __str__(self):
53 | raise NotImplementedError
54 |
55 | def multiprocessing_test(self, wait_time):
56 |
57 | time.sleep(wait_time)
58 | print(f'Done after {wait_time=} seconds')
59 |
60 | def sample_independent_chain(self):
61 |
62 | probmodel = copy.deepcopy(self.probmodel)
63 | probmodel.reset_parameters()
64 |
65 | if self.pretrain:
66 | probmodel.pretrain()
67 |
68 | optim = SGLD_Optim(probmodel, step_size=self.step_size, prior_std=0., addnoise=True)
69 | chain = Chain(probmodel=probmodel)
70 |
71 | progress = tqdm(range(self.num_steps))
72 | for step in progress:
73 |
74 | sample_log_prob, sample = self.propose(probmodel, optim)
75 |
76 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], chain.state['log_prob'])
77 |
78 | chain += (probmodel, sample_log_prob, accept, step)
79 |
80 | if not accept:
81 | probmodel.load_state_dict(chain.state['state'])
82 |
83 | desc = f'{str(self)}: Accept: {chain.running_accepts.avg:.2f}/{chain.accept_ratio:.2f} \t'
84 | for key, running_avg in chain.running_avgs.items():
85 | # print(f'{key}: {running_avg.avg=}')
86 | desc += f' {key}: {running_avg.avg:.2f} '
87 | desc += f'StepSize: {optim.param_groups[0]["lr"]:.3f}'
88 | progress.set_description(desc=desc)
89 |
90 | # print(list(chain.samples[-1].values())[-1][0])
91 | '''
92 | Remove Burn_in
93 | '''
94 | assert len(chain.accepted_steps) > self.burn_in, f'{len(chain.accepted_steps)=} <= {self.burn_in=}'
95 | chain.accepted_steps = chain.accepted_steps[self.burn_in:]
96 |
97 | return chain
98 |
99 | def sample_chain(self, step_size=None):
100 |
101 | if self.pretrain:
102 | self.probmodel.pretrain()
103 |
104 | self.optim = SGLD_Optim(self.probmodel,
105 | step_size=step_size,
106 | prior_std=0.,
107 | addnoise=True)
108 |
109 | if self.tune: self.tune_step_size()
110 |
111 | self.chain = Chain(probmodel=self.probmodel)
112 |
113 | progress = tqdm(range(self.num_steps))
114 | for step in progress:
115 |
116 | sample_log_prob, sample = self.propose()
117 |
118 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], self.chain.state['log_prob'])
119 |
120 | self.chain += (self.probmodel, sample_log_prob, accept, step)
121 |
122 | if not accept:
123 | self.probmodel.load_state_dict(self.chain.state['state'])
124 |
125 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t'
126 | for key, running_avg in self.chain.running_avgs.items():
127 | # print(f'{key}: {running_avg.avg=}')
128 | desc+= f' {key}: {running_avg.avg:.2f} '
129 | desc += f'StepSize: {self.optim.param_groups[0]["lr"]:.3f}'
130 | progress.set_description(desc=desc)
131 |
132 | print(len(self.chain))
133 | '''
134 | Remove Burn_in
135 | '''
136 | assert len(self.chain.accepted_steps)>self.burn_in, f'{len(self.chain.accepted_steps)=} <= {self.burn_in=}'
137 | self.chain.accepted_steps = self.chain.accepted_steps[self.burn_in:]
138 |
139 | def posterior_dist(self, param=None, verbose=False, plot=True):
140 |
141 | if len(self.probmodel.state_dict())==1:
142 | '''
143 | We're sampling from a predefined distribution like a GMM and simulating a particle
144 | '''
145 | post = []
146 |
147 | accepted_models = self.chain.samples
148 | for model_state_dict in accepted_models:
149 | post.append(list(model_state_dict.values())[0])
150 |
151 | post = torch.cat(post, dim=0)
152 |
153 | if plot:
154 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]),
155 | density=True)
156 | plt.colorbar(hist2d[3])
157 | plt.show()
158 |
159 | elif len(self.probmodel.state_dict()) > 1:
160 | '''
161 | There is more than one parameter in the model
162 | '''
163 |
164 | param_names = list(self.probmodel.state_dict().keys())
165 | accepted_models = self.chain.samples
166 |
167 | for param_name in param_names:
168 |
169 | post = []
170 |
171 | for model_state_dict in accepted_models:
172 |
173 | post.append(model_state_dict[param_name])
174 |
175 | post = torch.cat(post)
176 | # print(post)
177 |
178 | if plot:
179 | plt.hist(x=post, bins=50,
180 | range=np.array([-3, 3]),
181 | density=True,
182 | alpha=0.5)
183 | plt.title(param_name)
184 | plt.show()
185 |
186 | def trace(self, param=None, verbose=False, plot=True):
187 |
188 | if len(self.probmodel.state_dict()) >= 1:
189 | '''
190 | There is more than one parameter in the model
191 | '''
192 |
193 | param_names = list(self.probmodel.state_dict().keys())
194 | accepted_models = [self.chain.samples[idx] for idx in self.chain.accepted_steps]
195 |
196 | for param_name in param_names:
197 |
198 | post = []
199 |
200 | for model_state_dict in accepted_models:
201 | post.append(model_state_dict[param_name])
202 |
203 | # print(post)
204 |
205 | post = torch.cat(post)
206 | # print(post)
207 |
208 | if plot:
209 | plt.plot(np.arange(len(accepted_models)), post)
210 | plt.title(param_name)
211 |
212 | plt.show()
213 |
214 | class MetropolisHastings_Sampler(Sampler):
215 |
216 | def __init__(self, probmodel, step_size=1., num_steps=10000, burn_in=100, pretrain=False, tune=True):
217 | '''
218 |
219 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample
220 | :param step_length:
221 | :param num_steps:
222 | :param burn_in:
223 | '''
224 |
225 | assert isinstance(probmodel, ProbModel)
226 | super().__init__(probmodel, step_size, num_steps, burn_in, pretrain, tune)
227 |
228 | self.optim = MetropolisHastings_Optim(self.probmodel,
229 | step_length=step_size)
230 |
231 | self.acceptance = MetropolisHastingsAcceptance()
232 |
233 | def __str__(self):
234 | return 'MH'
235 |
236 | @torch.no_grad()
237 | def propose(self):
238 | self.optim.step()
239 | log_prob = self.probmodel.log_prob()
240 |
241 | return log_prob, self.probmodel
242 |
243 | class SGLD_Sampler(Sampler):
244 |
245 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True):
246 | '''
247 |
248 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample
249 | :param step_length:
250 | :param num_steps:
251 | :param burn_in:
252 | '''
253 |
254 | assert isinstance(probmodel, ProbModel)
255 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune)
256 |
257 | def sample_chains(self):
258 |
259 | if self.num_chains > 1:
260 | self.parallel_chains = [SGLD_Chain(copy.deepcopy(self.probmodel),
261 | step_size=self.step_size,
262 | num_steps=self.num_steps,
263 | burn_in=self.burn_in,
264 | pretrain=self.pretrain,
265 | tune=False)
266 | for _ in range(self.num_chains)]
267 |
268 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains)
269 |
270 | elif self.num_chains == 1:
271 | chain = SGLD_Chain(copy.deepcopy(self.probmodel),
272 | step_size=self.step_size,
273 | num_steps=self.num_steps,
274 | burn_in=self.burn_in,
275 | pretrain=self.pretrain,
276 | tune=False)
277 | chains = [chain.sample_chain()]
278 |
279 | self.chain = Chain(probmodel=self.probmodel)
280 |
281 | for chain in chains:
282 | self.chain += chain
283 |
284 | return chains
285 |
286 | def __str__(self):
287 | return 'SGLD'
288 |
289 | class MALA_Sampler(Sampler):
290 |
291 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=4, burn_in=500, pretrain=True, tune=True):
292 | '''
293 |
294 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample
295 | :param step_length:
296 | :param num_steps:
297 | :param burn_in:
298 | '''
299 |
300 | assert isinstance(probmodel, ProbModel)
301 | super().__init__(probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune)
302 |
303 | def sample_chains(self):
304 |
305 | if self.num_chains>1:
306 | self.parallel_chains = [MALA_Chain(copy.deepcopy(self.probmodel),
307 | step_size=self.step_size,
308 | num_steps=self.num_steps,
309 | burn_in=self.burn_in,
310 | pretrain=self.pretrain,
311 | tune=self.tune,
312 | num_chain=i)
313 | for i in range(self.num_chains)]
314 |
315 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains)
316 |
317 | elif self.num_chains == 1:
318 | chain = MALA_Chain(copy.deepcopy(self.probmodel),
319 | step_size=self.step_size,
320 | num_steps=self.num_steps,
321 | burn_in=self.burn_in,
322 | pretrain=self.pretrain,
323 | tune=self.tune,
324 | num_chain=0)
325 | chains = [chain.sample_chain()]
326 |
327 | self.chain = Chain(probmodel=self.probmodel)
328 |
329 |
330 | for chain in chains:
331 | self.chain += chain
332 |
333 | return chains
334 |
335 | def __str__(self):
336 | return 'SGLD'
337 |
338 | class HMC_Sampler(Sampler):
339 |
340 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True,
341 | traj_length=21):
342 | '''
343 |
344 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample
345 | :param step_length:
346 | :param num_steps:
347 | :param burn_in:
348 | '''
349 |
350 | assert isinstance(probmodel, ProbModel)
351 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune)
352 |
353 | self.traj_length = traj_length
354 |
355 | def __str__(self):
356 | return 'HMC'
357 |
358 | def sample_chains(self):
359 |
360 | if self.num_chains > 1:
361 | self.parallel_chains = [HMC_Chain(copy.deepcopy(self.probmodel),
362 | step_size=self.step_size,
363 | num_steps=self.num_steps,
364 | burn_in=self.burn_in,
365 | pretrain=self.pretrain,
366 | tune=self.tune)
367 | for i in range(self.num_chains)]
368 |
369 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains)
370 |
371 | elif self.num_chains == 1:
372 | chain = HMC_Chain(copy.deepcopy(self.probmodel),
373 | step_size=self.step_size,
374 | num_steps=self.num_steps,
375 | burn_in=self.burn_in,
376 | pretrain=self.pretrain,
377 | tune=self.tune)
378 | chains = [chain.sample_chain()]
379 |
380 | self.chain = Chain(probmodel=self.probmodel) # the aggregating chain
381 |
382 | for chain in chains:
383 | self.chain += chain
384 |
385 | return chains
386 |
387 | class SGNHT_Sampler(Sampler):
388 |
389 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True,
390 | traj_length=21):
391 | '''
392 |
393 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample
394 | :param step_length:
395 | :param num_steps:
396 | :param burn_in:
397 | '''
398 |
399 | assert isinstance(probmodel, ProbModel)
400 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune)
401 |
402 | self.traj_length = traj_length
403 |
404 | def __str__(self):
405 | return 'SGNHT'
406 |
407 | def sample_chains(self):
408 |
409 | if self.num_chains > 1:
410 | self.parallel_chains = [SGNHT_Chain(copy.deepcopy(self.probmodel),
411 | step_size=self.step_size,
412 | num_steps=self.num_steps,
413 | burn_in=self.burn_in,
414 | pretrain=self.pretrain,
415 | tune=self.tune)
416 | for i in range(self.num_chains)]
417 |
418 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains)
419 |
420 | elif self.num_chains == 1:
421 | chain = SGNHT_Chain(copy.deepcopy(self.probmodel),
422 | step_size=self.step_size,
423 | num_steps=self.num_steps,
424 | burn_in=self.burn_in,
425 | pretrain=self.pretrain,
426 | tune=self.tune)
427 | chains = [chain.sample_chain()]
428 |
429 | self.chain = Chain(probmodel=self.probmodel) # the aggregating chain
430 |
431 | for chain in chains:
432 | self.chain += chain
433 |
434 | return chains
--------------------------------------------------------------------------------
/src/MCMC_Utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 |
5 | def posterior_dist(chain, param=None, verbose=False, plot=True):
6 |
7 | if len(chain.samples[0]) == 1:
8 | '''
9 | We're sampling from a predefined distribution like a GMM and simulating a particle
10 | '''
11 | post = []
12 |
13 | # print(list(self.probmodel.state_dict().values())[0])
14 | # exit()
15 |
16 | # accepted_models = [chain.samples[idx] for idx in chain.accepted_steps]
17 | for model_state_dict in chain.samples:
18 | post.append(list(model_state_dict.values())[0])
19 |
20 | post = torch.cat(post, dim=0)
21 |
22 | if plot:
23 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]),
24 | density=True)
25 | plt.colorbar(hist2d[3])
26 | plt.show()
27 |
28 | elif len(chain.samples[0]) > 1:
29 | '''
30 | There is more than one parameter in the model
31 | '''
32 |
33 | param_names = list(chain.samples[0].keys())
34 | accepted_models = [chain.samples[idx] for idx in chain.accepted_idxs]
35 |
36 | for param_name in param_names:
37 |
38 | post = []
39 |
40 | for model_state_dict in accepted_models:
41 | post.append(model_state_dict[param_name])
42 |
43 | post = torch.cat(post)
44 | # print(post)
45 |
46 | if plot:
47 | plt.hist(x=post, bins=50,
48 | range=np.array([-3, 3]),
49 | density=True,
50 | alpha=0.5)
51 | plt.title(param_name)
52 | plt.show()
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Test Test Test
--------------------------------------------------------------------------------