├── gsmvi ├── __init__.py ├── initializers.py ├── monitors.py ├── advi.py ├── gsm_numpy.py ├── gsm.py └── bam.py ├── pyproject.toml ├── LICENSE ├── examples ├── example_gsm.py ├── example_gsm_numpy.py ├── example_advi.py ├── example_bam.py └── example_initializers.py ├── .gitignore └── README.md /gsmvi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gsmvi/initializers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import minimize 3 | 4 | 5 | def lbfgs_init(x0, lp, lp_g=None, maxiter=1000, maxfun=1000): 6 | 7 | f = lambda x: -lp(x) 8 | if lp_g is not None: 9 | f_g = lambda x: -lp_g(x) 10 | else: 11 | f_g = None 12 | res = minimize(f, x0, method='L-BFGS-B', jac=f_g, \ 13 | options={"maxiter":maxiter, "maxfun":maxfun}) 14 | 15 | mu = res.x 16 | cov = res.hess_inv.todense() 17 | return mu, cov, res 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gsmvi" 7 | version = "0.0.7" 8 | authors = [ 9 | { name="Chirag Modi", email="modichirag92@gmail.com" }, 10 | ] 11 | description = "Implementation of Gaussian score matching for variational inference (arXiv:2307.07849)" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [ 20 | ] 21 | 22 | 23 | [project.optional-dependencies] 24 | full = [ 25 | "jax", 26 | "jaxlib", 27 | "numpyro", 28 | "optax", 29 | "scipy"] 30 | 31 | [project.urls] 32 | Homepage = "https://github.com/modichirag/GSM-VI/tree/package" 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chirag Modi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/example_gsm.py: -------------------------------------------------------------------------------- 1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates 2 | 3 | ## Uncomment the following lines if you run into memory issues with JAX 4 | # import os 5 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" 7 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" 8 | 9 | import numpy as np 10 | import jax.numpy as jnp 11 | from jax import jit, grad, random 12 | import numpyro.distributions as dist 13 | 14 | from gsmvi.gsm import GSM 15 | 16 | ##### 17 | def setup_model(D=10): 18 | 19 | # setup a Gaussian target distribution 20 | mean = np.random.random(D) 21 | L = np.random.normal(size = D**2).reshape(D, D) 22 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3 23 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov) 24 | return model 25 | 26 | 27 | if __name__=="__main__": 28 | 29 | ### 30 | # setup a toy Gaussia model and extracet score needed for GSM 31 | D = 10 32 | model = setup_model(D=D) 33 | mean, cov = model.loc, model.covariance_matrix 34 | lp = jit(lambda x: jnp.sum(model.log_prob(x))) 35 | lp_g = jit(grad(lp, argnums=0)) 36 | 37 | ### 38 | # Fit with GSM 39 | niter = 500 40 | key = random.PRNGKey(99) 41 | gsm = GSM(D=D, lp=lp, lp_g=lp_g) 42 | mean_fit, cov_fit = gsm.fit(key, niter=niter) 43 | 44 | print("\nTrue mean : ", mean) 45 | print("Fit mean : ", mean_fit) 46 | -------------------------------------------------------------------------------- /examples/example_gsm_numpy.py: -------------------------------------------------------------------------------- 1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates 2 | 3 | import numpy as np 4 | 5 | from gsmvi.gsm_numpy import GSM 6 | 7 | ##### 8 | def setup_model(D): 9 | 10 | # setup a Gaussian target distribution 11 | mean = np.random.random(D) 12 | L = np.random.normal(size = D**2).reshape(D, D) 13 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3 14 | icov = np.linalg.inv(cov) 15 | 16 | # functions for log_prob and score. These are to be supplied by the user 17 | def lp(x): 18 | assert len(x.shape) == 2 19 | lp = 0 20 | for i in range(x.shape[0]): 21 | lp += -0.5 * np.dot(np.dot(mean - x[i], icov), mean - x[i]) 22 | return lp 23 | 24 | def lp_g(x): 25 | assert len(x.shape) == 2 26 | lp_g = [] 27 | for i in range(x.shape[0]): 28 | lp_g.append( -1. * np.dot(icov, x[i] - mean)) 29 | return np.array(lp_g) 30 | 31 | return mean, cov, lp, lp_g 32 | 33 | 34 | if __name__=="__main__": 35 | 36 | ### 37 | # setup a toy Gaussia model and extracet score needed for GSM 38 | D = 5 39 | mean, cov, lp, lp_g = setup_model(D=D) 40 | 41 | ### 42 | # Fit with GSM 43 | niter = 500 44 | key = 99 45 | gsm = GSM(D=D, lp=lp, lp_g=lp_g) 46 | mean_fit, cov_fit = gsm.fit(key, niter=niter) 47 | 48 | print("\nTrue mean : ", mean) 49 | print("Fit mean : ", mean_fit) 50 | -------------------------------------------------------------------------------- /examples/example_advi.py: -------------------------------------------------------------------------------- 1 | ## A basic example for fitting a target Multivariate Gaussian distribution with GSM updates 2 | 3 | ## Uncomment the following lines if you run into memory issues with JAX 4 | # import os 5 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" 7 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" 8 | 9 | import numpy as np 10 | import jax.numpy as jnp 11 | from jax import jit, grad, random 12 | import optax 13 | import numpyro.distributions as dist 14 | 15 | from gsmvi.advi import ADVI 16 | 17 | def setup_model(D=10): 18 | 19 | # setup a Gaussian target distribution 20 | mean = np.random.random(D) 21 | L = np.random.normal(size = D**2).reshape(D, D) 22 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3 23 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov) 24 | return model 25 | 26 | 27 | if __name__=="__main__": 28 | 29 | ### 30 | # setup a toy Gaussia model and extract log-prob 31 | D = 4 32 | model = setup_model(D=D) 33 | mean, cov = model.loc, model.covariance_matrix 34 | lp = jit(lambda x: jnp.sum(model.log_prob(x))) 35 | 36 | ### 37 | # Fit with advi 38 | niter = 10000 39 | lr = 1e-2 40 | batch_size = 16 41 | advi = ADVI(D=D, lp=lp) 42 | key = random.PRNGKey(99) 43 | opt = optax.adam(learning_rate=lr) 44 | mean_fit, cov_fit, losses = advi.fit(key, opt, batch_size=batch_size, niter=niter) 45 | 46 | print("\nTrue mean : ", mean) 47 | print("Fit mean : ", mean_fit) 48 | 49 | -------------------------------------------------------------------------------- /examples/example_bam.py: -------------------------------------------------------------------------------- 1 | ## Most basic example for fitting a target Multivariate Gaussian distribution with BaM updates 2 | 3 | import numpy as np 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" 7 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" 8 | #os.environ['CUDA_VISIBLE_DEVICES'] = '' # To enable CPU backend 9 | 10 | from jax.lib import xla_bridge 11 | print("Device : ", xla_bridge.get_backend().platform) 12 | 13 | # enable 16 bit precision for jax 14 | from jax import config 15 | config.update("jax_enable_x64", True) 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from jax import jit, grad, random 20 | 21 | import numpyro 22 | import numpyro.distributions as dist 23 | 24 | # Import BaM 25 | from gsmvi.bam import BaM, Regularizers 26 | from gsmvi.monitors import KLMonitor 27 | ##### 28 | 29 | 30 | ##### 31 | def setup_model(D=10): 32 | 33 | # setup a Gaussian target distribution 34 | mean = np.random.random(D) 35 | L = np.random.normal(size = D**2).reshape(D, D) 36 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3 37 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov) 38 | lp = jit(lambda x: jnp.sum(model.log_prob(x))) 39 | lp_g = jit(grad(lp, argnums=0)) 40 | 41 | return model, mean, cov, lp, lp_g 42 | 43 | 44 | 45 | if __name__=="__main__": 46 | 47 | D = 5 48 | model, mean, cov, lp, lp_g = setup_model(D=D) 49 | ref_samples = model.sample(random.PRNGKey(99), (1000,)) 50 | 51 | niter = 100 52 | batch_size = 2 53 | regularizer = Regularizers() 54 | 55 | # Example regularization functions 56 | # regf = regularizer.constant(100) 57 | # regf = regularizer.linear(100) 58 | func = lambda i : 100/(1+i) 59 | regf = regularizer.custom(func) 60 | 61 | 62 | bam = BaM(D=D, lp=lp, lp_g=lp_g, use_lowrank=True, jit_compile=True) 63 | key = random.PRNGKey(99) 64 | mean_fit, cov_fit = bam.fit(key, regf=regf, niter=niter, batch_size=batch_size) 65 | 66 | print() 67 | print("True mean : ", mean) 68 | print("Fit mean : ", mean_fit) 69 | print() 70 | print("Check mean fit") 71 | print(np.allclose(mean, mean_fit)) 72 | 73 | print() 74 | print("Check cov fit") 75 | print(np.allclose(cov, cov_fit)) 76 | print(cov) 77 | print() 78 | print(cov_fit) 79 | -------------------------------------------------------------------------------- /examples/example_initializers.py: -------------------------------------------------------------------------------- 1 | ## Example for fitting a target Multivariate Gaussian distribution with GSM and ADVI 2 | ## Variational distribution is initialized with LBFGS fit for the mean and covariance 3 | ## The progress is monitored with a Monitor class. 4 | 5 | ## Uncomment the following lines if you run into memory issues with JAX 6 | # import os 7 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" 9 | # os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import jax.numpy as jnp 14 | from jax import jit, grad, random 15 | import numpyro.distributions as dist 16 | import optax 17 | 18 | # enable 16 bit precision for jax required for lbfgs initializer 19 | from jax import config 20 | config.update("jax_enable_x64", True) 21 | 22 | from gsmvi.gsm import GSM 23 | from gsmvi.advi import ADVI 24 | from gsmvi.initializers import lbfgs_init 25 | from gsmvi.monitors import KLMonitor 26 | 27 | ##### 28 | def setup_model(D=10): 29 | 30 | # setup a Gaussian target distribution 31 | mean = np.random.random(D) 32 | L = np.random.normal(size = D**2).reshape(D, D) 33 | cov = np.matmul(L, L.T) + np.eye(D)*1e-3 34 | model = dist.MultivariateNormal(loc=mean, covariance_matrix=cov) 35 | return model 36 | 37 | 38 | # 39 | def gsm_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res): 40 | print("Now fit with GSM") 41 | niter = 500 42 | batch_size = 1 43 | key = random.PRNGKey(99) 44 | monitor = KLMonitor(batch_size_kl=32, checkpoint=10, \ 45 | offset_evals=lbfgs_res.nfev) #note the offset number of evals 46 | 47 | gsm = GSM(D=D, lp=lp, lp_g=lp_g) 48 | mean_fit, cov_fit = gsm.fit(key, mean=mean_init, cov=cov_init, niter=niter, batch_size=batch_size, monitor=monitor) 49 | return mean_fit, cov_fit, monitor 50 | 51 | 52 | # 53 | def advi_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res): 54 | print("\nNow fit with ADVI") 55 | niter = 500 56 | lr = 1e-2 57 | batch_size = 1 58 | key = random.PRNGKey(99) 59 | opt = optax.adam(learning_rate=lr) 60 | monitor = KLMonitor(batch_size_kl=32, checkpoint=10, \ 61 | offset_evals=lbfgs_res.nfev) #note the offset number of evals 62 | 63 | advi = ADVI(D=D, lp=lp) 64 | mean_fit, cov_fit, losses = advi.fit(key, mean=mean_init, cov=cov_init, opt=opt, batch_size=batch_size, niter=niter, monitor=monitor) 65 | return mean_fit, cov_fit, monitor 66 | 67 | 68 | 69 | if __name__=="__main__": 70 | 71 | ### 72 | # setup a toy Gaussia model and extracet score needed for GSM 73 | D = 16 74 | model = setup_model(D=D) 75 | mean, cov = model.loc, model.covariance_matrix 76 | lp = jit(lambda x: jnp.sum(model.log_prob(x))) 77 | lp_g = jit(grad(lp, argnums=0)) 78 | 79 | ### 80 | print("Initialize with LBFGS") 81 | mean_init = np.ones(D) # setup gsm with initilization from LBFGS fit 82 | mean_init, cov_init, lbfgs_res = lbfgs_init(mean_init, lp, lp_g) 83 | print(f'LBFGS fit: \n{lbfgs_res}\n') 84 | 85 | mean_gsm, cov_gsm, monitor_gsm = gsm_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res) 86 | mean_advi, cov_advi, monitor_advi = advi_fit(D, lp, lp_g, mean_init, cov_init, lbfgs_res) 87 | 88 | # Check that the output is correct 89 | print("\nTrue mean : ", mean) 90 | print("Fit gsm : ", mean_gsm) 91 | print("Fit advi : ", mean_advi) 92 | 93 | 94 | # Check that the KL divergence decreases 95 | plt.plot(monitor_gsm.nevals, monitor_gsm.rkl, label='GSM') 96 | plt.plot(monitor_advi.nevals, monitor_advi.rkl, label='ADVI') 97 | plt.legend() 98 | plt.xlabel("Iteration") 99 | plt.ylabel("Reverse KL") 100 | plt.savefig("monitor_kl.png") 101 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /gsmvi/monitors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from functools import partial 4 | 5 | import jax.numpy as jnp 6 | from jax import jit, grad, random 7 | from numpyro.distributions import MultivariateNormal 8 | 9 | 10 | def reverse_kl(samples, lpq, lpp): 11 | logl = np.sum(lpp(samples)) 12 | logq = np.sum(lpq(samples)) 13 | rkl = logq - logl 14 | rkl /= samples.shape[0] 15 | return rkl 16 | 17 | def forward_kl(samples, lpq, lpp): 18 | logl = np.sum(lpp(samples)) 19 | logq = np.sum(lpq(samples)) 20 | fkl = logl - logq 21 | fkl /= samples.shape[0] 22 | return fkl 23 | 24 | @partial(jit, static_argnums=(3)) 25 | def reverse_kl_jit(samples, mu, cov, lp): 26 | q = MultivariateNormal(mu, cov) 27 | logq = jnp.sum(q.log_prob(samples)) 28 | logl = jnp.sum(lp(samples)) 29 | rkl = logq - logl 30 | rkl /= samples.shape[0] 31 | return rkl 32 | 33 | @partial(jit, static_argnums=(3)) 34 | def forward_kl_jit(samples, mu, cov, lp): 35 | q = MultivariateNormal(mu, cov) 36 | logq = jnp.sum(q.log_prob(samples)) 37 | logl = jnp.sum(lp(samples)) 38 | fkl = logl- logq 39 | fkl /= samples.shape[0] 40 | return fkl 41 | 42 | 43 | @dataclass 44 | class KLMonitor(): 45 | """ 46 | Class to monitor KL divergence during optimization for VI 47 | 48 | Inputs: 49 | 50 | batch_size_kl: (int) Number of samples to use to estimate KL divergence 51 | checkpoint: (int) Number of iterations after which to run monitor 52 | offset_evals: (int) Value with which to offset number of gradient evaluatoins 53 | Used to account for gradient evaluations done in warmup or initilization 54 | ref_samples: Optional, samples from the target distribution. 55 | If provided, also track forward KL divergence 56 | """ 57 | 58 | batch_size_kl : int = 8 59 | checkpoint : int = 20 60 | offset_evals : int = 0 61 | ref_samples : np.array = None 62 | 63 | def __post_init__(self): 64 | 65 | self.rkl = [] 66 | self.fkl = [] 67 | self.nevals = [] 68 | 69 | def reset(self, 70 | batch_size_kl=None, 71 | checkpoint=None, 72 | offset_evals=None, 73 | ref_samples=None ): 74 | self.nevals = [] 75 | self.rkl = [] 76 | self.fkl = [] 77 | if batch_size_kl is not None: self.batch_size_kl = batch_size_kl 78 | if checkpoint is not None: self.checkpoint = checkpoint 79 | if offset_evals is not None: self.offset_evals = offset_evals 80 | if ref_samples is not None: self.ref_samples = ref_samples 81 | print('offset evals reset to : ', self.offset_evals) 82 | 83 | def __call__(self, i, params, lp, key, nevals=1): 84 | """ 85 | Main function to monitor reverse (and forward) KL divergence over iterations. 86 | 87 | Inputs: 88 | 89 | i: (int) iteration number 90 | params: (tuple; (mean, cov)) Current estimate of mean and covariance matrix 91 | lp: Function to evaluate target log-probability 92 | key: Random number generator key (jax.random.PRNGKey) 93 | nevals: (int) Number of gradient evaluations SINCE the last call of the monitor function 94 | 95 | Returns: 96 | key : New key for generation random number 97 | """ 98 | 99 | # 100 | mu, cov = params 101 | key, key_sample = random.split(key) 102 | np.random.seed(key_sample[0]) 103 | 104 | 105 | try: 106 | qsamples = np.random.multivariate_normal(mean=mu, cov=cov, size=self.batch_size_kl) 107 | q = MultivariateNormal(loc=mu, covariance_matrix=cov) 108 | self.rkl.append(reverse_kl(qsamples, q.log_prob, lp)) 109 | 110 | if self.ref_samples is not None: 111 | idx = np.random.permutation(self.ref_samples.shape[0])[:self.batch_size_kl] 112 | psamples = self.ref_samples[idx] 113 | self.fkl.append(forward_kl(psamples, q.log_prob, lp)) 114 | else: 115 | self.fkl.append(np.NaN) 116 | 117 | except Exception as e: 118 | print(f"Exception occured in monitor : {e}.\nAppending NaN") 119 | self.rkl.append(np.NaN) 120 | self.fkl.append(np.NaN) 121 | 122 | self.nevals.append(self.offset_evals + nevals) 123 | self.offset_evals = self.nevals[-1] 124 | 125 | return key 126 | 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GSM-VI: Gaussian score-based variational inference algorithms 2 | 3 | This repository provides a Python implementation of two score-based (black-box) 4 | variational inference algorithms: 5 | 1. GSM-VI, which is described in the NeurIPS 2023 paper https://arxiv.org/abs/2307.07849. 6 | 2. Batch and match (BaM), which is described in the ICML 2024 paper https://arxiv.org/pdf/2402.14758. 7 | 8 | We describe each of these in the following two sections. 9 | 10 | ## Variational Inference (VI) with Gaussian Score Matching (GSM) (NeurIPS 2023) 11 | 12 | GSM-VI fits a multivariate Gasussian distribution with dense covaraince matrix to the target distribution 13 | by score matching. It only requires access to the score function i.e. the gradient of the target log-probability 14 | distribution and implements analytic updates for the variational parameters (mean and covariance matrix). 15 | 16 | ### Installation:
17 | The code is available on `PyPI` 18 | ``` 19 | pip install gsmvi 20 | ``` 21 | 22 | ### Usage 23 | The simplest version of the algorithm is written in numpy. 24 | The following is the minimal code to use GSM to fit the parameters `x` of a `model` given its `log_prob` and `log_prob_grad` functions. 25 | See `example/example_gsm_numpy.py` for a full example. 26 | ``` 27 | dimensions = D 28 | def log_prob(x): 29 | # return log_prbability at sample x 30 | ... 31 | 32 | def log_prob_grad(x): 33 | # return the score fuction i.e. the gradient of log_prbability at sample x 34 | ... 35 | 36 | from gsmvi.gsm_numpy import GSM 37 | gsm = GSM(D=D, lp=log_prob, lp_g=log_prob_grad) 38 | random_seed = 99 39 | number_of_iterations = 500 40 | mean_fit, cov_fit = gsm.fit(key=random_seed, niter=number_of_iterations) 41 | ``` 42 | 43 | A more efficient version of the algorithm is implemented in Jax where it can benefit from jit compilation. The basic signature stays the same. 44 | See `example/example_gsm.py` for a full example. 45 | ``` 46 | dimensions = D 47 | model = setup_model(D=D) # Ths example sets up a numpyro model which has log_prob attribute implemented 48 | lp = jit(lambda x: jnp.sum(model.log_prob(x))) 49 | lp_g = jit(grad(lp, argnums=0)) 50 | 51 | from gsmvi.gsm import GSM 52 | gsm = GSM(D=D, lp=lp, lp_g=lp_g) 53 | mean_fit, cov_fit = gsm.fit(key=random.PRNGKey(99), niter=500) 54 | ``` 55 | 56 | #### Other utilities:
57 | - For comparison, we also provide implementation of ADVI algorithm (https://arxiv.org/abs/1603.00788), 58 | another common approach to fit a multivariate Gaussian variational distribution which maximizes ELBO. 59 | - We provide LBFGS initilization for the variational distribution which can be used with GSM and ADVI. 60 | - We also provide a Monitor class to monitor the KL divergence over iterations as the algorithms progress. 61 | 62 | ### Code Dependencies
63 | The vanilla code is written in python3 and does not have any dependencies.
64 | 65 | #### Optional dependencies 66 | These will not be installed with the package and should be installed by user depending on the use-case. 67 | 68 | The Jax version of the code requires `jax` and `jaxlib`.
69 | The target distributions in example files other than example_gsm_numpy.py are implemented in `numpyro`.
70 | ADVI algorithm uses `optax` for maximizing ELBO.
71 | LBFGS initialization for initializing variational distributions uses `scipy`. 72 | 73 | ### Starting point
74 | We provide simple examples in `examples/` folder to fit a target multivariate Gaussian distribution with GSM and ADVI.
75 | ``` 76 | cd examples 77 | python3 example_gsm_numpy.py # vanilla example in numpy, no dependencies 78 | python3 example_gsm.py # jax version, requires jax and numpyro 79 | python3 example_advi.py # jax version, requires jax, numpyro and optax 80 | ``` 81 | An example on how to use the Monitor class and LBFGS initialization is in `examples/example_initializers.py` 82 | ``` 83 | cd examples 84 | python3 example_initializers.py # jax version, requires jax, numpyro, optax and scipy 85 | ``` 86 | 87 | ## Batch and match: black-box variational inference with a score-based divergence (ICML 2024) 88 | 89 | [Batch and match (BaM)](https://arxiv.org/pdf/2402.14758) also fits a full covariance multivariate Gaussian and recovers (a version of) GSM as a special case. 90 | In the BaM algorithm, a score-based divergence is minimized. 91 | The code is set up similarly to the GSM code. Currently, it is not yet available 92 | on `PyPI`. 93 | 94 | To install, run 95 | ``` 96 | pip install -e . 97 | ``` 98 | 99 | 100 | The example usage code is in `examples/example_bam.py': 101 | 102 | ``` 103 | cd examples 104 | python3 example_bam.py # jax version, requires jax and numpyro 105 | ``` 106 | 107 | Note that this installation approach also includes the GSM and ADVI code 108 | examples above. 109 | 110 | -------------------------------------------------------------------------------- /gsmvi/advi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import jit, grad, random 5 | from numpyro.distributions import MultivariateNormal 6 | import optax 7 | 8 | class ADVI(): 9 | """ 10 | Class for fitting a multivariate Gaussian distribution with dense covariance matrix 11 | by maximizing ELBO. 12 | """ 13 | 14 | def __init__(self, D, lp): 15 | """ 16 | Inputs: 17 | D: (int) Dimensionality (number) of parameters. 18 | lp : Function to evaluate target log-probability distribution 19 | whose gradient can be evaluated with jax.grad(lp) 20 | """ 21 | self.D = D 22 | self.lp = lp 23 | self.idx_tril = jnp.stack(jnp.tril_indices(D)).T 24 | 25 | def scales_to_cov(self, scales): 26 | scale_tril = jnp.zeros((self.D, self.D)) 27 | scale_tril = scale_tril.at[self.idx_tril[:, 0], self.idx_tril[:, 1]].set(scales) 28 | cov = np.matmul(scale_tril, scale_tril.T) 29 | return cov 30 | 31 | def neg_elbo(self, params, key, batch_size): 32 | """ 33 | Internal function to evaluate negative-ELBO which is the loss function 34 | """ 35 | loc, scales = params 36 | scale_tril = jnp.zeros((self.D, self.D)) 37 | scale_tril = scale_tril.at[self.idx_tril[:, 0], self.idx_tril[:, 1]].set(scales) 38 | q = MultivariateNormal(loc=loc, scale_tril=scale_tril) 39 | # 40 | samples = q.sample(key, (batch_size,)) 41 | logl = jnp.sum(self.lp(samples)) 42 | logq = jnp.sum(q.log_prob(samples)) 43 | elbo = logl - logq 44 | negelbo = -1. * elbo 45 | return negelbo 46 | 47 | def fit(self, key, opt, mean=None, cov=None, batch_size=8, niter=1000, nprint=10, monitor=None): 48 | """ 49 | Main function to fit a multivariate Gaussian distribution to the target 50 | 51 | Inputs: 52 | key: Random number generator key (jax.random.PRNGKey) 53 | mean : Optional, initial value of the mean. Expected None or array of size D 54 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD 55 | batch_size : Optional, int. Number of samples to match scores for at every iteration 56 | niter : Optional, int. Total number of iterations 57 | nprint : Optional, int. Number of iterations after which to print logs 58 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics. 59 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals). 60 | Example of monitor class is provided in utils/monitors.py 61 | Returns: 62 | mu : Array of shape D, fit of the mean 63 | cov : Array of shape DxD, fit of the covariance matrix 64 | """ 65 | 66 | lossf = jit(self.neg_elbo, static_argnums=(2)) 67 | 68 | @jit 69 | def opt_step(params, opt_state, key): 70 | loss, grads = jax.value_and_grad(lossf, argnums=0)(params, key, batch_size=batch_size) 71 | updates, opt_state = opt.update(grads, opt_state, params) 72 | params = optax.apply_updates(params, updates) 73 | return params, opt_state, loss 74 | 75 | if mean is None: 76 | mean = jnp.zeros(self.D) 77 | if cov is None: 78 | cov = np.identity(self.D) 79 | 80 | # Optimization is done on unconstrained Cholesky factors of covariance matrix 81 | L = np.linalg.cholesky(cov) 82 | scales = jnp.array(L[np.tril_indices(self.D)]) 83 | params = (mean, scales) 84 | 85 | # run optimization 86 | opt_state = opt.init(params) 87 | losses = [] 88 | nevals = 1 89 | 90 | for i in range(niter + 1): 91 | if(i%(niter//nprint)==0): 92 | print(f'Iteration {i} of {niter}') 93 | if monitor is not None: 94 | if (i%monitor.checkpoint) == 0: 95 | mean = params[0] 96 | cov = self.scales_to_cov( params[1]*1.) 97 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 98 | nevals = 0 99 | 100 | params, opt_state, loss = opt_step(params, opt_state, key) 101 | key, _ = random.split(key) 102 | losses.append(loss) 103 | nevals += batch_size 104 | 105 | 106 | # Convert back to mean and covariance matrix 107 | mean = params[0] 108 | cov = self.scales_to_cov( params[1]*1.) 109 | if monitor is not None: 110 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 111 | 112 | return mean, cov, losses 113 | -------------------------------------------------------------------------------- /gsmvi/gsm_numpy.py: -------------------------------------------------------------------------------- 1 | ## Pure numpy implementation of GSM updates. 2 | import numpy as np 3 | 4 | def _gsm_update_single(sample, v, mu0, S0): 5 | '''returns GSM update to mean and covariance matrix for a single sample 6 | ''' 7 | S0v = np.matmul(S0, v) 8 | vSv = np.matmul(v, S0v) 9 | mu_v = np.matmul((mu0 - sample), v) 10 | rho = 0.5 * np.sqrt(1 + 4*(vSv + mu_v**2)) - 0.5 11 | eps0 = S0v - mu0 + sample 12 | 13 | #mu update 14 | mu_vT = np.outer((mu0 - sample), v) 15 | den = 1 + rho + mu_v 16 | I = np.eye(sample.shape[0]) 17 | mu_update = 1/(1 + rho) * np.matmul(( I - mu_vT / den), eps0) 18 | mu = mu0 + mu_update 19 | 20 | #S update 21 | Supdate_0 = np.outer((mu0-sample), (mu0-sample)) 22 | Supdate_1 = np.outer((mu-sample), (mu-sample)) 23 | S_update = (Supdate_0 - Supdate_1) 24 | return mu_update, S_update 25 | 26 | 27 | def gsm_update(samples, vs, mu0, S0): 28 | """ 29 | Returns updated mean and covariance matrix with GSM updates. 30 | For a batch, this is simply the mean of updates for individual samples. 31 | 32 | Inputs: 33 | samples: Array of samples of shape BxD where B is the batch dimension 34 | vs : Array of score functions of shape BxD corresponding to samples 35 | mu0 : Array of shape D, current estimate of the mean 36 | S0 : Array of shape DxD, current estimate of the covariance matrix 37 | 38 | Returns: 39 | mu : Array of shape D, new estimate of the mean 40 | S : Array of shape DxD, new estimate of the covariance matrix 41 | """ 42 | 43 | assert len(samples.shape) == 2 44 | assert len(vs.shape) == 2 45 | 46 | B, D = samples.shape 47 | mu_update, S_update = np.zeros((B, D)), np.zeros((B, D, D)) 48 | for i in range(B): 49 | mu_update[i], S_update[i] = _gsm_update_single(samples[i], vs[i], mu0, S0) 50 | mu_update = np.mean(mu_update, axis=0) 51 | S_update = np.mean(S_update, axis=0) 52 | mu = mu0 + mu_update 53 | S = S0 + S_update 54 | 55 | return mu, S 56 | 57 | 58 | 59 | 60 | class GSM: 61 | """ 62 | Wrapper class for using GSM updates to fit a distribution 63 | """ 64 | def __init__(self, D, lp, lp_g): 65 | """ 66 | Inputs: 67 | D: (int) Dimensionality (number) of parameters 68 | lp : Function to evaluate target log-probability distribution. 69 | (Only used in monitor, not for fitting) 70 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution 71 | """ 72 | self.D = D 73 | self.lp = lp 74 | self.lp_g = lp_g 75 | 76 | 77 | def fit(self, key, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None): 78 | """ 79 | Main function to fit a multivariate Gaussian distribution to the target 80 | 81 | Inputs: 82 | key: Seed for random number generator 83 | mean : Optional, initial value of the mean. Expected None or array of size D 84 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD 85 | batch_size : Optional, int. Number of samples to match scores for at every iteration 86 | niter : Optional, int. Total number of iterations 87 | nprint : Optional, int. Number of iterations after which to print logs 88 | verbose : Optional, bool. If true, print number of iterations after nprint 89 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update 90 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics. 91 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals). 92 | Example of monitor class is provided in utils/monitors.py 93 | 94 | Returns: 95 | mu : Array of shape D, fit of the mean 96 | cov : Array of shape DxD, fit of the covariance matrix 97 | """ 98 | if mean is None: 99 | mean = np.zeros(self.D) 100 | if cov is None: 101 | cov = np.identity(self.D) 102 | 103 | nevals = 1 104 | 105 | np.random.seed(key) 106 | for i in range(niter + 1): 107 | if (i%(niter//nprint) == 0) and verbose : 108 | print(f'Iteration {i} of {niter}') 109 | 110 | if monitor is not None: 111 | if (i%monitor.checkpoint) == 0: 112 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 113 | nevals = 0 114 | 115 | # Can generate samples from jax distribution (commented below), but using numpy is faster 116 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size) 117 | vs = self.lp_g(samples) 118 | mean_new, cov_new = gsm_update(samples, vs, mean, cov) 119 | nevals += batch_size 120 | 121 | is_good = self._check_goodness(cov_new) 122 | if is_good: 123 | mean, cov = mean_new, cov_new 124 | else: 125 | if verbose: print("Bad update for covariance matrix. Revert") 126 | 127 | if monitor is not None: 128 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 129 | return mean, cov 130 | 131 | 132 | def _check_goodness(self, cov): 133 | ''' 134 | Internal function to check if the new covariance matrix is a valid covariance matrix. 135 | Required due to floating point errors in updating the convariance matrix directly, 136 | insteead of it's Cholesky form. 137 | ''' 138 | is_good = False 139 | try: 140 | if (np.isnan(np.linalg.cholesky(cov))).any(): 141 | nan_update.append(j) 142 | else: 143 | is_good = True 144 | return is_good 145 | except: 146 | return is_good 147 | -------------------------------------------------------------------------------- /gsmvi/gsm.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import jit, random 4 | #from numpyro.distributions import MultivariateNormal ##Needed if sampling from numpyro dist below 5 | import numpy as np 6 | 7 | @jit 8 | def _gsm_update_single(sample, v, mu0, S0): 9 | '''returns GSM update to mean and covariance matrix for a single sample 10 | ''' 11 | S0v = jnp.matmul(S0, v) 12 | vSv = jnp.matmul(v, S0v) 13 | mu_v = jnp.matmul((mu0 - sample), v) 14 | rho = 0.5 * jnp.sqrt(1 + 4*(vSv + mu_v**2)) - 0.5 15 | eps0 = S0v - mu0 + sample 16 | 17 | #mu update 18 | mu_vT = jnp.outer((mu0 - sample), v) 19 | den = 1 + rho + mu_v 20 | I = jnp.eye(sample.shape[0]) 21 | mu_update = 1/(1 + rho) * jnp.matmul(( I - mu_vT / den), eps0) 22 | mu = mu0 + mu_update 23 | 24 | #S update 25 | Supdate_0 = jnp.outer((mu0-sample), (mu0-sample)) 26 | Supdate_1 = jnp.outer((mu-sample), (mu-sample)) 27 | S_update = (Supdate_0 - Supdate_1) 28 | return mu_update, S_update 29 | 30 | 31 | @jit 32 | def gsm_update(samples, vs, mu0, S0): 33 | """ 34 | Returns updated mean and covariance matrix with GSM updates. 35 | For a batch, this is simply the mean of updates for individual samples. 36 | 37 | Inputs: 38 | samples: Array of samples of shape BxD where B is the batch dimension 39 | vs : Array of score functions of shape BxD corresponding to samples 40 | mu0 : Array of shape D, current estimate of the mean 41 | S0 : Array of shape DxD, current estimate of the covariance matrix 42 | 43 | Returns: 44 | mu : Array of shape D, new estimate of the mean 45 | S : Array of shape DxD, new estimate of the covariance matrix 46 | """ 47 | 48 | assert len(samples.shape) == 2 49 | assert len(vs.shape) == 2 50 | 51 | vgsm_update = jax.vmap(_gsm_update_single, in_axes=(0, 0, None, None)) 52 | mu_update, S_update = vgsm_update(samples, vs, mu0, S0) 53 | mu_update = jnp.mean(mu_update, axis=0) 54 | S_update = jnp.mean(S_update, axis=0) 55 | mu = mu0 + mu_update 56 | S = S0 + S_update 57 | 58 | return mu, S 59 | 60 | 61 | 62 | class GSM: 63 | """ 64 | Wrapper class for using GSM updates to fit a distribution 65 | """ 66 | def __init__(self, D, lp, lp_g): 67 | """ 68 | Inputs: 69 | D: (int) Dimensionality (number) of parameters 70 | lp : Function to evaluate target log-probability distribution. 71 | (Only used in monitor, not for fitting) 72 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution 73 | """ 74 | self.D = D 75 | self.lp = lp 76 | self.lp_g = lp_g 77 | 78 | 79 | def fit(self, key, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None): 80 | """ 81 | Main function to fit a multivariate Gaussian distribution to the target 82 | 83 | Inputs: 84 | key: Random number generator key (jax.random.PRNGKey) 85 | mean : Optional, initial value of the mean. Expected None or array of size D 86 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD 87 | batch_size : Optional, int. Number of samples to match scores for at every iteration 88 | niter : Optional, int. Total number of iterations 89 | nprint : Optional, int. Number of iterations after which to print logs 90 | verbose : Optional, bool. If true, print number of iterations after nprint 91 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update 92 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics. 93 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals). 94 | Example of monitor class is provided in utils/monitors.py 95 | 96 | Returns: 97 | mu : Array of shape D, fit of the mean 98 | cov : Array of shape DxD, fit of the covariance matrix 99 | """ 100 | if mean is None: 101 | mean = jnp.zeros(self.D) 102 | if cov is None: 103 | cov = jnp.identity(self.D) 104 | 105 | nevals = 1 106 | 107 | for i in range(niter + 1): 108 | if (i%(niter//nprint) == 0) and verbose : 109 | print(f'Iteration {i} of {niter}') 110 | 111 | if monitor is not None: 112 | if (i%monitor.checkpoint) == 0: 113 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 114 | nevals = 0 115 | 116 | # Can generate samples from jax distribution (commented below), but using numpy is faster 117 | key, key_sample = random.split(key, 2) 118 | np.random.seed(key_sample[0]) 119 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size) 120 | # samples = MultivariateNormal(loc=mean, covariance_matrix=cov).sample(key, (batch_size,)) 121 | vs = self.lp_g(samples) 122 | mean_new, cov_new = gsm_update(samples, vs, mean, cov) 123 | nevals += batch_size 124 | 125 | is_good = self._check_goodness(cov_new) 126 | if is_good: 127 | mean, cov = mean_new, cov_new 128 | else: 129 | if verbose: print("Bad update for covariance matrix. Revert") 130 | 131 | if monitor is not None: 132 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 133 | return mean, cov 134 | 135 | 136 | def _check_goodness(self, cov): 137 | ''' 138 | Internal function to check if the new covariance matrix is a valid covariance matrix. 139 | Required due to floating point errors in updating the convariance matrix directly, 140 | insteead of it's Cholesky form. 141 | ''' 142 | is_good = False 143 | try: 144 | if (np.isnan(np.linalg.cholesky(cov))).any(): 145 | nan_update.append(j) 146 | else: 147 | is_good = True 148 | return is_good 149 | except: 150 | return is_good 151 | -------------------------------------------------------------------------------- /gsmvi/bam.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import jit, random 4 | from jax.scipy.linalg import sqrtm as sqrtm_jsp 5 | from scipy.linalg import sqrtm as sqrtm_sp 6 | import numpy as np 7 | import scipy.sparse as spys 8 | from jax.lib import xla_bridge 9 | 10 | def compute_Q_host(U_B): 11 | U, B = U_B 12 | UU, DD, VV = spys.linalg.svds(U, k=B) 13 | return UU * np.sqrt(DD) 14 | 15 | def compute_Q(U_B): 16 | result_shape = jax.ShapeDtypeStruct((U_B[0].shape[0], U_B[1]), U_B[0].dtype) 17 | return jax.pure_callback(compute_Q_host, result_shape, U_B) 18 | 19 | def get_sqrt(M): 20 | if xla_bridge.get_backend().platform == 'gpu': 21 | result_shape = jax.ShapeDtypeStruct(M.shape, M.dtype) 22 | M_root = jax.pure_callback(lambda x:sqrtm_sp(x).astype(M.dtype), result_shape, M) # sqrt can be complex sometimes, we only want real part 23 | elif xla_bridge.get_backend().platform == 'cpu': 24 | M_root = sqrtm_jsp(M) 25 | else: 26 | print("Backend not recongnized in get_sqrt function. Should be either gpu or cpu") 27 | raise 28 | return M_root.real 29 | 30 | 31 | def bam_update(samples, vs, mu0, S0, reg): 32 | """ 33 | Returns updated mean and covariance matrix with batch and match updates. 34 | For a batch, this is simply the mean of updates for individual samples. 35 | 36 | Inputs: 37 | samples: Array of samples of shape BxD where B is the batch dimension 38 | vs : Array of score functions of shape BxD corresponding to samples 39 | mu0 : Array of shape D, current estimate of the mean 40 | S0 : Array of shape DxD, current estimate of the covariance matrix 41 | 42 | Returns: 43 | mu : Array of shape D, new estimate of the mean 44 | S : Array of shape DxD, new estimate of the covariance matrix 45 | """ 46 | 47 | assert len(samples.shape) == 2 48 | assert len(vs.shape) == 2 49 | B = samples.shape[0] 50 | xbar = jnp.mean(samples, axis=0) 51 | outer_map = jax.vmap(jnp.outer, in_axes=(0, 0)) 52 | xdiff = samples - xbar 53 | C = jnp.mean(outer_map(xdiff, xdiff), axis=0) 54 | 55 | gbar = jnp.mean(vs, axis=0) 56 | gdiff = vs - gbar 57 | G = jnp.mean(outer_map(gdiff, gdiff), axis=0) 58 | 59 | U = reg * G + (reg)/(1+reg) * jnp.outer(gbar, gbar) 60 | V = S0 + reg * C + (reg)/(1+reg) * jnp.outer(mu0 - xbar, mu0 - xbar) 61 | I = jnp.identity(samples.shape[1]) 62 | 63 | mat = I + 4 * jnp.matmul(U, V) 64 | # S = 2 * jnp.matmul(V, jnp.linalg.inv(I + sqrtm(mat).real)) 65 | S = 2 * jnp.linalg.solve(I + get_sqrt(mat).T, V.T) 66 | 67 | mu = 1/(1+reg) * mu0 + reg/(1+reg) * (jnp.matmul(S, gbar) + xbar) 68 | 69 | return mu, S 70 | 71 | 72 | def bam_lowrank_update(samples, vs, mu0, S0, reg): 73 | """ 74 | Returns updated mean and covariance matrix with low-rank BaM updates. 75 | For a batch, this is simply the mean of updates for individual samples. 76 | 77 | Inputs: 78 | samples: Array of samples of shape BxD where B is the batch dimension 79 | vs : Array of score functions of shape BxD corresponding to samples 80 | mu0 : Array of shape D, current estimate of the mean 81 | S0 : Array of shape DxD, current estimate of the covariance matrix 82 | 83 | Returns: 84 | mu : Array of shape D, new estimate of the mean 85 | S : Array of shape DxD, new estimate of the covariance matrix 86 | """ 87 | 88 | assert len(samples.shape) == 2 89 | assert len(vs.shape) == 2 90 | B = samples.shape[0] 91 | xbar = jnp.mean(samples, axis=0) 92 | outer_map = jax.vmap(jnp.outer, in_axes=(0, 0)) 93 | xdiff = samples - xbar 94 | C = jnp.mean(outer_map(xdiff, xdiff), axis=0) 95 | 96 | gbar = jnp.mean(vs, axis=0) 97 | gdiff = vs - gbar 98 | G = jnp.mean(outer_map(gdiff, gdiff), axis=0) 99 | 100 | U = reg * G + (reg)/(1+reg) * jnp.outer(gbar, gbar) 101 | V = S0 + reg * C + (reg)/(1+reg) * jnp.outer(mu0 - xbar, mu0 - xbar) 102 | 103 | # Form decomposition that is D x K 104 | Q = compute_Q((U, B)) 105 | I = jnp.identity(B) 106 | VT = V.T 107 | A = VT.dot(Q) 108 | BB = 0.5*I + jnp.real(get_sqrt(A.T.dot(Q) + 0.25*I)) 109 | BB = BB.dot(BB) 110 | CC = jnp.linalg.solve(BB, A.T) 111 | S = VT - A @ CC 112 | mu = 1/(1+reg) * mu0 + reg/(1+reg) * (jnp.matmul(S, gbar) + xbar) 113 | 114 | return mu, S 115 | 116 | 117 | class BaM: 118 | """ 119 | Wrapper class for using BaM updates to fit a distribution 120 | """ 121 | def __init__(self, D, lp, lp_g, use_lowrank=False, jit_compile=True): 122 | """ 123 | Inputs: 124 | D: (int) Dimensionality (number) of parameters 125 | lp : Function to evaluate target log-probability distribution. 126 | (Only used in monitor, not for fitting) 127 | lp_g : Function to evaluate score, i.e. the gradient of the target log-probability distribution 128 | """ 129 | self.D = D 130 | self.lp = lp 131 | self.lp_g = lp_g 132 | self.use_lowrank = use_lowrank 133 | if use_lowrank: 134 | print("Using lowrank update") 135 | self.jit_compile = jit_compile 136 | if not jit_compile: 137 | print("Not using jit compilation. This may take longer than it needs to.") 138 | 139 | 140 | def fit(self, key, regf, mean=None, cov=None, batch_size=2, niter=5000, nprint=10, verbose=True, check_goodness=True, monitor=None, retries=10, jitter=1e-6): 141 | """ 142 | Main function to fit a multivariate Gaussian distribution to the target 143 | 144 | Inputs: 145 | key: Random number generator key (jax.random.PRNGKey) 146 | mean : Function to return regularizer value at an iteration. See Regularizers class below 147 | mean : Optional, initial value of the mean. Expected None or array of size D 148 | cov : Optional, initial value of the covariance matrix. Expected None or array of size DxD 149 | batch_size : Optional, int. Number of samples to match scores for at every iteration 150 | niter : Optional, int. Total number of iterations 151 | nprint : Optional, int. Number of iterations after which to print logs 152 | verbose : Optional, bool. If true, print number of iterations after nprint 153 | check_goodness : Optional, bool. Recommended. Wether to check floating point errors in covariance matrix update 154 | monitor : Optional. Function to monitor the progress and track different statistics for diagnostics. 155 | Function call should take the input tuple (iteration number, [mean, cov], lp, key, number of grad evals). 156 | Example of monitor class is provided in utils/monitors.py 157 | 158 | Returns: 159 | mu : Array of shape D, fit of the mean 160 | cov : Array of shape DxD, fit of the covariance matrix 161 | """ 162 | 163 | if mean is None: 164 | mean = jnp.zeros(self.D) 165 | if cov is None: 166 | cov = jnp.identity(self.D) 167 | 168 | nevals = 1 169 | 170 | if self.use_lowrank: 171 | update_function = bam_lowrank_update 172 | else: 173 | update_function = bam_update 174 | if self.jit_compile: 175 | update_function = jit(update_function) 176 | 177 | if nprint > niter: nprint = niter 178 | for i in range(niter + 1): 179 | if (i%(niter//nprint) == 0) and verbose : 180 | print(f'Iteration {i} of {niter}') 181 | 182 | if monitor is not None: 183 | if (i%monitor.checkpoint) == 0: 184 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 185 | nevals = 0 186 | 187 | # Can generate samples from jax distribution (commented below), but using numpy is faster 188 | j = 0 189 | while True: # Sometimes run crashes due to a bad sample. Avoid that by re-trying. 190 | try: 191 | key, key_sample = random.split(key, 2) 192 | np.random.seed(key_sample[0]) 193 | samples = np.random.multivariate_normal(mean=mean, cov=cov, size=batch_size) 194 | vs = self.lp_g(samples) 195 | nevals += batch_size 196 | reg = regf(i) 197 | mean_new, cov_new = update_function(samples, vs, mean, cov, reg) 198 | cov_new += np.eye(self.D) * jitter # jitter covariance matrix 199 | cov_new = (cov_new + cov_new.T)/2. 200 | break 201 | except Exception as e: 202 | if j < retries : 203 | j += 1 204 | print(f"Failed with exception {e}") 205 | print(f"Trying again {j} of {retries}") 206 | else : raise e 207 | 208 | is_good = self._check_goodness(cov_new) 209 | if is_good: 210 | mean, cov = mean_new, cov_new 211 | else: 212 | if verbose: print("Bad update for covariance matrix. Revert") 213 | 214 | if monitor is not None: 215 | monitor(i, [mean, cov], self.lp, key, nevals=nevals) 216 | return mean, cov 217 | 218 | 219 | def _check_goodness(self, cov): 220 | ''' 221 | Internal function to check if the new covariance matrix is a valid covariance matrix. 222 | Required due to floating point errors in updating the convariance matrix directly, 223 | insteead of it's Cholesky form. 224 | ''' 225 | is_good = False 226 | try: 227 | if (np.isnan(np.linalg.cholesky(cov))).any(): 228 | nan_update.append(j) 229 | else: 230 | is_good = True 231 | return is_good 232 | except: 233 | return is_good 234 | 235 | 236 | 237 | class Regularizers(): 238 | """ 239 | Class for regularizers used in BaM 240 | """ 241 | 242 | def __init__(self): 243 | 244 | self.counter = 0 245 | 246 | def reset(self): 247 | 248 | self.counter = 0 249 | 250 | 251 | def constant(self, reg0): 252 | 253 | def reg_iter(iteration): 254 | self.counter +=1 255 | return reg0 256 | return reg_iter 257 | 258 | 259 | def linear(self, reg0): 260 | 261 | def reg_iter(iteration): 262 | self.counter += 1 263 | return reg0/self.counter 264 | 265 | return reg_iter 266 | 267 | 268 | def custom(self, func): 269 | 270 | def reg_iter(iteration): 271 | self.counter += 1 272 | return func(self.counter) 273 | 274 | return reg_iter 275 | --------------------------------------------------------------------------------