├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── stein_lib ├── .gitignore ├── __init__.py ├── models │ ├── __init__.py │ ├── bhm.py │ ├── double_banana_analytic.py │ ├── gaussian_mixture.py │ └── obstacles_2D │ │ ├── __init__.py │ │ ├── map_generator.py │ │ ├── obs_map.py │ │ ├── obs_utils.py │ │ └── rbf_map.py ├── prm_utils.py ├── svgd │ ├── LBFGS.py │ ├── __init__.py │ ├── base_kernels.py │ ├── composite_kernels.py │ ├── matrix_svgd │ │ ├── __init__.py │ │ ├── base_kernels.py │ │ ├── matrix_mix_svgd.py │ │ ├── matrix_svgd.py │ │ ├── mp_composite_kernels.py │ │ └── mp_matrix_svgd.py │ ├── mp_composite_kernels.py │ ├── mp_svgd.py │ ├── priors.py │ └── svgd.py ├── svn │ ├── __init__.py │ ├── mp_svn.py │ ├── svn.py │ └── svn_original.py └── utils.py └── tests ├── .gitignore ├── bayesian_hilbert_map.py ├── double_banana.py └── gaussian_mixture.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.mp4 3 | *.orig 4 | *.png 5 | *.jpg 6 | *.jpeg 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alexander Lambert, Brian Hou, Rosario Scalise, Siddhartha S. Srinivasa, Byron Boots 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stein_lib 2 | 3 | Library for Stein-Variational Gradient Descent methods. 4 | 5 | ## References 6 | 7 | [1] Lambert, A.; Hou, B.; Scalise. R.; Srinivasa, S.; Boots, B. (2022). Stein Variational Probabilistic Roadmaps, IEEE/RSJ International Conference on Intelligent Robots and Systems (ICRA). 8 | 9 | If you found this code useful, please cite us! 10 | ``` 11 | @inproceedings{icra2022_steinprm, 12 | title={Stein Variational Probabilistic Roadmaps}, 13 | author={Lambert, Alexander and Hou, Brian and Scalise, Rosario and Srinivasa, Siddhartha S and Boots, Byron}, 14 | booktitle={2022 International Conference on Robotics and Automation (ICRA)}, 15 | pages={11094--11101}, 16 | year={2022}, 17 | organization={IEEE} 18 | } 19 | 20 | ``` 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from setuptools import setup, find_packages 4 | 5 | 6 | if sys.version_info.major != 3: 7 | print("This Python is only compatible with Python 3, but you are running " 8 | "Python {}. The installation will likely fail.".format(sys.version_info.major)) 9 | 10 | setup( 11 | name='steinlib', 12 | version='1.0.0', 13 | packages=find_packages(exclude=('results*', '*results', 'tests')), 14 | description='Stein Variational Inference Library', 15 | url='https://github.com/sashalambert/stein_lib.git', 16 | author='Sasha Alexander Lambert', 17 | install_requires=[ 18 | 'numpy', 19 | 'torch', 20 | 'pyro-ppl', 21 | 'matplotlib', 22 | 'bhmlib @ git+ssh://git@github.com/sashalambert/Bayesian_Hilbert_Maps@devel', 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /stein_lib/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__/ 3 | ../.idea/ 4 | *.mp4 5 | *.jpg 6 | *.png 7 | *.pdf 8 | *.pyc 9 | -------------------------------------------------------------------------------- /stein_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/__init__.py -------------------------------------------------------------------------------- /stein_lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/models/__init__.py -------------------------------------------------------------------------------- /stein_lib/models/bhm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | from pathlib import Path 27 | from bhmlib.BHM.pytorch.bhmtorch_cpu import BHM2D_PYTORCH 28 | 29 | 30 | class BayesianHilbertMap: 31 | def __init__( 32 | self, 33 | file_path=None, 34 | limits=((-10, 20,), (-25, 5)), 35 | ): 36 | 37 | # Load trained Bayesian Hilbert Map 38 | params = torch.load(file_path) 39 | self.bhm = BHM2D_PYTORCH(torch_kernel_func=True, **params) 40 | self.limits = torch.tensor(limits) 41 | 42 | def log_prob(self, x): 43 | log_p = self.bhm.log_prob_vacancy(x) 44 | if self.limits is not None: 45 | scale = 1. 46 | log_p -= torch.exp(-scale*(x[:, 0] - self.limits[0, 0])) 47 | log_p -= torch.exp( scale*(x[:, 0] - self.limits[0, 1])) 48 | log_p -= torch.exp(-scale*(x[:, 1] - self.limits[1, 0])) 49 | log_p -= torch.exp( scale*(x[:, 1] - self.limits[1, 1])) 50 | return log_p 51 | 52 | def grad_log_p(self, x): 53 | return self.bhm.grad_log_p_vacancy(x) 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | import bhmlib 59 | bhm_path = Path(bhmlib.__path__[0]).resolve() 60 | model_file = bhm_path / 'Outputs' / 'saved_models' / 'bhm_intel_res0.25_iter010.pt' 61 | 62 | bhm = BayesianHilbertMap(model_file) 63 | 64 | -------------------------------------------------------------------------------- /stein_lib/models/double_banana_analytic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | 27 | class doubleBanana_analytic: 28 | """ 29 | Bi-modal posterior distribution with analytic derivative terms. 30 | """ 31 | def __init__( 32 | self, 33 | mu_n=0., 34 | seed=0, 35 | prior_var=1., 36 | obs_var=0.3**2, 37 | # a=1, 38 | a=0, 39 | b=100, 40 | ): 41 | 42 | self.dim = 2 43 | self.a = a 44 | self.b = b 45 | # Prior prob. params 46 | self.mu_0 = torch.zeros((self.dim, 1)) 47 | self.var_0 = prior_var * torch.ones((self.dim, 1)) 48 | 49 | # Likelihood prob. params 50 | self.mu_n = mu_n 51 | self.var_n = obs_var 52 | 53 | torch.manual_seed(seed) 54 | 55 | self.thetaTrue = np.random.normal(size=self.dim) 56 | 57 | def forward_model(self, x): 58 | """ 59 | Observation function. 60 | 61 | Parameters 62 | ---------- 63 | x : (Tensor) 64 | Tensor of 2D samples, with shape [2, num_samples] 65 | 66 | Returns 67 | ------- 68 | F : (Tensor) 69 | Tensor of 1-D function values, with shape [1, num_samples] 70 | """ 71 | assert x.dim() == 2 and x.shape[0] == self.dim 72 | return torch.log( ( self.a - x[0] )**2 + self.b * ( x[1] - x[0]**2 )**2 ) 73 | 74 | def log_lh(self, x, F=None): 75 | """ 76 | Returns the log-likelihood probability densities. 77 | 78 | Parameters 79 | ---------- 80 | x : (Tensor) 81 | Tensor of 2D samples, with shape [2, num_samples] 82 | F : (Tensor) 83 | (Optional) Function evaluations of samples, F = F(x). With shape 84 | [1, num_samples] 85 | 86 | Returns 87 | ------- 88 | log_prob:(Tensor) 89 | Tensor of observation probabilities, with shape [1, num_samples] 90 | """ 91 | if F is None: 92 | F = self.forward_model(x) 93 | F = F.reshape(1, -1) 94 | return - 0.5 * torch.sum( (self.mu_n - F) ** 2, dim=0) / self.var_n 95 | 96 | def jacob_forward(self, x): 97 | """ 98 | Jacobian of the forward model. 99 | 100 | Parameters 101 | ---------- 102 | x : (Tensor) 103 | Tensor of 2D samples, with shape [2, num_samples] 104 | 105 | Returns 106 | ------- 107 | J : (Tensor) 108 | Jacobian tensor, with shape [2, num_samples] 109 | """ 110 | J = torch.stack( ( 111 | (2 * ( x[0, :] - self.a - 2 * self.b * x[0, :] * (x[1, :] - x[0, :] ** 2) ) ) \ 112 | / ( 1 + x[0, :] ** 2 - 2 * x[0, :] + self.b * (x[1, :] - x[0, :] ** 2) ** 2 ), 113 | ( 2 * self.b * (x[1, :] - x[0, :] ** 2) ) \ 114 | / ( 1 + x[0, :] ** 2 - 2 * x[0, :] + self.b * (x[1, :] - x[0, :] ** 2) ** 2 ) 115 | ) ) 116 | return J 117 | 118 | def grad_log_lh(self, x, F=None, J=None): 119 | """ 120 | Gradient of the log likelihood. 121 | """ 122 | if F is None: 123 | F = self.forward_model(x).reshape(1,-1) 124 | if J is None: 125 | J = self.jacob_forward(x) 126 | return - J * (F - self.mu_n) / self.var_n 127 | 128 | def log_prior(self, x): 129 | """ 130 | Returns the log-prior probability densities. 131 | 132 | Parameters 133 | ---------- 134 | x : (Tensor) 135 | Tensor of 2D samples, with shape [2, num_samples] 136 | Returns 137 | ------- 138 | 139 | """ 140 | return - 0.5 * torch.sum( (x - self.mu_0) ** 2 / self.var_0, dim=0) 141 | 142 | def grad_log_prior(self, x): 143 | """ 144 | Gradient of the log prior. 145 | """ 146 | return - (x - self.mu_0) / self.var_0 147 | 148 | def log_prob(self, x): 149 | """ 150 | Returns the log-posterior probability densities, without 151 | log_partition term. 152 | Parameters 153 | ---------- 154 | x : (Tensor) 155 | Tensor of 2D samples, with shape [2, num_samples] 156 | Returns 157 | ------- 158 | """ 159 | return self.log_prior(x) + self.log_lh(x) 160 | 161 | def grad_log_p(self, x, F=None, J=None): 162 | """ 163 | Gradient of the log posterior. 164 | """ 165 | return self.grad_log_prior(x) + self.grad_log_lh(x, F, J) 166 | 167 | def hessian(self, x, J=None): 168 | """ 169 | Gauss-Newton Hessian approximation of the log posterior. 170 | 171 | Parameters 172 | ---------- 173 | x : (Tensor) 174 | Tensor of 2D samples, with shape [2, num_samples] 175 | 176 | J : (Tensor) 177 | Jacobian of the forward model, with shape [ 178 | """ 179 | if J is None: 180 | J = self.jacob_forward(x) 181 | 182 | Hess = J.reshape(self.dim, 1, -1) * J.reshape(1, self.dim, -1) / self.var_n \ 183 | + (torch.eye(self.dim) / self.var_0).unsqueeze(2) 184 | return -1. * Hess 185 | 186 | 187 | # num_particles = 100 188 | # 189 | # prior_dist = Normal(loc=0., scale=1.) 190 | # particles_0 = prior_dist.sample((2, num_particles)) 191 | # 192 | # particles = torch.autograd.Variable( 193 | # particles_0, 194 | # requires_grad=True, 195 | # ) 196 | # 197 | -------------------------------------------------------------------------------- /stein_lib/models/gaussian_mixture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals 27 | 28 | 29 | class mixture_of_gaussians: 30 | 31 | def __init__( 32 | self, 33 | num_comp, 34 | mu_list, 35 | sigma_list, 36 | ): 37 | 38 | mus = torch.from_numpy(np.array(mu_list)) 39 | sigmas = torch.from_numpy(np.array(sigma_list)) 40 | 41 | mix_prior = 1./ num_comp 42 | mix_coeffs = torch.ones(num_comp) * mix_prior 43 | 44 | self.dist = MixtureOfDiagNormals(mus, sigmas, mix_coeffs) 45 | 46 | def log_prob(self, x): 47 | return self.dist.log_prob(x) 48 | 49 | def grad_log_p(self, x): 50 | x_ = torch.autograd.Variable(x, requires_grad=True) 51 | dlog_p = torch.autograd.grad( 52 | self.log_prob(x_).sum(), 53 | x_, 54 | )[0] 55 | return dlog_p -------------------------------------------------------------------------------- /stein_lib/models/obstacles_2D/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/models/obstacles_2D/__init__.py -------------------------------------------------------------------------------- /stein_lib/models/obstacles_2D/map_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from obs_map import ObstacleRectangle, ObstacleMap 29 | from rbf_map import RBF_map 30 | from obs_utils import random_rect, save_map_image 31 | 32 | 33 | def generate_obstacle_map( 34 | map_dim=(10,10), 35 | obst_list=[], 36 | cell_size=1., 37 | start_pts=None, 38 | goal_pts=None, 39 | seed=0, 40 | device=None, 41 | random_gen=False, 42 | num_obst=0, 43 | rand_xy_limits=None, 44 | rand_shape=[2,2], 45 | map_type=None, 46 | plot=False, 47 | delta=0.5, 48 | sigma=0.5, 49 | ): 50 | 51 | """ 52 | Args 53 | --- 54 | map_dim : (int,int) 55 | 2D tuple containing dimensions of obstacle/occupancy grid. 56 | Treat as [x,y] coordinates. Origin is in the center. 57 | ** Dimensions must be an even number. ** 58 | cell_sz : float 59 | size of each square map cell 60 | obst_list : [(cx_i, cy_i, width, height)] 61 | List of obstacle param tuples 62 | start_pts : float 63 | Array of x-y points for start configuration. 64 | Dim: [Num. of points, 2] 65 | goal_pts : float 66 | Array of x-y points for target configuration. 67 | Dim: [Num. of points, 2] 68 | seed : int or None 69 | random_gen : bool 70 | Specify whether to generate random obstacles. Will first generate obstacles provided by obst_list, 71 | then add random obstacles until number specified by num_obst. 72 | num_obst : int 73 | Total number of obstacles 74 | rand_limit: [[float, float],[float, float]] 75 | List defining x-y sampling bounds [[x_min, x_max], [y_min, y_max]] 76 | rand_shape: [float, float] 77 | Shape [width, height] of randomly generated obstacles. 78 | """ 79 | ## Make occpuancy grid 80 | obst_map = ObstacleMap(map_dim, cell_size, device=device) 81 | num_fixed = len(obst_list) 82 | for param in obst_list: 83 | cx, cy, width, height = param 84 | rect = ObstacleRectangle(cx,cy,width,height) 85 | ## Check validity of new obstacle 86 | # valid = rect._obstacle_collision_check(obst_map) 87 | # rect._point_collision_check(obst_map,start_pts) & \ 88 | # rect._point_collision_check(obst_map,goal_pts) 89 | rect._add_to_map(obst_map) 90 | 91 | ## Add random obstacles 92 | if random_gen: 93 | # random.seed(seed) 94 | assert num_fixed <= num_obst, "Total number of obstacles must be greater than or equal to number specified in obst_list" 95 | xlim = rand_xy_limits[0] 96 | ylim = rand_xy_limits[1] 97 | width = rand_shape[0] 98 | height = rand_shape[1] 99 | for _ in range(num_obst - num_fixed + 1): 100 | num_attempts = 0 101 | max_attempts = 25 102 | while num_attempts <= max_attempts: 103 | rect = random_rect(xlim, ylim, width, height) 104 | 105 | # Check validity of new obstacle 106 | valid = rect._obstacle_collision_check(obst_map) 107 | # rect._point_collision_check(obst_map,start_pts) & \ 108 | # rect._point_collision_check(obst_map,goal_pts) 109 | 110 | if valid: 111 | # Add to Map 112 | rect._add_to_map(obst_map) 113 | # Add to list 114 | obst_list.append([rect.center_x,rect.center_y, 115 | rect.width, rect.height]) 116 | break 117 | 118 | if num_attempts == max_attempts: 119 | print("Obstacle generation: Max. number of attempts reached. ") 120 | print("Total num. obstacles: {}. Num. random obstacles: {}.\n" 121 | .format( len(obst_list), len(obst_list) - num_fixed)) 122 | 123 | num_attempts += 1 124 | 125 | obst_map.convert_map() 126 | 127 | ## Fit mapping model 128 | if map_type == 'RBF': 129 | print("Generating RBF map...\n") 130 | rbf_map = RBF_map(obst_map.xlim, obst_map.ylim, delta=delta, sigma=sigma, device=device) 131 | rbf_map.fit_to_obst_map(obst_map, plot=plot) 132 | return rbf_map 133 | 134 | elif map_type == 'direct': 135 | return obst_map 136 | 137 | else: 138 | raise IOError('Map type "{}" not recognized'.format(map_type)) 139 | return obst_list 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | import sys 145 | import numpy 146 | numpy.set_printoptions(threshold=sys.maxsize) 147 | obst_list = [(0, 0, 4, 8)] 148 | cell_size = 0.1 149 | map_dim = [20, 20] 150 | seed = 0 151 | 152 | obst_map = generate_obstacle_map( 153 | map_dim, obst_list, cell_size, 154 | map_type='RBF', 155 | random_gen=True, 156 | # random_gen=False, 157 | num_obst=5, 158 | rand_xy_limits=[[-10, 10], [-10, 10]], 159 | rand_shape=[2,2], 160 | plot=True, 161 | ) 162 | 163 | traj_y = torch.linspace(-map_dim[1]/2., map_dim[1]/2., 20) 164 | traj_x = torch.zeros_like(traj_y) 165 | X = torch.cat((traj_x.unsqueeze(1), traj_y.unsqueeze(1)), dim=1) 166 | cost = obst_map.eval(X) 167 | print(cost) 168 | 169 | # Get cost: 170 | 171 | # obst_map = generate_obstacle_map( 172 | # map_dim, obst_list, cell_size, 173 | # map_type='direct', 174 | # seed=seed, 175 | # random_gen=True, 176 | # num_obst=5, 177 | # rand_xy_limits=[[-10, 10], [-10, 10]], 178 | # rand_shape=[2,2], 179 | # ) 180 | # obst_map.plot() 181 | # save_map_image(obst_map.map,dir='/tmp') 182 | 183 | # w_obs = 1. 184 | # traj_y = torch.linspace(-map_dim[1]/2., map_dim[1]/2., 20) 185 | # traj_x = torch.zeros_like(traj_y) 186 | # X = torch.stack([traj_x, traj_y], dim=1) 187 | # X = X.repeat(2,1,1) 188 | # obst_cost, occ_values = obst_map.get_collision_cost(w_obs, X.double(), device=torch.device('cpu')) 189 | # print(X) 190 | -------------------------------------------------------------------------------- /stein_lib/models/obstacles_2D/obs_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | import numpy as np 26 | import torch 27 | import matplotlib.pyplot as plt 28 | from math import ceil 29 | from abc import ABC, abstractmethod 30 | import os.path as osp 31 | from copy import deepcopy 32 | 33 | 34 | class Obstacle(ABC): 35 | """ 36 | Base 2D Obstacle class 37 | """ 38 | 39 | def __init__(self,center_x,center_y): 40 | self.center_x = int(center_x) 41 | self.center_y = int(center_y) 42 | 43 | @abstractmethod 44 | def _obstacle_collision_check(self, obst_map): 45 | pass 46 | 47 | @abstractmethod 48 | def _point_collision_check(self, obst_map, pts): 49 | pass 50 | 51 | @abstractmethod 52 | def _add_to_map(self, obst_map): 53 | pass 54 | 55 | 56 | class ObstacleRectangle(Obstacle): 57 | """ 58 | Derived 2D rectangular Obstacle class 59 | """ 60 | 61 | def __init__( 62 | self, 63 | center_x=0, 64 | center_y=0, 65 | width=None, 66 | height=None, 67 | ): 68 | super().__init__(center_x, center_y) 69 | self.width = width 70 | self.height = height 71 | 72 | def _obstacle_collision_check(self, obst_map): 73 | valid=True 74 | obst_map_test = self._add_to_map(deepcopy(obst_map)) 75 | if (np.any( obst_map_test.map > 1)): 76 | valid=False 77 | return valid 78 | 79 | def _point_collision_check(self,obst_map,pts): 80 | valid=True 81 | if pts is not None: 82 | obst_map_test = self._add_to_map(np.copy(obst_map)) 83 | for pt in pts: 84 | if (obst_map_test[ ceil(pt[0]), ceil(pt[1])] == 1): 85 | valid=False 86 | break 87 | return valid 88 | 89 | def _add_to_map(self, obst_map): 90 | # Convert dims to cell indicies 91 | w = ceil(self.width / obst_map.cell_size) 92 | h = ceil(self.height / obst_map.cell_size) 93 | c_x = ceil(self.center_x / obst_map.cell_size) 94 | c_y = ceil(self.center_y / obst_map.cell_size) 95 | 96 | obst_map.map[ 97 | c_y - ceil(h/2.) + obst_map.origin_yi: 98 | c_y + ceil(h/2.) + obst_map.origin_yi, 99 | c_x - ceil(w/2.) + obst_map.origin_xi: 100 | c_x + ceil(w/2.) + obst_map.origin_xi, 101 | ] = 1 102 | # ] += 1 103 | return obst_map 104 | 105 | 106 | class ObstacleMap : 107 | """ 108 | Generates an occupancy grid. 109 | """ 110 | def __init__(self, map_dim, cell_size, device=None): 111 | 112 | assert map_dim[0] % 2 == 0 113 | assert map_dim[1] % 2 == 0 114 | 115 | cmap_dim = [0,0] 116 | cmap_dim[0] = ceil(map_dim[0]/cell_size) 117 | cmap_dim[1] = ceil(map_dim[1]/cell_size) 118 | 119 | self.map = np.zeros(cmap_dim) 120 | self.cell_size = cell_size 121 | 122 | # Map center (in cells) 123 | self.origin_xi = int(cmap_dim[0]/2) 124 | self.origin_yi = int(cmap_dim[1]/2) 125 | 126 | # self.xlim = map_dim[0] 127 | 128 | self.x_dim, self.y_dim = self.map.shape 129 | x_range = self.cell_size * self.x_dim 130 | y_range = self.cell_size * self.y_dim 131 | self.xlim = [-x_range/2, x_range/2] 132 | self.ylim = [-y_range/2, y_range/2] 133 | 134 | self.c_offset = torch.Tensor([self.origin_xi, self.origin_yi]).to(device) 135 | 136 | def convert_map(self): 137 | self.map_torch = torch.Tensor(self.map) 138 | return self.map_torch 139 | 140 | def plot(self, save_dir=None, filename="obst_map.png"): 141 | plt.figure() 142 | plt.imshow(self.map) 143 | plt.gca().invert_yaxis() 144 | plt.show() 145 | if save_dir is not None: 146 | plt.savefig(osp.join(save_dir, filename)) 147 | 148 | def get_xy_grid(self, device): 149 | xv, yv = torch.meshgrid([torch.linspace(self.xlim[0], self.xlim[1], self.x_dim), 150 | torch.linspace(self.ylim[0], self.ylim[1], self.y_dim)]) 151 | xy_grid = torch.stack((xv, yv), dim=2) 152 | return xy_grid.to(device) 153 | 154 | def get_collisions(self, X): 155 | """ 156 | Checks for collision in a batch of trajectories using the generated occupancy grid (i.e. obstacle map), and 157 | returns sum of collision costs for the entire batch. 158 | 159 | :param weight: weight on obstacle cost, float tensor. 160 | :param X: Tensor of trajectories, of shape (batch_size, traj_length, position_dim) 161 | :return: collision cost on the trajectories 162 | """ 163 | 164 | # Convert traj. positions to occupancy indicies 165 | # try: 166 | # c_offset = torch.Tensor([self.origin_xi, self.origin_yi]).double().to(device) 167 | # except Exception as e: 168 | # print("Exception: ", e) 169 | # print("self.origin_xi", self.origin_xi) 170 | # print("self.origin_yi", self.origin_yi) 171 | 172 | X_occ = X * (1/self.cell_size) + self.c_offset 173 | X_occ = X_occ.floor() 174 | 175 | # X_occ = X_occ.cpu().numpy().astype(np.int) 176 | X_occ = X_occ.type(torch.LongTensor) 177 | 178 | # Project out-of-bounds locations to axis 179 | X_occ[...,0] = X_occ[...,0].clamp(0, self.map.shape[0]-1) 180 | X_occ[...,1] = X_occ[...,1].clamp(0, self.map.shape[1]-1) 181 | 182 | # Collisions 183 | try: 184 | collision_vals = self.map_torch[X_occ[...,0],X_occ[...,1]] 185 | except Exception as e: 186 | print(e) 187 | print(X_occ) 188 | print(X_occ.clamp(0, self.map.shape[0]-1)) 189 | return collision_vals 190 | 191 | -------------------------------------------------------------------------------- /stein_lib/models/obstacles_2D/obs_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | from math import ceil 26 | import random 27 | import matplotlib.pyplot as plt 28 | from obs_map import ObstacleRectangle 29 | 30 | 31 | def round_up(n, decimals=0): 32 | multiplier = 10 ** decimals 33 | return ceil(n * multiplier) / multiplier 34 | 35 | def random_rect(xlim=(0,0), ylim=(0,0), width=2, height=2): 36 | """ 37 | Generates an rectangular obstacle object, with random location and dimensions. 38 | """ 39 | cx = random.uniform(xlim[0], xlim[1]) 40 | cy = random.uniform(ylim[0], ylim[1]) 41 | return ObstacleRectangle(cx,cy,width,height) 42 | 43 | 44 | def save_map_image(obst_map=None,start_pts=None,goal_pts=None,dir='.'): 45 | try: 46 | plt.imshow(obst_map,cmap='gray') 47 | if start_pts is not None: 48 | for pt in start_pts: plt.plot(pt[0],pt[1],'.g') 49 | if goal_pts is not None: 50 | for pt in goal_pts: plt.plot(pt[0],pt[1],'.r') 51 | plt.gca().invert_yaxis() 52 | plt.savefig('{}/obst_map.png'.format(dir)) 53 | except Exception as err: 54 | print("Error: could not save map.") 55 | print(err) 56 | return 57 | 58 | def get_obst_preset( 59 | preset_name, 60 | obst_width=2, 61 | map_dim=(10,10), 62 | num_rand_obst=20, 63 | ): 64 | w = obst_width 65 | map_length = map_dim[0] 66 | map_height = map_dim[1] 67 | if preset_name == 'staggered_3-2-3' : 68 | obst_params = [[-4., 4., w, w], [0., 4., w, w], [4., 4., w, w], 69 | [-6, 0, w, w], [-2, 0, w, w], [2, 0, w, w], [6, 0, w, w], 70 | [-4., -4., w, w], [0., -4., w, w], [4., -4., w, w]] 71 | 72 | elif preset_name == 'staggered_4-3-4-3-4' : 73 | obst_params = [[-6, 6, w, w], [-2., 6, w, w], [2., 6, w, w], [6, 6, w, w], 74 | [-4., 3, w, w], [0., 3, w, w], [4., 3, w, w], 75 | [-6, 0, w, w], [-2., 0, w, w], [2., 0, w, w], [6, 0, w, w], 76 | [-4, -3, w, w], [0., -3, w, w], [4, -3, w, w], 77 | [-6, -6, w, w], [-2, -6, w, w], [2, -6, w, w], [6, -6, w, w],] 78 | 79 | elif preset_name == 'grid_3x3' : 80 | s = 5 81 | obst_params = [[-s, s, w, w], [0., s, w, w], [s, s, w, w], 82 | [-s, 0, w, w], [0, 0, w, w], [s, 0, w, w], 83 | [-s, -s, w, w], [0., -s, w, w], [s, -s, w, w]] 84 | elif preset_name == 'grid_4x4' : 85 | # w = 3 86 | # w = 2.5 87 | # w = 2.25 88 | s = 4 89 | obst_params = [[-s*3/2, s*3/2, w, w], [-s*1/2, s*3/2, w, w], [s*1/2, s*3/2, w, w], [s*3/2, s*3/2, w, w], 90 | [-s*3/2, s/2, w, w], [-s*1/2, s*1/2, w, w], [s*1/2, s*1/2, w, w], [s*3/2, s*1/2, w, w], 91 | [-s*3/2, -s*1/2, w, w], [-s*1/2, -s*1/2, w, w], [s*1/2, -s*1/2, w, w], [s*3/2, -s*1/2, w, w], 92 | [-s*3/2, -s*3/2, w, w], [-s*1/2, -s*3/2, w, w], [s*1/2, -s*3/2, w, w], [s*3/2, -s*3/2, w, w]] 93 | 94 | elif preset_name == 'grid_6x6' : 95 | w = obst_width 96 | # s = 3 97 | s = 1 98 | obst_params = [[-s*5/2, s*5/2, w, w], [-s*3/2, s*5/2, w, w], [-s*1/2, s*5/2, w, w], [s*1/2, s*5/2, w, w], [s*3/2, s*5/2, w, w], [s*5/2, s*5/2, w, w], 99 | [-s*5/2, s*3/2, w, w], [-s*3/2, s*3/2, w, w], [-s*1/2, s*3/2, w, w], [s*1/2, s*3/2, w, w], [s*3/2, s*3/2, w, w], [s*5/2, s*3/2, w, w], 100 | [-s*5/2, s/2, w, w], [-s*3/2, s/2, w, w], [-s*1/2, s*1/2, w, w], [s*1/2, s*1/2, w, w], [s*3/2, s*1/2, w, w],[s*5/2, s*1/2, w, w], 101 | [-s*5/2, -s*1/2, w, w], [-s*3/2, -s*1/2, w, w], [-s*1/2, -s*1/2, w, w], [s*1/2, -s*1/2, w, w], [s*3/2, -s*1/2, w, w], [s*5/2, -s*1/2, w, w], 102 | [-s*5/2, -s*3/2, w, w], [-s*3/2, -s*3/2, w, w], [-s*1/2, -s*3/2, w, w], [s*1/2, -s*3/2, w, w], [s*3/2, -s*3/2, w, w], [s*5/2, -s*3/2, w, w], 103 | [-s*5/2, -s*5/2, w, w], [-s*3/2, -s*5/2, w, w], [-s*1/2, -s*5/2, w, w], [s*1/2, -s*5/2, w, w], [s*3/2, -s*5/2, w, w], [s*5/2, -s*5/2, w, w]] 104 | 105 | elif preset_name == 'maze' : 106 | b = 6 107 | # obst_params = [ [-b, b, b, w], [-b, b, w, b], [0, b, b, w], [0, b, w, b], 108 | # [-b/2, b/2, b, w], [-b/2, b/2, w, b], [b/2., b/2, b, w], [b/2, b/2, w, b], 109 | # [-b, 0., b, w], [-b, 0, w, b], [0, 0., b, w], [0, 0, w, b], [b, 0., b, w], [b, 0, w, b], 110 | # [-b/2, -b/2, b, w], [-b/2, -b/2, w, b], [b/2, -b/2, b, w],[b/2, -b/2, w, b], 111 | # [0, -b, b, w], [0, -b, w, b], [b, -b, b, w], [b, -b, w, b], 112 | # ] 113 | 114 | obst_params = [ [-b, b, b/2, w], [-b, b, w, b/2], [0, b, b/2, w], [0, b, w, b/2], [b, b, b/2, w], [b, b, w, b/2], 115 | [-b/2, b/2, b/2, w], [-b/2, b/2, w, b/2], [b/2., b/2, b/2, w], [b/2, b/2, w, b/2], 116 | [-b, 0., b/2, w], [-b, 0, w, b/2], [0, 0., b/2, w], [0, 0, w, b/2], [b, 0., b/2, w], [b, 0, w, b/2], 117 | [-b/2, -b/2, b/2, w], [-b/2, -b/2, w, b/2], [b/2, -b/2, b/2, w], [b/2, -b/2, w, b/2], 118 | [-b, -b, b/2, w], [-b, -b, w, b/2], [0, -b, b/2, w], [0, -b, w, b/2], [b, -b, b/2, w], [b, -b, w, b/2] 119 | ] 120 | 121 | elif preset_name == 'single_centered' : 122 | obst_params = [[-5, 0, w, w]] 123 | 124 | elif preset_name == 'rand_halton': 125 | # num_obst = 22 126 | # num_obst = 50 127 | # num_obst = 10 128 | num_obst = num_rand_obst 129 | import ghalton 130 | sequencer = ghalton.Halton(2) 131 | obst_params = sequencer.get(num_obst) 132 | for obst in obst_params: 133 | obst[0] = round_up(obst[0]*map_length - map_length / 2, 0) 134 | obst[1] = round_up(obst[1]*map_height - map_height / 2, 0) 135 | # w_rand = randint(1,2) 136 | # rand_ind = randint(0,1) 137 | dims = [w, w] 138 | # dims[rand_ind] = w_rand 139 | obst += dims 140 | # obst_params += [[-2, 2, 1, 1], [0, 1, 1, 1]] 141 | # obst_params += [[-1, 1.5, 2, 1],[0, 0, 1, 2]] 142 | # obst_params += [[2.5, 8, 2, 1],[5, 6, 1, 2]] 143 | 144 | elif preset_name == 'rand_sobol': 145 | # num_obst = 28 146 | # map_width = 18 147 | num_obst = 15 148 | import sobol_seq 149 | obst_params = sobol_seq.i4_sobol_generate(2, num_obst).tolist() 150 | for obst in obst_params: 151 | obst[0] = round_up(obst[0]*map_length - map_length / 2, 0) 152 | obst[1] = round_up(obst[1]*map_height - map_height / 2, 0) 153 | # obst[0] = round(obst[0]*map_width - map_width / 2, 0) 154 | # obst[1] = round(obst[1]*map_width - map_width / 2, 0) 155 | # w_rand = randint(1,2) 156 | # rand_ind = randint(0,1) 157 | dims = [w, w] 158 | # dims[rand_ind] = w_rand 159 | obst += dims 160 | obst_params += [[-1, 1.5, 2, 1],[0, 0, 1, 2]] 161 | # obst_params += [[2.5, 8, 2, 1],[5, 6, 1, 2]] 162 | 163 | elif preset_name == 'rand_mix': 164 | num_obst = 12 165 | import sobol_seq 166 | obst_sobol = sobol_seq.i4_sobol_generate(2, num_obst).tolist() 167 | for obst in obst_sobol: 168 | obst[0] = round_up(obst[0]*map_length - map_length / 2, 0) 169 | obst[1] = round_up(obst[1]*map_height - map_height / 2, 0) 170 | # obst[0] = round(obst[0]*map_width - map_width / 2, 0) 171 | # obst[1] = round(obst[1]*map_width - map_width / 2, 0) 172 | # w_rand = randint(1,2) 173 | # rand_ind = randint(0,1) 174 | dims = [w, w] 175 | # dims[rand_ind] = w_rand 176 | obst += dims 177 | 178 | num_obst = 8 179 | import ghalton 180 | sequencer = ghalton.Halton(2) 181 | obst_halton = sequencer.get(num_obst) 182 | for obst in obst_halton: 183 | obst[0] = round_up(obst[0]*map_length - map_length / 2, 0) 184 | obst[1] = round_up(obst[1]*map_height - map_height / 2, 0) 185 | # w_rand = randint(1,2) 186 | # rand_ind = randint(0,1) 187 | dims = [w, w] 188 | # dims[rand_ind] = w_rand 189 | obst += dims 190 | 191 | obst_params = obst_sobol + obst_halton 192 | 193 | obst_params += [[-1, 5, w, w]] 194 | obst_params += [[4, 2, w, w]] 195 | obst_params += [[0, 8, w, w]] 196 | obst_params += [[6, -1, w, w]] 197 | 198 | else: 199 | raise IOError('Obstacle preset not supported: ', preset_name) 200 | return obst_params -------------------------------------------------------------------------------- /stein_lib/models/obstacles_2D/rbf_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | import torch 26 | import matplotlib.pyplot as plt 27 | import os.path as osp 28 | 29 | 30 | class RBF_map: 31 | 32 | def __init__(self, xlim=(0,1), ylim=(0,1), delta=1, sigma=1, device=None): 33 | """ 34 | :param xlim: x-limits of map 35 | :param ylim: y-limits of map 36 | :param delta: spacing between kernels 37 | :param sigma: standard deviation for each kernel 38 | """ 39 | self.xlim=xlim; self.ylim=ylim; 40 | self.device = device 41 | self.feature_list = self.make_features(xlim, ylim, delta, sigma) 42 | # self.weights = torch.ones(len(self.feature_list), 1).double().to(device) 43 | self.weights = torch.ones(len(self.feature_list), 1).to(device) 44 | 45 | def make_features(self, xlim, ylim, delta, sigma): 46 | mu_list = self.make_xy_centers(xlim, ylim, delta) 47 | features = [self.kernel(mu, sigma) for mu in mu_list] 48 | 49 | # # DEBUG - single centered kernel ## 50 | # mu = torch.Tensor((0, 0)).to(self.device) 51 | # features = [self.kernel(mu, sigma)] 52 | 53 | features.append(self.base()) 54 | return features 55 | 56 | def kernel(self, mu=torch.zeros(2), sigma=1): 57 | assert mu.dim() == 1 and mu.size(0) == 2 58 | return lambda x: torch.exp(-torch.bmm((x - mu).view(-1, 1, 2), (x - mu).view(-1, 2, 1)) / sigma**2).view(-1) 59 | 60 | def base(self): 61 | return lambda x: torch.ones(x.size(0), dtype=x.dtype, device=self.device).view(-1) 62 | 63 | def make_xy_centers(self, xlim, ylim, delta, offset=1e-3): 64 | """ 65 | Make equally-spaced x-y means 66 | :return: list of tensor[(mu_x, mu_y)] 67 | """ 68 | xc, yc = torch.meshgrid([torch.arange(xlim[0], xlim[1]+offset, delta), torch.arange(ylim[0], ylim[1]+offset, delta)]) 69 | mu_vec = torch.stack((xc, yc), dim=2).view(-1,2).to(self.device) 70 | mu_list = [mu_vec[i,:] for i in range(mu_vec.shape[0])] 71 | return mu_list 72 | 73 | def eval(self, xy_data): 74 | return self.get_embedding(xy_data) @ self.weights 75 | 76 | def get_embedding(self, xy_data): 77 | return torch.stack([f(xy_data) for f in self.feature_list], dim=1) 78 | 79 | def fit(self, xy_grid, occ_grid): 80 | """ 81 | Find weight parameters via linear regression. 82 | """ 83 | xy_data = xy_grid.view(-1,2) 84 | labels = occ_grid.view(-1) 85 | Phi = self.get_embedding(xy_data) 86 | self.weights = torch.inverse(Phi.T.matmul(Phi)) @ (Phi.T @ labels) 87 | 88 | def fit_to_obst_map(self, obst_map, plot=False): 89 | """ 90 | Implements RBF interpolation on an occupancy grid. 91 | """ 92 | assert obst_map.xlim == self.xlim 93 | assert obst_map.ylim == self.ylim 94 | assert obst_map.map.size >= len(self.feature_list), "Number of map cells must be greater than number of RBF features." 95 | 96 | occ_grid = obst_map.convert_map() 97 | xy_grid = obst_map.get_xy_grid(device=self.device) 98 | if plot: 99 | obst_map.plot() 100 | self.plot() 101 | self.fit(xy_grid, occ_grid) 102 | if plot: 103 | self.plot() 104 | 105 | def get_collisions(self, X, device, clamp=False): 106 | """ 107 | Checks for collision in a batch of trajectories using the generated occupancy grid (i.e. obstacle map), and 108 | returns sum of collision costs for the entire batch. 109 | 110 | :param weights: weights on obstacle cost, float tensor 111 | :param X: Tensor of trajectories, of shape (batch_size, traj_length, position_dim) 112 | :return: collision cost on the trajectories 113 | """ 114 | 115 | assert X.dim() == 3 and X.size(2) == 2 116 | batch_size, traj_length, _ = X.shape 117 | 118 | # Convert traj. positions to occupancy values 119 | # occ_values = self.eval(X.view(-1,2)).view(batch_size, traj_length, 1) 120 | occ_values = self.eval(X.reshape(-1,2)).view(batch_size, traj_length) 121 | if clamp: occ_values = torch.clamp(occ_values, 0, 1) 122 | return occ_values 123 | 124 | def make_costmap(self, xres=100, yres=100): 125 | # make grid points 126 | xlim = self.xlim; ylim = self.ylim 127 | xv, yv = torch.meshgrid([torch.linspace(xlim[0],xlim[1],xres), torch.linspace(ylim[0], ylim[1], yres)]) 128 | # xy_grid = torch.stack((xv, yv), dim=2).double() 129 | xy_grid = torch.stack((xv, yv), dim=2) 130 | grid_shape = xy_grid.shape[:2] 131 | xy_vec = xy_grid.view(-1,2) 132 | # get cost 133 | out = self.eval(xy_vec.to(self.device)) 134 | out = out.view(grid_shape).cpu().numpy() 135 | return out 136 | 137 | def plot(self, xres=100, yres=100, save_dir=None, filename="rbf_map.png"): 138 | # plot costmap 139 | out = self.make_costmap(xres, yres) 140 | plt.figure() 141 | plt.imshow(out) 142 | plt.gca().invert_yaxis() 143 | plt.show() 144 | if save_dir is not None: 145 | plt.savefig(osp.join(save_dir, filename)) 146 | 147 | # plot cross-section 148 | # x_size, y_size = out.shape 149 | # plt.figure() 150 | # plt.plot(out[:,int(y_size*0.5)]) 151 | # plt.plot(out[int(x_size*0.5),:]) 152 | # plt.show() 153 | -------------------------------------------------------------------------------- /stein_lib/prm_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | import torch 26 | import numpy as np 27 | import pickle 28 | 29 | def get_graph( 30 | particles, 31 | pw_dists, 32 | model, 33 | collision_res=0.1, 34 | collision_thresh=1.e-3, 35 | connect_radius=np.inf, 36 | include_coll_pts=False, 37 | save_path=None, 38 | ): 39 | # Avoid repeated edges 40 | pw_dists = torch.triu(pw_dists) 41 | 42 | # Remove low-probability nodes (i.e. collision check) 43 | # coll_vals = model.log_prob(particles) 44 | coll_vals = torch.exp(1 - model.log_prob(particles)) 45 | coll_inds = (coll_vals > collision_thresh).nonzero() 46 | pw_dists[coll_inds, :] = 0. 47 | pw_dists[:, coll_inds] = 0. 48 | 49 | # Filter edges according to max length 50 | pw_dists[pw_dists > connect_radius] = 0. 51 | 52 | # Get graph node indicies 53 | node_inds = (pw_dists > 0.).nonzero() 54 | 55 | # Collision check on edges: 56 | edge_coll_pts = [] # for debugging 57 | edge_coll_vals = [] 58 | edge_num_pts = [] 59 | dim = particles.shape[-1] 60 | for node_pair in node_inds: 61 | start_pt = particles[node_pair[0]] 62 | end_pt = particles[node_pair[1]] 63 | edge_len = pw_dists[node_pair[0], node_pair[1]] 64 | num_pts = torch.floor(edge_len / collision_res).int() 65 | 66 | coll_pts = [] 67 | for i in range(dim): 68 | coll_pts.append(torch.linspace(start_pt[i], end_pt[i], num_pts)) # includes endpoints 69 | 70 | coll_pts = torch.stack(coll_pts, dim=1) 71 | edge_cost = torch.exp(1 - model.log_prob(coll_pts)).sum() 72 | edge_coll_vals.append(edge_cost) 73 | edge_num_pts.append(num_pts) 74 | if include_coll_pts: 75 | edge_coll_pts.append(coll_pts) # for debugging 76 | 77 | edge_coll_vals = torch.stack(edge_coll_vals, dim=0) 78 | edge_coll_num_pts = torch.stack(edge_num_pts) 79 | 80 | if include_coll_pts: 81 | edge_coll_pts = torch.cat(edge_coll_pts, dim=0) # for debugging 82 | 83 | edge_lengths = pw_dists[node_inds] 84 | 85 | if not include_coll_pts: 86 | edge_coll_pts = None 87 | 88 | params = { 89 | 'collision_res': collision_res, 90 | 'collision_thresh': collision_thresh, 91 | 'connect_radius': connect_radius, 92 | } 93 | 94 | graph = [ 95 | node_inds, 96 | edge_lengths, 97 | edge_coll_vals, 98 | edge_coll_num_pts, 99 | edge_coll_pts, 100 | params, 101 | ] 102 | 103 | for i in range(len(graph)): 104 | try: 105 | graph[i] = graph[i].cpu().numpy() 106 | except: 107 | pass 108 | 109 | if save_path is not None: 110 | pickle.dump(graph, open(save_path, 'w')) 111 | 112 | return tuple(graph) -------------------------------------------------------------------------------- /stein_lib/svgd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/svgd/__init__.py -------------------------------------------------------------------------------- /stein_lib/svgd/base_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | from abc import ABC, abstractmethod 27 | 28 | 29 | class BaseKernel(ABC): 30 | 31 | def __init__( 32 | self, 33 | analytic_grad=True, 34 | ): 35 | 36 | self.analytic_grad = analytic_grad 37 | 38 | @abstractmethod 39 | def eval(self, X, Y, M=None, compute_dK_dK_t=False, **kwargs): 40 | """ 41 | Evaluate kernel function and corresponding gradient terms for batch of inputs. 42 | 43 | Parameters 44 | ---------- 45 | X : Tensor 46 | Data, of shape [batch, dim] 47 | Y : Tensor 48 | Data, of shape [batch, dim] 49 | M : Tensor (Optional) 50 | Metric, of shape [batch, dim, dim] 51 | compute_dK_dK_t : Bool 52 | Compute outer-products of kernel gradients. 53 | kwargs : dict 54 | Kernel-specific parameters 55 | 56 | Returns 57 | ------- 58 | K: Tensor 59 | Kernel Gram matrix, of shape [batch, batch]. 60 | d_K_Xi: Tensor 61 | Kernel gradients wrt. first input X. Shape: [batch, batch, dim] 62 | dK_dK_t: Tensor (Optional) 63 | Outer products of kernel gradients (used by SVN). 64 | Shape: [batch, batch, dim, dim] 65 | pw_dists_sq: Tensor (Optional) 66 | If applicable, returns the squared, inter-particle pairwise distances. 67 | Shape: [batch, batch] 68 | """ 69 | pass 70 | 71 | 72 | class RBF(BaseKernel): 73 | """ 74 | k(x, x') = exp( - || x - x'||**2 / (2 * ell**2)) 75 | """ 76 | def __init__( 77 | self, 78 | bandwidth=-1, 79 | analytic_grad=True, 80 | **kwargs, 81 | ): 82 | super().__init__( 83 | analytic_grad, 84 | ) 85 | self.ell = bandwidth 86 | self.analytic_grad = analytic_grad 87 | 88 | def compute_bandwidth( 89 | self, 90 | X, Y 91 | ): 92 | """ 93 | Older version. 94 | """ 95 | 96 | XX = X.matmul(X.t()) 97 | XY = X.matmul(Y.t()) 98 | YY = Y.matmul(Y.t()) 99 | 100 | pairwise_dists_sq = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0) 101 | 102 | if self.ell < 0: # use median trick 103 | try: 104 | h = torch.median(pairwise_dists_sq).detach() 105 | except Exception as e: 106 | print(pairwise_dists_sq) 107 | print(e) 108 | else: 109 | h = self.ell**2 110 | 111 | h = h / np.log(X.shape[0]) 112 | 113 | # Clamp bandwidth 114 | tol = 1e-5 115 | if isinstance(h, torch.Tensor): 116 | h = torch.clamp(h, min=tol) 117 | else: 118 | h = np.clip(h, a_min=tol, a_max=None) 119 | 120 | return h, pairwise_dists_sq 121 | 122 | def eval( 123 | self, 124 | X, Y, 125 | M=None, 126 | compute_dK_dK_t=False, 127 | bw=None, 128 | **kwargs, 129 | ): 130 | 131 | assert X.shape == Y.shape 132 | 133 | if self.analytic_grad: 134 | if bw is None: 135 | h, pw_dists_sq = self.compute_bandwidth(X, Y) 136 | else: 137 | _, pw_dists_sq = self.compute_bandwidth(X, Y) 138 | h = bw 139 | 140 | K = (- pw_dists_sq / h).exp() 141 | d_K_Xi = K.unsqueeze(2) * (X.unsqueeze(1) - Y) * 2 / h 142 | else: 143 | raise NotImplementedError 144 | 145 | # Used for SVN updates 146 | dK_dK_t = None 147 | if compute_dK_dK_t: 148 | dK_dK_t = torch.einsum( 149 | 'bijk,bilm->bijm', 150 | d_K_Xi.unsqueeze(3), 151 | d_K_Xi.unsqueeze(2), 152 | ) 153 | return ( 154 | K, 155 | d_K_Xi, 156 | dK_dK_t, 157 | pw_dists_sq, 158 | ) 159 | 160 | 161 | class IMQ(BaseKernel): 162 | """ 163 | IMQ Matrix-valued kernel, with metric M. 164 | k(x, x') = M^-1 (alpha + (x - y) M (x - y)^T ) ** beta 165 | """ 166 | def __init__( 167 | self, 168 | # alpha=1, 169 | # beta=-0.5, 170 | alpha=1, 171 | beta=-0.5, 172 | hessian_scale=1, 173 | analytic_grad=True, 174 | median_heuristic=True, 175 | **kwargs, 176 | ): 177 | 178 | self.alpha = alpha 179 | self.beta = beta 180 | 181 | super().__init__( 182 | analytic_grad, 183 | ) 184 | self.hessian_scale = hessian_scale 185 | self.median_heuristic = median_heuristic 186 | 187 | def eval( 188 | self, 189 | X, Y, 190 | M=None, 191 | compute_dK_dK_t=False, 192 | **kwargs, 193 | ): 194 | 195 | assert X.shape == Y.shape 196 | b, dim = X.shape 197 | 198 | # Empirical average of Hessian / Fisher matrices 199 | M = M.mean(dim=0) 200 | 201 | # PSD stabilization 202 | # M = 0.5 * (M + M.T) 203 | 204 | M *= self.hessian_scale 205 | X_M_Xt = X @ M @ X.t() 206 | X_M_Yt = X @ M @ Y.t() 207 | Y_M_Yt = Y @ M @ Y.t() 208 | 209 | pw_dists_sq = -2 * X_M_Yt + X_M_Xt.diag().unsqueeze(1) + Y_M_Yt.diag().unsqueeze(0) 210 | if self.median_heuristic: 211 | h = torch.median(pw_dists_sq).detach() 212 | h = h / np.log(X.shape[0]) 213 | # h *= 0.5 214 | else: 215 | h = self.hessian_scale * X.shape[1] 216 | 217 | # Clamp bandwidth 218 | tol = 1e-5 219 | if isinstance(h, torch.Tensor): 220 | h = torch.clamp(h, min=tol) 221 | else: 222 | h = np.clip(h, a_min=tol, a_max=None) 223 | 224 | K = ( self.alpha + pw_dists_sq) ** self.beta 225 | d_K_Xi = self.beta * ((self.alpha + pw_dists_sq) ** (self.beta - 1)).unsqueeze(2) \ 226 | * ( -1 * (X.unsqueeze(1) - Y) @ M ) * 2 / h 227 | # * ( (X.unsqueeze(1) - Y) @ M ) * 2 / h 228 | 229 | # Used for SVN updates 230 | dK_dK_t = None 231 | if compute_dK_dK_t: 232 | dK_dK_t = torch.einsum( 233 | 'bijk,bilm->bijm', 234 | d_K_Xi.unsqueeze(3), 235 | d_K_Xi.unsqueeze(2), 236 | ) 237 | return ( 238 | K, 239 | d_K_Xi, 240 | dK_dK_t, 241 | pw_dists_sq, 242 | ) 243 | 244 | 245 | class RBF_Anisotropic(RBF): 246 | """ 247 | k(x, x') = exp( - (x - y) M (x - y)^T / (2 * d)) 248 | """ 249 | def __init__( 250 | self, 251 | hessian_scale=1, 252 | analytic_grad=True, 253 | median_heuristic=False, 254 | **kwargs, 255 | ): 256 | super().__init__( 257 | analytic_grad, 258 | ) 259 | self.hessian_scale = hessian_scale 260 | self.median_heuristic = median_heuristic 261 | 262 | def eval( 263 | self, 264 | X, Y, 265 | M=None, 266 | compute_dK_dK_t=False, 267 | bw=None, 268 | **kwargs, 269 | ): 270 | 271 | assert X.shape == Y.shape 272 | 273 | # Empirical average of Hessian / Fisher matrices 274 | M = M.mean(dim=0) 275 | 276 | # PSD stabilization 277 | # M = 0.5 * (M + M.T) 278 | 279 | M *= self.hessian_scale 280 | 281 | X_M_Xt = X @ M @ X.t() 282 | X_M_Yt = X @ M @ Y.t() 283 | Y_M_Yt = Y @ M @ Y.t() 284 | 285 | if self.analytic_grad: 286 | if self.median_heuristic: 287 | bandwidth, pw_dists_sq = self.compute_bandwidth(X, Y) 288 | else: 289 | # bandwidth = self.hessian_scale * X.shape[1] 290 | bandwidth = self.hessian_scale 291 | pw_dists_sq = -2 * X_M_Yt + X_M_Xt.diag().unsqueeze(1) + Y_M_Yt.diag().unsqueeze(0) 292 | 293 | if bw is not None: 294 | bandwidth = bw 295 | 296 | K = (- pw_dists_sq / bandwidth).exp() 297 | d_K_Xi = K.unsqueeze(2) * ( (X.unsqueeze(1) - Y) @ M ) * 2 / bandwidth 298 | else: 299 | raise NotImplementedError 300 | 301 | # Used for SVN updates 302 | dK_dK_t = None 303 | if compute_dK_dK_t: 304 | dK_dK_t = torch.einsum( 305 | 'bijk,bilm->bijm', 306 | d_K_Xi.unsqueeze(3), 307 | d_K_Xi.unsqueeze(2), 308 | ) 309 | return ( 310 | K, 311 | d_K_Xi, 312 | dK_dK_t, 313 | pw_dists_sq, 314 | ) 315 | 316 | 317 | class Linear(BaseKernel): 318 | """ 319 | k(x, x') = x^T x' + 1 320 | """ 321 | def __init__( 322 | self, 323 | analytic_grad=True, 324 | subtract_mean=True, 325 | with_scaling=False, 326 | **kwargs, 327 | ): 328 | super().__init__( 329 | analytic_grad, 330 | ) 331 | self.analytic_grad = analytic_grad 332 | self.subtract_mean = subtract_mean 333 | self.with_scaling = with_scaling 334 | 335 | def eval( 336 | self, 337 | X, Y, 338 | M=None, 339 | compute_dK_dK_t=False, 340 | **kwargs, 341 | ): 342 | 343 | assert X.shape == Y.shape 344 | batch, dim = X.shape 345 | 346 | if self.subtract_mean: 347 | mean = X.mean(0) 348 | X = X - mean 349 | Y = Y - mean 350 | 351 | if self.analytic_grad: 352 | K = X @ Y.t() + 1 353 | d_K_Xi = Y.repeat(batch, 1, 1) 354 | else: 355 | raise NotImplementedError 356 | 357 | if self.with_scaling: 358 | K = K / (dim + 1) 359 | d_K_Xi = d_K_Xi / (dim + 1) 360 | 361 | # Used for SVN updates 362 | dK_dK_t = None 363 | if compute_dK_dK_t: 364 | dK_dK_t = torch.einsum( 365 | 'bijk,bilm->bijm', 366 | d_K_Xi.unsqueeze(3), 367 | d_K_Xi.unsqueeze(2), 368 | ) 369 | 370 | return ( 371 | K, 372 | d_K_Xi, 373 | dK_dK_t, 374 | None, 375 | ) 376 | -------------------------------------------------------------------------------- /stein_lib/svgd/composite_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | from .base_kernels import RBF 26 | from abc import ABC, abstractmethod 27 | 28 | 29 | class CompositeKernel(ABC): 30 | 31 | def __init__( 32 | self, 33 | kernel=RBF, 34 | ctrl_dim=1, 35 | indep_controls=True, 36 | compute_hess_terms=False, 37 | **kargs, 38 | ): 39 | 40 | self.ctrl_dim = ctrl_dim 41 | self.base_kernel = kernel 42 | self.indep_controls = indep_controls 43 | self.compute_hess_terms = compute_hess_terms 44 | 45 | @abstractmethod 46 | def eval(self, X, Y, M=None, **kwargs): 47 | 48 | """ 49 | 50 | Parameters 51 | ---------- 52 | X : tensor of shape [batch, dim] 53 | Y : tensor of shape [batch, dim] 54 | M : tensor of shape [dim, dim] 55 | kwargs : 56 | 57 | Returns 58 | ------- 59 | 60 | """ 61 | pass 62 | 63 | 64 | class iid(CompositeKernel): 65 | 66 | def eval( 67 | self, 68 | X, Y, 69 | M=None, 70 | compute_dK_dK_t=False, 71 | **kwargs, 72 | ): 73 | X = X.view(X.shape[0], -1, self.ctrl_dim) 74 | Y = Y.view(Y.shape[0], -1, self.ctrl_dim) 75 | 76 | # m: batch, h: horizon, d: ctrl_dim 77 | m, h, d = X.shape 78 | 79 | # Keep another batch-dim for grad. mult. later on. 80 | kernel_Xj_Xi = torch.zeros(m, m, h, d) 81 | 82 | # shape : (m, h, d) 83 | d_kernel_Xi = torch.zeros(m, m, h, d) 84 | 85 | # shape : (m, h, d) 86 | d_k_Xj_dk_Xj = torch.zeros(m, m, h, d, h, d) 87 | 88 | if M is not None: 89 | M = M.view(h, d, h, d).transpose(1, 2) # shape: (h, h, d, d) 90 | 91 | if self.indep_controls: 92 | for i in range(h): 93 | for q in range(self.ctrl_dim): 94 | M_ii = None 95 | if M is not None: 96 | M_ii = M[i, i, q, q] 97 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 98 | X[:, i, q].reshape(-1, 1), 99 | Y[:, i, q].reshape(-1, 1), 100 | M_ii.reshape(1, 1), 101 | compute_dK_dK_t=compute_dK_dK_t, 102 | ) 103 | kernel_Xj_Xi[:, :, i, q] += k_tmp 104 | d_kernel_Xi[:, :, i, q] += dk_tmp.squeeze(2) 105 | if compute_dK_dK_t: 106 | d_k_Xj_dk_Xj[:, :, i, q, i, q] += dk_dk_t_tmp.view(m, m) 107 | else: 108 | for i in range(h): 109 | M_ii = None 110 | if M is not None: 111 | M_ii = M[i, i, :, :] 112 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 113 | X[:, i, :], 114 | Y[:, i, :], 115 | M_ii, 116 | compute_dK_dK_t=compute_dK_dK_t, 117 | ) 118 | kernel_Xj_Xi[:, :, i, :] += k_tmp.unsqueeze(2) 119 | d_kernel_Xi[:, :, i, :] += dk_tmp 120 | if compute_dK_dK_t: 121 | d_k_Xj_dk_Xj[:, :, i, :, i, :] += dk_dk_t_tmp 122 | 123 | kernel_Xj_Xi = kernel_Xj_Xi.reshape(m, m, h * d).sum(2) 124 | 125 | d_kernel_Xi = d_kernel_Xi.reshape(m, m, h * d) 126 | d_k_Xj_dk_Xj = d_k_Xj_dk_Xj.reshape(m, m, h * d, h * d) 127 | 128 | return kernel_Xj_Xi, d_kernel_Xi, d_k_Xj_dk_Xj 129 | -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/svgd/matrix_svgd/__init__.py -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/base_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import numpy as np 26 | from abc import ABC, abstractmethod 27 | 28 | 29 | class BaseMatrixKernel(ABC): 30 | 31 | def __init__( 32 | self, 33 | analytic_grad=True, 34 | median_heuristic=False, 35 | ): 36 | self.analytic_grad = analytic_grad 37 | self.median_heuristic = median_heuristic 38 | 39 | @abstractmethod 40 | def eval(self, X, Y, M=None, **kwargs): 41 | """ 42 | Evaluate kernel function and corresponding gradient terms for batch of inputs. 43 | 44 | Parameters 45 | ---------- 46 | X : Tensor 47 | Data, of shape [batch, dim] 48 | Y : Tensor 49 | Data, of shape [batch, dim] 50 | M : Tensor (Optional) 51 | Metric, of shape [batch, dim, dim] 52 | kwargs : dict 53 | Kernel-specific parameters 54 | 55 | Returns 56 | ------- 57 | K: Tensor 58 | Kernel Gram matrix which is pre-conditioned by the inverse metric M^-1. 59 | Of shape [batch, batch, dim, dim]. 60 | d_K_Xi: Tensor 61 | Kernel gradients wrt. first input X. Shape: [batch, batch, dim] 62 | """ 63 | pass 64 | 65 | 66 | class RBF_Matrix(BaseMatrixKernel): 67 | """ 68 | RBF Matrix-valued kernel, with averaged Hessian/Fisher metric M. 69 | Similar to the RBF_Anisotropic kernel, but preconditioned with inverse of M. 70 | 71 | k(x, x') = M^-1 exp( - (x - y) M (x - y)^T / (2 * d)) 72 | """ 73 | def __init__( 74 | self, 75 | hessian_scale=1, 76 | analytic_grad=True, 77 | median_heuristic=False, 78 | **kwargs, 79 | ): 80 | super().__init__( 81 | analytic_grad, 82 | median_heuristic, 83 | ) 84 | self.hessian_scale = hessian_scale 85 | 86 | def eval(self, X, Y, M=None, **kwargs): 87 | 88 | assert X.shape == Y.shape 89 | b, dim = X.shape 90 | 91 | # Empirical average of Hessian / Fisher matrices 92 | M = M.mean(dim=0) 93 | 94 | # PSD stabilization 95 | M_psd = 0.5 * (M + M.T) 96 | 97 | M *= self.hessian_scale 98 | X_M_Xt = X @ M @ X.t() 99 | X_M_Yt = X @ M @ Y.t() 100 | Y_M_Yt = Y @ M @ Y.t() 101 | 102 | pairwise_dists_sq = -2 * X_M_Yt + X_M_Xt.diag().unsqueeze(1) + Y_M_Yt.diag().unsqueeze(0) 103 | if self.median_heuristic: 104 | h = torch.median(pairwise_dists_sq).detach() 105 | h = h / np.log(X.shape[0]) 106 | # h *= 0.5 107 | else: 108 | # h = self.hessian_scale * X.shape[1] 109 | h = self.hessian_scale 110 | # bandwidth = self.hessian_scale * X.shape[1] 111 | # K = (- pairwise_dists_sq / h).exp() 112 | K = (- 0.5 * pairwise_dists_sq / h).exp() 113 | d_K_Xi = K.unsqueeze(2) * ( (X.unsqueeze(1) - Y) @ M ) * 2 / h 114 | 115 | ## Matrix preconditioning 116 | M_inv = torch.inverse(M) 117 | 118 | K = K.reshape(b, b, 1, 1) 119 | 120 | K = M_inv * K 121 | d_K_Xi = (M_inv @ d_K_Xi.unsqueeze(-1)).squeeze(-1) 122 | return ( 123 | K, 124 | d_K_Xi, 125 | ) 126 | 127 | class IMQ_Matrix(BaseMatrixKernel): 128 | """ 129 | IMQ Matrix-valued kernel, with metric M. 130 | k(x, x') = M^-1 (alpha + (x - y) M (x - y)^T ) ** beta 131 | """ 132 | def __init__( 133 | self, 134 | alpha=1, 135 | beta=-0.5, 136 | hessian_scale=1, 137 | analytic_grad=True, 138 | median_heuristic=False, 139 | **kwargs, 140 | ): 141 | 142 | self.alpha = alpha 143 | self.beta = beta 144 | 145 | super().__init__( 146 | analytic_grad, 147 | median_heuristic, 148 | ) 149 | self.hessian_scale = hessian_scale 150 | 151 | def eval(self, X, Y, M=None, **kwargs): 152 | 153 | assert X.shape == Y.shape 154 | b, dim = X.shape 155 | 156 | # Empirical average of Hessian / Fisher matrices 157 | M = M.mean(dim=0) 158 | 159 | # PSD stabilization 160 | M_psd = 0.5 * (M + M.T) 161 | 162 | M *= self.hessian_scale 163 | X_M_Xt = X @ M @ X.t() 164 | X_M_Yt = X @ M @ Y.t() 165 | Y_M_Yt = Y @ M @ Y.t() 166 | 167 | pairwise_dists_sq = -2 * X_M_Yt + X_M_Xt.diag().unsqueeze(1) + Y_M_Yt.diag().unsqueeze(0) 168 | if self.median_heuristic: 169 | h = torch.median(pairwise_dists_sq).detach() 170 | h = h / np.log(X.shape[0]) 171 | # h *= 0.5 172 | else: 173 | h = self.hessian_scale * X.shape[1] 174 | # bandwidth = self.hessian_scale * X.shape[1] 175 | 176 | # K = (- pairwise_dists_sq / h).exp() 177 | # d_K_Xi = K.unsqueeze(2) * ( (X.unsqueeze(1) - Y) @ M) * 2 / h 178 | 179 | K = (( self.alpha + pairwise_dists_sq) ** self.beta).reshape(-1, 1) 180 | d_K_Xi = self.beta * ((self.alpha + pairwise_dists_sq) ** (self.beta - 1)).unsqueeze(2) \ 181 | * ( -1. * (X.unsqueeze(1) - Y) @ M ) * 2 / h 182 | 183 | ## Matrix preconditioning 184 | M_inv = torch.inverse(M) 185 | 186 | K = K.reshape(b, b, 1, 1) 187 | 188 | K = M_inv * K 189 | d_K_Xi = (M_inv @ d_K_Xi.unsqueeze(-1)).squeeze(-1) 190 | return ( 191 | K, 192 | d_K_Xi, 193 | ) 194 | 195 | 196 | class RBF_Weighted_Matrix(BaseMatrixKernel): 197 | 198 | def __init__( 199 | self, 200 | hessian_scale=1, 201 | analytic_grad=True, 202 | alpha=0.5, 203 | **kwargs, 204 | ): 205 | super().__init__( 206 | analytic_grad, 207 | ) 208 | self.hessian_scale = hessian_scale 209 | self.alpha = alpha 210 | 211 | def get_mix_weights(self, X, M, pw_dist_sq): 212 | """ 213 | Finds the Gaussian mixture weights used by the weighted Kernel. 214 | 215 | Parameters 216 | ---------- 217 | X : Tensor 218 | Input values, of shape [batch, dim] 219 | M : Tensor 220 | Metric tensor, of shape [batch, dim, dim] 221 | pw_dist_sq : Tensor 222 | Pair-wise distances for each metric tensor, of shape [batch, batch, batch] 223 | Last dimension corresponds to each metric. 224 | 225 | Returns 226 | ------- 227 | mix_weights: Tensor 228 | Gaussian mixture weights, of shape [batch, batch] 229 | mix_dlog_w: Tensor 230 | Log-derivative of weights, of shape [batch, batch, dim] 231 | """ 232 | 233 | # Get pw_dists for z_el and corresponding M_el 234 | pw_dist_sq_el = torch.diagonal(pw_dist_sq, dim1=1, dim2=2) 235 | mix_weights = torch.softmax( - pw_dist_sq_el - torch.logdet(M), dim=0) 236 | 237 | #TODO: parallelize this somehow? 238 | mix_dlog_w = [ 239 | torch.autograd.grad( 240 | mix_weights[i].sum(), 241 | X, 242 | retain_graph=True, 243 | )[0] for i in range(mix_weights.shape[0]) 244 | ] 245 | mix_dlog_w = torch.stack(mix_dlog_w, dim=0) 246 | 247 | return mix_weights, mix_dlog_w 248 | 249 | def eval(self, X, Y, M=None, **kwargs): 250 | 251 | assert X.shape == Y.shape 252 | 253 | M = 0.5 * (M + M.transpose(1, 2)) # PSD stabilization 254 | 255 | M *= self.hessian_scale # M of shape (batch, dim, dim) 256 | 257 | # Mix w/ average Hessian for robustness (from Wang et al. 2019 implementation) 258 | M = (1 - self.alpha) * M.mean(0) + self.alpha * M 259 | 260 | bandwidth = self.hessian_scale * X.shape[1] 261 | 262 | b, dim = X.shape 263 | diff_XY = X.unsqueeze(1) - Y # (b, b, dim) 264 | diff_XY = diff_XY.reshape(b, b, 1, 1, dim) 265 | diff_XY_M = diff_XY @ M # (b, b, b, dim) 266 | pairwise_dists_sq = diff_XY_M @ diff_XY.transpose(-2, -1) 267 | 268 | # Mixture weights 269 | w, dlog_w = self.get_mix_weights(X, M, pairwise_dists_sq.reshape(b, b, b)) 270 | w = w.unsqueeze(-1) 271 | w_w_T = w.bmm(w.transpose(1, 2)) 272 | 273 | # Mixture Grammians and gradients 274 | M_inv = torch.inverse(M) 275 | K_el = (- pairwise_dists_sq / bandwidth).exp() 276 | K_el = M_inv * K_el # (b, b, b, d, d) 277 | d_K_Xi_el = (diff_XY_M @ K_el) * 2 / bandwidth 278 | 279 | # Full Kernel Grammian 280 | w_w_T = w_w_T.reshape(b, b, b, 1, 1) 281 | K = (w_w_T * K_el).sum(dim=2) 282 | 283 | # Full Kernel gradient 284 | dlog_w_K_el = torch.einsum('abcde, bce -> abcd', K_el, dlog_w) 285 | w_w_T = w_w_T.reshape(b, b, b, 1) 286 | dlog_w_K_el = dlog_w_K_el.reshape(b, b, b, dim) 287 | d_K_Xi_el = d_K_Xi_el.reshape(b, b, b, dim) 288 | d_K_Xi = (w_w_T * (dlog_w_K_el + d_K_Xi_el)).sum(dim=2) 289 | 290 | return ( 291 | K, 292 | d_K_Xi, 293 | ) 294 | -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/matrix_mix_svgd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import numpy as np 26 | from stein_lib.svgd.svgd import SVGD 27 | from torch.distributions.normal import Normal 28 | from torch.distributions.independent import Independent 29 | from torch.distributions.categorical import Categorical 30 | from torch.distributions.mixture_same_family import MixtureSameFamily 31 | 32 | # import sys 33 | # np.set_printoptions(threshold=sys.maxsize) 34 | 35 | 36 | class RBF_kernel: 37 | """ 38 | k(x, x') = exp( - (x - y) M (x - y)^T / (2 * d)) 39 | """ 40 | def __init__( 41 | self, 42 | hessian_scale=1, 43 | median_heuristic=True, 44 | **kwargs, 45 | ): 46 | 47 | self.hessian_scale = hessian_scale 48 | self.median_heuristic = median_heuristic 49 | 50 | def eval(self, X, Y, M, **kwargs): 51 | """ 52 | Parameters 53 | ---------- 54 | X : Tensor of shape [batch, dim], 55 | Y : Tensor of shape [batch, dim], 56 | M : Tensor of shape [dim, dim], 57 | 58 | Returns 59 | ------- 60 | """ 61 | assert X.shape == Y.shape 62 | 63 | # PSD stabilization 64 | # M_psd = 0.5 * (M + M.T) 65 | 66 | M *= self.hessian_scale 67 | b, d = X.shape 68 | diff_XY = X.unsqueeze(1) - Y # b x b x d 69 | diff_XY = diff_XY.reshape(b, b, 1, d) 70 | diff_XY_M = diff_XY @ M # (b, b, d) 71 | 72 | pw_dists_sq = diff_XY_M @ diff_XY.reshape(b, b, d, 1) 73 | pw_dists_sq = pw_dists_sq.reshape(b, b) 74 | 75 | if self.median_heuristic: 76 | h = torch.median(pw_dists_sq).detach() 77 | h = h / np.log(X.shape[0]) 78 | else: 79 | h = self.hessian_scale * d 80 | 81 | K = (- 0.5 * pw_dists_sq / h).exp() 82 | # K = (- pw_dists_sq / h).exp() 83 | 84 | d_K_Xi = K.unsqueeze(2) * diff_XY_M.reshape(b, b, d) / h 85 | return ( 86 | K, 87 | d_K_Xi, 88 | ) 89 | 90 | 91 | class MatrixMixtureSVGD(SVGD): 92 | 93 | def __init__( 94 | self, 95 | kernel_base_type='RBF_Matrix', 96 | **kwargs, 97 | ): 98 | super().__init__(kernel_base_type, **kwargs) 99 | 100 | def get_base_kernel( 101 | self, 102 | **kernel_params, 103 | ): 104 | # if self.kernel_base_type == 'RBF_Matrix': 105 | # return RBF_Matrix( 106 | # **kernel_params, 107 | # ) 108 | if self.kernel_base_type == 'RBF_Matrix': # Does not precondition 109 | return RBF_kernel( 110 | **kernel_params, 111 | ) 112 | else: 113 | raise IOError('Weighted-Matrix-SVGD kernel type not recognized: ', 114 | self.kernel_base_type) 115 | 116 | def get_kernel( 117 | self, 118 | **kernel_params, 119 | ): 120 | 121 | if self.kernel_structure is None: 122 | return self.base_kernel 123 | else: 124 | raise IOError('Kernel structure not recognized for matrix-SVGD: ', 125 | self.kernel_structure,) 126 | 127 | def get_pairwise_dists_sq(self, X, Y, M): 128 | """ 129 | Get batched pairwise-distances squared. 130 | Parameters 131 | ---------- 132 | X : Tensor of shape [batch, dim], 133 | Y : Tensor of shape [batch, dim], 134 | M : Tensor of shape [batch, dim, dim], 135 | 136 | Returns 137 | ------- 138 | pw_dists : Tensor of shape [batch, batch] 139 | , 140 | """ 141 | b, d = X.shape 142 | diff_XY = X.unsqueeze(1) - Y # b x b x d 143 | diff_XY = diff_XY.reshape(b, b, 1, d) 144 | diff_XY_M = diff_XY @ M # (b, b, d) 145 | pw_dists_sq = diff_XY_M @ diff_XY.reshape(b, b, d, 1) 146 | return pw_dists_sq.squeeze(), diff_XY_M.reshape(b, b, d) 147 | 148 | def get_weights( 149 | self, 150 | X, 151 | H, 152 | H_diff, 153 | pw_dists_sq, 154 | ): 155 | ## Debug: test using average metric for weights 156 | # M = H.mean(0) 157 | # b, d = X.shape 158 | # Y = X.clone() 159 | # diff_XY = X.unsqueeze(1) - Y # b x b x d 160 | # diff_XY = diff_XY.reshape(b, b, 1, d) 161 | # diff_XY_M = diff_XY @ M # (b, b, d) 162 | # 163 | # pw_dists_sq = diff_XY_M @ diff_XY.reshape(b, b, d, 1) 164 | # pw_dists_sq = pw_dists_sq.reshape(b, b) 165 | # H_diff = diff_XY_M.reshape(b, b, d) 166 | 167 | # TODO: neg. definite H 168 | ww = torch.exp( - 0.5 * ( pw_dists_sq - pw_dists_sq.min(0).values - torch.logdet(H)) ) 169 | # ww = torch.exp( - 0.5 * (pw_dists_sq - pw_dists_sq.min(0).values) ) 170 | w = ww / torch.sum(ww, dim=0) 171 | dlog_w = torch.sum((H_diff[:,None,:,:] - H_diff[None,:,:,:]) * ww[None,:,:,None], dim=1) 172 | dlog_w = dlog_w / torch.sum(ww, dim=0)[None,:,None] 173 | return w, dlog_w 174 | 175 | def weighted_Hessian_SVGD(self, X, dlog_p, Hess, H_inv, w): 176 | 177 | # k_XX, grad_k = self.kernel.eval(X, X, H_inv) 178 | k_XX, grad_k = self.kernel.eval(X, X, Hess) 179 | 180 | # print('k_XX', k_XX) 181 | # print('grad_k', grad_k) 182 | 183 | velocity = torch.sum( 184 | w[None,:,None] * k_XX[:,:,None] * dlog_p[None,:,:], 185 | dim=1 186 | ) + torch.sum( 187 | w[:,None,None] * grad_k, 188 | dim=0 189 | ) 190 | velocity = velocity @ H_inv 191 | return velocity 192 | 193 | def get_update( 194 | self, 195 | X, 196 | dlog_p, 197 | Hess, 198 | Hess_prior=None, 199 | # alpha=0.5, 200 | alpha=0., 201 | ): 202 | 203 | """ 204 | Handle matrix-valued SVGD terms. 205 | 206 | Parameters 207 | ---------- 208 | X : Stein particles. Tensor of shape [batch, dim], 209 | dlog_p : tensor of shape [batch, dim] 210 | M : Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 211 | 212 | Returns 213 | ------- 214 | gradient: tensor of shape [batch, dim] 215 | repulsive: tensor of shape [batch, dim] 216 | 217 | """ 218 | 219 | b, d = X.shape 220 | 221 | Hess_avg = Hess.mean(0) 222 | Hess = alpha * Hess + (1 - alpha) * Hess_avg # for 'robustness' 223 | 224 | pw_dists_sq, H_diff = self.get_pairwise_dists_sq(X, X, Hess) 225 | 226 | w, dlog_w = self.get_weights(X, Hess, H_diff, pw_dists_sq) 227 | 228 | # print('\nw', w) 229 | # print('dlog_w', dlog_w) 230 | H_inv = torch.inverse(Hess) # b, d, d 231 | # print('H_inv', H_inv) 232 | 233 | velocity = torch.zeros_like(X) 234 | 235 | for i in range(b): 236 | velocity += w[i, :, None] * self.weighted_Hessian_SVGD( 237 | X, 238 | dlog_p + dlog_w[i,:,:], 239 | Hess[i,:,:], 240 | H_inv[i,:,:], 241 | w[i,:], 242 | ) 243 | # ) / b 244 | # print("velocity", velocity) 245 | return velocity 246 | 247 | 248 | def phi( 249 | self, 250 | X, 251 | dlog_p, 252 | dlog_lh, 253 | Hess=None, 254 | Hess_prior=None, 255 | reshape_inputs=True, 256 | transpose=False, 257 | ): 258 | """ 259 | Parameters 260 | ---------- 261 | X : (Tensor) 262 | Stein particles, of shape [num_particles, dim], 263 | or of shape [dim, num_particles]. If Tensor dimension is greater than 2, 264 | extra dimensions will be flattened. 265 | dlog_p : (Tensor) 266 | Score function, of shape [num_particles, dim] 267 | or of shape [dim, num_particles]. If Tensor dimension is greater than 2, 268 | extra dimensions will be flattened. 269 | transpose: Bool 270 | Transpose input and output Tensors. 271 | Returns 272 | ------- 273 | Phi: (Tensor) 274 | Empirical Stein gradient, of shape [num_particles, dim] 275 | """ 276 | 277 | shape_original = X.shape 278 | if reshape_inputs: 279 | X, dlog_p, Hess = self.reshape_inputs( 280 | X, 281 | dlog_p, 282 | Hess, 283 | transpose, 284 | ) 285 | 286 | if self.geom_metric_type is None: 287 | pass 288 | if self.geom_metric_type == 'full_hessian': 289 | assert Hess is not None 290 | M = - Hess 291 | # M += 1.e-6 * torch.eye(M.shape[1], M.shape[2]) 292 | elif self.geom_metric_type == 'fisher': 293 | ## Average Fisher matrix (likelihood only) 294 | np = dlog_lh.shape[0] 295 | M = torch.bmm(dlog_lh.reshape(np, -1, 1,), dlog_lh.reshape(np, 1, -1)) 296 | M += 1.e-6 * torch.eye(M.shape[1], M.shape[2]) 297 | elif self.geom_metric_type == 'jacobian_product': 298 | ## Average Fisher matrix (full posterior gradient) 299 | dim = dlog_p.shape[-1] 300 | M = torch.bmm(dlog_p.view(-1, dim, 1,), dlog_p.view(-1, 1, dim)) 301 | M += 1.e-6 * torch.eye(M.shape[1], M.shape[2]) 302 | elif self.geom_metric_type == 'riemannian': 303 | # Average Fisher matrix plus neg. Hessian of log prior 304 | b = dlog_lh.shape[0] 305 | F = torch.bmm(dlog_lh.view(b, -1, 1,), dlog_lh.view(b, 1, -1)) 306 | Hess_prior = Hess_prior.reshape( 307 | dlog_p.shape[0], 308 | dlog_p.shape[1], 309 | dlog_p.shape[1], 310 | ) 311 | M = F - Hess_prior 312 | else: 313 | raise NotImplementedError 314 | 315 | phi = self.get_update( 316 | X, 317 | dlog_p, 318 | M, 319 | ) 320 | # Reshape Phi to match original input tensor dimensions 321 | if reshape_inputs: 322 | if transpose: 323 | phi = phi.t() 324 | phi = phi.reshape(shape_original) 325 | 326 | return phi -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/matrix_svgd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | from stein_lib.svgd.svgd import SVGD 25 | from stein_lib.svgd.matrix_svgd.base_kernels import ( 26 | RBF_Matrix, 27 | IMQ_Matrix, 28 | RBF_Weighted_Matrix, 29 | ) 30 | 31 | # import sys 32 | # np.set_printoptions(threshold=sys.maxsize) 33 | 34 | class MatrixSVGD(SVGD): 35 | 36 | def __init__( 37 | self, 38 | kernel_base_type='RBF_Matrix', 39 | **kwargs, 40 | ): 41 | super().__init__(kernel_base_type, **kwargs) 42 | 43 | def get_base_kernel( 44 | self, 45 | **kernel_params, 46 | ): 47 | if self.kernel_base_type == 'RBF_Matrix': 48 | return RBF_Matrix( 49 | **kernel_params, 50 | ) 51 | if self.kernel_base_type == 'IMQ_Matrix': 52 | return IMQ_Matrix( 53 | **kernel_params, 54 | ) 55 | elif self.kernel_base_type == 'RBF_Weighted_Matrix': 56 | return RBF_Weighted_Matrix( 57 | **kernel_params, 58 | ) 59 | else: 60 | raise IOError('Matrix-SVGD kernel type not recognized: ', 61 | self.kernel_base_type) 62 | 63 | def get_kernel( 64 | self, 65 | **kernel_params, 66 | ): 67 | 68 | if self.kernel_structure is None: 69 | return self.base_kernel 70 | else: 71 | raise IOError('Kernel structure not recognized for matrix-SVGD: ', 72 | self.kernel_structure,) 73 | 74 | def get_svgd_terms( 75 | self, 76 | X, 77 | dlog_p, 78 | M=None, 79 | ): 80 | 81 | """ 82 | Handle matrix-valued SVGD terms. 83 | 84 | Parameters 85 | ---------- 86 | X : Stein particles. Tensor of shape [batch, dim], 87 | dlog_p : tensor of shape [batch, dim] 88 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 89 | 90 | Returns 91 | ------- 92 | gradient: tensor of shape [batch, dim] 93 | repulsive: tensor of shape [batch, dim] 94 | 95 | """ 96 | 97 | k_XX, grad_k = self.evaluate_kernel(X, M) 98 | 99 | b, dim = dlog_p.shape 100 | dlog_p = dlog_p.reshape(1, b, dim, 1) 101 | 102 | gradient = (k_XX.detach() @ dlog_p).mean(1).squeeze(-1) 103 | repulsive = grad_k.mean(1) 104 | return gradient, repulsive 105 | 106 | def evaluate_kernel(self, X, M=None): 107 | """ 108 | Parameters 109 | ---------- 110 | X : tensor. Stein particles, of shape [batch, dim], 111 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 112 | 113 | Returns 114 | ------- 115 | k_XX : tensor of shape [batch, batch, dim, dim] 116 | grad_k : tensor of shape [batch, batch, dim] 117 | 118 | """ 119 | k_XX, grad_k = self.kernel.eval( 120 | X, X.clone().detach(), 121 | M, 122 | ) 123 | return k_XX, grad_k 124 | -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/mp_composite_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | from .base_kernels import RBF_Matrix 26 | from abc import ABC, abstractmethod 27 | 28 | 29 | class CompositeMatrixKernel(ABC): 30 | 31 | def __init__( 32 | self, 33 | kernel=RBF_Matrix, 34 | ctrl_dim=1, 35 | indep_controls=True, 36 | compute_hess_terms=False, 37 | **kargs, 38 | ): 39 | 40 | self.ctrl_dim = ctrl_dim 41 | self.base_kernel = kernel 42 | self.indep_controls = indep_controls 43 | self.compute_hess_terms = compute_hess_terms 44 | 45 | @abstractmethod 46 | def eval_svg_terms(self, X, Y, dlog_p, M=None, **kwargs): 47 | """ 48 | Evaluates svgd gradient and repulsive terms. 49 | 50 | Parameters 51 | ---------- 52 | X : tensor of shape [batch, dim] 53 | Y : tensor of shape [batch, dim] 54 | dlog_p : tensor of shape [batch, dim] 55 | M : tensor of shape [batch, dim, dim] 56 | kwargs : 57 | 58 | Returns 59 | ------- 60 | kernel_Xj_Xi: Tensor 61 | Kernel Grammian, of shape [batch, batch, h, cl_size * dim, cl_size * dim] 62 | where cl_size is the size of the clique (ex. 63 | d_kernel_Xi: tensor of shape [batch, batch, dim] 64 | """ 65 | pass 66 | 67 | 68 | class matrix_iid_mp(CompositeMatrixKernel): 69 | 70 | def eval_svg_terms( 71 | self, 72 | X, Y, 73 | dlog_p, 74 | M=None, 75 | **kwargs, 76 | ): 77 | 78 | X = X.view(X.shape[0], -1, self.ctrl_dim) 79 | Y = Y.view(Y.shape[0], -1, self.ctrl_dim) 80 | 81 | # m: batch, h: horizon, d: ctrl_dim 82 | m, h, d = X.shape 83 | 84 | kernel_Xj_Xi = torch.zeros(m, m, h, d, h, d) 85 | 86 | d_kernel_Xi = torch.zeros(m, m, h, d) 87 | 88 | if M is not None: 89 | M = M.view(m, h, d, h, d) 90 | 91 | if self.indep_controls: 92 | for i in range(h): 93 | for q in range(self.ctrl_dim): 94 | M_ii = None 95 | if M is not None: 96 | M_ii = M[:, i, q, i, q] 97 | k_tmp, dk_tmp = self.base_kernel.eval( 98 | X[:, i, q].reshape(-1, 1), 99 | Y[:, i, q].reshape(-1, 1), 100 | M_ii.reshape(-1, 1, 1), 101 | ) 102 | kernel_Xj_Xi[:, :, i, q, i, q] += k_tmp.reshape(m, m) 103 | d_kernel_Xi[:, :, i, q] += dk_tmp.squeeze(2) 104 | else: 105 | for i in range(h): 106 | M_ii = None 107 | if M is not None: 108 | M_ii = M[:, i, :, i, :] 109 | k_tmp, dk_tmp = self.base_kernel.eval( 110 | X[:, i, :], 111 | Y[:, i, :], 112 | M_ii, 113 | ) 114 | kernel_Xj_Xi[:, :, i, :, i, :] += k_tmp 115 | d_kernel_Xi[:, :, i, :] += dk_tmp 116 | 117 | kernel_Xj_Xi = kernel_Xj_Xi.reshape(m, m, h * d, h * d) 118 | d_kernel_Xi = d_kernel_Xi.reshape(m, m, h * d) 119 | 120 | dlog_p = dlog_p.reshape(1, m, h * d, 1) 121 | k_Xj_Xi_dlog_p = kernel_Xj_Xi.detach() @ dlog_p 122 | 123 | return k_Xj_Xi_dlog_p, d_kernel_Xi 124 | 125 | 126 | class matrix_first_order_mp(CompositeMatrixKernel): 127 | def eval_svg_terms( 128 | self, 129 | X, Y, 130 | dlog_p, 131 | M=None, 132 | **kwargs, 133 | ): 134 | 135 | X = X.view(X.shape[0], -1, self.ctrl_dim) 136 | Y = Y.view(Y.shape[0], -1, self.ctrl_dim) 137 | 138 | # m: batch, h: horizon, d: ctrl_dim 139 | m, h, d = X.shape 140 | 141 | k_Xj_Xi_dlog_p = torch.zeros(m, m, h, d) 142 | d_kernel_Xi = torch.zeros(m, m, h, d) 143 | dlog_p = dlog_p.reshape(1, m, h, d) 144 | 145 | if M is not None: 146 | M = M.view(m, h, d, h, d) 147 | 148 | for i in range(h): 149 | 150 | # clique : i-th node + Markov blanket 151 | if i == 0: 152 | clique = [i, i + 1] 153 | cl_i = 0 154 | elif i == h-1: 155 | clique = [i - 1, i] 156 | cl_i = 1 157 | else: 158 | clique = [i - 1, i, i + 1] 159 | cl_i = 1 160 | num = len(clique) 161 | 162 | if self.indep_controls: 163 | for q in range(self.ctrl_dim): 164 | M_i = None 165 | if M is not None: 166 | M_i = M[:, clique, q, :, q][:, :, clique] 167 | M_i = M_i.reshape(-1, num, num) 168 | k_tmp, dk_tmp = self.base_kernel.eval( 169 | X[:, clique, q].reshape(-1, num), 170 | Y[:, clique, q].reshape(-1, num), 171 | M_i, 172 | ) 173 | k_dlog_p = (k_tmp @ dlog_p[:, :, clique, q].unsqueeze(-1)).squeeze(-1) 174 | k_Xj_Xi_dlog_p[:, :, i, q] += k_dlog_p[:, :, cl_i] 175 | d_kernel_Xi[:, :, i, q] += dk_tmp[:, :, cl_i] 176 | else: 177 | M_i = None 178 | if M is not None: 179 | M_i = M[:, clique, :, :, :][:, :, :, clique, :] 180 | M_i = M_i.reshape(-1, num * d, num * d) 181 | k_tmp, dk_tmp = self.base_kernel.eval( 182 | X[:, clique, :].reshape(-1, num * d), 183 | Y[:, clique, :].reshape(-1, num * d), 184 | M_i, 185 | ) 186 | dlog_p_cl = dlog_p[:, :, clique, :].reshape(1, m, num * d, 1) 187 | k_dlog_p = (k_tmp @ dlog_p_cl).reshape(m, m, num, d) 188 | dk_tmp = dk_tmp.reshape(m, m, num, d) 189 | k_Xj_Xi_dlog_p[:, :, i, :] += k_dlog_p[:, :, cl_i, :] 190 | d_kernel_Xi[:, :, i, :] += dk_tmp[:, :, cl_i, :] 191 | 192 | k_Xj_Xi_dlog_p = k_Xj_Xi_dlog_p.reshape(m, m, h * d) 193 | d_kernel_Xi = d_kernel_Xi.reshape(m, m, h * d) 194 | 195 | return k_Xj_Xi_dlog_p, d_kernel_Xi 196 | 197 | -------------------------------------------------------------------------------- /stein_lib/svgd/matrix_svgd/mp_matrix_svgd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import numpy as np 26 | from svmpc_np.svgd.matrix_svgd.matrix_svgd import MatrixSVGD 27 | from svmpc_np.svgd.base_kernels import ( 28 | RBF_Anisotropic, 29 | ) 30 | from svmpc_np.svgd.matrix_svgd.base_kernels import ( 31 | RBF_Matrix, 32 | RBF_Weighted_Matrix, 33 | ) 34 | from svmpc_np.svgd.matrix_svgd.mp_composite_kernels import ( 35 | matrix_iid_mp, 36 | matrix_first_order_mp, 37 | ) 38 | 39 | 40 | class MP_MatrixSVGD(MatrixSVGD): 41 | 42 | def __init__( 43 | self, 44 | kernel_base_type='RBF_Matrix', 45 | **kwargs, 46 | ): 47 | super().__init__(kernel_base_type, **kwargs) 48 | 49 | def get_base_kernel( 50 | self, 51 | **kernel_params, 52 | ): 53 | if self.kernel_base_type == 'RBF_Matrix': 54 | return RBF_Matrix( 55 | **kernel_params, 56 | ) 57 | elif self.kernel_base_type == 'RBF_Weighted_Matrix': 58 | return RBF_Weighted_Matrix( 59 | **kernel_params, 60 | ) 61 | else: 62 | raise IOError('Matrix-SVGD kernel type not recognized: ', 63 | self.kernel_base_type) 64 | 65 | def get_kernel( 66 | self, 67 | **kernel_params, 68 | ): 69 | 70 | if self.kernel_structure == 'matrix_iid_mp': 71 | return matrix_iid_mp( 72 | ctrl_dim=self.ctrl_dim, 73 | kernel=self.base_kernel, 74 | **kernel_params, 75 | ) 76 | elif self.kernel_structure == 'matrix_first_order_mp': 77 | return matrix_first_order_mp( 78 | ctrl_dim=self.ctrl_dim, 79 | kernel=self.base_kernel, 80 | **kernel_params, 81 | ) 82 | else: 83 | raise IOError('Kernel structure not recognized for matrix-SVGD: ', 84 | self.kernel_structure,) 85 | 86 | def get_svgd_terms( 87 | self, 88 | X, 89 | dlog_p, 90 | M=None, 91 | ): 92 | 93 | """ 94 | Handle matrix-valued SVGD terms. 95 | 96 | Parameters 97 | ---------- 98 | X : Stein particles. Tensor of shape [batch, dim], 99 | dlog_p : tensor of shape [batch, dim] 100 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 101 | 102 | Returns 103 | ------- 104 | gradient: tensor of shape [batch, dim] 105 | repulsive: tensor of shape [batch, dim] 106 | 107 | """ 108 | 109 | k_dlog_p, grad_k = self.kernel.eval_svg_terms( 110 | X, X.clone().detach(), 111 | dlog_p, 112 | M, 113 | ) 114 | 115 | gradient = k_dlog_p.mean(1).squeeze(-1) 116 | repulsive = grad_k.mean(1) 117 | 118 | return gradient, repulsive 119 | -------------------------------------------------------------------------------- /stein_lib/svgd/mp_composite_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | from .composite_kernels import CompositeKernel 26 | 27 | 28 | class iid_mp(CompositeKernel): 29 | 30 | def eval( 31 | self, 32 | X, Y, 33 | M=None, 34 | compute_dK_dK_t=False, 35 | **kwargs, 36 | ): 37 | 38 | X = X.view(X.shape[0], -1, self.ctrl_dim) 39 | Y = Y.view(Y.shape[0], -1, self.ctrl_dim) 40 | 41 | # m: batch, h: horizon, d: ctrl_dim 42 | m, h, d = X.shape 43 | 44 | # Keep another batch-dim for grad. mult. later on. 45 | kernel_Xj_Xi = torch.zeros(m, m, h, d) 46 | 47 | # shape : (m, h, d) 48 | d_kernel_Xi = torch.zeros(m, m, h, d) 49 | 50 | # shape : (m, h, d) 51 | d_k_Xj_dk_Xj = torch.zeros(m, m, h, d, h, d) 52 | 53 | if M is not None: 54 | M = M.view(m, h, d, h, d).transpose(2, 3) # shape: (m, h, h, d, d) 55 | 56 | if self.indep_controls: 57 | for i in range(h): 58 | for q in range(self.ctrl_dim): 59 | M_ii = None 60 | if M is not None: 61 | M_ii = M[:, i, i, q, q].reshape(-1, 1, 1) 62 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 63 | X[:, i, q].reshape(-1, 1), 64 | Y[:, i, q].reshape(-1, 1), 65 | M_ii, 66 | compute_dK_dK_t=compute_dK_dK_t, 67 | ) 68 | kernel_Xj_Xi[:, :, i, q] += k_tmp 69 | d_kernel_Xi[:, :, i, q] += dk_tmp.squeeze(2) 70 | if compute_dK_dK_t: 71 | d_k_Xj_dk_Xj[:, :, i, q, i, q] += dk_dk_t_tmp.view(m, m) 72 | else: 73 | for i in range(h): 74 | M_ii = None 75 | if M is not None: 76 | M_ii = M[:, i, i, :, :] 77 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 78 | X[:, i, :], 79 | Y[:, i, :], 80 | M_ii, 81 | compute_dK_dK_t=compute_dK_dK_t, 82 | ) 83 | kernel_Xj_Xi[:, :, i, :] += k_tmp.unsqueeze(2) 84 | d_kernel_Xi[:, :, i, :] += dk_tmp 85 | if compute_dK_dK_t: 86 | d_k_Xj_dk_Xj[:, :, i, :, i, :] += dk_dk_t_tmp 87 | 88 | kernel_Xj_Xi = kernel_Xj_Xi.reshape(m, m, h * d) 89 | d_kernel_Xi = d_kernel_Xi.reshape(m, m, h * d) 90 | d_k_Xj_dk_Xj = d_k_Xj_dk_Xj.reshape(m, m, h * d, h * d) 91 | 92 | return kernel_Xj_Xi, d_kernel_Xi, d_k_Xj_dk_Xj 93 | 94 | 95 | class first_order_mp(CompositeKernel): 96 | 97 | def eval( 98 | self, 99 | X, Y, 100 | M=None, 101 | compute_dK_dK_t=False, 102 | **kwargs, 103 | ): 104 | 105 | X = X.view(X.shape[0], -1, self.ctrl_dim) 106 | Y = Y.view(Y.shape[0], -1, self.ctrl_dim) 107 | 108 | # m: batch, h: horizon, d: ctrl_dim 109 | m, h, d = X.shape 110 | 111 | # Keep another batch-dim for grad. mult. later on. 112 | kernel_Xj_Xi = torch.zeros(m, m, h, d) 113 | 114 | # shape : (m, h, d) 115 | d_kernel_Xi = torch.zeros(m, m, h, d) 116 | 117 | # shape : (m, h, d) 118 | d_k_Xj_dk_Xj = torch.zeros(m, m, h, d, h, d) 119 | 120 | if M is not None: 121 | M = M.view(m, h, d, h, d) 122 | 123 | for i in range(h): 124 | 125 | # clique : i-th node + Markov blanket 126 | # cl_i : index of i-th node in clique 127 | if i == 0: 128 | clique = [i, i + 1] 129 | cl_i = 0 130 | elif i == h-1: 131 | clique = [i - 1, i] 132 | cl_i = 1 133 | else: 134 | clique = [i - 1, i, i + 1] 135 | cl_i = 1 136 | num = len(clique) 137 | 138 | if self.indep_controls: 139 | for q in range(self.ctrl_dim): 140 | M_i = None 141 | 142 | if M is not None: 143 | M_i = M[:, clique, q, :, q][:, :, clique] 144 | M_i = M_i.reshape(-1, num, num) 145 | 146 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 147 | X[:, clique, q].reshape(-1, num), 148 | Y[:, clique, q].reshape(-1, num), 149 | M_i, 150 | compute_dK_dK_t=compute_dK_dK_t, 151 | ) 152 | kernel_Xj_Xi[:, :, i, q] += k_tmp 153 | d_kernel_Xi[:, :, i, q] += dk_tmp.reshape(m, m, num)[:, :, cl_i] 154 | if compute_dK_dK_t: 155 | # TODO: test this with SVN 156 | #d_k_Xj_dk_Xj[:, :, i, q, i, q] += dk_dk_t_tmp 157 | pass 158 | else: 159 | M_i = None 160 | if M is not None: 161 | M_i = M[:, clique, :, :, :][:, :, :, clique, :] 162 | M_i = M_i.reshape(-1, num * d, num * d) 163 | k_tmp, dk_tmp, dk_dk_t_tmp = self.base_kernel.eval( 164 | X[:, clique, :].reshape(-1, num * d), 165 | Y[:, clique, :].reshape(-1, num * d), 166 | M_i, 167 | compute_dK_dK_t=compute_dK_dK_t, 168 | ) 169 | kernel_Xj_Xi[:, :, i, :] += k_tmp.unsqueeze(2) 170 | d_kernel_Xi[:, :, i, :] += dk_tmp.reshape(m, m, num, d)[:, :, cl_i, :] 171 | if compute_dK_dK_t: 172 | # TODO: test this with SVN 173 | # d_k_Xj_dk_Xj[:, :, i, :, i, :] += dk_dk_t_tmp 174 | pass 175 | kernel_Xj_Xi = kernel_Xj_Xi.reshape(m, m, h * d) 176 | d_kernel_Xi = d_kernel_Xi.reshape(m, m, h * d) 177 | 178 | if compute_dK_dK_t: 179 | d_k_Xj_dk_Xj = torch.einsum( 180 | 'bijk,bilm->bijm', 181 | d_kernel_Xi.unsqueeze(3), 182 | d_kernel_Xi.unsqueeze(2), 183 | ) 184 | d_k_Xj_dk_Xj = d_k_Xj_dk_Xj.reshape(m, m, h * d, h * d) 185 | 186 | return kernel_Xj_Xi, d_kernel_Xi, d_k_Xj_dk_Xj 187 | -------------------------------------------------------------------------------- /stein_lib/svgd/mp_svgd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | from .svgd import SVGD 25 | from .mp_composite_kernels import ( 26 | first_order_mp, 27 | # ternary_mp, 28 | iid_mp, 29 | ) 30 | 31 | 32 | class MP_SVGD(SVGD): 33 | """ 34 | Message Passing SVGD. 35 | """ 36 | 37 | def get_kernel( 38 | self, 39 | **kernel_params, 40 | ): 41 | 42 | if self.kernel_structure is None: 43 | return self.base_kernel 44 | elif self.kernel_structure == 'iid_mp': 45 | return iid_mp( 46 | ctrl_dim=self.ctrl_dim, 47 | kernel=self.base_kernel, 48 | **kernel_params, 49 | ) 50 | elif self.kernel_structure == 'first_order_mp': 51 | return first_order_mp( 52 | ctrl_dim=self.ctrl_dim, 53 | kernel=self.base_kernel, 54 | **kernel_params, 55 | ) 56 | else: 57 | raise IOError('Kernel structure not recognized for MP-SVGD: ', 58 | self.kernel_structure,) 59 | 60 | def get_svgd_terms( 61 | self, 62 | X, 63 | dlog_p, 64 | M=None, 65 | ): 66 | """ 67 | Handle Message-Passing SVGD update. 68 | 69 | Parameters 70 | ---------- 71 | X : Stein particles. Tensor of shape [batch, dim], 72 | dlog_p : tensor of shape [batch, dim] 73 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 74 | 75 | Returns 76 | ------- 77 | gradient: tensor of shape [batch, dim] 78 | repulsive: tensor of shape [batch, dim] 79 | 80 | """ 81 | 82 | k_XX, grad_k, pw_dists_sq = self.evaluate_kernel(X, M) 83 | 84 | gradient = (k_XX.detach() * dlog_p.unsqueeze(0)).mean(1) 85 | repulsive = grad_k.mean(1) 86 | 87 | return gradient, repulsive, pw_dists_sq 88 | 89 | def evaluate_kernel(self, X, M=None): 90 | """ 91 | 92 | Parameters 93 | ---------- 94 | X : tensor. Stein particles, of shape [batch, dim], 95 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 96 | 97 | Returns 98 | ------- 99 | k_XX : tensor of shape [batch, batch, dim] 100 | grad_k : tensor of shape [batch, batch, dim] 101 | 102 | """ 103 | k_XX, grad_k, _, pw_dists_sq = self.kernel.eval( 104 | X, X.clone().detach(), 105 | M, 106 | compute_dK_dK_t=False, 107 | ) 108 | return k_XX, grad_k, pw_dists_sq 109 | -------------------------------------------------------------------------------- /stein_lib/svgd/priors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch.distributions as dist 25 | import torch 26 | from abc import ABC, abstractmethod 27 | import numpy as np 28 | from svmpc_np.utils import from_np 29 | 30 | 31 | class mixture_of_gaussians: 32 | def __init__( 33 | self, 34 | means, 35 | sigmas, 36 | weights, 37 | ): 38 | """ 39 | 40 | Parameters 41 | ---------- 42 | means : 43 | shape: [num_particles, steps, ctrl_dim] 44 | sigmas : 45 | shape: [num_particles, steps, ctrl_dim] 46 | weights : 47 | shape: [num_particles] 48 | """ 49 | self.num_particles, self.rollout_steps, self.ctrl_dim = means.shape 50 | components = dist.Independent(dist.Normal(means, sigmas), 2) 51 | mixture = dist.Categorical(weights) 52 | self.dist = dist.mixture_same_family.MixtureSameFamily(mixture, components) 53 | 54 | def sample(self, num_samples): 55 | return self.dist.sample(num_samples).transpose(0, 1) 56 | 57 | def log_prob(self, x): 58 | return self.dist.log_prob(x) 59 | 60 | def avg_ctrl_to_goal( 61 | state, 62 | target, 63 | rollout_steps, 64 | dt, 65 | max_ctrl=100, 66 | control_type='velocity', 67 | ): 68 | """ Average control from state to target. Control must be 69 | first/second-derivative in state/target space. Ex. velocity control for 70 | cartesian states""" 71 | assert control_type in [ 72 | 'velocity', 73 | 'acceleration', 74 | ] 75 | if control_type == 'velocity': 76 | return ( 77 | (target - state)/ rollout_steps / dt 78 | ).clamp(max=max_ctrl) 79 | else: 80 | state_dim = state.dim() 81 | pos_dim = int(state_dim / 2) 82 | return ( 83 | (target[:pos_dim] - state[:pos_dim]) / rollout_steps / dt**2 84 | ).clamp(max=max_ctrl) 85 | 86 | def get_indep_gaussian_prior( 87 | sigma_init, 88 | rollout_steps, 89 | control_dim, 90 | mu_init=None, 91 | ): 92 | 93 | mu = torch.zeros( 94 | rollout_steps, 95 | control_dim, 96 | ) 97 | if mu_init is not None: 98 | mu[:, :] = mu_init 99 | 100 | sigma = torch.ones( 101 | rollout_steps, 102 | control_dim, 103 | ) * sigma_init 104 | 105 | return dist.Normal(mu, sigma) 106 | 107 | def get_multivar_gaussian_prior( 108 | sigma, 109 | rollout_steps, 110 | control_dim, 111 | Sigma_type='indep_ctrl', 112 | mu_init=None, 113 | ): 114 | """ 115 | :param sigma: standard deviation on controls 116 | :param control_dim: 117 | :param rollout_steps: 118 | :param Sigma_type: Covariance prior type, 'indep_ctrl': diagonal Sigma, 119 | 'const_ctrl': const. ctrl Sigma. 120 | :param mu_init: 121 | :return: distribution with MultivariateNormal for each control dimension 122 | """ 123 | assert Sigma_type in [ 124 | 'indep_ctrl', 125 | 'const_ctrl', 126 | ], 'Invalid type for control prior dist.' 127 | 128 | mu = torch.zeros( 129 | rollout_steps, 130 | control_dim, 131 | ) 132 | 133 | if mu_init is not None: 134 | mu[:, :] = mu_init 135 | 136 | if Sigma_type == 'const_ctrl': 137 | # Const-ctrl covariance 138 | Sigma_gen = const_ctrl_Sigma 139 | 140 | elif Sigma_type == 'indep_ctrl': 141 | # Isotropic covariance 142 | Sigma_gen = diag_Sigma 143 | else: 144 | raise IOError('Sigma_type not recognized.') 145 | 146 | Sigma = Sigma_gen( 147 | sigma, 148 | rollout_steps, 149 | control_dim, 150 | ) 151 | 152 | # check_Sigma_is_valid(Sigma) 153 | 154 | return Gaussian_Ctrl_Dist( 155 | rollout_steps, 156 | control_dim, 157 | mu, 158 | Sigma, 159 | ) 160 | 161 | def diag_Sigma(sigma, length=None, ctrl_dim=None): 162 | """ 163 | Time-independent diagonal covariance matrix. Assumes independence 164 | across control dimension. 165 | """ 166 | Sigma = torch.eye( 167 | length, 168 | ).unsqueeze(-1).repeat(1, 1, ctrl_dim) 169 | 170 | if isinstance(sigma, list): 171 | Sigma = Sigma * torch.from_numpy(np.array(sigma)).float()**2 172 | else: 173 | Sigma = Sigma * sigma**2 174 | return Sigma 175 | 176 | def const_ctrl_Sigma( 177 | sigma, 178 | length=None, 179 | ctrl_dim=None, 180 | ): 181 | """ 182 | Constant-control covariance prior. Assumes independence across control 183 | dimension. 184 | """ 185 | 186 | if isinstance(sigma, list): 187 | sigma = torch.from_numpy(np.array(sigma)).float() 188 | 189 | L = torch.tril( 190 | torch.ones( 191 | length, 192 | length-1, 193 | ), diagonal=-1, 194 | ) 195 | LL_t = torch.matmul( 196 | L, L.transpose(0, 1) 197 | ) 198 | 199 | LL_t += torch.ones( 200 | length, 201 | length, 202 | ) 203 | 204 | Sigma = LL_t.unsqueeze(-1).repeat(1, 1, ctrl_dim) * sigma**2 205 | return Sigma 206 | 207 | def check_Sigma_is_valid(Sigma): 208 | """ 209 | Check determinant of Sigma to pre-empt potential numerical instability 210 | in Multi-variate Gaussian. 211 | For example, dist.logprob(x) >> 1. 212 | :param Sigma: covariance matrix (Tensor) of shape [rollout_length, 213 | rollout_length, control_dim] 214 | """ 215 | Sigma_np = Sigma.cpu().numpy() 216 | for i in range(Sigma.shape[-1]): 217 | det = np.linalg.det(Sigma_np[:,:,i]) 218 | if det < 1.e-7: 219 | raise ZeroDivisionError( 220 | 'Covariance-determinant too small, potential for underflow. ' 221 | 'Consider increasing sigma.' 222 | ) 223 | 224 | class Prior_Ctrl_Traj_Dist (ABC): 225 | """ 226 | Prior distribution on control trajectories, Assumes independence across 227 | ctrl_dim. 228 | """ 229 | def __init__( 230 | self, 231 | rollout_steps, 232 | ctrl_dim, 233 | ): 234 | self.rollout_steps = rollout_steps 235 | self.ctrl_dim = ctrl_dim 236 | self.list_ctrl_dists = [] 237 | 238 | @abstractmethod 239 | def make_dist(self): 240 | """ Construct list of sampling distribution, one for each control 241 | dimension""" 242 | pass 243 | 244 | def log_prob( 245 | self, 246 | samples, 247 | cond_inputs=None, 248 | ): 249 | """ 250 | :param samples: control samples of shape ( num_particles, rollout_steps, 251 | ctrl_dim) 252 | :return: log_probs, of shape (num_particles.) 253 | """ 254 | assert samples.dim() == 3 255 | assert samples.size(1) == self.rollout_steps 256 | assert samples.size(2) == self.ctrl_dim 257 | num_particles = samples.size(0) 258 | 259 | log_probs = torch.zeros( 260 | num_particles, 261 | ) 262 | 263 | for i in range(self.ctrl_dim): 264 | samp = samples[:, :, i] # [num_particles, rollout_steps] 265 | log_probs += self.list_ctrl_dists[i].log_prob( 266 | samp 267 | ) 268 | 269 | return log_probs 270 | 271 | def update_means(self, means): 272 | for i in range(self.ctrl_dim): 273 | self.list_ctrl_dists[i].loc = means[..., i].detach().clone() 274 | 275 | def sample( 276 | self, 277 | num_samples, 278 | cond_inputs=None, 279 | ): 280 | """ 281 | :param num_particles: number of control particles 282 | :param cond_inputs: conditional input (not implemented) 283 | :return: control tensor, of size (rollout_steps, num_particles, 284 | ctrl_dim) 285 | """ 286 | U_s = torch.empty( 287 | num_samples, 288 | self.rollout_steps, 289 | self.ctrl_dim, 290 | ) 291 | for i in range(self.ctrl_dim): 292 | U_s[:, :, i] = self.list_ctrl_dists[i].sample( 293 | (num_samples,) 294 | ) 295 | return U_s 296 | 297 | class Gaussian_Ctrl_Dist(Prior_Ctrl_Traj_Dist): 298 | """ 299 | Multivariate Gaussian distribution for each control dimension. 300 | """ 301 | def __init__( 302 | self, 303 | rollout_steps, 304 | ctrl_dim, 305 | mu=None, 306 | Sigma=None, 307 | ): 308 | 309 | assert mu.size(0) == rollout_steps 310 | assert mu.size(1) == ctrl_dim 311 | super().__init__( 312 | rollout_steps, 313 | ctrl_dim, 314 | ) 315 | self.mu = mu 316 | self.Sigma = Sigma 317 | self.list_ctrl_dists = [] 318 | 319 | self.make_dist() 320 | 321 | def make_dist(self): 322 | for i in range(self.ctrl_dim): 323 | self.list_ctrl_dists.append( 324 | dist.MultivariateNormal( 325 | self.mu[:, i], 326 | covariance_matrix=self.Sigma[:,:,i] 327 | )) 328 | 329 | class Gamma_Ctrl_Dist(Prior_Ctrl_Traj_Dist): 330 | def __init__(self): 331 | raise NotImplementedError 332 | 333 | def make_dist(self): 334 | raise NotImplementedError 335 | -------------------------------------------------------------------------------- /stein_lib/svgd/svgd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import numpy as np 26 | from time import time 27 | from .base_kernels import ( 28 | RBF, 29 | IMQ, 30 | RBF_Anisotropic, 31 | Linear, 32 | ) 33 | 34 | from .composite_kernels import ( 35 | iid, 36 | ) 37 | from .LBFGS import FullBatchLBFGS, LBFGS 38 | from ..utils import get_jacobian, calc_pw_distances, calc_scaled_pw_distances 39 | from stein_lib.models.double_banana_analytic import doubleBanana_analytic 40 | 41 | 42 | class SVGD(): 43 | """ 44 | Uses analytic kernel gradients. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | kernel_base_type='RBF', 50 | kernel_structure=None, 51 | verbose=False, 52 | control_dim=None, 53 | repulsive_scaling=1., 54 | **kernel_params, 55 | ): 56 | 57 | self.verbose = verbose 58 | self.kernel_base_type = kernel_base_type 59 | self.kernel_structure = kernel_structure 60 | self.ctrl_dim = control_dim 61 | self.repulsive_scaling = repulsive_scaling 62 | 63 | self.base_kernel = self.get_base_kernel(**kernel_params) 64 | self.kernel = self.get_kernel(**kernel_params) 65 | self.geom_metric_type = kernel_params['geom_metric_type'] 66 | self.hessian_scaled = False 67 | self._M = None 68 | if self.kernel_base_type in \ 69 | [ 70 | 'RBF_Anisotropic', 71 | 'RBF_Matrix', 72 | 'IMQ_Matrix', 73 | 'RBF_Weighted_Matrix', 74 | ]: 75 | self.hessian_scaled = True 76 | 77 | def get_base_kernel( 78 | self, 79 | **kernel_params, 80 | ): 81 | """ 82 | 83 | """ 84 | if self.kernel_base_type == 'RBF': 85 | return RBF( 86 | **kernel_params, 87 | ) 88 | elif self.kernel_base_type == 'IMQ': 89 | return IMQ( 90 | **kernel_params, 91 | ) 92 | elif self.kernel_base_type == 'RBF_Anisotropic': 93 | return RBF_Anisotropic( 94 | **kernel_params, 95 | ) 96 | elif self.kernel_base_type == 'Linear': 97 | return Linear( 98 | **kernel_params, 99 | ) 100 | else: 101 | raise IOError('Stein kernel type not recognized: ', 102 | self.kernel_base_type) 103 | 104 | def get_kernel( 105 | self, 106 | **kernel_params, 107 | ): 108 | 109 | if self.kernel_structure is None: 110 | return self.base_kernel 111 | elif self.kernel_structure == 'iid': 112 | return iid( 113 | kernel=self.base_kernel, 114 | **kernel_params, 115 | ) 116 | else: 117 | raise IOError('Kernel structure not recognized for SVGD: ', 118 | self.kernel_structure,) 119 | 120 | def get_svgd_terms( 121 | self, 122 | X, 123 | dlog_p, 124 | M=None, 125 | ): 126 | """ 127 | Parameters 128 | ---------- 129 | X : Tensor 130 | Stein particles. Tensor of shape [batch, dim], 131 | dlog_p : Tensor 132 | Gradient of log probability. Shape [batch, dim] 133 | M : (Optional) 134 | Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 135 | 136 | Returns 137 | ------- 138 | grad: Tensor 139 | Attractive SVGD gradient term. 140 | Shape [batch, dim]. 141 | rep: Tensor 142 | Repulsive SVGD term. 143 | Shape [batch, dim]. 144 | pw_dists_sq: Tensor 145 | Squared pairwise distances between particles. Can be scaled by a 146 | metric. 147 | Shape [batch, batch]. 148 | """ 149 | 150 | k_XX, grad_k, pw_dists_sq = self.evaluate_kernel(X, M) 151 | 152 | grad = k_XX.mm(dlog_p) / k_XX.size(1) 153 | rep = grad_k.mean(1) 154 | 155 | return grad, rep, pw_dists_sq 156 | 157 | def evaluate_kernel(self, X, M=None): 158 | """ 159 | 160 | Parameters 161 | ---------- 162 | X : tensor. Stein particles, of shape [batch, dim], 163 | M : (Optional) Negative Hessian or Fisher matrices. Tensor of shape [batch, dim, dim] 164 | 165 | Returns 166 | ------- 167 | k_XX : 168 | tensor of shape [batch, batch] 169 | grad_k : 170 | tensor of shape [batch, batch, dim] 171 | pw_dists_sq: 172 | 173 | """ 174 | k_XX, grad_k, _, pw_dists_sq = self.kernel.eval( 175 | X, X.clone().detach(), 176 | M, 177 | compute_dK_dK_t=False, 178 | ) 179 | return k_XX, grad_k, pw_dists_sq 180 | 181 | def phi( 182 | self, 183 | X, 184 | dlog_p, 185 | dlog_lh=None, 186 | Hess=None, 187 | Hess_prior=None, 188 | Jacobian=None, 189 | copy_pw_dists=False 190 | ): 191 | """ 192 | Computes the SVGD gradient. 193 | 194 | Parameters 195 | ---------- 196 | X : Tensor 197 | Stein particles, of shape [batch, dim]. 198 | dlog_p : Tensor 199 | Score function, of shape [batch, dim]. 200 | 201 | Returns 202 | ------- 203 | Phi: Tensor 204 | Empirical Stein gradient, of shape [batch, dim]. 205 | pw_dists_sq: Tensor 206 | Squared pairwise distances between particles. Can be metric-scaled. 207 | Shape [batch, batch]. 208 | """ 209 | 210 | if self.geom_metric_type is None: 211 | M = None 212 | pass 213 | elif self.geom_metric_type == 'full_hessian': 214 | assert Hess is not None 215 | M = - Hess 216 | elif self.geom_metric_type == 'fisher': 217 | # Average Fisher matrix (likelihood only) 218 | np = dlog_lh.shape[0] 219 | M = torch.bmm(dlog_lh.reshape(np, -1, 1,), dlog_lh.reshape(np, 1, -1)) 220 | # M -= torch.eye(M.shape[1], M.shape[2]) * 1.e-8 221 | elif self.geom_metric_type == 'jacobian_product': 222 | # Average Fisher matrix (full posterior gradient) 223 | M = torch.bmm(Jacobian.transpose(1, 2), Jacobian) 224 | M = M - Hess_prior 225 | elif self.geom_metric_type == 'riemannian': 226 | # Average Fisher matrix plus neg. Hessian of log prior 227 | b = dlog_lh.shape[0] 228 | Hess = torch.bmm(dlog_lh.view(b, -1, 1,), dlog_lh.view(b, 1, -1)) 229 | M = - Hess - Hess_prior 230 | elif self.geom_metric_type == 'local_Hessians': 231 | # Average Fisher matrix plus neg. Hessian of log prior 232 | M = - Hess - Hess_prior 233 | else: 234 | raise NotImplementedError 235 | 236 | # SVGD attractive / repulsive terms, inter-particle distances 237 | grad, rep, pw_dists_sq = self.get_svgd_terms( 238 | X, 239 | dlog_p, 240 | M, 241 | ) 242 | if self.verbose: 243 | print('gradient l2-norm: {:5.4f}'.format( 244 | grad.norm().detach().cpu().numpy())) 245 | print('repulsive l2-norm: {:5.4f}'.format( 246 | rep.norm().detach().cpu().numpy())) 247 | 248 | # SVGD gradient 249 | phi = grad + self.repulsive_scaling * rep 250 | 251 | self._pw_dists_sq = pw_dists_sq 252 | self._X = X 253 | 254 | return phi, pw_dists_sq 255 | 256 | def apply( 257 | self, 258 | X, 259 | model, 260 | iters=100, 261 | step_size=1., 262 | use_analytic_grads=False, 263 | optimizer_type='SGD' 264 | ): 265 | """ 266 | Runs SVGD optimization on a distribution model, given a particle 267 | initialization X, and a selected optimization algorithm. 268 | 269 | Parameters 270 | ---------- 271 | X : (Tensor) 272 | Stein particles, of shape [dim, num_particles] 273 | model: 274 | Probability distribution model instance. Can be of differentiable 275 | type torch.distributions, or custom model with analytic functions 276 | (see examples under 'stein_lib/models'). 277 | iters: 278 | Number of optimization iterations. 279 | eps : Float 280 | Step size. 281 | use_analytic_grads: Bool 282 | Set to 'True' if probability model uses analytic gradients. If set to 283 | 'False', numerical gradient will be computed. 284 | optimizer_type: str 285 | Optimizer used for updates. 286 | """ 287 | 288 | particle_history = [] 289 | particle_history.append(X.clone().cpu().numpy()) 290 | 291 | dts = [] 292 | X = torch.autograd.Variable(X, requires_grad=True) 293 | 294 | if optimizer_type == 'SGD': 295 | optimizer = torch.optim.SGD([X], lr=step_size) 296 | elif optimizer_type == 'Adam': 297 | optimizer = torch.optim.Adam([X], lr=step_size) 298 | elif optimizer_type == 'LBFGS': 299 | optimizer = torch.optim.LBFGS( 300 | [X], 301 | lr=step_size, 302 | max_iter=100, 303 | # max_eval=20 * 1.25, 304 | tolerance_change=1e-9, 305 | history_size=25, 306 | line_search_fn=None, #'strong_wolfe' 307 | ) 308 | elif optimizer_type == 'FullBatchLBFGS': 309 | optimizer = FullBatchLBFGS( 310 | [X], 311 | lr=step_size, 312 | history_size=25, 313 | line_search='None', #'Wolfe' 314 | ) 315 | else: 316 | raise NotImplementedError 317 | 318 | # Optimizer type 319 | def closure(): 320 | optimizer.zero_grad() 321 | Hess = None 322 | if use_analytic_grads: 323 | 324 | if isinstance(model, doubleBanana_analytic): 325 | # Used only by double_banana model 326 | F = model.forward_model(X) 327 | J = model.jacob_forward(X) 328 | dlog_p = model.grad_log_p(X, F, J) 329 | else: 330 | dlog_p = model.grad_log_p(X) 331 | 332 | if self.hessian_scaled and \ 333 | self.geom_metric_type not in ['fisher']: 334 | 335 | if isinstance(model, doubleBanana_analytic): 336 | ## Used only by double_banana model 337 | # Gauss-Newton approximation 338 | Hess = model.hessian(X, J) # returns hessian of negative log posterior 339 | else: 340 | Hess = model.hessian(dlog_p, X) 341 | else: 342 | # Numerical Gradients 343 | log_p = model.log_prob(X).unsqueeze(1) 344 | dlog_p = torch.autograd.grad( 345 | log_p.sum(), 346 | X, 347 | create_graph=True, 348 | )[0] 349 | if self.hessian_scaled and \ 350 | self.geom_metric_type not in ['fisher']: 351 | Hess = get_jacobian(dlog_p, X) 352 | 353 | # SVGD gradient 354 | with torch.no_grad(): 355 | Phi, pw_dists_sq = self.phi( 356 | X, 357 | dlog_p, 358 | dlog_lh=dlog_p, 359 | Hess=Hess, 360 | ) 361 | X.grad = -1. * Phi 362 | # check(X.grad, 'X.grad') 363 | loss = 1. 364 | return loss 365 | 366 | for i in range(iters): 367 | self.i = i 368 | t_start = time() 369 | if isinstance(optimizer, FullBatchLBFGS): 370 | options = {'closure': closure, 'current_loss': closure()} 371 | optimizer.step(options) 372 | else: 373 | optimizer.step(closure) 374 | dt = time() - t_start 375 | if self.verbose: 376 | print('dt (SVGD): {}\n'.format(dt)) 377 | dts.append(dt) 378 | particle_history.append(X.clone().detach().cpu().numpy()) 379 | dt_stats = np.array(dts) 380 | if self.verbose: 381 | print("\nAvg. SVGD compute time: {}".format(dt_stats.mean())) 382 | print("Std. dev. SVGD compute time: {}\n".format(dt_stats.std())) 383 | 384 | (pw_dists, 385 | pw_dists_scaled,) = self.get_pairwise_dists() 386 | 387 | return ( 388 | X, 389 | particle_history, 390 | pw_dists, 391 | pw_dists_scaled, 392 | ) 393 | 394 | def get_pairwise_dists(self): 395 | # pw_dists output from svgd-gradient computation 396 | pw_dists_out = torch.sqrt(self._pw_dists_sq.clone().detach()) 397 | X = self._X.clone().detach() 398 | 399 | if self.hessian_scaled: 400 | # Hessian-scaled pw_dists 401 | pw_dists_scaled = pw_dists_out 402 | pw_dists = calc_pw_distances(X) 403 | else: 404 | # Euclidean Pairwise distances 405 | pw_dists = pw_dists_out 406 | pw_dists_scaled = None 407 | return pw_dists, pw_dists_scaled 408 | 409 | def check(tsr, name): 410 | """Check a tensor for inf/nan/large values.""" 411 | isinf = torch.isinf(tsr) 412 | if isinf.any(): 413 | if isinf.all(): 414 | infind = 'all' 415 | else: 416 | infind = torch.nonzero(isinf) 417 | print(name, 'isinf', infind, flush=True) 418 | 419 | isnan = torch.isnan(tsr) 420 | if isnan.any(): 421 | if isnan.all(): 422 | nanind = 'all' 423 | else: 424 | nanind = torch.nonzero(isnan) 425 | print(name, 'isnan', nanind, flush=True) 426 | 427 | if (tsr.abs() > 1e6).any(): 428 | print(name, 'isvlarge', flush=True) 429 | -------------------------------------------------------------------------------- /stein_lib/svn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sashalambert/stein_lib/a1afeca70afc831aab5a4d057be773eb17750246/stein_lib/svn/__init__.py -------------------------------------------------------------------------------- /stein_lib/svn/mp_svn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | from ..svgd.mp_svgd import MP_SVGD 25 | from .svn import SVN 26 | 27 | 28 | class MP_SVN(SVN, MP_SVGD): 29 | 30 | def __init__( 31 | self, 32 | kernel_base_type='RBF', 33 | kernel_structure=None, 34 | verbose=False, 35 | control_dim=None, 36 | repulsive_scaling=1, 37 | **kernel_params, 38 | ): 39 | 40 | super().__init__( 41 | kernel_base_type, 42 | kernel_structure, 43 | verbose, 44 | control_dim, 45 | repulsive_scaling, 46 | **kernel_params, 47 | ) 48 | 49 | def get_second_variation( 50 | self, 51 | k_XX, 52 | dk_dk_t, 53 | Hess, 54 | ): 55 | """ 56 | 57 | Parameters 58 | ---------- 59 | k_XX : tensor 60 | Kernel Grammian. Shape: [num_particles, num_particles, dim] 61 | dk_dk_t : tensor 62 | Outer products of kernel gradients. 63 | Shape: [num_particles, num_particles, dim, dim] 64 | Hess : tensor 65 | Hessian of log_prob. 66 | Shape: [num_particles, dim, dim] 67 | 68 | Returns 69 | ------- 70 | H : tensor 71 | Second variation. Shape [num_particles, dim, dim]. 72 | """ 73 | k_sq = (k_XX ** 2).unsqueeze(-1) # b x b x d x 1 74 | H_ii = - Hess * k_sq + dk_dk_t 75 | H = H_ii.mean(dim=1) 76 | return H 77 | 78 | def get_svgd_terms( 79 | self, 80 | X, 81 | dlog_p, 82 | M=None, 83 | ): 84 | 85 | # Use message-passing format 86 | return MP_SVGD.get_svgd_terms( 87 | self, X, dlog_p, M, 88 | ) 89 | -------------------------------------------------------------------------------- /stein_lib/svn/svn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | from stein_lib.svgd.svgd import SVGD 27 | from time import time 28 | 29 | 30 | class SVN(SVGD): 31 | 32 | def __init__( 33 | self, 34 | kernel_base_type='RBF', 35 | kernel_structure=None, 36 | verbose=False, 37 | control_dim=None, 38 | repulsive_scaling=1., 39 | **kernel_params, 40 | ): 41 | 42 | super().__init__( 43 | kernel_base_type, 44 | kernel_structure, 45 | verbose, 46 | control_dim, 47 | repulsive_scaling, 48 | **kernel_params, 49 | ) 50 | 51 | def get_second_variation( 52 | self, 53 | k_XX, 54 | dk_dk_t, 55 | Hess, 56 | ): 57 | """ 58 | 59 | Parameters 60 | ---------- 61 | k_XX : tensor 62 | Kernel Grammian. Shape: [num_particles, num_particles] 63 | dk_dk_t : tensor 64 | Outer products of kernel gradients. 65 | Shape: [num_particles, num_particles, dim, dim] 66 | Hess : tensor 67 | Hessian of log_prob. 68 | Shape: [num_particles, dim, dim] 69 | 70 | Returns 71 | ------- 72 | H : tensor 73 | Second variation. Shape [num_particles, dim, dim]. 74 | """ 75 | k_sq = (k_XX ** 2).unsqueeze(-1).unsqueeze(-1) # b x b x 1 x 1 76 | H_ii = - Hess * k_sq + dk_dk_t 77 | H = H_ii.mean(dim=1) 78 | return H 79 | 80 | def get_svn_terms( 81 | self, 82 | X, 83 | dlog_p, 84 | dlog_lh, 85 | Hess, 86 | Hess_prior=None, 87 | transpose=False, 88 | ): 89 | """ 90 | Parameters 91 | ---------- 92 | X : (Tensor) 93 | Stein particles, of shape [num_particles, dim] 94 | or of shape [dim, num_particles]. If Tensor dimension is greater than 2, 95 | extra dimensions will be flattened. 96 | dlog_p : (Tensor) 97 | Score function, of shape [num_particles, dim] 98 | or of shape [dim, num_particles]. If Tensor dimension is greater than 2, 99 | extra dimensions will be flattened. 100 | Hess : (Tensor) 101 | Hessian of prob. density, of shape [num_particles, dim, dim] 102 | or of shape [dim, dim, num_particles]. If Tensor dimension is greater than 3, 103 | it will be reshaped appropriately such that dimension is 3. 104 | trans 105 | 106 | pose: Bool 107 | Transpose input and output Tensors. 108 | Returns 109 | ------- 110 | Phi: (Tensor) 111 | Empirical Stein gradient, of shape [num_particles, dim] 112 | """ 113 | 114 | shape_original = X.shape 115 | 116 | X, dlog_p, Hess = self.reshape_inputs( 117 | X, 118 | dlog_p, 119 | Hess, 120 | transpose, 121 | ) 122 | 123 | if self.geom_metric_type is None: 124 | pass 125 | elif self.geom_metric_type == 'full_hessian': 126 | assert Hess is not None 127 | M = - Hess 128 | M += 1.e-6 * torch.eye(M.shape[1], M.shape[2]) 129 | elif self.geom_metric_type == 'fisher': 130 | ## Average Fisher matrix (likelihood only) 131 | np = dlog_lh.shape[0] 132 | M = torch.bmm(dlog_lh.reshape(np, -1, 1,), dlog_lh.reshape(np, 1, -1)) 133 | M += 1.e-6 * torch.eye(M.shape[1], M.shape[2]) 134 | elif self.geom_metric_type == 'jacobian_product': 135 | ## Average Fisher matrix (full posterior gradient) 136 | dim = dlog_p.shape[-1] 137 | M = torch.bmm(dlog_p.view(-1, dim, 1,), dlog_p.view(-1, 1, dim)) 138 | M += 1.e-3 * torch.eye(M.shape[1], M.shape[2]) 139 | elif self.geom_metric_type == 'riemannian': 140 | # Average Fisher matrix plus neg. Hessian of log prior 141 | b = dlog_lh.shape[0] 142 | F = torch.bmm(dlog_lh.view(b, -1, 1,), dlog_lh.view(b, 1, -1)) 143 | Hess_prior = Hess_prior.reshape( 144 | dlog_p.shape[0], 145 | dlog_p.shape[1], 146 | dlog_p.shape[1], 147 | ) 148 | M = F - Hess_prior 149 | else: 150 | raise NotImplementedError 151 | 152 | (k_XX, 153 | grad_k, 154 | dk_dk_t) = self.kernel.eval( 155 | X, X.clone().detach(), 156 | M, 157 | compute_dK_dK_t=True, 158 | ) 159 | 160 | ## Phi - first variation ### 161 | grad, rep = self.get_svgd_terms( 162 | X, 163 | dlog_p, 164 | M, 165 | ) 166 | # if True: 167 | if self.verbose: 168 | print('gradient l2-norm: {:5.4f}'.format( 169 | grad.norm().detach().cpu().numpy())) 170 | print('repulsive l2-norm: {:5.4f}'.format( 171 | rep.norm().detach().cpu().numpy())) 172 | 173 | phi = grad + self.repulsive_scaling * rep 174 | 175 | # phi += 0.005 * torch.randn(phi.shape) 176 | 177 | ## Q - Second varation ## 178 | H = self.get_second_variation(k_XX, dk_dk_t, Hess) 179 | H += 1.e-4 * torch.eye(H.shape[1], H.shape[2]) 180 | 181 | Q = torch.solve(phi.unsqueeze(2), H).solution 182 | 183 | ### Debugging - use scipy solver 184 | # phi_np = phi.unsqueeze(2).clone().detach().numpy() 185 | # H_np = H.clone().detach().numpy() 186 | # Q_np = np.zeros_like(phi_np) 187 | # for i in range(X.shape[0]): 188 | # Q_np[i] = scipy_solve(H_np[i], phi_np[i]) 189 | # Q = torch.from_numpy(Q_np) 190 | 191 | Q = Q.squeeze() 192 | 193 | # Reshape Q to match original tensor dimensions 194 | if transpose: 195 | Q = Q.t() 196 | Q = Q.reshape(shape_original) 197 | 198 | return ( 199 | Q, 200 | k_XX, 201 | grad_k, 202 | ) 203 | 204 | def apply( 205 | self, 206 | X, 207 | model, 208 | iters=100, 209 | eps=1., 210 | use_analytic_grads=False, 211 | ): 212 | """ 213 | SVGD updates. 214 | 215 | Parameters 216 | ---------- 217 | X : (Tensor) of nd.array 218 | Stein particles, of shape [dim, num_particles] 219 | eps : Float 220 | Step size. 221 | """ 222 | 223 | particle_history = [] 224 | particle_history.append(X.clone().cpu().numpy()) 225 | 226 | X = torch.autograd.Variable(X, requires_grad=True) 227 | 228 | # Time stats 229 | dts = [] 230 | for i in range(iters): 231 | 232 | if use_analytic_grads: 233 | F = model.forward_model(X) 234 | J = model.jacob_forward(X) 235 | dlog_p = model.grad_log_p(X, F, J) 236 | # Gauss-Newton approximation 237 | Hess = model.GN_hessian(X, J) 238 | Hess = -1 * Hess 239 | else: 240 | log_p = model.log_prob(X).unsqueeze(1) 241 | dlog_p = torch.autograd.grad( 242 | log_p.sum(), 243 | X, 244 | create_graph=True, 245 | )[0] 246 | 247 | ## Full Hessian 248 | Hess = self.get_jacobian(dlog_p, X) 249 | 250 | t_start = time() 251 | 252 | (Q, 253 | k_XX, 254 | grad_k) = self.get_svn_terms( 255 | X, 256 | dlog_p, 257 | dlog_p.transpose(0,1), 258 | Hess, 259 | None, 260 | transpose=True 261 | ) 262 | 263 | dt = time() - t_start 264 | print('dt (SVN): {}'.format(dt)) 265 | dts.append(dt) 266 | 267 | X = X + eps * Q 268 | 269 | X = X.detach() 270 | X.requires_grad = True 271 | 272 | particle_history.append(X.clone().detach().cpu().numpy()) 273 | 274 | dt_stats = np.array(dts) 275 | print("\nAvg. SVN compute time: {}".format(dt_stats.mean())) 276 | print("Std. dev. SVN compute time: {}\n".format(dt_stats.std())) 277 | 278 | return X, particle_history -------------------------------------------------------------------------------- /stein_lib/svn/svn_original.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import numpy as np 25 | import torch 26 | 27 | from time import time 28 | torch.set_default_tensor_type(torch.DoubleTensor) 29 | 30 | 31 | class SVN_original: 32 | """ 33 | SVN for debugging. 34 | Adapted from https://github.com/gianlucadetommaso/Stein-variational-samplers. 35 | """ 36 | def __init__( 37 | self, 38 | model, 39 | particles, 40 | iters=100, 41 | eps=1.0 42 | ): 43 | self.model = model 44 | self.DoF = model.dim 45 | self.nParticles = particles.shape[1] 46 | self.nIterations = iters 47 | self.stepsize = eps 48 | self.particles = particles 49 | 50 | def apply(self): 51 | maxmaxshiftold = np.inf 52 | maxshift = np.zeros(self.nParticles) 53 | Q = np.zeros( (self.DoF, self.nParticles) ) 54 | particle_history = [] 55 | particle_history.append(np.copy(self.particles)) 56 | 57 | # Time stats 58 | dts = [] 59 | for iter_ in range(self.nIterations): 60 | 61 | particles_tn = torch.from_numpy(self.particles) 62 | F_tn = self.model.forward_model(particles_tn) 63 | J_tn = self.model.jacob_forward(particles_tn) 64 | gmlpt_tn = self.model.grad_log_p(particles_tn, F_tn, J_tn) 65 | Hmlpt_tn = self.model.GN_hessian(particles_tn, J_tn) 66 | 67 | gmlpt = gmlpt_tn.cpu().numpy() 68 | Hmlpt = Hmlpt_tn.cpu().numpy() 69 | M = np.mean(Hmlpt, 2) 70 | 71 | t_start = time() 72 | for i_ in range(self.nParticles): 73 | 74 | sign_diff = self.particles[:, i_, np.newaxis] - self.particles 75 | Msd = np.matmul(M, sign_diff) 76 | kern = np.exp( - 0.5 * np.sum( sign_diff * Msd, 0)) 77 | gkern = Msd * kern 78 | 79 | mgJ = np.mean(- gmlpt * kern + gkern, 1) 80 | HJ = np.mean(Hmlpt * kern ** 2, 2) + np.matmul(gkern, gkern.T) / self.nParticles 81 | 82 | Q[:, i_] = np.linalg.solve(HJ, mgJ) 83 | 84 | maxshift[i_] = np.linalg.norm(Q[:, i_], np.inf) 85 | 86 | self.particles += self.stepsize * Q 87 | maxmaxshift = np.max(maxshift) 88 | 89 | dt = time() - t_start 90 | print('dt (SVN): {}'.format(dt)) 91 | dts.append(dt) 92 | 93 | if np.isnan(maxmaxshift) or (maxmaxshift > 1e20): 94 | print('Reset particles...') 95 | self.resetParticles() 96 | self.stepsize = 1 97 | elif maxmaxshift < maxmaxshiftold: 98 | self.stepsize *= 1.01 99 | else: 100 | self.stepsize *= 0.9 101 | maxmaxshiftold = maxmaxshift 102 | particle_history.append(np.copy(self.particles)) 103 | 104 | dt_stats = np.array(dts) 105 | print("\nAvg. SVN_orignal compute time: {}".format(dt_stats.mean())) 106 | print("Std. dev. SVN_original compute time: {}\n".format(dt_stats.std())) 107 | 108 | return particle_history 109 | 110 | def resetParticles(self): 111 | self.particles = np.random.normal(scale=1, size=(self.DoF, self.nParticles) ) 112 | -------------------------------------------------------------------------------- /stein_lib/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | 25 | import torch 26 | import numpy as np 27 | import torch 28 | import matplotlib.pyplot as plt 29 | import seaborn as sns 30 | import matplotlib.animation as animation 31 | 32 | 33 | def get_jacobian( 34 | gradient, 35 | X, 36 | ): 37 | """ 38 | Returns the Jacobian matrix, given the gradient 39 | Parameters 40 | ---------- 41 | gradient : (Tensor) 42 | Of shape [dim, batch] 43 | X : (Tensor) 44 | of shape [dim, batch] 45 | Returns 46 | ------- 47 | J : (Tensor) 48 | Jacobian, of shape [dim, dim, batch] 49 | """ 50 | dg_dXi = [ 51 | torch.autograd.grad( 52 | gradient[:, i].sum(), 53 | X, 54 | retain_graph=True, 55 | )[0] for i in range(gradient.shape[1]) 56 | ] 57 | J = torch.stack(dg_dXi, dim=1) 58 | return J 59 | 60 | 61 | def calc_pw_distances(X): 62 | """ 63 | Returns the pairwise distances between particles. 64 | Parameters 65 | ---------- 66 | X : Tensor 67 | Points. of shape [dim, batch] 68 | """ 69 | XX = X.matmul(X.t()) 70 | pairwise_dists_sq = -2 * XX + XX.diag().unsqueeze(1) + XX.diag().unsqueeze(0) 71 | pw_dists = torch.sqrt(pairwise_dists_sq) 72 | return pw_dists 73 | 74 | 75 | def calc_scaled_pw_distances(X, M): 76 | """ 77 | Returns the metric-scaled / anisotropic pairwise distances between particles. 78 | Parameters 79 | ---------- 80 | X : Tensor 81 | Points. of shape [dim, batch] 82 | M : Tensor 83 | Metric. of shape [dim, dim] 84 | """ 85 | X_M_Xt = X @ M @ X.t() 86 | pw_dists_sq = -2 * X_M_Xt + X_M_Xt.diag().unsqueeze(1) + X_M_Xt.diag().unsqueeze(0) 87 | pw_dists = torch.sqrt(pw_dists_sq) 88 | return pw_dists 89 | 90 | 91 | def plot_graph_2D( 92 | particles, 93 | edges, 94 | log_prob, 95 | edge_vals=None, 96 | edge_coll_pts=None, 97 | edge_coll_thresh=None, 98 | save_path='/tmp/graph.png', 99 | to_numpy=False, 100 | ax_limits=[[-4, 4],[4, 4]], 101 | ): 102 | 103 | fig = plt.figure(figsize=(5,5)) 104 | ax = plt.gca() 105 | 106 | ngrid = 100 107 | x = np.linspace(ax_limits[0][0], ax_limits[0][1], ngrid) 108 | y = np.linspace(ax_limits[1][0], ax_limits[1][1], ngrid) 109 | X, Y = np.meshgrid(x,y) 110 | 111 | grid = np.vstack( 112 | (np.ndarray.flatten(X), np.ndarray.flatten(Y)), 113 | ) 114 | 115 | if to_numpy: 116 | grid = torch.from_numpy(grid) 117 | z = log_prob(grid.t()).cpu().numpy() 118 | Z = np.exp(z).reshape(ngrid, ngrid) 119 | particles = particles.detach().cpu().numpy() 120 | else: 121 | Z = np.exp( 122 | log_prob(grid), 123 | ).reshape(ngrid, ngrid) 124 | 125 | plt.contourf(X, Y, Z, 10) 126 | 127 | for i in range(edges.shape[0]): 128 | node_pair = edges[i] 129 | color = 'k' 130 | if edge_vals is not None: 131 | if edge_vals[i] < edge_coll_thresh: 132 | color = 'b' 133 | plt.plot( 134 | particles[node_pair, 0], 135 | particles[node_pair, 1], 136 | color, 137 | markersize=1, 138 | ) 139 | 140 | if edge_coll_pts is not None: 141 | plt.plot(edge_coll_pts[:, 0], edge_coll_pts[:, 1], 'go', markersize=2) 142 | 143 | xlim = ax_limits[0] 144 | ylim = ax_limits[1] 145 | plt.plot(particles[:, 0], particles[:, 1], 'ro', markersize=3) 146 | 147 | ax.set_xlim(xlim) 148 | ax.set_ylim(ylim) 149 | plt.savefig(save_path) 150 | plt.show() 151 | 152 | 153 | def create_movie_2D( 154 | particle_hist, 155 | log_prob, 156 | save_path="/tmp/stein_movie.mp4", 157 | ax_limits=[[-4, 4],[4, 4]], 158 | to_numpy=False, 159 | kernel_base_type=None, 160 | opt=None, 161 | num_particles=None, 162 | eps=None, 163 | ): 164 | 165 | k_type = kernel_base_type, 166 | if kernel_base_type == 'RBF_Anisotropic': 167 | k_type = 'RBF_H' 168 | 169 | case_name = '{}-{} (np = {}, eps = {})'.format( 170 | opt, 171 | k_type, 172 | num_particles, 173 | eps, 174 | ) 175 | 176 | fig = plt.figure(figsize=(5,5)) 177 | ax = plt.gca() 178 | ax.set_title(case_name + '\n' + str(0) + '$ ^{th}$ iteration') 179 | 180 | ngrid = 100 181 | x = np.linspace(ax_limits[0][0], ax_limits[0][1], ngrid) 182 | y = np.linspace(ax_limits[1][0], ax_limits[1][1], ngrid) 183 | X, Y = np.meshgrid(x,y) 184 | 185 | grid = np.vstack( 186 | (np.ndarray.flatten(X), np.ndarray.flatten(Y)), 187 | ) 188 | if to_numpy: 189 | grid = torch.from_numpy(grid) 190 | z = log_prob(grid.t()).cpu().numpy() 191 | Z = np.exp(z).reshape(ngrid, ngrid) 192 | else: 193 | Z = np.exp( 194 | log_prob(grid), 195 | ).reshape(ngrid, ngrid) 196 | 197 | plt.contourf(X, Y, Z, 10) 198 | xlim = ax_limits[0] 199 | ylim = ax_limits[1] 200 | p_start = particle_hist[0] 201 | particles = plt.plot(p_start[:, 0], p_start[:, 1], 'ro', markersize=3) 202 | n_iter = len(particle_hist) 203 | 204 | def _init(): # only required for blitting to give a clean slate. 205 | # ax.set_title(str(0) + '$ ^{th}$ iteration') 206 | ax.set_title(case_name + '\n' + str(0) + '$ ^{th}$ iteration') 207 | ax.set_xlim(xlim) 208 | ax.set_ylim(ylim) 209 | return particles 210 | 211 | def _animate(i): 212 | # ax.set_title(str(i) + '$ ^{th}$ iteration') 213 | ax.set_title(case_name + '\n' + str(i) + '$ ^{th}$ iteration') 214 | pos = particle_hist[i] 215 | particles[0].set_xdata(pos[:, 0]) 216 | particles[0].set_ydata(pos[:, 1]) 217 | return particles 218 | 219 | ani = animation.FuncAnimation( 220 | fig, 221 | _animate, 222 | frames=n_iter, 223 | init_func=_init, 224 | # interval=250, 225 | interval=100, 226 | save_count=n_iter, 227 | ) 228 | 229 | ani.save(save_path) 230 | plt.show() 231 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | sp_tests/* 2 | svn_tests/* 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /tests/bayesian_hilbert_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | from torch.distributions import Normal, Uniform 26 | from stein_lib.models.gaussian_mixture import mixture_of_gaussians 27 | from stein_lib.svgd.svgd import SVGD 28 | from pathlib import Path 29 | from stein_lib.models.bhm import BayesianHilbertMap 30 | from stein_lib.utils import create_movie_2D, plot_graph_2D 31 | from stein_lib.prm_utils import get_graph 32 | 33 | torch.set_default_tensor_type(torch.DoubleTensor) 34 | 35 | ###### Params ###### 36 | # num_particles = 100 37 | num_particles = 250 38 | # iters = 3000 39 | # iters = 200 40 | iters = 100 41 | # iters = 1 42 | 43 | # Sample intial particles 44 | torch.manual_seed(1) 45 | 46 | ## Large Gaussian in center of intel map. 47 | # prior_dist = Normal(loc=torch.tensor([3.,-10.]), 48 | # scale=torch.tensor([10., 10.])) 49 | 50 | ## Small gaussian in corner of intel map. 51 | # prior_dist = Normal(loc=torch.tensor([12.,-3.]), 52 | # scale=torch.tensor([1.,1.])) 53 | 54 | ## Two small gaussians in opposing corners of intel map. 55 | # sigma = 5. 56 | # radii_list = [[sigma, sigma],] * 2 57 | # prior_dist = mixture_of_gaussians( 58 | # num_comp=2, 59 | # mu_list=[[12.,-3.], [-5, -18] ], 60 | # sigma_list=radii_list, 61 | # ) 62 | 63 | ## Uniform distribution 64 | prior_dist = Uniform(low=torch.tensor([-10., -25.]), 65 | high=torch.tensor([20., 5.])) 66 | 67 | 68 | particles_0 = prior_dist.sample((num_particles,)) 69 | 70 | # Load model 71 | import bhmlib 72 | bhm_path = Path(bhmlib.__path__[0]).resolve() 73 | # model_file = bhm_path / 'Outputs' / 'saved_models' / 'bhm_intel_res0.25_iter100.pt' 74 | model_file = '/tmp/bhm_intel_res0.25_iter100.pt' 75 | ax_limits = [[-10, 20],[-25, 5]] 76 | model = BayesianHilbertMap(model_file, ax_limits) 77 | 78 | #================== SVGD =========================== 79 | particles = particles_0.clone().cpu().numpy() 80 | particles = torch.from_numpy(particles) 81 | 82 | # kernel_base_type = 'RBF' 83 | # # optimizer_type = 'SGD' 84 | # optimizer_type = 'Adam' 85 | # step_size = 1. 86 | # svgd = SVGD( 87 | # kernel_base_type=kernel_base_type, 88 | # kernel_structure=None, 89 | # median_heuristic=False, 90 | # repulsive_scaling=1., 91 | # geom_metric_type=None, 92 | # verbose=True, 93 | # bandwidth=5., 94 | # ) 95 | 96 | kernel_base_type = 'RBF_Anisotropic' 97 | # optimizer_type = 'SGD' 98 | optimizer_type = 'Adam' 99 | step_size = 0.25 100 | # step_size = 0. 101 | svgd = SVGD( 102 | kernel_base_type=kernel_base_type, 103 | kernel_structure=None, 104 | median_heuristic=False, 105 | repulsive_scaling=1., 106 | geom_metric_type='fisher', 107 | verbose=True, 108 | bandwidth=5., 109 | ) 110 | 111 | 112 | # kernel_base_type = 'RBF_Anisotropic' 113 | # optimizer_type = 'LBFGS' # 'FullBatchLBFGS' 114 | # step_size = 0.1 115 | # svgd = SVGD( 116 | # kernel_base_type=kernel_base_type, 117 | # kernel_structure=None, 118 | # median_heuristic=False, 119 | # repulsive_scaling=1., 120 | # geom_metric_type='fisher', 121 | # verbose=True, 122 | # bandwidth=5., 123 | # ) 124 | 125 | ## Optimize 126 | (particles, 127 | p_hist, 128 | pw_dists, 129 | pw_dists_scaled) = svgd.apply( 130 | particles, 131 | model, 132 | iters, 133 | step_size, 134 | # use_analytic_grads=True, 135 | use_analytic_grads=False, 136 | optimizer_type=optimizer_type, 137 | ) 138 | 139 | print("\nMean Est.: ", particles.mean(0)) 140 | print("Std Est.: ", particles.std(0)) 141 | 142 | #============================================= 143 | 144 | # Construct Graph 145 | (nodes, 146 | edge_lengths, 147 | edge_vals, 148 | edge_coll_num_pts, 149 | edge_coll_pts, 150 | params) = get_graph( 151 | particles.detach(), 152 | pw_dists, 153 | model, 154 | collision_thresh=5., 155 | collision_res=0.25, 156 | connect_radius=5., 157 | include_coll_pts=True, # For debugging, visualization 158 | ) 159 | 160 | # Plot Graph 161 | plot_graph_2D( 162 | particles.detach(), 163 | nodes, 164 | model.log_prob, 165 | edge_vals=edge_vals, 166 | edge_coll_thresh=50., 167 | # edge_coll_pts=edge_coll_pts, 168 | ax_limits=ax_limits, 169 | to_numpy=True, 170 | save_path='./graph_svgd_{}_bhm_intel_np_{}_eps_{}.png'.format( 171 | kernel_base_type, 172 | num_particles, 173 | step_size, 174 | ), 175 | ) 176 | 177 | # # Make movie 178 | # create_movie_2D( 179 | # p_hist, 180 | # model.log_prob, 181 | # to_numpy=True, 182 | # save_path='./svgd_{}_bhm_intel_np_{}_eps_{}.mp4'.format( 183 | # kernel_base_type, 184 | # num_particles, 185 | # step_size, 186 | # ), 187 | # ax_limits=ax_limits, 188 | # opt='SVGD', 189 | # kernel_base_type=kernel_base_type, 190 | # num_particles=num_particles, 191 | # eps=step_size, 192 | # ) 193 | -------------------------------------------------------------------------------- /tests/double_banana.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | import torch 24 | from torch.distributions import Normal 25 | from stein_lib.svgd.matrix_svgd.matrix_svgd import MatrixSVGD 26 | 27 | from stein_lib.models.double_banana_analytic import doubleBanana_analytic 28 | from stein_lib.utils import create_movie_2D 29 | 30 | torch.set_default_tensor_type(torch.DoubleTensor) 31 | 32 | # Params 33 | num_particles = 50 34 | # num_particles = 3 35 | iters = 100 36 | # eps_list = [0.5] 37 | # eps_list = [0.1] 38 | # eps_list = [0.01, 0.1] 39 | # kernel_base_type = 'RBF' 40 | kernel_base_type = 'RBF_Anisotropic' 41 | 42 | # Sample intial particles 43 | torch.manual_seed(0) 44 | # prior_dist = Normal(loc=0., scale=1.5) 45 | # prior_dist = Normal(loc=0., scale=0.75) 46 | prior_dist = Normal(loc=0., scale=0.5) 47 | particles_0 = prior_dist.sample((2, num_particles)) 48 | 49 | # DEBUG - load particles 50 | # particles_0 = np.load('/home/sasha/np_50.npy') 51 | # particles_0 = torch.from_numpy(particles_0).t() 52 | 53 | # Load model 54 | model = doubleBanana_analytic( 55 | mu_n=3.57857342, # actually, your observation 56 | # mu_n=np.log(30), # actually, your observation 57 | seed=0, 58 | ) 59 | 60 | 61 | ### SVN ######## 62 | # 63 | # particles = particles_0.clone().cpu().numpy() 64 | # particles = torch.from_numpy(particles) 65 | # 66 | # eps = 0.1 67 | # kernel_base_type='RBF_Anisotropic' 68 | # # kernel_base_type='RBF' 69 | # # kernel_base_type='IMQ' 70 | # svn = SVN( 71 | # kernel_base_type=kernel_base_type, 72 | # compute_hess_terms=True, 73 | # use_hessian_metric=False, 74 | # umedian_heurstic=True, 75 | # # geom_metric_type='fisher', 76 | # geom_metric_type='jacobian_product', 77 | # # geom_metric_type='full_hessian', 78 | # ) 79 | # 80 | # particles, p_hist = svn.apply( 81 | # particles, 82 | # model, 83 | # iters, 84 | # eps, 85 | # use_analytic_grads=True, 86 | # ) 87 | # 88 | # create_movie_2D( 89 | # p_hist, 90 | # model.log_prob, 91 | # ax_limits=(-2, 2), 92 | # to_numpy=True, 93 | # save_path='./svn_tests/svn_{}_double_banana_x_np_{}_eps_{}.mp4'.format( 94 | # kernel_base_type, 95 | # num_particles, 96 | # eps, 97 | # ), 98 | # opt='SVN', 99 | # kernel_base_type=kernel_base_type, 100 | # num_particles=num_particles, 101 | # eps=eps, 102 | # ) 103 | 104 | #### SVGD ############ 105 | 106 | # particles = particles_0.clone().cpu().numpy() 107 | # particles = torch.from_numpy(particles) 108 | # eps = 1. 109 | 110 | # svgd = SVGD( 111 | # kernel_base_type=kernel_base_type, 112 | # kernel_structure=None, 113 | # geom_metric_type='full_hessian', 114 | # ) 115 | # 116 | # particles, p_hist = svgd.apply( 117 | # particles, 118 | # model, 119 | # iters, 120 | # eps, 121 | # use_analytic_grads=True, 122 | # ) 123 | # 124 | # create_movie_2D( 125 | # p_hist, 126 | # model.log_prob, 127 | # ax_limits=(-2, 2), 128 | # to_numpy=True, 129 | # save_path='./svgd_{}_double_banana_np_{}_eps_{}.mp4'.format( 130 | # kernel_base_type, 131 | # num_particles, 132 | # eps, 133 | # ), 134 | # opt='SVGD', 135 | # kernel_base_type=kernel_base_type, 136 | # num_particles=num_particles, 137 | # eps=eps, 138 | # ) 139 | 140 | # #### Matrix-valued SVGD ############ 141 | 142 | # kernel_base_type = 'IMQ_Matrix' 143 | kernel_base_type = 'RBF_Matrix' 144 | # eps = 1. 145 | # eps = 2.5 146 | eps = 10. 147 | particles = particles_0.clone().cpu().numpy() 148 | particles = torch.from_numpy(particles) 149 | 150 | matrix_svgd = MatrixSVGD( 151 | kernel_base_type=kernel_base_type, 152 | kernel_structure=None, 153 | use_hessian_metric=True, 154 | # use_hessian_metric=False, 155 | # geom_metric_type='fisher', 156 | geom_metric_type='full_hessian', 157 | # median_heuristic=True, 158 | ) 159 | 160 | particles, p_hist = matrix_svgd.apply( 161 | particles, 162 | model, 163 | iters, 164 | eps, 165 | use_analytic_grads=False, 166 | ) 167 | 168 | create_movie_2D( 169 | p_hist, 170 | model.log_prob, 171 | ax_limits=(-2, 2), 172 | to_numpy=True, 173 | save_path='./matrix_svgd_{}_double_banana_np_{}_eps_{}.mp4'.format( 174 | kernel_base_type, 175 | num_particles, 176 | eps, 177 | ), 178 | opt='Matrix_SVGD', 179 | kernel_base_type=kernel_base_type, 180 | num_particles=num_particles, 181 | eps=eps, 182 | ) 183 | 184 | #### Weighted Matrix-valued SVGD ############ 185 | 186 | # kernel_base_type = 'RBF_Matrix' 187 | # # eps = 1. 188 | # eps = 0.5 189 | # # eps = 0.1 190 | # 191 | # particles = particles_0.clone().cpu().numpy() 192 | # particles = torch.from_numpy(particles) 193 | # 194 | # mix_matrix_svgd = MatrixMixtureSVGD( 195 | # kernel_base_type=kernel_base_type, 196 | # kernel_structure=None, 197 | # geom_metric_type='full_hessian', 198 | # hessian_scale=1., 199 | # ) 200 | # 201 | # particles, p_hist = mix_matrix_svgd.apply( 202 | # particles, 203 | # model, 204 | # iters, 205 | # eps, 206 | # use_analytic_grads=True, 207 | # ) 208 | # 209 | # create_movie_2D( 210 | # p_hist, 211 | # model.log_prob, 212 | # ax_limits=(-2, 2), 213 | # to_numpy=True, 214 | # save_path='./mix_matrix_RBF_svgd_{}_double_banana_np_{}_eps_{}.mp4'.format( 215 | # kernel_base_type, 216 | # num_particles, 217 | # eps, 218 | # ), 219 | # opt='Matrix_Mix_SVGD', 220 | # kernel_base_type=kernel_base_type, 221 | # num_particles=num_particles, 222 | # eps=eps, 223 | # ) -------------------------------------------------------------------------------- /tests/gaussian_mixture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-2021 Alexander Lambert 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | """ 23 | 24 | import torch 25 | from torch.distributions import Normal 26 | from stein_lib.svgd.svgd import SVGD 27 | 28 | from stein_lib.models.gaussian_mixture import mixture_of_gaussians 29 | from stein_lib.utils import create_movie_2D 30 | 31 | torch.set_default_tensor_type(torch.DoubleTensor) 32 | 33 | ###### Params ###### 34 | num_particles = 100 35 | iters = 200 36 | eps_list = [0.1] 37 | 38 | # Sample intial particles 39 | torch.manual_seed(1) 40 | # prior_dist = Normal(loc=0., scale=1.) 41 | prior_dist = Normal(loc=-4, scale=0.5) 42 | # particles_0 = prior_dist.sample((2, num_particles)) 43 | particles_0 = prior_dist.sample((num_particles, 2)) 44 | 45 | # Load model 46 | 47 | #### Bi-modal Gaussian mixture ######### 48 | # sigma = 1. 49 | # nc = 2 50 | # centers_list=[ 51 | # [-1.5, 0.], 52 | # [1.5, 0.], 53 | # ] 54 | 55 | #### Toy Bayesian occupancy map ########### 56 | sigma = 0.4 57 | nc = 37 # number of components 58 | s = 2. # rbf spacing 59 | centers_list=[ 60 | [-2*s, 2*s] , [-s, 2*s], [0., 2*s], [s, 2*s], [2*s, 2*s], 61 | [-2*s, s] , [-s, s], [0., s], [s, s], [2*s, s], 62 | [-2*s, 0.] , [-s, 0.], [0., 0.], [s, 0.], [2*s, 0.], 63 | [-s, -s], [2*s, -s], 64 | [-2*s, -2*s], [-s, -2*s], [0., -2*s], [2*s, -2*s], 65 | 66 | [-2*s, 1.5*s] , [-2*s, 0.5*s], [-1.5*s, 0], [-0.5*s, 0], 67 | [0.5*s, 0], [1.5*s, 0], [2*s, -0.5*s], [2*s, -1.5*s], 68 | [-1*s, -0.5*s], [-1*s, -1.5*s], [-0.5*s, -2*s], [0.5*s, -2*s], 69 | [-1.5*s, -2*s], [1*s, 0.5*s], [1*s, 1.5*s], [1.5*s, 2*s], 70 | ] 71 | ############ 72 | 73 | radii_list = [ 74 | [sigma, sigma], 75 | ] * nc 76 | 77 | model = mixture_of_gaussians( 78 | num_comp=nc, 79 | mu_list=centers_list, 80 | sigma_list=radii_list, 81 | ) 82 | 83 | 84 | for eps in eps_list: 85 | 86 | ### SVN ######## 87 | # particles = particles_0.clone().cpu().numpy() 88 | # particles = torch.from_numpy(particles) 89 | # 90 | # eps = 1. 91 | # # kernel_base_type='RBF_Anisotropic' 92 | # # kernel_base_type='RBF' 93 | # kernel_base_type='IMQ' 94 | # svn = SVN( 95 | # kernel_base_type=kernel_base_type, 96 | # compute_hess_terms=True, 97 | # use_hessian_metric=False, 98 | # umedian_heurstic=True, 99 | # # geom_metric_type='fisher', 100 | # geom_metric_type='jacobian_product', 101 | # # geom_metric_type='full_hessian', 102 | # ) 103 | # 104 | # particles, p_hist = svn.apply( 105 | # particles, 106 | # model, 107 | # iters, 108 | # eps, 109 | # optimizer_type='LBFGS' 110 | # ) 111 | # 112 | # create_movie_2D( 113 | # p_hist, 114 | # model.log_prob, 115 | # ax_limits=(-4, 4), 116 | # to_numpy=True, 117 | # save_path='./svn_tests/svn_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 118 | # # save_path='./svn_tests/svn_{}_gaussian_mix_hard_np_{}_eps_{}.mp4'.format( 119 | # kernel_base_type, 120 | # num_particles, 121 | # eps, 122 | # ), 123 | # opt='SVN', 124 | # kernel_base_type=kernel_base_type, 125 | # num_particles=num_particles, 126 | # eps=eps, 127 | # ) 128 | 129 | #================== SVGD =========================== 130 | 131 | particles = particles_0.clone().cpu().numpy() 132 | particles = torch.from_numpy(particles) 133 | kernel_base_type = 'RBF_Anisotropic' # 'RBF', 'IMQ' 134 | optimizer_type = 'LBFGS' # 'FullBatchLBFGS' 135 | 136 | svgd = SVGD( 137 | kernel_base_type=kernel_base_type, 138 | kernel_structure=None, 139 | median_heuristic=False, 140 | repulsive_scaling=2., 141 | geom_metric_type='fisher', 142 | verbose=True, 143 | ) 144 | 145 | particles, p_hist, pw_dists_sq = svgd.apply( 146 | particles, 147 | model, 148 | iters, 149 | eps, 150 | use_analytic_grads=False, 151 | optimizer_type=optimizer_type, 152 | ) 153 | 154 | print("\nMean Est.: ", particles.mean(0)) 155 | print("Std Est.: ", particles.std(0)) 156 | 157 | create_movie_2D( 158 | p_hist, 159 | model.log_prob, 160 | to_numpy=True, 161 | save_path='./svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 162 | kernel_base_type, 163 | num_particles, 164 | eps, 165 | ), 166 | ax_limits=[[-5, 5],[-5, 5]], 167 | opt='SVGD', 168 | kernel_base_type=kernel_base_type, 169 | num_particles=num_particles, 170 | eps=eps, 171 | ) 172 | 173 | # ========= Matrix-valued SVGD =================== 174 | 175 | # # kernel_base_type = 'IMQ_Matrix' 176 | # kernel_base_type = 'RBF_Matrix' 177 | # # kernel_base_type = 'RBF_Weighted_Matrix' 178 | # # eps = 0.1 179 | # # eps = 5. 180 | # eps = 2.5 181 | # particles = particles_0.clone().cpu().numpy() 182 | # particles = torch.from_numpy(particles) 183 | # 184 | # matrix_svgd = MatrixSVGD( 185 | # kernel_base_type=kernel_base_type, 186 | # kernel_structure=None, 187 | # # use_hessian_metric=True, 188 | # use_hessian_metric=False, 189 | # geom_metric_type='fisher', 190 | # # geom_metric_type='full_hessian', 191 | # # median_heuristic=True, 192 | # 193 | # 194 | # particles, p_hist = matrix_svgd.apply( 195 | # particles, 196 | # model, 197 | # iters, 198 | # eps, 199 | # optimizer_type='LBFGS' 200 | # ) 201 | # 202 | # create_movie_2D( 203 | # p_hist, 204 | # model.log_prob, 205 | # to_numpy=True, 206 | # save_path='./matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 207 | # # save_path='./weighted_matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 208 | # # save_path='./check_matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 209 | # kernel_base_type, 210 | # num_particles, 211 | # eps, 212 | # ), 213 | # opt='Matrix_SVGD', 214 | # kernel_base_type=kernel_base_type, 215 | # num_particles=num_particles, 216 | # eps=eps, 217 | # ) 218 | 219 | #============== Weighted Matrix-valued SVGD ===================== 220 | # kernel_base_type = 'RBF_Matrix' 221 | # 222 | # particles = particles_0.clone().cpu().numpy() 223 | # particles = torch.from_numpy(particles) 224 | # 225 | # mix_matrix_svgd = MatrixMixtureSVGD( 226 | # kernel_base_type=kernel_base_type, 227 | # kernel_structure=None, 228 | # geom_metric_type='full_hessian', 229 | # # geom_metric_type='jacobian_product', 230 | # ) 231 | # 232 | # particles, p_hist = mix_matrix_svgd.apply( 233 | # particles, 234 | # model, 235 | # iters, 236 | # eps, 237 | # optimizer_type='LBFGS' 238 | # ) 239 | # 240 | # create_movie_2D( 241 | # p_hist, 242 | # model.log_prob, 243 | # to_numpy=True, 244 | # save_path='./matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 245 | # # save_path='./weighted_matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 246 | # # save_path='./check_matrix_RBF_svgd_{}_gaussian_mix_np_{}_eps_{}.mp4'.format( 247 | # kernel_base_type, 248 | # num_particles, 249 | # eps, 250 | # ), 251 | # opt='Matrix_Mix_SVGD', 252 | # kernel_base_type=kernel_base_type, 253 | # num_particles=num_particles, 254 | # eps=eps, 255 | # ) --------------------------------------------------------------------------------