├── .DS_Store ├── .gitignore ├── .vscode └── settings.json ├── LICENSE.md ├── README.md ├── data ├── MCMC_SyntheticData.py └── plots │ ├── EpistemicAleatoricUncertainty.png │ ├── GMM_HMC1.gif │ ├── HMC_Sampler.gif │ ├── HMC_Sampler2.gif │ ├── HMC_Sampler3.gif │ └── HMC_Sampler4.gif ├── experiments ├── .DS_Store ├── AIS.py ├── DataViz.py ├── MCMC_HMCTest.py ├── MCMC_Test.py ├── Testing.py └── functional_pytorch ├── models └── MCMC_Models.py └── src ├── MCMC_Acceptance.py ├── MCMC_Chain.py ├── MCMC_Optim.py ├── MCMC_ProbModel.py ├── MCMC_Sampler.py ├── MCMC_Utils.py └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # user files 7 | *.chain 8 | *.ipynb 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "{}" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright {yyyy} {name of copyright owner} 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # torch-MC^2 (torch-MCMC) 3 | HMC on 3 layer NN | HMC on GMM 4 | :-------------------------------------------:|:------------------------------: 5 | ![](data/plots/HMC_Sampler3.gif) | ![alt-text-2](data/plots/GMM_HMC1.gif "GMM") 6 | 7 | 8 | 9 | This package implements a series of MCMC sampling algorithms in PyTorch in a modular way: 10 | 11 | - Metropolis Hastings 12 | - Stochastic Gradient Langevin Dynamics 13 | - (Stochastic) Hamiltonian Monte Carlo 14 | - Stochastic Gradient Nose-Hoover Thermostat 15 | - SG Riemann Hamiltonian Monte Carlo (coming ...) 16 | 17 | The focus lies on four core ingredients to MCMC with corresponding routines: 18 | 19 | - `MCMC.src.MCMC_ProbModel` : Probabilistic wrapper around your model providing a uniform interface 20 | - `MCMC.src.MCMC_Chain` : Markov Chain for storing samples√ 21 | - `MCMC.src.MCMC_Optim` : MCMC_Optim parent class for handling parameters 22 | - `MCMC.src.MCMC_Sampler`: Sampler that binds it all together 23 | 24 | 25 | These classes and functions are constructed along the structure of the core PyTorch framework. 26 | Especially the gradient samplers are designed around PyTorch's `optim` class to handle all things related to parameters. 27 | 28 | # ProbModel 29 | 30 | The wrapper `MCMC.src.MCMC_ProbModel` defines are small set of functions which are required in order to allow the `Sampler_Chain` to interact with it and evaluate the relevant quantities. 31 | 32 | Any parameter in the model that we wish to sample from has to be designated a `torch.nn.Parameter()`. 33 | This could be as simple as a single particle that we move around a 2-D distribution or a full neural network. 34 | 35 | It has four methods which have to be defined by the user: 36 | 37 | `MCMC_ProbModel.log_prob()`: 38 | 39 | Evaluates the log_probability of the likelihood of the model. 40 | 41 | It returns a dictionary which the first entry being the log-probability, i.e. `{"log_prob": -20.03}`. 42 | 43 | Moreover, additional evaluation metrics can be added to the dictionary as well, i.e. `{"log_prob": -20.03, "Accuracy": 0.58}`. 44 | The sampler will inspect the dictionary returned by the evaluation of the model and will create corresponding running averages of the used metrics. 45 | 46 | `MCMC_ProbModel.reset_parameters()`: 47 | 48 | Any value that we want to sample has to be declared as a `torch.nn.Parameter()` such that the `MCMC_Optims` can track the values in the background. 49 | `reset_parameters()` is mainly used to reinitialize the model. 50 | 51 | `MCMC_ProbModel.pretrain()`: 52 | 53 | In cases where we want a good initial guess to start our markov chain, we can implement a pretrain method which will optimize the parameters in some user defined manner. 54 | If `pretrain=True` is passed during initialization of the sampler, it will assume that the `MCMC_ProbModel.pretrain()` is implemented. 55 | 56 | In order to allow more sophisticated dynamic samplers such as `HMC_Sampler` to properly sample mini-batches, the probmodel should be initialized with a dataloader that takes care of sampling minibatches. 57 | That way, dynamic samplers can simple access `probmodel.dataloader`. 58 | 59 | # Chain 60 | 61 | This is just a convenience container that stores the sampled values and can be queried for specific values to determine the progress of the sampling chain. 62 | The samples of the parameters of the model are stored as a list `chain.samples` where each entry is PyTorch's very own `state_dict()`. 63 | 64 | After the sampler is finished the samples of the model can be accessed through the property `chain.samples` which returns a list of `state_dict()`'s that can be loaded into the model. 65 | 66 | An example: 67 | 68 | ``` 69 | 70 | for sample_state_dict in chain.samples: 71 | 72 | self.load_state_dict(sample_state_dict) 73 | 74 | ... do something like ensemble prediction ... 75 | ``` 76 | 77 | `MCMC_Chain` is implemented as a mutable sequence and allows the concatination of tuples `(probmodel/torch.state_dict, log_probs: dict/odict, accept: bool)` and the concatenation of entire chains. 78 | 79 | # MCMC_Optim 80 | 81 | The `MCMC_Optim`'s inherit from PyTorch's very own `Optimizers` and make working with gradients just so significantely more pleasant. 82 | 83 | By calling `MCMC_Optim.step()` they propose a new a set of parameters for `ProbModel`, the `log_prob()` of which is evaluated by `MCMC_Sampler`. 84 | 85 | # MCMC_Sampler 86 | 87 | The core component that ties everything together. 88 | 89 | 1. `MCMC_Optim` does a `step()` and proposes new parameters for the `ProbModel`. 90 | 2. `MCMC_Sampler` evaluates the `log_prob()` of the `ProbModel` and determines the acceptance of the proposal. 91 | 3. If accepted, the `ProbModel` is passed to the `MCMC_Chain` to be saved 92 | 4. If not accepted, we play the game again with a new proposal. 93 | 94 | # What's the data structure underneath? 95 | 96 | Each sampler uses the following datastructure: 97 | 98 | ``` 99 | Sampler: 100 | - Sampler_Chain #1 101 | - ProbModel #1 102 | - Optim #1 103 | - Chain #1 104 | - Sampler_Chain #2 105 | - ProbModel #2 106 | - Optim #2 107 | - Chain #2 108 | - Sampler_Chain #3 109 | . 110 | . 111 | . 112 | . 113 | . 114 | . 115 | 116 | ``` 117 | 118 | By packaging the optimizers and probmodels directly into the chain, these chains can be run completely independently, possibly even on multi-GPU systems. 119 | 120 | # Final Note 121 | 122 | The are way more sophisticated sampling packages out there such as Pyro, Stan and PyMC. 123 | Yet all of these packages require implementing the models explicitely for these frameworks. 124 | This package aims at providing MCMC sampling for **native** PyTorch Models such that the infamous Anon Reviewer 2 can be satisfied who requests a MCMC benchmark of an experiment. 125 | 126 | # Final Final Note 127 | 128 | May your chains hang low 129 | 130 | https://www.youtube.com/watch?v=4SBN_ikibtg -------------------------------------------------------------------------------- /data/MCMC_SyntheticData.py: -------------------------------------------------------------------------------- 1 | import future, sys, os, datetime, argparse 2 | # print(os.path.dirname(sys.executable)) 3 | import torch 4 | import numpy as np 5 | import matplotlib 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | from matplotlib.lines import Line2D 9 | 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | from torch.nn import Module, Parameter 14 | from torch.nn import Linear, Tanh, ReLU 15 | import torch.nn.functional as F 16 | 17 | Tensor = torch.Tensor 18 | FloatTensor = torch.FloatTensor 19 | 20 | torch.set_printoptions(precision=4, sci_mode=False) 21 | np.set_printoptions(precision=4, suppress=True) 22 | 23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD 24 | 25 | cwd = os.path.abspath(os.getcwd()) 26 | os.chdir(cwd) 27 | 28 | params = argparse.ArgumentParser() 29 | params.add_argument('-xyz', type=str, default='test_xyz') 30 | 31 | params = params.parse_args() 32 | 33 | 34 | def generate_linear_regression_data(num_samples=100, m=1.0, b=-1.0, y_noise=1.0, x_noise=.01, plot=False): 35 | x = torch.linspace(-2, 2, num_samples).reshape(-1, 1) 36 | x += x_noise * torch.randn_like(x) 37 | y = m * x + b 38 | y += y_noise * torch.randn_like(y) 39 | 40 | if plot: 41 | plt.scatter(x, y) 42 | plt.show() 43 | 44 | return x, y 45 | 46 | 47 | def generate_multimodal_linear_regression(num_samples, y_noise=1, x_noise=1, plot=False): 48 | x1, y1 = generate_linear_regression_data(num_samples=num_samples // 10 * int(3), m=-1., b=0, y_noise=0.1, x_noise=0.1, plot=False) 49 | x2, y2 = generate_linear_regression_data(num_samples=num_samples // 10 * int(7), m=2, b=0, y_noise=0.1, x_noise=0.1, plot=False) 50 | 51 | x = torch.cat([x1, x2], dim=0).float() 52 | y = torch.cat([y1, y2], dim=0).float() 53 | 54 | if plot: 55 | plt.scatter(x, y, s=1) 56 | plt.show() 57 | 58 | return x, y 59 | 60 | 61 | def generate_nonstationary_data(num_samples=1000, y_constant_noise_std=0.1, y_nonstationary_noise_std=1., plot=False): 62 | x = np.linspace(-0.35, 0.45, num_samples) 63 | x_noise = np.random.normal(0., 0.01, size=x.shape) 64 | 65 | constant_noise = np.random.normal(0, y_constant_noise_std, size=x.shape) 66 | std = np.linspace(0, y_nonstationary_noise_std, num_samples) # * _y_noise_std 67 | non_stationary_noise = np.random.normal(loc=0, scale=std) 68 | 69 | y = x + 0.3 * np.sin(2 * np.pi * (x + x_noise)) + 0.3 * np.sin(4 * np.pi * (x + x_noise)) + non_stationary_noise + constant_noise 70 | 71 | x = torch.from_numpy(x).reshape(-1, 1).float() 72 | y = torch.from_numpy(y).reshape(-1, 1).float() 73 | 74 | x = (x - x.mean(dim=0))/(x.std(dim=0)+1e-3) 75 | y = (y - y.mean(dim=0))/(y.std(dim=0)+1e-3) 76 | 77 | if plot: 78 | plt.scatter(x, y) 79 | plt.show() 80 | 81 | return x, y -------------------------------------------------------------------------------- /data/plots/EpistemicAleatoricUncertainty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/EpistemicAleatoricUncertainty.png -------------------------------------------------------------------------------- /data/plots/GMM_HMC1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/GMM_HMC1.gif -------------------------------------------------------------------------------- /data/plots/HMC_Sampler.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler.gif -------------------------------------------------------------------------------- /data/plots/HMC_Sampler2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler2.gif -------------------------------------------------------------------------------- /data/plots/HMC_Sampler3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler3.gif -------------------------------------------------------------------------------- /data/plots/HMC_Sampler4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/data/plots/HMC_Sampler4.gif -------------------------------------------------------------------------------- /experiments/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/pytorch_MCMC/c62fdb0af6e173f9292169b99230357fc36c9884/experiments/.DS_Store -------------------------------------------------------------------------------- /experiments/AIS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from matplotlib import cm 5 | from matplotlib.colors import to_rgba 6 | from tqdm import tqdm 7 | from scipy.integrate import quad 8 | import numpy as np 9 | 10 | 11 | class Normal1D: 12 | def __init__(self, mean, std): 13 | self.mean = mean 14 | self.std = std 15 | 16 | def sample(self, num_samples=1): 17 | return torch.normal(self.mean, self.std, size=(num_samples,)) 18 | 19 | def prob(self, x): 20 | var = self.std**2 21 | return ( 22 | 1 23 | / (self.std * torch.sqrt(2 * torch.tensor(torch.pi))) 24 | * torch.exp(-0.5 * (x - self.mean) ** 2 / var) 25 | ) 26 | 27 | def energy(self, x): 28 | return 0.5 * (x - self.mean) ** 2 / self.std**2 29 | 30 | @property 31 | def Z(self): 32 | """Partition function for the normal distribution Z = (sqrt(2 * pi * std**2)""" 33 | return self.std * torch.sqrt(2 * torch.tensor(torch.pi)) 34 | 35 | def log_prob(self, x): 36 | var = self.std**2 37 | return ( 38 | -0.5 * torch.log(2 * torch.tensor(torch.pi) * var) 39 | - 0.5 * (x - self.mean) ** 2 / var 40 | ) 41 | 42 | 43 | class GaussianMixture1D: 44 | def __init__(self, means, stds, weights): 45 | self.means = torch.tensor(means, dtype=torch.float32) 46 | self.stds = torch.tensor(stds, dtype=torch.float32) 47 | self.weights = torch.tensor(weights, dtype=torch.float32) 48 | self.weights = self.weights / self.weights.sum() # Normalize weights 49 | 50 | def sample(self, num_samples=1): 51 | component = torch.multinomial(self.weights, num_samples, replacement=True) 52 | samples = torch.normal(self.means[component], self.stds[component]) 53 | return samples 54 | 55 | def prob(self, x): 56 | x = x.unsqueeze(-1) # Shape (N, 1) 57 | probs = ( 58 | 1 59 | / (self.stds * torch.sqrt(2 * torch.tensor(torch.pi))) 60 | * torch.exp(-0.5 * (x - self.means) ** 2 / (self.stds**2)) 61 | ) 62 | weighted_probs = probs * self.weights 63 | return weighted_probs.sum(dim=-1) 64 | 65 | def log_prob(self, x): 66 | x = x.unsqueeze(-1) 67 | log_probs = ( 68 | -0.5 * torch.log(2 * torch.tensor(torch.pi) * self.stds**2) 69 | - 0.5 * (x - self.means) ** 2 / (self.stds**2) 70 | + torch.log(self.weights) 71 | ) 72 | return torch.logsumexp(log_probs, dim=-1) 73 | 74 | def energy(self, x): 75 | # Negative log probability (up to constant) 76 | return -self.log_prob(x) + 1 77 | 78 | 79 | def sgl_sampler(energy_fn, x_init, lr=1e-2, n_steps=1000, noise_scale=1.0): 80 | """ 81 | Stochastic Gradient Langevin Dynamics (SGLD) sampler. 82 | Args: 83 | energy_fn: Callable, computes energy for input x (requires_grad=True). 84 | x_init: Initial sample (torch.tensor, requires_grad=False). 85 | lr: Learning rate (step size). 86 | n_steps: Number of sampling steps. 87 | noise_scale: Multiplier for injected noise (default 1.0). 88 | Returns: 89 | samples: Tensor of shape (n_steps,). 90 | """ 91 | x = x_init.clone().detach().requires_grad_(True) 92 | samples = [] 93 | pbar = tqdm(range(n_steps)) 94 | for _ in pbar: 95 | lr_t = lr * 0.5 * (1 + torch.cos(torch.tensor(_ / n_steps * torch.pi))) 96 | pbar.set_description(f"lr={lr_t:.5f}") 97 | x.requires_grad_(True) 98 | energy = energy_fn(x) 99 | grad = torch.autograd.grad(energy.sum(), x)[0] 100 | noise = torch.randn_like(x) * lr**0.5 101 | lr_t = lr / 2 * (1 + torch.cos(torch.tensor(_ / n_steps * torch.pi))) 102 | x = (x - lr_t * grad + noise).detach() 103 | samples.append(x.detach().clone()) 104 | return x 105 | 106 | 107 | def metropolishastings_sampler(energy_fn, x_init, n_steps=1000, proposal_std=0.5): 108 | """ 109 | Metropolis-Hastings sampler. 110 | Args: 111 | energy_fn: Callable, computes energy for input x (requires_grad=True). 112 | x_init: Initial sample (torch.tensor, requires_grad=False). 113 | n_steps: Number of sampling steps. 114 | proposal_std: Standard deviation for the proposal distribution. 115 | Returns: 116 | samples: Tensor of shape (n_steps,). 117 | """ 118 | x = x_init 119 | samples = [] 120 | # progress_bar = tqdm(total=n_steps, desc="MH") 121 | acceptance_ratio_ema, ema_weight = None, 0.9999 122 | pbar = tqdm(range(n_steps), desc="MH") 123 | for _ in pbar: 124 | x_new = x + torch.randn_like(x) * proposal_std 125 | log_acceptance_ratio = energy_fn(x) - energy_fn(x_new) 126 | acceptance_ratio = torch.exp(log_acceptance_ratio) 127 | accept = torch.rand_like(acceptance_ratio) < acceptance_ratio 128 | 129 | # Debiased running average (corrects for initial bias) 130 | if acceptance_ratio_ema is None: 131 | acceptance_ratio_ema = accept.int().float().mean() 132 | ema_correction = 1.0 133 | else: 134 | acceptance_ratio_ema = ( 135 | acceptance_ratio_ema * ema_weight 136 | + accept.int().float().mean() * (1 - ema_weight) 137 | ) 138 | ema_correction = 1 - ema_weight ** (_ + 1) 139 | debiased_acceptance = acceptance_ratio_ema / ema_correction 140 | pbar.set_postfix({"Accept": float(acceptance_ratio_ema.mean())}) 141 | x = torch.where(accept, x_new, x) 142 | samples.append(x.detach().clone()) 143 | return x 144 | 145 | 146 | def compute_partition_function_1d(energy_fn, x_min, x_max): 147 | integrand = lambda x: np.exp(-energy_fn(torch.tensor(x)).item()) 148 | Z, _ = quad(integrand, x_min, x_max) 149 | return Z 150 | 151 | 152 | p0 = Normal1D(torch.tensor(0.0), torch.tensor(5.0)) 153 | 154 | N = 10_000 155 | samples = p0.sample(N) 156 | prob = p0.prob(samples) 157 | log_prob = p0.log_prob(samples) 158 | Z = p0.Z 159 | 160 | p1 = Normal1D(torch.tensor(0.0), torch.tensor(1.0)) 161 | print(p0.Z, p1.Z) 162 | 163 | log_w = -p1.energy(samples) - p0.log_prob(samples) 164 | logZ1 = torch.logsumexp(log_w, dim=0) - torch.log(torch.tensor(N, dtype=torch.float32)) 165 | print(f"Estimated Z1: {torch.exp(logZ1)}") 166 | print(f"True Z1: {p1.Z}") 167 | 168 | samples = samples[samples < 3] 169 | samples = samples[samples > -3] 170 | logprob1 = -logZ1 - p1.energy(samples) 171 | prob1 = torch.exp(logprob1) 172 | 173 | plt.figure(figsize=(8, 4)) 174 | plt.hist( 175 | samples.numpy(), 176 | bins=50, 177 | weights=prob1.numpy(), 178 | density=True, 179 | alpha=0.6, 180 | label="Weighted Histogram", 181 | ) 182 | # sns.kdeplot(samples.numpy(), weights=prob1.numpy(), color='red', label='KDE') 183 | plt.title("Histogram of prob1 at sample locations") 184 | plt.xlabel("x") 185 | plt.ylabel("Probability Density") 186 | plt.legend() 187 | plt.show() 188 | 189 | gmm = GaussianMixture1D( 190 | means=[-3.0, 0.0, 3.0], 191 | stds=[0.5, 0.5, 0.5], 192 | weights=[2, 0.3, 0.1], 193 | ) 194 | x = torch.linspace(-5, 5, 200) 195 | 196 | 197 | plt.figure(figsize=(8, 4)) 198 | plt.plot( 199 | x.numpy(), 200 | p1.energy(x).numpy(), 201 | label="True Normal(0, 1) Energy", 202 | color="blue", 203 | ) 204 | plt.plot( 205 | x.numpy(), 206 | gmm.energy(x).numpy(), 207 | label="GMM Energy", 208 | color="orange", 209 | ) 210 | 211 | # Define colors for interpolation 212 | color1 = to_rgba("blue") 213 | color2 = to_rgba("orange") 214 | 215 | for t in [0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 1.0]: 216 | # Linear interpolation in color space 217 | interp_color = tuple((1 - t) * c1 + t * c2 for c1, c2 in zip(color1, color2)) 218 | energy_t = gmm.energy(x) ** t * p1.energy(x) ** (1 - t) 219 | plt.plot( 220 | x.numpy(), 221 | energy_t.numpy(), 222 | label=f"t={t:.2f}", 223 | linestyle="--", 224 | color=interp_color, 225 | ) 226 | 227 | plt.title("Energy Interpolation between Normal and GMM") 228 | plt.xlabel("x") 229 | plt.ylabel("Energy") 230 | plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 231 | plt.show() 232 | 233 | t = 0.0 234 | energy_t = lambda x: gmm.energy(x) ** t * p1.energy(x) ** (1 - t) 235 | energy_t = lambda x: p1.energy(x) 236 | samples = metropolishastings_sampler( 237 | energy_fn=energy_t, x_init=torch.randn(5_000), n_steps=2_000, proposal_std=0.1 238 | ) 239 | 240 | print(f"{samples.shape=}") 241 | # Count unique samples and their frequencies 242 | log_weights = -energy_t(samples) 243 | bins = torch.histc(samples, bins=50, min=samples.min(), max=samples.max()) 244 | bin_edges = torch.linspace(samples.min(), samples.max(), steps=101) 245 | bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) 246 | bin_probs = bins / bins.sum() 247 | Z_binned = (torch.exp(-energy_t(bin_centers)) * (bin_edges[1:] - bin_edges[:-1])).sum() 248 | print(f"Z_binned: {Z_binned}") 249 | 250 | 251 | # Weight by frequency of each unique sample 252 | # logZ_est = torch.logsumexp(log_weights, dim=0) - torch.log( 253 | # torch.tensor(len(samples), dtype=torch.float32) 254 | # ) 255 | 256 | Z_est = torch.exp(log_weights).sum() / len(samples) 257 | print( 258 | f"Estimated partition function Z_t: {Z_est.item()} vs {compute_partition_function_1d(energy_t, -4, 4)}" 259 | ) 260 | 261 | _ = plt.hist( 262 | samples.detach().numpy(), 263 | bins=100, 264 | density=True, 265 | alpha=0.6, 266 | label="SGLD Samples", 267 | ) 268 | # plt.plot(x.numpy(), gmm.prob(x).numpy(), label="GMM PDF", color="orange") 269 | # plt.plot(x.numpy(), energy_t(x).numpy(), label="Normal PDF", color="blue") 270 | -------------------------------------------------------------------------------- /experiments/DataViz.py: -------------------------------------------------------------------------------- 1 | import future, sys, os, datetime, argparse 2 | # print(os.path.dirname(sys.executable)) 3 | import torch 4 | import numpy as np 5 | import matplotlib 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | from matplotlib.lines import Line2D 9 | 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | from torch.nn import Module, Parameter 14 | from torch.nn import Linear, Tanh, ReLU 15 | import torch.nn.functional as F 16 | 17 | Tensor = torch.Tensor 18 | FloatTensor = torch.FloatTensor 19 | 20 | torch.set_printoptions(precision=4, sci_mode=False) 21 | np.set_printoptions(precision=4, suppress=True) 22 | 23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD 24 | 25 | import scipy 26 | import scipy as sp 27 | from scipy.io import loadmat as sp_loadmat 28 | import copy 29 | 30 | cwd = os.path.abspath(os.getcwd()) 31 | os.chdir(cwd) 32 | 33 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel 34 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNN 35 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MetropolisHastings_Sampler, MALA_Sampler, HMC_Sampler 36 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_multimodal_linear_regression, generate_nonstationary_data 37 | from pytorch_MCMC.src.MCMC_Utils import posterior_dist 38 | from Utils.Utils import RunningAverageMeter, str2bool 39 | 40 | def create_supervised_gif(model, chain, data): 41 | 42 | x, y = data 43 | x_min = 2 * x.min() 44 | x_max = 2 * x.max() 45 | 46 | data, mu, _ = model.predict(chain) 47 | 48 | gif_frames = [] 49 | 50 | samples = [400, 600, 800, 1000] 51 | samples += range(2000, len(chain)//2, 2000) 52 | samples += range(len(chain)//2, len(chain), 4000) 53 | 54 | # print(len(samples)) 55 | # exit() 56 | 57 | for i in range(400,len(chain), 500): 58 | print(f"{i}/{len(samples)}") 59 | 60 | # _, _, std = model.predict(chain[:i]) 61 | fig = plt.figure() 62 | _, mu, std = model.predict(chain[399:i]) 63 | plt.fill_between(data.squeeze(), mu + std, mu - std, color='red', alpha=0.25) 64 | plt.fill_between(data.squeeze(), mu + 2 * std, mu - 2 * std, color='red', alpha=0.10) 65 | plt.fill_between(data.squeeze(), mu + 3 * std, mu - 3 * std, color='red', alpha=0.05) 66 | 67 | plt.plot(data.squeeze(), mu, c='red') 68 | plt.scatter(x, y, alpha=1, s=1, color='blue') 69 | plt.ylim(2 * y.min(), 2 * y.max()) 70 | plt.xlim(x_min, x_max) 71 | plt.grid() 72 | 73 | fig.canvas.draw() # draw the canvas, cache the renderer 74 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') 75 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 76 | 77 | # plt.show() 78 | gif_frames.append(image) 79 | 80 | import imageio 81 | imageio.mimsave('HMC_Sampler5.gif', gif_frames, fps=4) 82 | 83 | def create_gmm_gif(chains): 84 | 85 | # num_samples = [40, 80, 120, 160, 200, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000] 86 | num_samples = [x for x in range(3,2000,40)] 87 | num_samples += [x for x in range(2000, len(chains[0]), 500)] 88 | 89 | gif_frames = [] 90 | 91 | for num_samples_ in num_samples: 92 | 93 | print(f"{num_samples_}/{len(chains[0])}") 94 | 95 | post = [] 96 | 97 | for chain in chains: 98 | 99 | for model_state_dict in chain.samples[:num_samples_]: 100 | post.append(list(model_state_dict.values())[0]) 101 | 102 | post = torch.cat(post, dim=0) 103 | 104 | fig = plt.figure() 105 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]), 106 | density=True) 107 | plt.colorbar(hist2d[3]) 108 | # plt.show() 109 | 110 | 111 | fig.canvas.draw() # draw the canvas, cache the renderer 112 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') 113 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 114 | 115 | 116 | gif_frames.append(image) 117 | 118 | import imageio 119 | imageio.mimsave('GMM_HMC1.gif', gif_frames, fps=4) 120 | 121 | 122 | 123 | if True: 124 | chain = torch.load("hmc_regnn_ss0.01_len10000.chain") 125 | chain = chain[:50000] 126 | data = generate_nonstationary_data(num_samples=1000, plot=False, x_noise_std=0.01, y_noise_std=0.1) 127 | nn = RegressionNN(*data, batch_size=50) 128 | 129 | create_supervised_gif(nn, chain, data) 130 | 131 | if False: 132 | 133 | chains = torch.load("GMM_Chains.chain") 134 | 135 | create_gmm_gif(chains) 136 | 137 | posterior_dist(chains[0][:50]) 138 | plt.show() -------------------------------------------------------------------------------- /experiments/MCMC_HMCTest.py: -------------------------------------------------------------------------------- 1 | # cleaner interaction 2 | 3 | 4 | import os, argparse 5 | 6 | # print(os.path.dirname(sys.executable)) 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | 14 | SEED = 0 15 | # torch.manual_seed(SEED) 16 | # np.random.seed(SEED) 17 | 18 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNNHomo, RegressionNNHetero 19 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MALA_Sampler, HMC_Sampler, SGNHT_Sampler 20 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_nonstationary_data 21 | from Utils.Utils import str2bool 22 | 23 | params = argparse.ArgumentParser(description='parser example') 24 | params.add_argument('-logname', type=str, default='Tmp') 25 | 26 | params.add_argument('-num_samples', type=int, default=200) 27 | params.add_argument('-model', choices=['gmm', 'linreg', 'regnn'], default='gmm') 28 | params.add_argument('-sampler', choices=['sgld','mala', 'hmc', 'sgnht'], default='sgnht') 29 | 30 | params.add_argument('-step_size', type=float, default=0.1) 31 | params.add_argument('-num_steps', type=int, default=10000) 32 | params.add_argument('-pretrain', type=str2bool, default=False) 33 | params.add_argument('-tune', type=str2bool, default=False) 34 | params.add_argument('-burn_in', type=int, default=2000) 35 | # params.add_argument('-num_chains', type=int, default=1) 36 | params.add_argument('-num_chains', type=int, default=os.cpu_count() - 1) 37 | params.add_argument('-batch_size', type=int, default=50) 38 | 39 | params.add_argument('-hmc_traj_length', type=int, default=20) 40 | 41 | params.add_argument('-val_split', type=float, default=0.9) # first part is train, second is val i.e. val_split=0.8 -> 80% train, 20% val 42 | 43 | params.add_argument('-val_prediction_steps', type=int, default=50) 44 | params.add_argument('-val_converge_criterion', type=int, default=20) 45 | params.add_argument('-val_per_epoch', type=int, default=200) 46 | 47 | params = params.parse_args() 48 | 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | 51 | if torch.cuda.is_available(): 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | FloatTensor = torch.cuda.FloatTensor 54 | Tensor = torch.cuda.FloatTensorf 55 | else: 56 | device = torch.device('cpu') 57 | FloatTensor = torch.FloatTensor 58 | Tensor = torch.FloatTensor 59 | 60 | if params.model == 'gmm': 61 | gmm = GMM() 62 | gmm.generate_surface(plot=True) 63 | 64 | if params.sampler=='sgld': 65 | sampler = SGLD_Sampler(probmodel=gmm, 66 | step_size=params.step_size, 67 | num_steps=params.num_steps, 68 | burn_in=params.burn_in, 69 | pretrain=params.pretrain, 70 | tune=params.tune, 71 | num_chains=params.num_chains) 72 | elif params.sampler=='mala': 73 | sampler = MALA_Sampler(probmodel=gmm, 74 | step_size=params.step_size, 75 | num_steps=params.num_steps, 76 | burn_in=params.burn_in, 77 | pretrain=params.pretrain, 78 | tune=params.tune, 79 | num_chains=params.num_chains) 80 | elif params.sampler=='hmc': 81 | sampler = HMC_Sampler(probmodel=gmm, 82 | step_size=params.step_size, 83 | num_steps=params.num_steps, 84 | burn_in=params.burn_in, 85 | pretrain=params.pretrain, 86 | tune=params.tune, 87 | traj_length=params.hmc_traj_length, 88 | num_chains=params.num_chains) 89 | elif params.sampler == 'sgnht': 90 | sampler = SGNHT_Sampler(probmodel=gmm, 91 | step_size=params.step_size, 92 | num_steps=params.num_steps, 93 | burn_in=params.burn_in, 94 | pretrain=params.pretrain, 95 | tune=params.tune, 96 | traj_length=params.hmc_traj_length, 97 | num_chains=params.num_chains) 98 | else: 99 | raise ValueError('No Sampler defined') 100 | 101 | sampler.sample_chains() 102 | sampler.posterior_dist() 103 | # sampler.trace() 104 | 105 | # plt.plot(sampler.chain.accepted_steps) 106 | plt.show() 107 | 108 | elif params.model == 'linreg': 109 | 110 | x, y = generate_linear_regression_data(num_samples=params.num_samples, m=-2., b=-1, y_noise=0.5) 111 | linreg = LinReg(x, y) 112 | # sampler = MetropolisHastings_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, tune=params.tune) 113 | sampler = SGLD_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, 114 | pretrain=params.pretrain, tune=params.tune) 115 | sampler.sample_chains() 116 | sampler.posterior_dist() 117 | linreg.predict(sampler.chain) 118 | 119 | elif params.model == 'regnn': 120 | 121 | model = ['homo', 'hetero'][1] 122 | if model=='homo': 123 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, y_constant_noise_std=0.25, y_nonstationary_noise_std=0.01) 124 | nn = RegressionNNHomo(x, y, batch_size=params.batch_size) 125 | elif model=='hetero': 126 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, y_constant_noise_std=0.001, y_nonstationary_noise_std=0.5) 127 | nn = RegressionNNHetero(x, y, batch_size=params.batch_size) 128 | 129 | if params.sampler == 'sgld': 130 | sampler = SGLD_Sampler(probmodel=nn, 131 | step_size=params.step_size, 132 | num_steps=params.num_steps, 133 | burn_in=params.burn_in, 134 | pretrain=params.pretrain, 135 | tune=params.tune, 136 | num_chains=params.num_chains) 137 | elif params.sampler == 'mala': 138 | sampler = MALA_Sampler(probmodel=nn, 139 | step_size=params.step_size, 140 | num_steps=params.num_steps, 141 | burn_in=params.burn_in, 142 | pretrain=params.pretrain, 143 | tune=params.tune, 144 | num_chains=params.num_chains) 145 | elif params.sampler == 'hmc': 146 | sampler = HMC_Sampler(probmodel=nn, 147 | step_size=params.step_size, 148 | num_steps=params.num_steps, 149 | burn_in=params.burn_in, 150 | pretrain=params.pretrain, 151 | tune=params.tune, 152 | traj_length=params.hmc_traj_length, 153 | num_chains=params.num_chains) 154 | elif params.sampler == 'sgnht': 155 | sampler = SGNHT_Sampler(probmodel=nn, 156 | step_size=params.step_size, 157 | num_steps=params.num_steps, 158 | burn_in=params.burn_in, 159 | pretrain=params.pretrain, 160 | tune=params.tune, 161 | traj_length=params.hmc_traj_length, 162 | num_chains=params.num_chains) 163 | else: 164 | raise ValueError('No Sampler defined') 165 | chains = sampler.sample_chains() 166 | 167 | nn.predict(chains, plot=True) 168 | -------------------------------------------------------------------------------- /experiments/MCMC_Test.py: -------------------------------------------------------------------------------- 1 | # cleaner interaction 2 | 3 | 4 | import os, argparse 5 | 6 | # print(os.path.dirname(sys.executable)) 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | 14 | SEED = 0 15 | # torch.manual_seed(SEED) 16 | # np.random.seed(SEED) 17 | 18 | from pytorch_MCMC.models.MCMC_Models import GMM, LinReg, RegressionNN 19 | from pytorch_MCMC.src.MCMC_Sampler import SGLD_Sampler, MALA_Sampler, HMC_Sampler 20 | from pytorch_MCMC.data.MCMC_SyntheticData import generate_linear_regression_data, generate_nonstationary_data 21 | from Utils.Utils import str2bool 22 | 23 | params = argparse.ArgumentParser(description='parser example') 24 | params.add_argument('-logname', type=str, default='Tmp') 25 | 26 | params.add_argument('-num_samples', type=int, default=1000) 27 | params.add_argument('-model', choices=['gmm', 'linreg', 'regnn'], default='gmm') 28 | params.add_argument('-sampler', choices=['sgld', 'mala', 'hmc'], default='hmc') 29 | 30 | params.add_argument('-step_size', type=float, default=0.1) 31 | params.add_argument('-num_steps', type=int, default=10000) 32 | params.add_argument('-pretrain', type=str2bool, default=False) 33 | params.add_argument('-tune', type=str2bool, default=False) 34 | params.add_argument('-burn_in', type=int, default=1000) 35 | # params.add_argument('-num_chains', type=int, default=1) 36 | params.add_argument('-num_chains', type=int, default=os.cpu_count() - 1) 37 | params.add_argument('-batch_size', type=int, default=50) 38 | 39 | params.add_argument('-hmc_traj_length', type=int, default=20) 40 | 41 | params.add_argument('-val_split', type=float, default=0.9) # first part is train, second is val i.e. val_split=0.8 -> 80% train, 20% val 42 | 43 | params.add_argument('-val_prediction_steps', type=int, default=50) 44 | params.add_argument('-val_converge_criterion', type=int, default=20) 45 | params.add_argument('-val_per_epoch', type=int, default=200) 46 | 47 | params = params.parse_args() 48 | 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | 51 | if torch.cuda.is_available(): 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | FloatTensor = torch.cuda.FloatTensor 54 | Tensor = torch.cuda.FloatTensorf 55 | else: 56 | device = torch.device('cpu') 57 | FloatTensor = torch.FloatTensor 58 | Tensor = torch.FloatTensor 59 | 60 | if params.model == 'gmm': 61 | gmm = GMM() 62 | # gmm.generate_surface(plot=True) 63 | 64 | if params.sampler == 'sgld': 65 | sampler = SGLD_Sampler(probmodel=gmm, 66 | step_size=params.step_size, 67 | num_steps=params.num_steps, 68 | burn_in=params.burn_in, 69 | pretrain=params.pretrain, 70 | tune=params.tune, 71 | num_chains=params.num_chains) 72 | elif params.sampler == 'mala': 73 | sampler = MALA_Sampler(probmodel=gmm, 74 | step_size=params.step_size, 75 | num_steps=params.num_steps, 76 | burn_in=params.burn_in, 77 | pretrain=params.pretrain, 78 | tune=params.tune, 79 | num_chains=params.num_chains) 80 | elif params.sampler == 'hmc': 81 | sampler = HMC_Sampler(probmodel=gmm, 82 | step_size=params.step_size, 83 | num_steps=params.num_steps, 84 | burn_in=params.burn_in, 85 | pretrain=params.pretrain, 86 | tune=params.tune, 87 | traj_length=params.hmc_traj_length, 88 | num_chains=params.num_chains) 89 | sampler.sample_chains() 90 | sampler.posterior_dist() 91 | # sampler.trace() 92 | 93 | # plt.plot(sampler.chain.accepted_steps) 94 | plt.show() 95 | 96 | elif params.model == 'linreg': 97 | 98 | x, y = generate_linear_regression_data(num_samples=params.num_samples, m=-2., b=-1, y_noise=0.5) 99 | linreg = LinReg(x, y) 100 | # sampler = MetropolisHastings_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, tune=params.tune) 101 | sampler = SGLD_Sampler(probmodel=linreg, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, 102 | pretrain=params.pretrain, tune=params.tune) 103 | sampler.sample_chains() 104 | sampler.posterior_dist() 105 | linreg.predict(sampler.chain) 106 | 107 | elif params.model == 'regnn': 108 | 109 | x, y = generate_nonstationary_data(num_samples=params.num_samples, plot=False, x_noise_std=0.01, y_noise_std=0.1) 110 | # print(f'{x.shape=} {y.shape=}') 111 | nn = RegressionNN(x, y, batch_size=params.batch_size) 112 | 113 | # sampler = SGLD_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, tune=params.tune) 114 | # sampler = MALA_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, 115 | # tune=params.tune) 116 | sampler = HMC_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, 117 | tune=params.tune) 118 | # sampler = MetropolisHastings_Sampler(probmodel=nn, step_size=params.step_size, num_steps=params.num_steps, burn_in=params.burn_in, pretrain=params.pretrain, tune=params.tune) 119 | sampler.sample_chains() 120 | torch.save(sampler.chain, 'hmc_regnn_ss0.01_len10000.chain') 121 | nn.predict(sampler.chain) 122 | -------------------------------------------------------------------------------- /experiments/Testing.py: -------------------------------------------------------------------------------- 1 | import future, sys, os, datetime, argparse 2 | # print(os.path.dirname(sys.executable)) 3 | import torch 4 | import numpy as np 5 | import matplotlib 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | from matplotlib.lines import Line2D 9 | 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | from torch.nn import Module, Parameter 14 | from torch.nn import Linear, Tanh, ReLU 15 | import torch.nn.functional as F 16 | 17 | Tensor = torch.Tensor 18 | FloatTensor = torch.FloatTensor 19 | 20 | torch.set_printoptions(precision=4, sci_mode=False) 21 | np.set_printoptions(precision=4, suppress=True) 22 | 23 | import scipy 24 | import scipy as sp 25 | from scipy.io import loadmat as sp_loadmat 26 | import copy 27 | 28 | cwd = os.path.abspath(os.getcwd()) 29 | os.chdir(cwd) 30 | 31 | params = argparse.ArgumentParser() 32 | params.add_argument('-xyz', type=str, default='test_xyz') 33 | 34 | params = params.parse_args() -------------------------------------------------------------------------------- /experiments/functional_pytorch: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | inputs = torch.randn(64, 3) 4 | targets = torch.randn(64, 3) 5 | model = torch.nn.Linear(3, 3) 6 | 7 | params = dict(model.named_parameters()) 8 | 9 | 10 | def compute_loss(params, inputs, targets): 11 | prediction = torch.func.functional_call(model, params, (inputs,)) 12 | return torch.nn.functional.mse_loss(prediction, targets) 13 | 14 | 15 | grads = torch.func.grad(compute_loss)(params, inputs, targets) 16 | 17 | 18 | # %% 19 | import copy 20 | 21 | import torch 22 | 23 | num_models = 5 24 | batch_size = 64 25 | in_features, out_features = 3, 3 26 | models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 27 | data = torch.randn(batch_size, 3) 28 | 29 | # Construct a version of the model with no memory by putting the Tensors on 30 | # the meta device. 31 | base_model = copy.deepcopy(models[0]) 32 | base_model.to("meta") 33 | 34 | params, buffers = torch.func.stack_module_state(models) 35 | 36 | 37 | # It is possible to vmap directly over torch.func.functional_call, 38 | # but wrapping it in a function makes it clearer what is going on. 39 | def call_single_model(params, buffers, data): 40 | return torch.func.functional_call(base_model, (params, buffers), (data,)) 41 | 42 | 43 | def call_single_loss(params, buffers, data): 44 | prediction = torch.func.functional_call(base_model, (params, buffers), (data,)) 45 | return torch.nn.functional.mse_loss(prediction, targets) 46 | 47 | 48 | output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) 49 | grad = torch.func.grad( 50 | lambda p, b, d: torch.sum(torch.vmap(call_single_loss, (0, 0, None))(p, b, d)) 51 | )(params, buffers, data) 52 | assert output.shape == (num_models, batch_size, out_features) 53 | -------------------------------------------------------------------------------- /models/MCMC_Models.py: -------------------------------------------------------------------------------- 1 | import future, sys, os, datetime, argparse 2 | # print(os.path.dirname(sys.executable)) 3 | import torch 4 | import numpy as np 5 | import matplotlib 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | from matplotlib.lines import Line2D 9 | 10 | # matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | from torch.nn import Module, Parameter, Sequential 14 | from torch.nn import Linear, Tanh, ReLU, CELU 15 | from torch.utils.data import TensorDataset, DataLoader 16 | import torch.nn.functional as F 17 | from torch.distributions import MultivariateNormal, Categorical, Normal 18 | 19 | from joblib import Parallel, delayed 20 | 21 | Tensor = torch.Tensor 22 | FloatTensor = torch.FloatTensor 23 | 24 | torch.set_printoptions(precision=4, sci_mode=False) 25 | np.set_printoptions(precision=4, suppress=True) 26 | 27 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD 28 | 29 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel 30 | 31 | class GMM(ProbModel): 32 | 33 | def __init__(self): 34 | 35 | dataloader = DataLoader(TensorDataset(torch.zeros(1,2))) # bogus dataloader 36 | ProbModel.__init__(self, dataloader) 37 | 38 | self.means = FloatTensor([[-1, -1.25], [-1, 1.25], [1.5, 1]]) 39 | # self.means = FloatTensor([[-1,-1.25]]) 40 | self.num_dists = self.means.shape[0] 41 | I = FloatTensor([[1, 0], [0, 1]]) 42 | I_compl = FloatTensor([[0, 1], [1, 0]]) 43 | self.covars = [I * 0.5, I * 0.5, I * 0.5 + I_compl * 0.3] 44 | # self.covars = [I * 0.9, I * 0.9, I * 0.9 + I_compl * 0.3] 45 | self.weights = [0.4, 0.2, 0.4] 46 | self.dists = [] 47 | 48 | for mean, covar in zip(self.means, self.covars): 49 | self.dists.append(MultivariateNormal(mean, covar)) 50 | 51 | self.X_grid = None 52 | self.Y_grid = None 53 | self.surface = None 54 | 55 | self.param = torch.nn.Parameter(self.sample()) 56 | 57 | 58 | 59 | def forward(self, x=None): 60 | 61 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(x)) for dist, weight in zip(self.dists, self.weights)], dim=1) 62 | log_prob = torch.log(torch.sum(log_probs, dim=1)) 63 | 64 | return log_prob 65 | 66 | def log_prob(self, *x): 67 | 68 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(self.param)) for dist, weight in zip(self.dists, self.weights)], dim=1) 69 | log_prob = torch.log(torch.sum(log_probs, dim=1)) 70 | 71 | return {'log_prob': log_prob} 72 | 73 | def prob(self, x): 74 | 75 | log_probs = torch.stack([weight * torch.exp(dist.log_prob(x)) for dist, weight in zip(self.dists, self.weights)], dim=1) 76 | log_prob = torch.sum(log_probs, dim=1) 77 | return log_prob 78 | 79 | def sample(self, _shape=(1,)): 80 | 81 | probs = torch.ones(self.num_dists) / self.num_dists 82 | categorical = Categorical(probs) 83 | sampled_dists = categorical.sample(_shape) 84 | 85 | samples = [] 86 | for sampled_dist in sampled_dists: 87 | sample = self.dists[sampled_dist].sample((1,)) 88 | samples.append(sample) 89 | 90 | samples = torch.cat(samples) 91 | 92 | return samples 93 | 94 | def reset_parameters(self): 95 | 96 | self.param.data = self.sample() 97 | 98 | def generate_surface(self, plot_min=-3, plot_max=3, plot_res=500, plot=False): 99 | 100 | # print('in surface') 101 | 102 | x = np.linspace(plot_min, plot_max, plot_res) 103 | y = np.linspace(plot_min, plot_max, plot_res) 104 | X, Y = np.meshgrid(x, y) 105 | 106 | self.X_grid = X 107 | self.Y_grid = Y 108 | 109 | points = FloatTensor(np.stack((X.ravel(), Y.ravel())).T) # .requires_grad_() 110 | 111 | probs = self.prob(points).view(plot_res, plot_res) 112 | self.surface = probs.numpy() 113 | 114 | area = ((plot_max - plot_min) / plot_res) ** 2 115 | sum_px = probs.sum() * area # analogous to integrating cubes: volume is probs are the height times the area 116 | 117 | fig = plt.figure(figsize=(10, 10)) 118 | 119 | contour = plt.contourf(self.X_grid, self.Y_grid, self.surface, levels=20) 120 | plt.xlim(-3, 3) 121 | plt.ylim(-3, 3) 122 | plt.grid() 123 | cbar = fig.colorbar(contour) 124 | if plot: plt.show() 125 | 126 | return fig 127 | 128 | class LinReg(ProbModel): 129 | 130 | def __init__(self, x, y): 131 | super().__init__() 132 | 133 | self.data = x 134 | self.target = y 135 | 136 | self.dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0]) 137 | 138 | self.m = Parameter(FloatTensor(1 * torch.randn((1,)))) 139 | self.b = Parameter(FloatTensor(1 * torch.randn((1,)))) 140 | # self.log_noise = Parameter(FloatTensor([-1.])) 141 | self.log_noise = FloatTensor([0]) 142 | 143 | def reset_parameters(self): 144 | torch.nn.init.normal_(self.m, std=.1) 145 | torch.nn.init.normal_(self.b, std=.1) 146 | 147 | def sample(self): 148 | self.reset_parameters() 149 | 150 | def forward(self, x): 151 | 152 | return self.m * x + self.b 153 | 154 | def log_prob(self): 155 | 156 | data, target = next(self.dataloader.__iter__()) 157 | # data, target = self.data, self.target 158 | mu = self.forward(data) 159 | log_prob = Normal(mu, F.softplus(self.log_noise)).log_prob(target).mean() 160 | 161 | return {'log_prob': log_prob} 162 | 163 | @torch.no_grad() 164 | def predict(self, chain): 165 | 166 | x_min = 2*self.data.min() 167 | x_max = 2*self.data.max() 168 | data = torch.arange(x_min, x_max).reshape(-1,1) 169 | 170 | pred = [] 171 | for model_state_dict in chain.samples: 172 | self.load_state_dict(model_state_dict) 173 | # data.append(self.data) 174 | pred_i = self.forward(data) 175 | pred.append(pred_i) 176 | 177 | pred = torch.stack(pred) 178 | # data = torch.stack(data) 179 | 180 | mu = pred.mean(dim=0).squeeze() 181 | std = pred.std(dim=0).squeeze() 182 | 183 | # print(f'{data.shape=}') 184 | # print(f'{pred.shape=}') 185 | 186 | plt.plot(data, mu, alpha=1., color='red') 187 | plt.fill_between(data.squeeze(), mu+std, mu-std, color='red', alpha=0.25) 188 | plt.fill_between(data.squeeze(), mu+2*std, mu-2*std, color='red', alpha=0.10) 189 | plt.fill_between(data.squeeze(), mu+3*std, mu-3*std, color='red', alpha=0.05) 190 | plt.scatter(self.data, self.target, alpha=1, s=1, color='blue') 191 | plt.ylim(pred.min(), pred.max()) 192 | plt.xlim(x_min, x_max) 193 | plt.show() 194 | 195 | class RegressionNNHomo(ProbModel): 196 | 197 | def __init__(self, x, y, batch_size=1): 198 | 199 | self.data = x 200 | self.target = y 201 | 202 | # dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0], drop_last=False) 203 | dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=batch_size, drop_last=False) 204 | 205 | ProbModel.__init__(self, dataloader) 206 | 207 | num_hidden = 50 208 | self.model = Sequential(Linear(1, num_hidden), 209 | ReLU(), 210 | Linear(num_hidden, num_hidden), 211 | ReLU(), 212 | # Linear(num_hidden, num_hidden), 213 | # ReLU(), 214 | # Linear(num_hidden, num_hidden), 215 | # ReLU(), 216 | Linear(num_hidden, 1)) 217 | 218 | self.log_std = Parameter(FloatTensor([-1])) 219 | 220 | def reset_parameters(self): 221 | for module in self.model.modules(): 222 | if isinstance(module, Linear): 223 | module.reset_parameters() 224 | 225 | self.log_std.data = FloatTensor([3.]) 226 | 227 | def sample(self): 228 | self.reset_parameters() 229 | 230 | def forward(self, x): 231 | pred = self.model(x) 232 | return pred 233 | 234 | def log_prob(self, data, target): 235 | 236 | # if data is None and target is None: 237 | # data, target = next(self.dataloader.__iter__()) 238 | 239 | mu = self.forward(data) 240 | mse = F.mse_loss(mu,target) 241 | 242 | log_prob = Normal(mu, F.softplus(self.log_std)).log_prob(target).mean()*len(self.dataloader.dataset) 243 | 244 | return {'log_prob': log_prob, 'MSE': mse.detach_()} 245 | 246 | def pretrain(self): 247 | 248 | num_epochs = 200 249 | optim = torch.optim.Adam(self.parameters(), lr=0.01) 250 | 251 | # print(f"{F.softplus(self.log_std)=}") 252 | 253 | progress = tqdm(range(num_epochs)) 254 | for epoch in progress: 255 | for batch_i, (data, target) in enumerate(self.dataloader): 256 | optim.zero_grad() 257 | mu = self.forward(data) 258 | loss = -Normal(mu, F.softplus(self.log_std)).log_prob(target).mean() 259 | mse_loss = F.mse_loss(mu, target) 260 | loss.backward() 261 | optim.step() 262 | 263 | desc = f'Pretraining: MSE:{mse_loss:.3f}' 264 | progress.set_description(desc) 265 | 266 | # print(f"{F.softplus(self.log_std)=}") 267 | 268 | @torch.no_grad() 269 | def predict(self, chains, plot=False): 270 | 271 | x_min = 2*self.data.min() 272 | x_max = 2*self.data.max() 273 | data = torch.linspace(x_min, x_max).reshape(-1,1) 274 | 275 | def parallel_predict(parallel_chain): 276 | parallel_pred = [] 277 | for model_state_dict in parallel_chain.samples[::50]: 278 | self.load_state_dict(model_state_dict) 279 | pred_mu_i = self.forward(data) 280 | parallel_pred.append(pred_mu_i) 281 | try: 282 | parallel_pred_mu = torch.stack(parallel_pred) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N]) 283 | return parallel_pred_mu 284 | except: 285 | pass 286 | 287 | 288 | parallel_pred = Parallel(n_jobs=len(chains))(delayed(parallel_predict)(chain) for chain in chains) 289 | 290 | pred = [parallel_pred_i for parallel_pred_i in parallel_pred if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ] 291 | # pred_log_std = [parallel_pred_i for parallel_pred_i in parallel_pred_log_std if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ] 292 | 293 | 294 | pred = torch.cat(pred).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim 295 | std = F.softplus(self.log_std) 296 | 297 | 298 | epistemic = pred.std(dim=0) 299 | aleatoric = std 300 | total_std = (epistemic ** 2 + aleatoric ** 2) ** 0.5 301 | 302 | mu = pred.mean(dim=0) 303 | std = std.mean(dim=0) 304 | 305 | data.squeeze_() 306 | 307 | if plot: 308 | fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) 309 | axs = axs.flatten() 310 | 311 | axs[0].scatter(self.data, self.target, alpha=1, s=1, color='blue') 312 | axs[0].plot(data.squeeze(), mu, alpha=1., color='red') 313 | axs[0].fill_between(data, mu + total_std, mu - total_std, color='red', alpha=0.25) 314 | axs[0].fill_between(data, mu + 2 * total_std, mu - 2 * total_std, color='red', alpha=0.10) 315 | axs[0].fill_between(data, mu + 3 * total_std, mu - 3 * total_std, color='red', alpha=0.05) 316 | 317 | [axs[1].plot(data, pred, alpha=0.1, color='red') for pred in pred] 318 | axs[1].scatter(self.data, self.target, alpha=1, s=1, color='blue') 319 | 320 | axs[2].scatter(self.data, self.target, alpha=1, s=1, color='blue') 321 | axs[2].plot(data, mu, color='red') 322 | axs[2].fill_between(data, mu - aleatoric, mu + aleatoric, color='red', alpha=0.25, label='Aleatoric') 323 | axs[2].legend() 324 | 325 | axs[3].scatter(self.data, self.target, alpha=1, s=1, color='blue') 326 | axs[3].plot(data, mu, color='red') 327 | axs[3].fill_between(data, mu - epistemic, mu + epistemic, color='red', alpha=0.25, label='Epistemic') 328 | axs[3].legend() 329 | 330 | plt.ylim(2 * self.target.min(), 2 * self.target.max()) 331 | plt.xlim(x_min, x_max) 332 | plt.show() 333 | 334 | class RegressionNNHetero(ProbModel): 335 | 336 | def __init__(self, x, y, batch_size=1): 337 | 338 | self.data = x 339 | self.target = y 340 | 341 | dataloader = DataLoader(TensorDataset(self.data, self.target), shuffle=True, batch_size=self.data.shape[0], drop_last=False) 342 | # dataloader = DataLoader(TensorDataset(x, y), shuffle=True, batch_size=batch_size, drop_last=False) 343 | 344 | ProbModel.__init__(self, dataloader) 345 | 346 | num_hidden = 50 347 | self.model = Sequential(Linear(1, num_hidden), 348 | ReLU(), 349 | Linear(num_hidden, num_hidden), 350 | ReLU(), 351 | Linear(num_hidden, num_hidden), 352 | ReLU(), 353 | # Linear(num_hidden, num_hidden), 354 | # ReLU(), 355 | Linear(num_hidden, 2)) 356 | 357 | def reset_parameters(self): 358 | for module in self.model.modules(): 359 | if isinstance(module, Linear): 360 | module.reset_parameters() 361 | 362 | def sample(self): 363 | self.reset_parameters() 364 | 365 | def forward(self, x): 366 | pred = self.model(x) 367 | mu, log_std = torch.chunk(pred, chunks=2, dim=-1) 368 | return mu, log_std 369 | 370 | def log_prob(self, data, target): 371 | 372 | # if data is None and target is None: 373 | # data, target = next(self.dataloader.__iter__()) 374 | 375 | mu, log_std = self.forward(data) 376 | mse = F.mse_loss(mu,target) 377 | 378 | log_prob = Normal(mu, F.softplus(log_std)).log_prob(target).mean()*len(self.dataloader.dataset) 379 | 380 | return {'log_prob': log_prob, 'MSE': mse.detach_()} 381 | 382 | def pretrain(self): 383 | 384 | num_epochs = 100 385 | optim = torch.optim.Adam(self.parameters(), lr=0.001) 386 | 387 | progress = tqdm(range(num_epochs)) 388 | for epoch in progress: 389 | for batch_i, (data, target) in enumerate(self.dataloader): 390 | optim.zero_grad() 391 | mu, log_std = self.forward(data) 392 | loss = -Normal(mu, F.softplus(log_std)).log_prob(target).mean() 393 | mse_loss = F.mse_loss(mu, target) 394 | loss.backward() 395 | optim.step() 396 | 397 | desc = f'Pretraining: MSE:{mse_loss:.3f}' 398 | progress.set_description(desc) 399 | 400 | @torch.no_grad() 401 | def predict(self, chains, plot=False): 402 | 403 | x_min = 2*self.data.min() 404 | x_max = 2*self.data.max() 405 | data = torch.linspace(x_min, x_max).reshape(-1,1) 406 | 407 | def parallel_predict(parallel_chain): 408 | parallel_pred_mu = [] 409 | parallel_pred_log_std = [] 410 | for model_state_dict in parallel_chain.samples[::50]: 411 | self.load_state_dict(model_state_dict) 412 | pred_mu_i, pred_log_std_i = self.forward(data) 413 | parallel_pred_mu.append(pred_mu_i) 414 | parallel_pred_log_std.append(pred_log_std_i) 415 | 416 | try: 417 | parallel_pred_mu = torch.stack(parallel_pred_mu) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N]) 418 | parallel_pred_log_std = torch.stack(parallel_pred_log_std) # list [ pred_0, pred_1, ... pred_N] -> Tensor([pred_0, pred_1, ... pred_N]) 419 | return parallel_pred_mu, parallel_pred_log_std 420 | except: 421 | pass 422 | 423 | 424 | parallel_pred_mu, parallel_pred_log_std = zip(*Parallel(n_jobs=len(chains))(delayed(parallel_predict)(chain) for chain in chains)) 425 | 426 | pred_mu = [parallel_pred_i for parallel_pred_i in parallel_pred_mu if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ] 427 | pred_log_std = [parallel_pred_i for parallel_pred_i in parallel_pred_log_std if parallel_pred_i is not None] # flatten [ [pred_chain_0], [pred_chain_1] ... [pred_chain_N] ] 428 | 429 | 430 | pred_mu = torch.cat(pred_mu).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim 431 | pred_log_std = torch.cat(pred_log_std).squeeze() # cat list of tensors to single prediciton tensor with samples in first dim 432 | 433 | mu = pred_mu.squeeze() 434 | std = F.softplus(pred_log_std).squeeze() 435 | 436 | epistemic = mu.std(dim=0) 437 | aleatoric = (std**2).mean(dim=0)**0.5 438 | total_std = (epistemic**2 + aleatoric**2)**0.5 439 | 440 | mu = mu.mean(dim=0) 441 | std = std.mean(dim=0) 442 | 443 | data.squeeze_() 444 | 445 | if plot: 446 | 447 | fig, axs = plt.subplots(2,2, sharex=True, sharey=True) 448 | axs = axs.flatten() 449 | 450 | axs[0].scatter(self.data, self.target, alpha=1, s=1, color='blue') 451 | axs[0].plot(data.squeeze(), mu, alpha=1., color='red') 452 | axs[0].fill_between(data, mu+total_std, mu-total_std, color='red', alpha=0.25) 453 | axs[0].fill_between(data, mu+2*total_std, mu-2*total_std, color='red', alpha=0.10) 454 | axs[0].fill_between(data, mu+3*total_std, mu-3*total_std, color='red', alpha=0.05) 455 | 456 | [axs[1].plot(data, pred, alpha=0.1, color='red') for pred in pred_mu] 457 | axs[1].scatter(self.data, self.target, alpha=1, s=1, color='blue') 458 | 459 | axs[2].scatter(self.data, self.target, alpha=1, s=1, color='blue') 460 | axs[2].plot(data, mu, color='red') 461 | axs[2].fill_between(data, mu-aleatoric, mu+aleatoric, color='red', alpha=0.25, label='Aleatoric') 462 | axs[2].legend() 463 | 464 | axs[3].scatter(self.data, self.target, alpha=1, s=1, color='blue') 465 | axs[3].plot(data, mu, color='red') 466 | axs[3].fill_between(data, mu-epistemic, mu+epistemic, color='red', alpha=0.25, label='Epistemic') 467 | axs[3].legend() 468 | 469 | plt.ylim(2*self.target.min(), 2*self.target.max()) 470 | plt.xlim(x_min, x_max) 471 | plt.show() 472 | 473 | 474 | return data, mu, std 475 | 476 | 477 | if __name__ == '__main__': 478 | 479 | pass -------------------------------------------------------------------------------- /src/MCMC_Acceptance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 4 | 5 | class MetropolisHastingsAcceptance(): 6 | 7 | def __init__(self): 8 | 9 | pass 10 | 11 | def __call__(self, log_prob_proposal, log_prob_state): 12 | ''' 13 | accept = min ( p(x') / p(x) , 1) 14 | log_accept = min( log_p(x') - log_p(x) , 1) 15 | log_accept = min (log_ratio, 1) 16 | ''' 17 | 18 | 19 | if not torch.isnan(log_prob_proposal) or not torch.isinf(log_prob_proposal): 20 | log_ratio = (log_prob_proposal - log_prob_state) 21 | log_ratio = torch.min(log_ratio, torch.zeros_like(log_ratio)) 22 | 23 | log_u = torch.zeros_like(log_ratio).uniform_(0,1).log() 24 | 25 | log_accept = torch.gt(log_ratio, log_u) 26 | log_accept = log_accept.bool().item() 27 | 28 | return log_accept, log_ratio 29 | 30 | elif torch.isnan(log_prob_proposal) or torch.isinf(log_prob_proposal): 31 | exit(f'log_prob_proposal is nan or inf {log_prob_proposal}') 32 | return False, torch.Tensor([-1]) 33 | 34 | class SDE_Acceptance(): 35 | 36 | def __init__(self): 37 | 38 | pass 39 | 40 | def __call__(self, log_prob_proposal, log_prob_state): 41 | 42 | return True, torch.Tensor([0.]) -------------------------------------------------------------------------------- /src/MCMC_Chain.py: -------------------------------------------------------------------------------- 1 | import future, sys, os, datetime, argparse, copy, warnings, time 2 | from collections import MutableSequence, Iterable, OrderedDict 3 | from itertools import compress 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | from matplotlib.lines import Line2D 10 | matplotlib.rcParams["figure.figsize"] = [10, 10] 11 | 12 | import torch 13 | from torch.nn import Module, Parameter 14 | from torch.nn import Linear, Tanh, ReLU 15 | import torch.nn.functional as F 16 | 17 | Tensor = torch.Tensor 18 | FloatTensor = torch.FloatTensor 19 | 20 | torch.set_printoptions(precision=4, sci_mode=False) 21 | np.set_printoptions(precision=4, suppress=True) 22 | 23 | sys.path.append("../../..") # Up to -> KFAC -> Optimization -> PHD 24 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel 25 | from pytorch_MCMC.src.MCMC_Optim import SGLD_Optim, MetropolisHastings_Optim, MALA_Optim, HMC_Optim, SGNHT_Optim 26 | from pytorch_MCMC.src.MCMC_Acceptance import SDE_Acceptance, MetropolisHastingsAcceptance 27 | from Utils.Utils import RunningAverageMeter 28 | 29 | ''' 30 | Python Container Time Complexity: https://wiki.python.org/moin/TimeComplexity 31 | ''' 32 | 33 | 34 | class Chain(MutableSequence): 35 | 36 | ''' 37 | A container for storing the MCMC chain conveniently: 38 | samples: list of state_dicts 39 | log_probs: list of log_probs 40 | accepts: list of bools 41 | state_idx: 42 | init index of last accepted via np.where(accepts==True)[0][-1] 43 | can be set via len(samples) while sampling 44 | 45 | @property 46 | samples: filters the samples 47 | 48 | 49 | ''' 50 | 51 | def __init__(self, probmodel=None): 52 | 53 | super().__init__() 54 | 55 | if probmodel is None: 56 | ''' 57 | Create an empty chain 58 | ''' 59 | self.state_dicts = [] 60 | self.log_probs = [] 61 | self.accepts = [] 62 | 63 | if probmodel is not None: 64 | ''' 65 | Initialize chain with given model 66 | ''' 67 | assert isinstance(probmodel, ProbModel) 68 | 69 | self.state_dicts = [copy.deepcopy(probmodel.state_dict())] 70 | log_prob = probmodel.log_prob(*next(probmodel.dataloader.__iter__())) 71 | log_prob['log_prob'].detach_() 72 | self.log_probs = [copy.deepcopy(log_prob)] 73 | self.accepts = [True] 74 | self.last_accepted_idx = 0 75 | 76 | self.running_avgs = {} 77 | for key, value in log_prob.items(): 78 | self.running_avgs.update({key: RunningAverageMeter(0.99)}) 79 | 80 | self.running_accepts = RunningAverageMeter(0.999) 81 | 82 | def __len__(self): 83 | return len(self.state_dicts) 84 | 85 | def __iter__(self): 86 | return zip(self.state_dicts, self.log_probs, self.accepts) 87 | 88 | def __delitem__(self): 89 | raise NotImplementedError 90 | 91 | def __setitem__(self): 92 | raise NotImplementedError 93 | 94 | def insert(self): 95 | raise NotImplementedError 96 | 97 | def __repr__(self): 98 | return f'MCMC Chain: Length:{len(self)} Accept:{self.accept_ratio:.2f}' 99 | 100 | def __getitem__(self, i): 101 | chain = copy.deepcopy(self) 102 | chain.state_dicts = self.samples[i] 103 | chain.log_probs = self.log_probs[i] 104 | chain.accepts = self.accepts[i] 105 | return chain 106 | 107 | def __add__(self, other): 108 | 109 | if type(other) in [tuple, list]: 110 | assert len(other) == 3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4" 111 | self.append(*other) 112 | elif isinstance(other, Chain): 113 | self.cat(other) 114 | 115 | return self 116 | 117 | def __iadd__(self, other): 118 | 119 | if type(other) in [tuple, list]: 120 | assert len(other)==3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4" 121 | self.append(*other) 122 | elif isinstance(other, Chain): 123 | self.cat_chains(other) 124 | 125 | return self 126 | 127 | @property 128 | def state_idx(self): 129 | ''' 130 | Returns the index of the last accepted sample a.k.a. the state of the chain 131 | 132 | ''' 133 | if not hasattr(self, 'state_idx'): 134 | ''' 135 | If the chain hasn't a state_idx, compute it from self.accepts by taking the last True of self.accepts 136 | ''' 137 | self.last_accepted_idx = np.where(self.accepts==True)[0][-1] 138 | return self.last_accepted_idx 139 | else: 140 | ''' 141 | Check that the state of the chain is actually the last True in self.accepts 142 | ''' 143 | last_accepted_sample_ = np.where(self.accepts == True)[0][-1] 144 | assert last_accepted_sample_ == self.last_accepted_idx 145 | assert self.accepts[self.last_accepted_idx]==True 146 | return self.last_accepted_idx 147 | 148 | 149 | @property 150 | def samples(self): 151 | ''' 152 | Filters the list of state_dicts with the list of bools from self.accepts 153 | :return: list of accepted state_dicts 154 | ''' 155 | return list(compress(self.state_dicts, self.accepts)) 156 | 157 | @property 158 | def accept_ratio(self): 159 | ''' 160 | Sum the boolean list (=total number of Trues) and divides it by its length 161 | :return: float valued accept ratio 162 | ''' 163 | return sum(self.accepts)/len(self.accepts) 164 | 165 | @property 166 | def state(self): 167 | return {'state_dict': self.state_dicts[self.last_accepted_idx], 'log_prob': self.log_probs[self.last_accepted_idx]} 168 | 169 | def cat_chains(self, other): 170 | 171 | assert isinstance(other, Chain) 172 | self.state_dicts += other.state_dicts 173 | self.log_probs += other.log_probs 174 | self.accepts += other.accepts 175 | 176 | for key, value in other.running_avgs.items(): 177 | self.running_avgs[key].avg = 0.5*self.running_avgs[key].avg + 0.5 * other.running_avgs[key].avg 178 | 179 | 180 | def append(self, probmodel, log_prob, accept): 181 | 182 | if isinstance(probmodel, ProbModel): 183 | params_state_dict = copy.deepcopy(probmodel.state_dict()) 184 | elif isinstance(probmodel, OrderedDict): 185 | params_state_dict = copy.deepcopy(probmodel) 186 | assert isinstance(log_prob, dict) 187 | assert type(log_prob['log_prob'])==torch.Tensor 188 | assert log_prob['log_prob'].numel()==1 189 | 190 | log_prob['log_prob'].detach_() 191 | 192 | 193 | self.accepts.append(accept) 194 | self.running_accepts.update(1 * accept) 195 | 196 | if accept: 197 | self.state_dicts.append(params_state_dict) 198 | self.log_probs.append(copy.deepcopy(log_prob)) 199 | self.last_accepted_idx = len(self.state_dicts)-1 200 | for key, value in log_prob.items(): 201 | self.running_avgs[key].update(value.item()) 202 | 203 | elif not accept: 204 | self.state_dicts.append(False) 205 | self.log_probs.append(False) 206 | 207 | class Sampler_Chain: 208 | 209 | def __init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune): 210 | 211 | self.probmodel = probmodel 212 | self.chain = Chain(probmodel=self.probmodel) 213 | 214 | self.step_size = step_size 215 | self.num_steps = num_steps 216 | self.burn_in = burn_in 217 | 218 | self.pretrain = pretrain 219 | self.tune = tune 220 | 221 | def propose(self): 222 | raise NotImplementedError 223 | 224 | def __repr__(self): 225 | raise NotImplementedError 226 | 227 | def tune_step_size(self): 228 | 229 | tune_interval_length = 100 230 | num_tune_intervals = int(self.burn_in // tune_interval_length) 231 | 232 | verbose = True 233 | 234 | print(f'Tuning: Init Step Size: {self.optim.param_groups[0]["step_size"]:.5f}') 235 | 236 | self.probmodel.reset_parameters() 237 | tune_chain = Chain(probmodel=self.probmodel) 238 | tune_chain.running_accepts.momentum = 0.5 239 | 240 | progress = tqdm(range(self.burn_in)) 241 | for tune_step in progress: 242 | 243 | 244 | 245 | sample_log_prob, sample = self.propose() 246 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], self.chain.state['log_prob']['log_prob']) 247 | tune_chain += (self.probmodel, sample_log_prob, accept) 248 | 249 | # if tune_step < self.burn_in and tune_step % tune_interval_length == 0 and tune_step > 0: 250 | if tune_step > 1: 251 | # self.optim.dual_average_tune(tune_chain, np.exp(log_ratio.item())) 252 | self.optim.dual_average_tune(tune_chain.accepts[-tune_interval_length:], tune_step, np.exp(log_ratio.item())) 253 | # self.optim.tune(tune_chain.accepts[-tune_interval_length:]) 254 | 255 | if not accept: 256 | 257 | if torch.isnan(sample_log_prob['log_prob']): 258 | print(self.chain.state) 259 | exit() 260 | self.probmodel.load_state_dict(self.chain.state['state_dict']) 261 | 262 | desc = f'Tuning: Accept: {tune_chain.running_accepts.avg:.2f}/{tune_chain.accept_ratio:.2f} StepSize: {self.optim.param_groups[0]["step_size"]:.5f}' 263 | 264 | progress.set_description( 265 | desc=desc) 266 | 267 | 268 | 269 | time.sleep(0.1) # for cleaner printing in the console 270 | 271 | def sample_chain(self): 272 | 273 | self.probmodel.reset_parameters() 274 | 275 | if self.pretrain: 276 | try: 277 | self.probmodel.pretrain() 278 | except: 279 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.') 280 | 281 | if self.tune: 282 | self.tune_step_size() 283 | 284 | # print(f"After Tuning Step Size: {self.optim.param_groups[0]['step_size']=}") 285 | 286 | self.chain = Chain(probmodel=self.probmodel) 287 | 288 | progress = tqdm(range(self.num_steps)) 289 | for step in progress: 290 | 291 | proposal_log_prob, sample = self.propose() 292 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob']) 293 | self.chain += (self.probmodel, proposal_log_prob, accept) 294 | 295 | if not accept: 296 | 297 | if torch.isnan(proposal_log_prob['log_prob']): 298 | print(self.chain.state) 299 | exit() 300 | self.probmodel.load_state_dict(self.chain.state['state_dict']) 301 | 302 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t' 303 | for key, running_avg in self.chain.running_avgs.items(): 304 | desc += f' {key}: {running_avg.avg:.2f} ' 305 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}' 306 | # desc +=f" Std: {F.softplus(self.probmodel.log_std.detach()).item():.3f}" 307 | progress.set_description(desc=desc) 308 | 309 | self.chain = self.chain[self.burn_in:] 310 | 311 | return self.chain 312 | 313 | class SGLD_Chain(Sampler_Chain): 314 | 315 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False): 316 | 317 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune) 318 | 319 | self.optim = SGLD_Optim(self.probmodel, 320 | step_size=step_size, 321 | prior_std=1., 322 | addnoise=True) 323 | 324 | self.acceptance = SDE_Acceptance() 325 | 326 | def __repr__(self): 327 | return 'SGLD' 328 | 329 | @torch.enable_grad() 330 | def propose(self): 331 | 332 | self.optim.zero_grad() 333 | batch = next(self.probmodel.dataloader.__iter__()) 334 | log_prob = self.probmodel.log_prob(*batch) 335 | (-log_prob['log_prob']).backward() 336 | self.optim.step() 337 | 338 | return log_prob, self.probmodel 339 | 340 | class MALA_Chain(Sampler_Chain): 341 | 342 | def __init__(self, probmodel, step_size=0.1, num_steps=2000, burn_in=100, pretrain=False, tune=False, num_chain=0): 343 | 344 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune) 345 | 346 | self.num_chain = num_chain 347 | 348 | self.optim = MALA_Optim(self.probmodel, 349 | step_size=step_size, 350 | prior_std=1., 351 | addnoise=True) 352 | 353 | self.acceptance = MetropolisHastingsAcceptance() 354 | # self.acceptance = SDE_Acceptance() 355 | 356 | def __repr__(self): 357 | return 'MALA' 358 | 359 | @torch.enable_grad() 360 | def propose(self): 361 | 362 | self.optim.zero_grad() 363 | batch = next(self.probmodel.dataloader.__iter__()) 364 | log_prob = self.probmodel.log_prob(*batch) 365 | (-log_prob['log_prob']).backward() 366 | self.optim.step() 367 | 368 | return log_prob, self.probmodel 369 | 370 | class HMC_Chain(Sampler_Chain): 371 | 372 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False, 373 | traj_length=20): 374 | 375 | # assert probmodel.log_prob().keys()[:3] == ['log_prob', 'data', ] 376 | 377 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune) 378 | 379 | self.traj_length = traj_length 380 | 381 | self.optim = HMC_Optim(self.probmodel, 382 | step_size=step_size, 383 | prior_std=1.) 384 | 385 | # self.acceptance = SDE_Acceptance() 386 | self.acceptance = MetropolisHastingsAcceptance() 387 | 388 | def __repr__(self): 389 | return 'HMC' 390 | 391 | def sample_chain(self): 392 | 393 | self.probmodel.reset_parameters() 394 | 395 | if self.pretrain: 396 | try: 397 | self.probmodel.pretrain() 398 | except: 399 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.') 400 | 401 | if self.tune: self.tune_step_size() 402 | 403 | self.chain = Chain(probmodel=self.probmodel) 404 | 405 | progress = tqdm(range(self.num_steps)) 406 | for step in progress: 407 | 408 | _ = self.propose() # values are added directly to self.chain 409 | 410 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t' 411 | for key, running_avg in self.chain.running_avgs.items(): 412 | desc += f' {key}: {running_avg.avg:.2f} ' 413 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}' 414 | progress.set_description(desc=desc) 415 | 416 | self.chain = self.chain[self.burn_in:] 417 | 418 | return self.chain 419 | 420 | def propose(self): 421 | ''' 422 | 1) sample momentum for each parameter 423 | 2) sample one minibatch for an entire trajectory 424 | 3) solve trajectory forward for self.traj_length steps 425 | ''' 426 | 427 | hamiltonian_solver = ['euler', 'leapfrog'][0] 428 | 429 | self.optim.sample_momentum() 430 | batch = next(self.probmodel.dataloader.__iter__()) # samples one minibatch from dataloader 431 | 432 | def closure(): 433 | ''' 434 | Computes the gradients once for batch 435 | ''' 436 | self.optim.zero_grad() 437 | log_prob = self.probmodel.log_prob(*batch) 438 | (-log_prob['log_prob']).backward() 439 | return log_prob 440 | 441 | if hamiltonian_solver=='leapfrog': log_prob = closure() # compute initial grads 442 | 443 | for traj_step in range(self.traj_length): 444 | if hamiltonian_solver=='euler': 445 | proposal_log_prob = closure() 446 | self.optim.step() 447 | elif hamiltonian_solver=='leapfrog': 448 | proposal_log_prob = self.optim.leapfrog_step(closure) 449 | 450 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob']) 451 | 452 | if not accept: 453 | if torch.isnan(proposal_log_prob['log_prob']): 454 | print(f"{proposal_log_prob=}") 455 | print(self.chain.state) 456 | exit() 457 | self.probmodel.load_state_dict(self.chain.state['state_dict']) 458 | 459 | self.chain += (self.probmodel, proposal_log_prob, accept) 460 | 461 | class SGNHT_Chain(Sampler_Chain): 462 | 463 | def __init__(self, probmodel, step_size=0.0001, num_steps=2000, burn_in=100, pretrain=False, tune=False, 464 | traj_length=20): 465 | 466 | # assert probmodel.log_prob().keys()[:3] == ['log_prob', 'data', ] 467 | 468 | Sampler_Chain.__init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune) 469 | 470 | self.traj_length = traj_length 471 | 472 | self.optim = SGNHT_Optim(self.probmodel, 473 | step_size=step_size, 474 | prior_std=1.) 475 | 476 | # print(f"{self.optim.A=}") 477 | # print(f"{self.optim.num_params=}") 478 | # print(f"{self.optim.A=}") 479 | # exit() 480 | 481 | # self.acceptance = SDE_Acceptance() 482 | self.acceptance = MetropolisHastingsAcceptance() 483 | 484 | def __repr__(self): 485 | return 'SGNHT' 486 | 487 | def sample_chain(self): 488 | 489 | self.probmodel.reset_parameters() 490 | 491 | if self.pretrain: 492 | try: 493 | self.probmodel.pretrain() 494 | except: 495 | warnings.warn(f'Tried pretraining but couldnt find a probmodel.pretrain() method ... Continuing wihtout pretraining.') 496 | 497 | if self.tune: self.tune_step_size() 498 | 499 | self.chain = Chain(probmodel=self.probmodel) 500 | self.optim.sample_momentum() 501 | self.optim.sample_thermostat() 502 | 503 | progress = tqdm(range(self.num_steps)) 504 | for step in progress: 505 | 506 | proposal_log_prob, sample = self.propose() 507 | accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob']) 508 | self.chain += (self.probmodel, proposal_log_prob, accept) 509 | 510 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t' 511 | for key, running_avg in self.chain.running_avgs.items(): 512 | desc += f' {key}: {running_avg.avg:.2f} ' 513 | desc += f'StepSize: {self.optim.param_groups[0]["step_size"]:.3f}' 514 | progress.set_description(desc=desc) 515 | 516 | self.chain = self.chain[self.burn_in:] 517 | 518 | return self.chain 519 | 520 | def propose(self): 521 | ''' 522 | 1) sample momentum for each parameter 523 | 2) sample one minibatch for an entire trajectory 524 | 3) solve trajectory forward for self.traj_length steps 525 | ''' 526 | 527 | hamiltonian_solver = ['euler', 'leapfrog'][0] 528 | 529 | # self.optim.sample_momentum() 530 | # self.optim.sample_thermostat() 531 | batch = next(self.probmodel.dataloader.__iter__()) # samples one minibatch from dataloader 532 | 533 | self.optim.zero_grad() 534 | proposal_log_prob = self.probmodel.log_prob(*batch) 535 | (-proposal_log_prob['log_prob']).backward() 536 | self.optim.step() 537 | 538 | # def closure(): 539 | # ''' 540 | # Computes the gradients once for batch 541 | # ''' 542 | # self.optim.zero_grad() 543 | # log_prob = self.probmodel.log_prob(*batch) 544 | # (-log_prob['log_prob']).backward() 545 | # return log_prob 546 | # 547 | # if hamiltonian_solver=='leapfrog': log_prob = closure() # compute initial grads 548 | # 549 | # for traj_step in range(self.traj_length): 550 | # if hamiltonian_solver=='euler': 551 | # proposal_log_prob = closure() 552 | # self.optim.step() 553 | # elif hamiltonian_solver=='leapfrog': 554 | # proposal_log_prob = self.optim.leapfrog_step(closure) 555 | # 556 | # accept, log_ratio = self.acceptance(proposal_log_prob['log_prob'], self.chain.state['log_prob']['log_prob']) 557 | # 558 | # if not accept: 559 | # if torch.isnan(proposal_log_prob['log_prob']): 560 | # print(f"{proposal_log_prob=}") 561 | # print(self.chain.state) 562 | # exit() 563 | # self.probmodel.load_state_dict(self.chain.state['state_dict']) 564 | 565 | return proposal_log_prob, self.probmodel 566 | 567 | -------------------------------------------------------------------------------- /src/MCMC_Optim.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm import tqdm 3 | from collections import MutableSequence 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.optim import Optimizer, SGD 11 | from torch.distributions.distribution import Distribution 12 | from torch.distributions import MultivariateNormal, Normal 13 | from torch.nn import Module 14 | 15 | class MCMC_Optim: 16 | 17 | def __init__(self): 18 | 19 | self.tune_params = {'delta': 0.65, 20 | 't0': 10, 21 | 'gamma': .05, 22 | 'kappa': .75, 23 | # 'mu': np.log(self.param_groups[0]["step_size"]), 24 | 'mu': 0., 25 | 'H': 0, 26 | 'log_eps': 1.} 27 | 28 | # print(f"@MCMC_Optim {self.tune_params=}") 29 | # exit() 30 | 31 | def tune(self, accepts): 32 | 33 | ''' 34 | PyMC: 35 | # Switch statement 36 | if acceptance < 0.001: 0.1 37 | elif acceptance < 0.05: 0.5 38 | elif acceptance < 0.2: 0.9 39 | elif acceptance > 0.95: 10 40 | elif acceptance > 0.75: 2 41 | elif acceptance > 0.5: 1.1 42 | ''' 43 | 44 | avg_acc = sum(accepts)/len(accepts) 45 | 46 | ''' 47 | Switch statement: the first condition that is met exits the switch statement 48 | ''' 49 | if avg_acc < 0.001: 50 | scale = 0.1 51 | elif avg_acc < 0.05: 52 | scale = 0.5 53 | elif avg_acc < 0.20: 54 | # PyMC: 0.9 55 | scale = 0.5 56 | 57 | elif avg_acc > 0.99: 58 | scale = 10. 59 | elif avg_acc > 0.75: 60 | scale = 2. 61 | elif avg_acc > 0.5: 62 | # PyMC: 1.1 63 | scale = 1.1 64 | else: 65 | scale = 0.9 66 | 67 | for group in self.param_groups: 68 | 69 | group['step_size']*=scale 70 | # print(f'{avg_acc=:.3f} & {scale=} -> {group["lr"]=:.3f}') 71 | 72 | def dual_average_tune(self, accepts, t, alpha): 73 | ''' 74 | NUTS Sampler p.17 : Algorithm 5 75 | 76 | mu = log(10 * initial_step_size) 77 | 78 | H_m : running difference between target acceptance rate and current acceptance rate 79 | delta : target acceptance rate 80 | alpha : (singular) current acceptance rate 81 | 82 | log_eps = mu - t**0.5 / gamma H_m 83 | running_log_eps = t**(-kappa) log_eps + (1 - t**(-kappa)) running_log_eps 84 | ''' 85 | 86 | # accept_ratio = sum(accepts)/len(accepts) 87 | assert 0 KFAC -> Optimization -> PHD 25 | 26 | import scipy 27 | import scipy as sp 28 | from scipy.io import loadmat as sp_loadmat 29 | import copy 30 | 31 | class ProbModel(torch.nn.Module): 32 | 33 | ''' 34 | ProbModel: 35 | 36 | ''' 37 | 38 | def __init__(self, dataloader): 39 | super().__init__() 40 | assert isinstance(dataloader, torch.utils.data.DataLoader) 41 | self.dataloader = dataloader 42 | 43 | 44 | def log_prob(self): 45 | ''' 46 | If minibatches have to be sampled due to memory constraints, 47 | a standard PyTorch dataloader can be used. 48 | "Infinite minibatch sampling" can be achieved by calling: 49 | data, target = next(dataloader.__iter__()) 50 | next(Iterable.__iter__()) calls a single mini-batch sampling step 51 | But since it's not in a loop, we can call it add infinum 52 | ''' 53 | raise NotImplementedError 54 | 55 | def sample_minibatch(self): 56 | ''' 57 | Idea: 58 | Hybrid Monte Carlo Samplers require a constant tuple (data, target) to compute trajectories 59 | ''' 60 | raise NotImplementedError 61 | 62 | def reset_parameters(self): 63 | raise NotImplementedError 64 | 65 | def predict(self, chain): 66 | raise NotImplementedError 67 | 68 | def pretrain(self): 69 | 70 | pass -------------------------------------------------------------------------------- /src/MCMC_Sampler.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys, copy, time 3 | from tqdm import tqdm 4 | from collections import MutableSequence, OrderedDict 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | from pytorch_MCMC.src.MCMC_Utils import posterior_dist 11 | from pytorch_MCMC.src.MCMC_Optim import SGLD_Optim, MetropolisHastings_Optim 12 | from pytorch_MCMC.src.MCMC_Acceptance import MetropolisHastingsAcceptance 13 | from pytorch_MCMC.src.MCMC_Chain import Chain, SGLD_Chain, MALA_Chain, HMC_Chain, SGNHT_Chain 14 | from pytorch_MCMC.src.MCMC_Acceptance import MetropolisHastingsAcceptance, SDE_Acceptance 15 | from pytorch_MCMC.src.MCMC_ProbModel import ProbModel 16 | 17 | from joblib import Parallel, delayed 18 | import concurrent.futures 19 | 20 | 21 | import torch 22 | from torch.nn import Parameter 23 | from torch.optim import SGD 24 | from torch.optim.lr_scheduler import StepLR 25 | from torch.distributions.distribution import Distribution 26 | from torch.distributions import MultivariateNormal, Normal 27 | from torch.nn import Module 28 | 29 | 30 | class Sampler: 31 | 32 | def __init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune): 33 | 34 | self.probmodel = probmodel 35 | self.chain = None 36 | self.num_chains = num_chains 37 | 38 | self.step_size = step_size 39 | self.num_steps = num_steps 40 | self.burn_in = burn_in 41 | 42 | self.pretrain = pretrain 43 | self.tune = tune 44 | 45 | test_log_prob = self.probmodel.log_prob(*next(self.probmodel.dataloader.__iter__())) 46 | assert type(test_log_prob) == dict 47 | assert list(test_log_prob.keys())[0]=='log_prob' 48 | 49 | def sample_chains(self): 50 | raise NotImplementedError 51 | 52 | def __str__(self): 53 | raise NotImplementedError 54 | 55 | def multiprocessing_test(self, wait_time): 56 | 57 | time.sleep(wait_time) 58 | print(f'Done after {wait_time=} seconds') 59 | 60 | def sample_independent_chain(self): 61 | 62 | probmodel = copy.deepcopy(self.probmodel) 63 | probmodel.reset_parameters() 64 | 65 | if self.pretrain: 66 | probmodel.pretrain() 67 | 68 | optim = SGLD_Optim(probmodel, step_size=self.step_size, prior_std=0., addnoise=True) 69 | chain = Chain(probmodel=probmodel) 70 | 71 | progress = tqdm(range(self.num_steps)) 72 | for step in progress: 73 | 74 | sample_log_prob, sample = self.propose(probmodel, optim) 75 | 76 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], chain.state['log_prob']) 77 | 78 | chain += (probmodel, sample_log_prob, accept, step) 79 | 80 | if not accept: 81 | probmodel.load_state_dict(chain.state['state']) 82 | 83 | desc = f'{str(self)}: Accept: {chain.running_accepts.avg:.2f}/{chain.accept_ratio:.2f} \t' 84 | for key, running_avg in chain.running_avgs.items(): 85 | # print(f'{key}: {running_avg.avg=}') 86 | desc += f' {key}: {running_avg.avg:.2f} ' 87 | desc += f'StepSize: {optim.param_groups[0]["lr"]:.3f}' 88 | progress.set_description(desc=desc) 89 | 90 | # print(list(chain.samples[-1].values())[-1][0]) 91 | ''' 92 | Remove Burn_in 93 | ''' 94 | assert len(chain.accepted_steps) > self.burn_in, f'{len(chain.accepted_steps)=} <= {self.burn_in=}' 95 | chain.accepted_steps = chain.accepted_steps[self.burn_in:] 96 | 97 | return chain 98 | 99 | def sample_chain(self, step_size=None): 100 | 101 | if self.pretrain: 102 | self.probmodel.pretrain() 103 | 104 | self.optim = SGLD_Optim(self.probmodel, 105 | step_size=step_size, 106 | prior_std=0., 107 | addnoise=True) 108 | 109 | if self.tune: self.tune_step_size() 110 | 111 | self.chain = Chain(probmodel=self.probmodel) 112 | 113 | progress = tqdm(range(self.num_steps)) 114 | for step in progress: 115 | 116 | sample_log_prob, sample = self.propose() 117 | 118 | accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], self.chain.state['log_prob']) 119 | 120 | self.chain += (self.probmodel, sample_log_prob, accept, step) 121 | 122 | if not accept: 123 | self.probmodel.load_state_dict(self.chain.state['state']) 124 | 125 | desc = f'{str(self)}: Accept: {self.chain.running_accepts.avg:.2f}/{self.chain.accept_ratio:.2f} \t' 126 | for key, running_avg in self.chain.running_avgs.items(): 127 | # print(f'{key}: {running_avg.avg=}') 128 | desc+= f' {key}: {running_avg.avg:.2f} ' 129 | desc += f'StepSize: {self.optim.param_groups[0]["lr"]:.3f}' 130 | progress.set_description(desc=desc) 131 | 132 | print(len(self.chain)) 133 | ''' 134 | Remove Burn_in 135 | ''' 136 | assert len(self.chain.accepted_steps)>self.burn_in, f'{len(self.chain.accepted_steps)=} <= {self.burn_in=}' 137 | self.chain.accepted_steps = self.chain.accepted_steps[self.burn_in:] 138 | 139 | def posterior_dist(self, param=None, verbose=False, plot=True): 140 | 141 | if len(self.probmodel.state_dict())==1: 142 | ''' 143 | We're sampling from a predefined distribution like a GMM and simulating a particle 144 | ''' 145 | post = [] 146 | 147 | accepted_models = self.chain.samples 148 | for model_state_dict in accepted_models: 149 | post.append(list(model_state_dict.values())[0]) 150 | 151 | post = torch.cat(post, dim=0) 152 | 153 | if plot: 154 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]), 155 | density=True) 156 | plt.colorbar(hist2d[3]) 157 | plt.show() 158 | 159 | elif len(self.probmodel.state_dict()) > 1: 160 | ''' 161 | There is more than one parameter in the model 162 | ''' 163 | 164 | param_names = list(self.probmodel.state_dict().keys()) 165 | accepted_models = self.chain.samples 166 | 167 | for param_name in param_names: 168 | 169 | post = [] 170 | 171 | for model_state_dict in accepted_models: 172 | 173 | post.append(model_state_dict[param_name]) 174 | 175 | post = torch.cat(post) 176 | # print(post) 177 | 178 | if plot: 179 | plt.hist(x=post, bins=50, 180 | range=np.array([-3, 3]), 181 | density=True, 182 | alpha=0.5) 183 | plt.title(param_name) 184 | plt.show() 185 | 186 | def trace(self, param=None, verbose=False, plot=True): 187 | 188 | if len(self.probmodel.state_dict()) >= 1: 189 | ''' 190 | There is more than one parameter in the model 191 | ''' 192 | 193 | param_names = list(self.probmodel.state_dict().keys()) 194 | accepted_models = [self.chain.samples[idx] for idx in self.chain.accepted_steps] 195 | 196 | for param_name in param_names: 197 | 198 | post = [] 199 | 200 | for model_state_dict in accepted_models: 201 | post.append(model_state_dict[param_name]) 202 | 203 | # print(post) 204 | 205 | post = torch.cat(post) 206 | # print(post) 207 | 208 | if plot: 209 | plt.plot(np.arange(len(accepted_models)), post) 210 | plt.title(param_name) 211 | 212 | plt.show() 213 | 214 | class MetropolisHastings_Sampler(Sampler): 215 | 216 | def __init__(self, probmodel, step_size=1., num_steps=10000, burn_in=100, pretrain=False, tune=True): 217 | ''' 218 | 219 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample 220 | :param step_length: 221 | :param num_steps: 222 | :param burn_in: 223 | ''' 224 | 225 | assert isinstance(probmodel, ProbModel) 226 | super().__init__(probmodel, step_size, num_steps, burn_in, pretrain, tune) 227 | 228 | self.optim = MetropolisHastings_Optim(self.probmodel, 229 | step_length=step_size) 230 | 231 | self.acceptance = MetropolisHastingsAcceptance() 232 | 233 | def __str__(self): 234 | return 'MH' 235 | 236 | @torch.no_grad() 237 | def propose(self): 238 | self.optim.step() 239 | log_prob = self.probmodel.log_prob() 240 | 241 | return log_prob, self.probmodel 242 | 243 | class SGLD_Sampler(Sampler): 244 | 245 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True): 246 | ''' 247 | 248 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample 249 | :param step_length: 250 | :param num_steps: 251 | :param burn_in: 252 | ''' 253 | 254 | assert isinstance(probmodel, ProbModel) 255 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune) 256 | 257 | def sample_chains(self): 258 | 259 | if self.num_chains > 1: 260 | self.parallel_chains = [SGLD_Chain(copy.deepcopy(self.probmodel), 261 | step_size=self.step_size, 262 | num_steps=self.num_steps, 263 | burn_in=self.burn_in, 264 | pretrain=self.pretrain, 265 | tune=False) 266 | for _ in range(self.num_chains)] 267 | 268 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains) 269 | 270 | elif self.num_chains == 1: 271 | chain = SGLD_Chain(copy.deepcopy(self.probmodel), 272 | step_size=self.step_size, 273 | num_steps=self.num_steps, 274 | burn_in=self.burn_in, 275 | pretrain=self.pretrain, 276 | tune=False) 277 | chains = [chain.sample_chain()] 278 | 279 | self.chain = Chain(probmodel=self.probmodel) 280 | 281 | for chain in chains: 282 | self.chain += chain 283 | 284 | return chains 285 | 286 | def __str__(self): 287 | return 'SGLD' 288 | 289 | class MALA_Sampler(Sampler): 290 | 291 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=4, burn_in=500, pretrain=True, tune=True): 292 | ''' 293 | 294 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample 295 | :param step_length: 296 | :param num_steps: 297 | :param burn_in: 298 | ''' 299 | 300 | assert isinstance(probmodel, ProbModel) 301 | super().__init__(probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune) 302 | 303 | def sample_chains(self): 304 | 305 | if self.num_chains>1: 306 | self.parallel_chains = [MALA_Chain(copy.deepcopy(self.probmodel), 307 | step_size=self.step_size, 308 | num_steps=self.num_steps, 309 | burn_in=self.burn_in, 310 | pretrain=self.pretrain, 311 | tune=self.tune, 312 | num_chain=i) 313 | for i in range(self.num_chains)] 314 | 315 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains) 316 | 317 | elif self.num_chains == 1: 318 | chain = MALA_Chain(copy.deepcopy(self.probmodel), 319 | step_size=self.step_size, 320 | num_steps=self.num_steps, 321 | burn_in=self.burn_in, 322 | pretrain=self.pretrain, 323 | tune=self.tune, 324 | num_chain=0) 325 | chains = [chain.sample_chain()] 326 | 327 | self.chain = Chain(probmodel=self.probmodel) 328 | 329 | 330 | for chain in chains: 331 | self.chain += chain 332 | 333 | return chains 334 | 335 | def __str__(self): 336 | return 'SGLD' 337 | 338 | class HMC_Sampler(Sampler): 339 | 340 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True, 341 | traj_length=21): 342 | ''' 343 | 344 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample 345 | :param step_length: 346 | :param num_steps: 347 | :param burn_in: 348 | ''' 349 | 350 | assert isinstance(probmodel, ProbModel) 351 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune) 352 | 353 | self.traj_length = traj_length 354 | 355 | def __str__(self): 356 | return 'HMC' 357 | 358 | def sample_chains(self): 359 | 360 | if self.num_chains > 1: 361 | self.parallel_chains = [HMC_Chain(copy.deepcopy(self.probmodel), 362 | step_size=self.step_size, 363 | num_steps=self.num_steps, 364 | burn_in=self.burn_in, 365 | pretrain=self.pretrain, 366 | tune=self.tune) 367 | for i in range(self.num_chains)] 368 | 369 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains) 370 | 371 | elif self.num_chains == 1: 372 | chain = HMC_Chain(copy.deepcopy(self.probmodel), 373 | step_size=self.step_size, 374 | num_steps=self.num_steps, 375 | burn_in=self.burn_in, 376 | pretrain=self.pretrain, 377 | tune=self.tune) 378 | chains = [chain.sample_chain()] 379 | 380 | self.chain = Chain(probmodel=self.probmodel) # the aggregating chain 381 | 382 | for chain in chains: 383 | self.chain += chain 384 | 385 | return chains 386 | 387 | class SGNHT_Sampler(Sampler): 388 | 389 | def __init__(self, probmodel, step_size=0.01, num_steps=10000, num_chains=7, burn_in=500, pretrain=True, tune=True, 390 | traj_length=21): 391 | ''' 392 | 393 | :param probmodel: Probmodel() that implements forward, log_prob, prob and sample 394 | :param step_length: 395 | :param num_steps: 396 | :param burn_in: 397 | ''' 398 | 399 | assert isinstance(probmodel, ProbModel) 400 | Sampler.__init__(self, probmodel, step_size, num_steps, num_chains, burn_in, pretrain, tune) 401 | 402 | self.traj_length = traj_length 403 | 404 | def __str__(self): 405 | return 'SGNHT' 406 | 407 | def sample_chains(self): 408 | 409 | if self.num_chains > 1: 410 | self.parallel_chains = [SGNHT_Chain(copy.deepcopy(self.probmodel), 411 | step_size=self.step_size, 412 | num_steps=self.num_steps, 413 | burn_in=self.burn_in, 414 | pretrain=self.pretrain, 415 | tune=self.tune) 416 | for i in range(self.num_chains)] 417 | 418 | chains = Parallel(n_jobs=self.num_chains)(delayed(chain.sample_chain)() for chain in self.parallel_chains) 419 | 420 | elif self.num_chains == 1: 421 | chain = SGNHT_Chain(copy.deepcopy(self.probmodel), 422 | step_size=self.step_size, 423 | num_steps=self.num_steps, 424 | burn_in=self.burn_in, 425 | pretrain=self.pretrain, 426 | tune=self.tune) 427 | chains = [chain.sample_chain()] 428 | 429 | self.chain = Chain(probmodel=self.probmodel) # the aggregating chain 430 | 431 | for chain in chains: 432 | self.chain += chain 433 | 434 | return chains -------------------------------------------------------------------------------- /src/MCMC_Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | def posterior_dist(chain, param=None, verbose=False, plot=True): 6 | 7 | if len(chain.samples[0]) == 1: 8 | ''' 9 | We're sampling from a predefined distribution like a GMM and simulating a particle 10 | ''' 11 | post = [] 12 | 13 | # print(list(self.probmodel.state_dict().values())[0]) 14 | # exit() 15 | 16 | # accepted_models = [chain.samples[idx] for idx in chain.accepted_steps] 17 | for model_state_dict in chain.samples: 18 | post.append(list(model_state_dict.values())[0]) 19 | 20 | post = torch.cat(post, dim=0) 21 | 22 | if plot: 23 | hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]), 24 | density=True) 25 | plt.colorbar(hist2d[3]) 26 | plt.show() 27 | 28 | elif len(chain.samples[0]) > 1: 29 | ''' 30 | There is more than one parameter in the model 31 | ''' 32 | 33 | param_names = list(chain.samples[0].keys()) 34 | accepted_models = [chain.samples[idx] for idx in chain.accepted_idxs] 35 | 36 | for param_name in param_names: 37 | 38 | post = [] 39 | 40 | for model_state_dict in accepted_models: 41 | post.append(model_state_dict[param_name]) 42 | 43 | post = torch.cat(post) 44 | # print(post) 45 | 46 | if plot: 47 | plt.hist(x=post, bins=50, 48 | range=np.array([-3, 3]), 49 | density=True, 50 | alpha=0.5) 51 | plt.title(param_name) 52 | plt.show() -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Test Test Test --------------------------------------------------------------------------------